nf: xt_socket: export the fancy sock finder code
authorJP Abgrall <jpa@google.com>
Wed, 15 Jun 2011 23:52:40 +0000 (16:52 -0700)
committerJP Abgrall <jpa@google.com>
Mon, 20 Jun 2011 20:15:08 +0000 (13:15 -0700)
The socket matching function has some nifty logic to get the struct sock
from the skb or from the connection tracker.
We export this so other xt_* can use it, similarly to ho how
xt_socket uses nf_tproxy_get_sock.

Change-Id: I11c58f59087e7f7ae09e4abd4b937cd3370fa2fd
Signed-off-by: JP Abgrall <jpa@google.com>
include/linux/netfilter/xt_socket.h
net/netfilter/xt_socket.c

index 26d7217bd4f1cf1337a7c1d5d5643801d23b4e8a..63594564831c4b26f8e793d6d9b52c18406dc79c 100644 (file)
@@ -11,4 +11,10 @@ struct xt_socket_mtinfo1 {
        __u8 flags;
 };
 
+void xt_socket_put_sk(struct sock *sk);
+struct sock *xt_socket_get4_sk(const struct sk_buff *skb,
+                              struct xt_action_param *par);
+struct sock *xt_socket_get6_sk(const struct sk_buff *skb,
+                              struct xt_action_param *par);
+
 #endif /* _XT_SOCKET_H */
index fe39f7e913dff490e948ca47b1ce6c14844013d6..ddf5e0507f5f5805e5225317206e66915a1abc12 100644 (file)
@@ -35,7 +35,7 @@
 #include <net/netfilter/nf_conntrack.h>
 #endif
 
-static void
+void
 xt_socket_put_sk(struct sock *sk)
 {
        if (sk->sk_state == TCP_TIME_WAIT)
@@ -43,6 +43,7 @@ xt_socket_put_sk(struct sock *sk)
        else
                sock_put(sk);
 }
+EXPORT_SYMBOL(xt_socket_put_sk);
 
 static int
 extract_icmp4_fields(const struct sk_buff *skb,
@@ -101,9 +102,8 @@ extract_icmp4_fields(const struct sk_buff *skb,
        return 0;
 }
 
-static bool
-socket_match(const struct sk_buff *skb, struct xt_action_param *par,
-            const struct xt_socket_mtinfo1 *info)
+struct sock*
+xt_socket_get4_sk(const struct sk_buff *skb, struct xt_action_param *par)
 {
        const struct iphdr *iph = ip_hdr(skb);
        struct udphdr _hdr, *hp = NULL;
@@ -120,7 +120,7 @@ socket_match(const struct sk_buff *skb, struct xt_action_param *par,
                hp = skb_header_pointer(skb, ip_hdrlen(skb),
                                        sizeof(_hdr), &_hdr);
                if (hp == NULL)
-                       return false;
+                       return NULL;
 
                protocol = iph->protocol;
                saddr = iph->saddr;
@@ -131,9 +131,9 @@ socket_match(const struct sk_buff *skb, struct xt_action_param *par,
        } else if (iph->protocol == IPPROTO_ICMP) {
                if (extract_icmp4_fields(skb, &protocol, &saddr, &daddr,
                                        &sport, &dport))
-                       return false;
+                       return NULL;
        } else {
-               return false;
+               return NULL;
        }
 
 #ifdef XT_SOCKET_HAVE_CONNTRACK
@@ -157,6 +157,23 @@ socket_match(const struct sk_buff *skb, struct xt_action_param *par,
 
        sk = nf_tproxy_get_sock_v4(dev_net(skb->dev), protocol,
                                   saddr, daddr, sport, dport, par->in, NFT_LOOKUP_ANY);
+
+       pr_debug("proto %hhu %pI4:%hu -> %pI4:%hu (orig %pI4:%hu) sock %p\n",
+                protocol, &saddr, ntohs(sport),
+                &daddr, ntohs(dport),
+                &iph->daddr, hp ? ntohs(hp->dest) : 0, sk);
+
+       return sk;
+}
+EXPORT_SYMBOL(xt_socket_get4_sk);
+
+static bool
+socket_match(const struct sk_buff *skb, struct xt_action_param *par,
+            const struct xt_socket_mtinfo1 *info)
+{
+       struct sock *sk;
+
+       sk = xt_socket_get4_sk(skb, par);
        if (sk != NULL) {
                bool wildcard;
                bool transparent = true;
@@ -179,11 +196,6 @@ socket_match(const struct sk_buff *skb, struct xt_action_param *par,
                        sk = NULL;
        }
 
-       pr_debug("proto %hhu %pI4:%hu -> %pI4:%hu (orig %pI4:%hu) sock %p\n",
-                protocol, &saddr, ntohs(sport),
-                &daddr, ntohs(dport),
-                &iph->daddr, hp ? ntohs(hp->dest) : 0, sk);
-
        return (sk != NULL);
 }
 
@@ -253,8 +265,8 @@ extract_icmp6_fields(const struct sk_buff *skb,
        return 0;
 }
 
-static bool
-socket_mt6_v1(const struct sk_buff *skb, struct xt_action_param *par)
+struct sock*
+xt_socket_get6_sk(const struct sk_buff *skb, struct xt_action_param *par)
 {
        struct ipv6hdr *iph = ipv6_hdr(skb);
        struct udphdr _hdr, *hp = NULL;
@@ -262,7 +274,6 @@ socket_mt6_v1(const struct sk_buff *skb, struct xt_action_param *par)
        struct in6_addr *daddr, *saddr;
        __be16 dport, sport;
        int thoff, tproto;
-       const struct xt_socket_mtinfo1 *info = (struct xt_socket_mtinfo1 *) par->matchinfo;
 
        tproto = ipv6_find_hdr(skb, &thoff, -1, NULL);
        if (tproto < 0) {
@@ -274,7 +285,7 @@ socket_mt6_v1(const struct sk_buff *skb, struct xt_action_param *par)
                hp = skb_header_pointer(skb, thoff,
                                        sizeof(_hdr), &_hdr);
                if (hp == NULL)
-                       return false;
+                       return NULL;
 
                saddr = &iph->saddr;
                sport = hp->source;
@@ -284,13 +295,30 @@ socket_mt6_v1(const struct sk_buff *skb, struct xt_action_param *par)
        } else if (tproto == IPPROTO_ICMPV6) {
                if (extract_icmp6_fields(skb, thoff, &tproto, &saddr, &daddr,
                                         &sport, &dport))
-                       return false;
+                       return NULL;
        } else {
-               return false;
+               return NULL;
        }
 
        sk = nf_tproxy_get_sock_v6(dev_net(skb->dev), tproto,
                                   saddr, daddr, sport, dport, par->in, NFT_LOOKUP_ANY);
+       pr_debug("proto %hhd %pI6:%hu -> %pI6:%hu "
+                "(orig %pI6:%hu) sock %p\n",
+                tproto, saddr, ntohs(sport),
+                daddr, ntohs(dport),
+                &iph->daddr, hp ? ntohs(hp->dest) : 0, sk);
+       return sk;
+}
+EXPORT_SYMBOL(xt_socket_get6_sk);
+
+static bool
+socket_mt6_v1(const struct sk_buff *skb, struct xt_action_param *par)
+{
+       struct sock *sk;
+       const struct xt_socket_mtinfo1 *info;
+
+       info = (struct xt_socket_mtinfo1 *) par->matchinfo;
+       sk = xt_socket_get6_sk(skb, par);
        if (sk != NULL) {
                bool wildcard;
                bool transparent = true;
@@ -313,12 +341,6 @@ socket_mt6_v1(const struct sk_buff *skb, struct xt_action_param *par)
                        sk = NULL;
        }
 
-       pr_debug("proto %hhd %pI6:%hu -> %pI6:%hu "
-                "(orig %pI6:%hu) sock %p\n",
-                tproto, saddr, ntohs(sport),
-                daddr, ntohs(dport),
-                &iph->daddr, hp ? ntohs(hp->dest) : 0, sk);
-
        return (sk != NULL);
 }
 #endif