Revert "netfilter: xt_qtaguid: fix crash on non-full sks"
[firefly-linux-kernel-4.4.55.git] / net / netfilter / nf_nat_core.c
index 038eee5c8f8548787bff468c40256d52bb6655fd..06a9f45771ab613bdd529195f64583e6e4316a7c 100644 (file)
@@ -25,6 +25,7 @@
 #include <net/netfilter/nf_nat_core.h>
 #include <net/netfilter/nf_nat_helper.h>
 #include <net/netfilter/nf_conntrack_helper.h>
+#include <net/netfilter/nf_conntrack_seqadj.h>
 #include <net/netfilter/nf_conntrack_l3proto.h>
 #include <net/netfilter/nf_conntrack_zones.h>
 #include <linux/netfilter/nf_nat.h>
@@ -82,7 +83,7 @@ out:
        rcu_read_unlock();
 }
 
-int nf_xfrm_me_harder(struct sk_buff *skb, unsigned int family)
+int nf_xfrm_me_harder(struct net *net, struct sk_buff *skb, unsigned int family)
 {
        struct flowi fl;
        unsigned int hh_len;
@@ -98,7 +99,7 @@ int nf_xfrm_me_harder(struct sk_buff *skb, unsigned int family)
                dst = ((struct xfrm_dst *)dst)->route;
        dst_hold(dst);
 
-       dst = xfrm_lookup(dev_net(dst->dev), dst, &fl, skb->sk, 0);
+       dst = xfrm_lookup(net, dst, &fl, skb->sk, 0);
        if (IS_ERR(dst))
                return PTR_ERR(dst);
 
@@ -117,15 +118,15 @@ EXPORT_SYMBOL(nf_xfrm_me_harder);
 
 /* We keep an extra hash for each conntrack, for fast searching. */
 static inline unsigned int
-hash_by_src(const struct net *net, u16 zone,
-           const struct nf_conntrack_tuple *tuple)
+hash_by_src(const struct net *net, const struct nf_conntrack_tuple *tuple)
 {
        unsigned int hash;
 
        /* Original src, to ensure we map it consistently if poss. */
        hash = jhash2((u32 *)&tuple->src, sizeof(tuple->src) / sizeof(u32),
-                     tuple->dst.protonum ^ zone ^ nf_conntrack_hash_rnd);
-       return ((u64)hash * net->ct.nat_htable_size) >> 32;
+                     tuple->dst.protonum ^ nf_conntrack_hash_rnd);
+
+       return reciprocal_scale(hash, net->ct.nat_htable_size);
 }
 
 /* Is this tuple already taken? (not by us) */
@@ -183,20 +184,22 @@ same_src(const struct nf_conn *ct,
 
 /* Only called for SRC manip */
 static int
-find_appropriate_src(struct net *net, u16 zone,
+find_appropriate_src(struct net *net,
+                    const struct nf_conntrack_zone *zone,
                     const struct nf_nat_l3proto *l3proto,
                     const struct nf_nat_l4proto *l4proto,
                     const struct nf_conntrack_tuple *tuple,
                     struct nf_conntrack_tuple *result,
                     const struct nf_nat_range *range)
 {
-       unsigned int h = hash_by_src(net, zone, tuple);
+       unsigned int h = hash_by_src(net, tuple);
        const struct nf_conn_nat *nat;
        const struct nf_conn *ct;
 
        hlist_for_each_entry_rcu(nat, &net->ct.nat_bysource[h], bysource) {
                ct = nat->ct;
-               if (same_src(ct, tuple) && nf_ct_zone(ct) == zone) {
+               if (same_src(ct, tuple) &&
+                   nf_ct_zone_equal(ct, zone, IP_CT_DIR_ORIGINAL)) {
                        /* Copy source part from reply tuple. */
                        nf_ct_invert_tuplepr(result,
                                       &ct->tuplehash[IP_CT_DIR_REPLY].tuple);
@@ -216,7 +219,8 @@ find_appropriate_src(struct net *net, u16 zone,
  * the ip with the lowest src-ip/dst-ip/proto usage.
  */
 static void
-find_best_ips_proto(u16 zone, struct nf_conntrack_tuple *tuple,
+find_best_ips_proto(const struct nf_conntrack_zone *zone,
+                   struct nf_conntrack_tuple *tuple,
                    const struct nf_nat_range *range,
                    const struct nf_conn *ct,
                    enum nf_nat_manip_type maniptype)
@@ -256,7 +260,7 @@ find_best_ips_proto(u16 zone, struct nf_conntrack_tuple *tuple,
         */
        j = jhash2((u32 *)&tuple->src.u3, sizeof(tuple->src.u3) / sizeof(u32),
                   range->flags & NF_NAT_RANGE_PERSISTENT ?
-                       0 : (__force u32)tuple->dst.u3.all[max] ^ zone);
+                       0 : (__force u32)tuple->dst.u3.all[max] ^ zone->id);
 
        full_range = false;
        for (i = 0; i <= max; i++) {
@@ -273,7 +277,7 @@ find_best_ips_proto(u16 zone, struct nf_conntrack_tuple *tuple,
                }
 
                var_ipp->all[i] = (__force __u32)
-                       htonl(minip + (((u64)j * dist) >> 32));
+                       htonl(minip + reciprocal_scale(j, dist));
                if (var_ipp->all[i] != range->max_addr.all[i])
                        full_range = true;
 
@@ -295,10 +299,12 @@ get_unique_tuple(struct nf_conntrack_tuple *tuple,
                 struct nf_conn *ct,
                 enum nf_nat_manip_type maniptype)
 {
+       const struct nf_conntrack_zone *zone;
        const struct nf_nat_l3proto *l3proto;
        const struct nf_nat_l4proto *l4proto;
        struct net *net = nf_ct_net(ct);
-       u16 zone = nf_ct_zone(ct);
+
+       zone = nf_ct_zone(ct);
 
        rcu_read_lock();
        l3proto = __nf_nat_l3proto_find(orig_tuple->src.l3num);
@@ -314,7 +320,7 @@ get_unique_tuple(struct nf_conntrack_tuple *tuple,
         * manips not an issue.
         */
        if (maniptype == NF_NAT_MANIP_SRC &&
-           !(range->flags & NF_NAT_RANGE_PROTO_RANDOM)) {
+           !(range->flags & NF_NAT_RANGE_PROTO_RANDOM_ALL)) {
                /* try the original tuple first */
                if (in_range(l3proto, l4proto, orig_tuple, range)) {
                        if (!nf_nat_used_tuple(orig_tuple, ct)) {
@@ -338,7 +344,7 @@ get_unique_tuple(struct nf_conntrack_tuple *tuple,
         */
 
        /* Only bother mapping if it's not already in range and unique */
-       if (!(range->flags & NF_NAT_RANGE_PROTO_RANDOM)) {
+       if (!(range->flags & NF_NAT_RANGE_PROTO_RANDOM_ALL)) {
                if (range->flags & NF_NAT_RANGE_PROTO_SPECIFIED) {
                        if (l4proto->in_range(tuple, maniptype,
                                              &range->min_proto,
@@ -357,6 +363,19 @@ out:
        rcu_read_unlock();
 }
 
+struct nf_conn_nat *nf_ct_nat_ext_add(struct nf_conn *ct)
+{
+       struct nf_conn_nat *nat = nfct_nat(ct);
+       if (nat)
+               return nat;
+
+       if (!nf_ct_is_confirmed(ct))
+               nat = nf_ct_ext_add(ct, NF_CT_EXT_NAT, GFP_ATOMIC);
+
+       return nat;
+}
+EXPORT_SYMBOL_GPL(nf_ct_nat_ext_add);
+
 unsigned int
 nf_nat_setup_info(struct nf_conn *ct,
                  const struct nf_nat_range *range,
@@ -367,14 +386,9 @@ nf_nat_setup_info(struct nf_conn *ct,
        struct nf_conn_nat *nat;
 
        /* nat helper or nfctnetlink also setup binding */
-       nat = nfct_nat(ct);
-       if (!nat) {
-               nat = nf_ct_ext_add(ct, NF_CT_EXT_NAT, GFP_ATOMIC);
-               if (nat == NULL) {
-                       pr_debug("failed to add NAT extension\n");
-                       return NF_ACCEPT;
-               }
-       }
+       nat = nf_ct_nat_ext_add(ct);
+       if (nat == NULL)
+               return NF_ACCEPT;
 
        NF_CT_ASSERT(maniptype == NF_NAT_MANIP_SRC ||
                     maniptype == NF_NAT_MANIP_DST);
@@ -402,12 +416,15 @@ nf_nat_setup_info(struct nf_conn *ct,
                        ct->status |= IPS_SRC_NAT;
                else
                        ct->status |= IPS_DST_NAT;
+
+               if (nfct_help(ct))
+                       nfct_seqadj_ext_add(ct);
        }
 
        if (maniptype == NF_NAT_MANIP_SRC) {
                unsigned int srchash;
 
-               srchash = hash_by_src(net, nf_ct_zone(ct),
+               srchash = hash_by_src(net,
                                      &ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple);
                spin_lock_bh(&nf_nat_lock);
                /* nf_conntrack_alter_reply might re-allocate extension aera */
@@ -428,6 +445,32 @@ nf_nat_setup_info(struct nf_conn *ct,
 }
 EXPORT_SYMBOL(nf_nat_setup_info);
 
+static unsigned int
+__nf_nat_alloc_null_binding(struct nf_conn *ct, enum nf_nat_manip_type manip)
+{
+       /* Force range to this IP; let proto decide mapping for
+        * per-proto parts (hence not IP_NAT_RANGE_PROTO_SPECIFIED).
+        * Use reply in case it's already been mangled (eg local packet).
+        */
+       union nf_inet_addr ip =
+               (manip == NF_NAT_MANIP_SRC ?
+               ct->tuplehash[IP_CT_DIR_REPLY].tuple.dst.u3 :
+               ct->tuplehash[IP_CT_DIR_REPLY].tuple.src.u3);
+       struct nf_nat_range range = {
+               .flags          = NF_NAT_RANGE_MAP_IPS,
+               .min_addr       = ip,
+               .max_addr       = ip,
+       };
+       return nf_nat_setup_info(ct, &range, manip);
+}
+
+unsigned int
+nf_nat_alloc_null_binding(struct nf_conn *ct, unsigned int hooknum)
+{
+       return __nf_nat_alloc_null_binding(ct, HOOK2MANIP(hooknum));
+}
+EXPORT_SYMBOL_GPL(nf_nat_alloc_null_binding);
+
 /* Do packet manipulations according to nf_nat_setup_info. */
 unsigned int nf_nat_packet(struct nf_conn *ct,
                           enum ip_conntrack_info ctinfo,
@@ -487,6 +530,39 @@ static int nf_nat_proto_remove(struct nf_conn *i, void *data)
        return i->status & IPS_NAT_MASK ? 1 : 0;
 }
 
+static int nf_nat_proto_clean(struct nf_conn *ct, void *data)
+{
+       struct nf_conn_nat *nat = nfct_nat(ct);
+
+       if (nf_nat_proto_remove(ct, data))
+               return 1;
+
+       if (!nat || !nat->ct)
+               return 0;
+
+       /* This netns is being destroyed, and conntrack has nat null binding.
+        * Remove it from bysource hash, as the table will be freed soon.
+        *
+        * Else, when the conntrack is destoyed, nf_nat_cleanup_conntrack()
+        * will delete entry from already-freed table.
+        */
+       if (!del_timer(&ct->timeout))
+               return 1;
+
+       spin_lock_bh(&nf_nat_lock);
+       hlist_del_rcu(&nat->bysource);
+       ct->status &= ~IPS_NAT_DONE_MASK;
+       nat->ct = NULL;
+       spin_unlock_bh(&nf_nat_lock);
+
+       add_timer(&ct->timeout);
+
+       /* don't delete conntrack.  Although that would make things a lot
+        * simpler, we'd end up flushing all conntracks on nat rmmod.
+        */
+       return 0;
+}
+
 static void nf_nat_l4proto_clean(u8 l3proto, u8 l4proto)
 {
        struct nf_nat_proto_clean clean = {
@@ -497,7 +573,7 @@ static void nf_nat_l4proto_clean(u8 l3proto, u8 l4proto)
 
        rtnl_lock();
        for_each_net(net)
-               nf_ct_iterate_cleanup(net, nf_nat_proto_remove, &clean);
+               nf_ct_iterate_cleanup(net, nf_nat_proto_remove, &clean, 0, 0);
        rtnl_unlock();
 }
 
@@ -511,7 +587,7 @@ static void nf_nat_l3proto_clean(u8 l3proto)
        rtnl_lock();
 
        for_each_net(net)
-               nf_ct_iterate_cleanup(net, nf_nat_proto_remove, &clean);
+               nf_ct_iterate_cleanup(net, nf_nat_proto_remove, &clean, 0, 0);
        rtnl_unlock();
 }
 
@@ -639,7 +715,7 @@ static struct nf_ct_ext_type nat_extend __read_mostly = {
        .flags          = NF_CT_EXT_F_PREALLOC,
 };
 
-#if defined(CONFIG_NF_CT_NETLINK) || defined(CONFIG_NF_CT_NETLINK_MODULE)
+#if IS_ENABLED(CONFIG_NF_CT_NETLINK)
 
 #include <linux/netfilter/nfnetlink.h>
 #include <linux/netfilter/nfnetlink_conntrack.h>
@@ -678,9 +754,9 @@ static const struct nla_policy nat_nla_policy[CTA_NAT_MAX+1] = {
 
 static int
 nfnetlink_parse_nat(const struct nlattr *nat,
-                   const struct nf_conn *ct, struct nf_nat_range *range)
+                   const struct nf_conn *ct, struct nf_nat_range *range,
+                   const struct nf_nat_l3proto *l3proto)
 {
-       const struct nf_nat_l3proto *l3proto;
        struct nlattr *tb[CTA_NAT_MAX+1];
        int err;
 
@@ -690,38 +766,46 @@ nfnetlink_parse_nat(const struct nlattr *nat,
        if (err < 0)
                return err;
 
-       rcu_read_lock();
-       l3proto = __nf_nat_l3proto_find(nf_ct_l3num(ct));
-       if (l3proto == NULL) {
-               err = -EAGAIN;
-               goto out;
-       }
        err = l3proto->nlattr_to_range(tb, range);
        if (err < 0)
-               goto out;
+               return err;
 
        if (!tb[CTA_NAT_PROTO])
-               goto out;
+               return 0;
 
-       err = nfnetlink_parse_nat_proto(tb[CTA_NAT_PROTO], ct, range);
-out:
-       rcu_read_unlock();
-       return err;
+       return nfnetlink_parse_nat_proto(tb[CTA_NAT_PROTO], ct, range);
 }
 
+/* This function is called under rcu_read_lock() */
 static int
 nfnetlink_parse_nat_setup(struct nf_conn *ct,
                          enum nf_nat_manip_type manip,
                          const struct nlattr *attr)
 {
        struct nf_nat_range range;
+       const struct nf_nat_l3proto *l3proto;
        int err;
 
-       err = nfnetlink_parse_nat(attr, ct, &range);
+       /* Should not happen, restricted to creating new conntracks
+        * via ctnetlink.
+        */
+       if (WARN_ON_ONCE(nf_nat_initialized(ct, manip)))
+               return -EEXIST;
+
+       /* Make sure that L3 NAT is there by when we call nf_nat_setup_info to
+        * attach the null binding, otherwise this may oops.
+        */
+       l3proto = __nf_nat_l3proto_find(nf_ct_l3num(ct));
+       if (l3proto == NULL)
+               return -EAGAIN;
+
+       /* No NAT information has been passed, allocate the null-binding */
+       if (attr == NULL)
+               return __nf_nat_alloc_null_binding(ct, manip);
+
+       err = nfnetlink_parse_nat(attr, ct, &range, l3proto);
        if (err < 0)
                return err;
-       if (nf_nat_initialized(ct, manip))
-               return -EEXIST;
 
        return nf_nat_setup_info(ct, &range, manip);
 }
@@ -749,7 +833,7 @@ static void __net_exit nf_nat_net_exit(struct net *net)
 {
        struct nf_nat_proto_clean clean = {};
 
-       nf_ct_iterate_cleanup(net, &nf_nat_proto_remove, &clean);
+       nf_ct_iterate_cleanup(net, nf_nat_proto_clean, &clean, 0, 0);
        synchronize_rcu();
        nf_ct_free_hashtable(net->ct.nat_bysource, net->ct.nat_htable_size);
 }
@@ -764,10 +848,6 @@ static struct nf_ct_helper_expectfn follow_master_nat = {
        .expectfn       = nf_nat_follow_master,
 };
 
-static struct nfq_ct_nat_hook nfq_ct_nat = {
-       .seq_adjust     = nf_nat_tcp_seq_adjust,
-};
-
 static int __init nf_nat_init(void)
 {
        int ret;
@@ -787,14 +867,9 @@ static int __init nf_nat_init(void)
        /* Initialize fake conntrack so that NAT will skip it */
        nf_ct_untracked_status_or(IPS_NAT_DONE_MASK);
 
-       BUG_ON(nf_nat_seq_adjust_hook != NULL);
-       RCU_INIT_POINTER(nf_nat_seq_adjust_hook, nf_nat_seq_adjust);
        BUG_ON(nfnetlink_parse_nat_setup_hook != NULL);
        RCU_INIT_POINTER(nfnetlink_parse_nat_setup_hook,
                           nfnetlink_parse_nat_setup);
-       BUG_ON(nf_ct_nat_offset != NULL);
-       RCU_INIT_POINTER(nf_ct_nat_offset, nf_nat_get_offset);
-       RCU_INIT_POINTER(nfq_ct_nat_hook, &nfq_ct_nat);
 #ifdef CONFIG_XFRM
        BUG_ON(nf_nat_decode_session_hook != NULL);
        RCU_INIT_POINTER(nf_nat_decode_session_hook, __nf_nat_decode_session);
@@ -813,10 +888,7 @@ static void __exit nf_nat_cleanup(void)
        unregister_pernet_subsys(&nf_nat_net_ops);
        nf_ct_extend_unregister(&nat_extend);
        nf_ct_helper_expectfn_unregister(&follow_master_nat);
-       RCU_INIT_POINTER(nf_nat_seq_adjust_hook, NULL);
        RCU_INIT_POINTER(nfnetlink_parse_nat_setup_hook, NULL);
-       RCU_INIT_POINTER(nf_ct_nat_offset, NULL);
-       RCU_INIT_POINTER(nfq_ct_nat_hook, NULL);
 #ifdef CONFIG_XFRM
        RCU_INIT_POINTER(nf_nat_decode_session_hook, NULL);
 #endif