net: diag: slightly refactor the inet_diag_bc_audit error checks.
[firefly-linux-kernel-4.4.55.git] / net / ipv4 / inet_diag.c
1 /*
2  * inet_diag.c  Module for monitoring INET transport protocols sockets.
3  *
4  * Authors:     Alexey Kuznetsov, <kuznet@ms2.inr.ac.ru>
5  *
6  *      This program is free software; you can redistribute it and/or
7  *      modify it under the terms of the GNU General Public License
8  *      as published by the Free Software Foundation; either version
9  *      2 of the License, or (at your option) any later version.
10  */
11
12 #include <linux/kernel.h>
13 #include <linux/module.h>
14 #include <linux/types.h>
15 #include <linux/fcntl.h>
16 #include <linux/random.h>
17 #include <linux/slab.h>
18 #include <linux/cache.h>
19 #include <linux/init.h>
20 #include <linux/time.h>
21
22 #include <net/icmp.h>
23 #include <net/tcp.h>
24 #include <net/ipv6.h>
25 #include <net/inet_common.h>
26 #include <net/inet_connection_sock.h>
27 #include <net/inet_hashtables.h>
28 #include <net/inet_timewait_sock.h>
29 #include <net/inet6_hashtables.h>
30 #include <net/netlink.h>
31
32 #include <linux/inet.h>
33 #include <linux/stddef.h>
34
35 #include <linux/inet_diag.h>
36 #include <linux/sock_diag.h>
37
38 static const struct inet_diag_handler **inet_diag_table;
39
40 struct inet_diag_entry {
41         const __be32 *saddr;
42         const __be32 *daddr;
43         u16 sport;
44         u16 dport;
45         u16 family;
46         u16 userlocks;
47         u32 ifindex;
48 };
49
50 static DEFINE_MUTEX(inet_diag_table_mutex);
51
52 static const struct inet_diag_handler *inet_diag_lock_handler(int proto)
53 {
54         if (!inet_diag_table[proto])
55                 request_module("net-pf-%d-proto-%d-type-%d-%d", PF_NETLINK,
56                                NETLINK_SOCK_DIAG, AF_INET, proto);
57
58         mutex_lock(&inet_diag_table_mutex);
59         if (!inet_diag_table[proto])
60                 return ERR_PTR(-ENOENT);
61
62         return inet_diag_table[proto];
63 }
64
65 static void inet_diag_unlock_handler(const struct inet_diag_handler *handler)
66 {
67         mutex_unlock(&inet_diag_table_mutex);
68 }
69
70 static void inet_diag_msg_common_fill(struct inet_diag_msg *r, struct sock *sk)
71 {
72         r->idiag_family = sk->sk_family;
73
74         r->id.idiag_sport = htons(sk->sk_num);
75         r->id.idiag_dport = sk->sk_dport;
76         r->id.idiag_if = sk->sk_bound_dev_if;
77         sock_diag_save_cookie(sk, r->id.idiag_cookie);
78
79 #if IS_ENABLED(CONFIG_IPV6)
80         if (sk->sk_family == AF_INET6) {
81                 *(struct in6_addr *)r->id.idiag_src = sk->sk_v6_rcv_saddr;
82                 *(struct in6_addr *)r->id.idiag_dst = sk->sk_v6_daddr;
83         } else
84 #endif
85         {
86         memset(&r->id.idiag_src, 0, sizeof(r->id.idiag_src));
87         memset(&r->id.idiag_dst, 0, sizeof(r->id.idiag_dst));
88
89         r->id.idiag_src[0] = sk->sk_rcv_saddr;
90         r->id.idiag_dst[0] = sk->sk_daddr;
91         }
92 }
93
94 static size_t inet_sk_attr_size(void)
95 {
96         return    nla_total_size(sizeof(struct tcp_info))
97                 + nla_total_size(1) /* INET_DIAG_SHUTDOWN */
98                 + nla_total_size(1) /* INET_DIAG_TOS */
99                 + nla_total_size(1) /* INET_DIAG_TCLASS */
100                 + nla_total_size(sizeof(struct inet_diag_meminfo))
101                 + nla_total_size(sizeof(struct inet_diag_msg))
102                 + nla_total_size(SK_MEMINFO_VARS * sizeof(u32))
103                 + nla_total_size(TCP_CA_NAME_MAX)
104                 + nla_total_size(sizeof(struct tcpvegas_info))
105                 + 64;
106 }
107
108 int inet_sk_diag_fill(struct sock *sk, struct inet_connection_sock *icsk,
109                       struct sk_buff *skb, const struct inet_diag_req_v2 *req,
110                       struct user_namespace *user_ns,
111                       u32 portid, u32 seq, u16 nlmsg_flags,
112                       const struct nlmsghdr *unlh)
113 {
114         const struct inet_sock *inet = inet_sk(sk);
115         const struct tcp_congestion_ops *ca_ops;
116         const struct inet_diag_handler *handler;
117         int ext = req->idiag_ext;
118         struct inet_diag_msg *r;
119         struct nlmsghdr  *nlh;
120         struct nlattr *attr;
121         void *info = NULL;
122
123         handler = inet_diag_table[req->sdiag_protocol];
124         BUG_ON(!handler);
125
126         nlh = nlmsg_put(skb, portid, seq, unlh->nlmsg_type, sizeof(*r),
127                         nlmsg_flags);
128         if (!nlh)
129                 return -EMSGSIZE;
130
131         r = nlmsg_data(nlh);
132         BUG_ON(!sk_fullsock(sk));
133
134         inet_diag_msg_common_fill(r, sk);
135         r->idiag_state = sk->sk_state;
136         r->idiag_timer = 0;
137         r->idiag_retrans = 0;
138
139         if (nla_put_u8(skb, INET_DIAG_SHUTDOWN, sk->sk_shutdown))
140                 goto errout;
141
142         /* IPv6 dual-stack sockets use inet->tos for IPv4 connections,
143          * hence this needs to be included regardless of socket family.
144          */
145         if (ext & (1 << (INET_DIAG_TOS - 1)))
146                 if (nla_put_u8(skb, INET_DIAG_TOS, inet->tos) < 0)
147                         goto errout;
148
149 #if IS_ENABLED(CONFIG_IPV6)
150         if (r->idiag_family == AF_INET6) {
151                 if (ext & (1 << (INET_DIAG_TCLASS - 1)))
152                         if (nla_put_u8(skb, INET_DIAG_TCLASS,
153                                        inet6_sk(sk)->tclass) < 0)
154                                 goto errout;
155
156                 if (((1 << sk->sk_state) & (TCPF_LISTEN | TCPF_CLOSE)) &&
157                     nla_put_u8(skb, INET_DIAG_SKV6ONLY, ipv6_only_sock(sk)))
158                         goto errout;
159         }
160 #endif
161
162         r->idiag_uid = from_kuid_munged(user_ns, sock_i_uid(sk));
163         r->idiag_inode = sock_i_ino(sk);
164
165         if (ext & (1 << (INET_DIAG_MEMINFO - 1))) {
166                 struct inet_diag_meminfo minfo = {
167                         .idiag_rmem = sk_rmem_alloc_get(sk),
168                         .idiag_wmem = sk->sk_wmem_queued,
169                         .idiag_fmem = sk->sk_forward_alloc,
170                         .idiag_tmem = sk_wmem_alloc_get(sk),
171                 };
172
173                 if (nla_put(skb, INET_DIAG_MEMINFO, sizeof(minfo), &minfo) < 0)
174                         goto errout;
175         }
176
177         if (ext & (1 << (INET_DIAG_SKMEMINFO - 1)))
178                 if (sock_diag_put_meminfo(sk, skb, INET_DIAG_SKMEMINFO))
179                         goto errout;
180
181         if (!icsk) {
182                 handler->idiag_get_info(sk, r, NULL);
183                 goto out;
184         }
185
186 #define EXPIRES_IN_MS(tmo)  DIV_ROUND_UP((tmo - jiffies) * 1000, HZ)
187
188         if (icsk->icsk_pending == ICSK_TIME_RETRANS ||
189             icsk->icsk_pending == ICSK_TIME_EARLY_RETRANS ||
190             icsk->icsk_pending == ICSK_TIME_LOSS_PROBE) {
191                 r->idiag_timer = 1;
192                 r->idiag_retrans = icsk->icsk_retransmits;
193                 r->idiag_expires = EXPIRES_IN_MS(icsk->icsk_timeout);
194         } else if (icsk->icsk_pending == ICSK_TIME_PROBE0) {
195                 r->idiag_timer = 4;
196                 r->idiag_retrans = icsk->icsk_probes_out;
197                 r->idiag_expires = EXPIRES_IN_MS(icsk->icsk_timeout);
198         } else if (timer_pending(&sk->sk_timer)) {
199                 r->idiag_timer = 2;
200                 r->idiag_retrans = icsk->icsk_probes_out;
201                 r->idiag_expires = EXPIRES_IN_MS(sk->sk_timer.expires);
202         } else {
203                 r->idiag_timer = 0;
204                 r->idiag_expires = 0;
205         }
206 #undef EXPIRES_IN_MS
207
208         if ((ext & (1 << (INET_DIAG_INFO - 1))) && handler->idiag_info_size) {
209                 attr = nla_reserve(skb, INET_DIAG_INFO,
210                                    handler->idiag_info_size);
211                 if (!attr)
212                         goto errout;
213
214                 info = nla_data(attr);
215         }
216
217         if (ext & (1 << (INET_DIAG_CONG - 1))) {
218                 int err = 0;
219
220                 rcu_read_lock();
221                 ca_ops = READ_ONCE(icsk->icsk_ca_ops);
222                 if (ca_ops)
223                         err = nla_put_string(skb, INET_DIAG_CONG, ca_ops->name);
224                 rcu_read_unlock();
225                 if (err < 0)
226                         goto errout;
227         }
228
229         handler->idiag_get_info(sk, r, info);
230
231         if (sk->sk_state < TCP_TIME_WAIT) {
232                 union tcp_cc_info info;
233                 size_t sz = 0;
234                 int attr;
235
236                 rcu_read_lock();
237                 ca_ops = READ_ONCE(icsk->icsk_ca_ops);
238                 if (ca_ops && ca_ops->get_info)
239                         sz = ca_ops->get_info(sk, ext, &attr, &info);
240                 rcu_read_unlock();
241                 if (sz && nla_put(skb, attr, sz, &info) < 0)
242                         goto errout;
243         }
244
245 out:
246         nlmsg_end(skb, nlh);
247         return 0;
248
249 errout:
250         nlmsg_cancel(skb, nlh);
251         return -EMSGSIZE;
252 }
253 EXPORT_SYMBOL_GPL(inet_sk_diag_fill);
254
255 static int inet_csk_diag_fill(struct sock *sk,
256                               struct sk_buff *skb,
257                               const struct inet_diag_req_v2 *req,
258                               struct user_namespace *user_ns,
259                               u32 portid, u32 seq, u16 nlmsg_flags,
260                               const struct nlmsghdr *unlh)
261 {
262         return inet_sk_diag_fill(sk, inet_csk(sk), skb, req,
263                                  user_ns, portid, seq, nlmsg_flags, unlh);
264 }
265
266 static int inet_twsk_diag_fill(struct sock *sk,
267                                struct sk_buff *skb,
268                                u32 portid, u32 seq, u16 nlmsg_flags,
269                                const struct nlmsghdr *unlh)
270 {
271         struct inet_timewait_sock *tw = inet_twsk(sk);
272         struct inet_diag_msg *r;
273         struct nlmsghdr *nlh;
274         long tmo;
275
276         nlh = nlmsg_put(skb, portid, seq, unlh->nlmsg_type, sizeof(*r),
277                         nlmsg_flags);
278         if (!nlh)
279                 return -EMSGSIZE;
280
281         r = nlmsg_data(nlh);
282         BUG_ON(tw->tw_state != TCP_TIME_WAIT);
283
284         tmo = tw->tw_timer.expires - jiffies;
285         if (tmo < 0)
286                 tmo = 0;
287
288         inet_diag_msg_common_fill(r, sk);
289         r->idiag_retrans      = 0;
290
291         r->idiag_state        = tw->tw_substate;
292         r->idiag_timer        = 3;
293         r->idiag_expires      = jiffies_to_msecs(tmo);
294         r->idiag_rqueue       = 0;
295         r->idiag_wqueue       = 0;
296         r->idiag_uid          = 0;
297         r->idiag_inode        = 0;
298
299         nlmsg_end(skb, nlh);
300         return 0;
301 }
302
303 static int inet_req_diag_fill(struct sock *sk, struct sk_buff *skb,
304                               u32 portid, u32 seq, u16 nlmsg_flags,
305                               const struct nlmsghdr *unlh)
306 {
307         struct inet_diag_msg *r;
308         struct nlmsghdr *nlh;
309         long tmo;
310
311         nlh = nlmsg_put(skb, portid, seq, unlh->nlmsg_type, sizeof(*r),
312                         nlmsg_flags);
313         if (!nlh)
314                 return -EMSGSIZE;
315
316         r = nlmsg_data(nlh);
317         inet_diag_msg_common_fill(r, sk);
318         r->idiag_state = TCP_SYN_RECV;
319         r->idiag_timer = 1;
320         r->idiag_retrans = inet_reqsk(sk)->num_retrans;
321
322         BUILD_BUG_ON(offsetof(struct inet_request_sock, ir_cookie) !=
323                      offsetof(struct sock, sk_cookie));
324
325         tmo = inet_reqsk(sk)->rsk_timer.expires - jiffies;
326         r->idiag_expires = (tmo >= 0) ? jiffies_to_msecs(tmo) : 0;
327         r->idiag_rqueue = 0;
328         r->idiag_wqueue = 0;
329         r->idiag_uid    = 0;
330         r->idiag_inode  = 0;
331
332         nlmsg_end(skb, nlh);
333         return 0;
334 }
335
336 static int sk_diag_fill(struct sock *sk, struct sk_buff *skb,
337                         const struct inet_diag_req_v2 *r,
338                         struct user_namespace *user_ns,
339                         u32 portid, u32 seq, u16 nlmsg_flags,
340                         const struct nlmsghdr *unlh)
341 {
342         if (sk->sk_state == TCP_TIME_WAIT)
343                 return inet_twsk_diag_fill(sk, skb, portid, seq,
344                                            nlmsg_flags, unlh);
345
346         if (sk->sk_state == TCP_NEW_SYN_RECV)
347                 return inet_req_diag_fill(sk, skb, portid, seq,
348                                           nlmsg_flags, unlh);
349
350         return inet_csk_diag_fill(sk, skb, r, user_ns, portid, seq,
351                                   nlmsg_flags, unlh);
352 }
353
354 struct sock *inet_diag_find_one_icsk(struct net *net,
355                                      struct inet_hashinfo *hashinfo,
356                                      const struct inet_diag_req_v2 *req)
357 {
358         struct sock *sk;
359
360         if (req->sdiag_family == AF_INET)
361                 sk = inet_lookup(net, hashinfo, req->id.idiag_dst[0],
362                                  req->id.idiag_dport, req->id.idiag_src[0],
363                                  req->id.idiag_sport, req->id.idiag_if);
364 #if IS_ENABLED(CONFIG_IPV6)
365         else if (req->sdiag_family == AF_INET6) {
366                 if (ipv6_addr_v4mapped((struct in6_addr *)req->id.idiag_dst) &&
367                     ipv6_addr_v4mapped((struct in6_addr *)req->id.idiag_src))
368                         sk = inet_lookup(net, hashinfo, req->id.idiag_dst[3],
369                                          req->id.idiag_dport, req->id.idiag_src[3],
370                                          req->id.idiag_sport, req->id.idiag_if);
371                 else
372                         sk = inet6_lookup(net, hashinfo,
373                                           (struct in6_addr *)req->id.idiag_dst,
374                                           req->id.idiag_dport,
375                                           (struct in6_addr *)req->id.idiag_src,
376                                           req->id.idiag_sport,
377                                           req->id.idiag_if);
378         }
379 #endif
380         else
381                 return ERR_PTR(-EINVAL);
382
383         if (!sk)
384                 return ERR_PTR(-ENOENT);
385
386         if (sock_diag_check_cookie(sk, req->id.idiag_cookie)) {
387                 sock_gen_put(sk);
388                 return ERR_PTR(-ENOENT);
389         }
390
391         return sk;
392 }
393 EXPORT_SYMBOL_GPL(inet_diag_find_one_icsk);
394
395 int inet_diag_dump_one_icsk(struct inet_hashinfo *hashinfo,
396                             struct sk_buff *in_skb,
397                             const struct nlmsghdr *nlh,
398                             const struct inet_diag_req_v2 *req)
399 {
400         struct net *net = sock_net(in_skb->sk);
401         struct sk_buff *rep;
402         struct sock *sk;
403         int err;
404
405         sk = inet_diag_find_one_icsk(net, hashinfo, req);
406         if (IS_ERR(sk))
407                 return PTR_ERR(sk);
408
409         rep = nlmsg_new(inet_sk_attr_size(), GFP_KERNEL);
410         if (!rep) {
411                 err = -ENOMEM;
412                 goto out;
413         }
414
415         err = sk_diag_fill(sk, rep, req,
416                            sk_user_ns(NETLINK_CB(in_skb).sk),
417                            NETLINK_CB(in_skb).portid,
418                            nlh->nlmsg_seq, 0, nlh);
419         if (err < 0) {
420                 WARN_ON(err == -EMSGSIZE);
421                 nlmsg_free(rep);
422                 goto out;
423         }
424         err = netlink_unicast(net->diag_nlsk, rep, NETLINK_CB(in_skb).portid,
425                               MSG_DONTWAIT);
426         if (err > 0)
427                 err = 0;
428
429 out:
430         if (sk)
431                 sock_gen_put(sk);
432
433         return err;
434 }
435 EXPORT_SYMBOL_GPL(inet_diag_dump_one_icsk);
436
437 static int inet_diag_cmd_exact(int cmd, struct sk_buff *in_skb,
438                                const struct nlmsghdr *nlh,
439                                const struct inet_diag_req_v2 *req)
440 {
441         const struct inet_diag_handler *handler;
442         int err;
443
444         handler = inet_diag_lock_handler(req->sdiag_protocol);
445         if (IS_ERR(handler))
446                 err = PTR_ERR(handler);
447         else if (cmd == SOCK_DIAG_BY_FAMILY)
448                 err = handler->dump_one(in_skb, nlh, req);
449         else if (cmd == SOCK_DESTROY_BACKPORT && handler->destroy)
450                 err = handler->destroy(in_skb, req);
451         else
452                 err = -EOPNOTSUPP;
453         inet_diag_unlock_handler(handler);
454
455         return err;
456 }
457
458 static int bitstring_match(const __be32 *a1, const __be32 *a2, int bits)
459 {
460         int words = bits >> 5;
461
462         bits &= 0x1f;
463
464         if (words) {
465                 if (memcmp(a1, a2, words << 2))
466                         return 0;
467         }
468         if (bits) {
469                 __be32 w1, w2;
470                 __be32 mask;
471
472                 w1 = a1[words];
473                 w2 = a2[words];
474
475                 mask = htonl((0xffffffff) << (32 - bits));
476
477                 if ((w1 ^ w2) & mask)
478                         return 0;
479         }
480
481         return 1;
482 }
483
484 static int inet_diag_bc_run(const struct nlattr *_bc,
485                             const struct inet_diag_entry *entry)
486 {
487         const void *bc = nla_data(_bc);
488         int len = nla_len(_bc);
489
490         while (len > 0) {
491                 int yes = 1;
492                 const struct inet_diag_bc_op *op = bc;
493
494                 switch (op->code) {
495                 case INET_DIAG_BC_NOP:
496                         break;
497                 case INET_DIAG_BC_JMP:
498                         yes = 0;
499                         break;
500                 case INET_DIAG_BC_S_GE:
501                         yes = entry->sport >= op[1].no;
502                         break;
503                 case INET_DIAG_BC_S_LE:
504                         yes = entry->sport <= op[1].no;
505                         break;
506                 case INET_DIAG_BC_D_GE:
507                         yes = entry->dport >= op[1].no;
508                         break;
509                 case INET_DIAG_BC_D_LE:
510                         yes = entry->dport <= op[1].no;
511                         break;
512                 case INET_DIAG_BC_AUTO:
513                         yes = !(entry->userlocks & SOCK_BINDPORT_LOCK);
514                         break;
515                 case INET_DIAG_BC_S_COND:
516                 case INET_DIAG_BC_D_COND: {
517                         const struct inet_diag_hostcond *cond;
518                         const __be32 *addr;
519
520                         cond = (const struct inet_diag_hostcond *)(op + 1);
521                         if (cond->port != -1 &&
522                             cond->port != (op->code == INET_DIAG_BC_S_COND ?
523                                              entry->sport : entry->dport)) {
524                                 yes = 0;
525                                 break;
526                         }
527
528                         if (op->code == INET_DIAG_BC_S_COND)
529                                 addr = entry->saddr;
530                         else
531                                 addr = entry->daddr;
532
533                         if (cond->family != AF_UNSPEC &&
534                             cond->family != entry->family) {
535                                 if (entry->family == AF_INET6 &&
536                                     cond->family == AF_INET) {
537                                         if (addr[0] == 0 && addr[1] == 0 &&
538                                             addr[2] == htonl(0xffff) &&
539                                             bitstring_match(addr + 3,
540                                                             cond->addr,
541                                                             cond->prefix_len))
542                                                 break;
543                                 }
544                                 yes = 0;
545                                 break;
546                         }
547
548                         if (cond->prefix_len == 0)
549                                 break;
550                         if (bitstring_match(addr, cond->addr,
551                                             cond->prefix_len))
552                                 break;
553                         yes = 0;
554                         break;
555                 }
556                 case INET_DIAG_BC_DEV_COND: {
557                         u32 ifindex;
558
559                         ifindex = *((const u32 *)(op + 1));
560                         if (ifindex != entry->ifindex)
561                                 yes = 0;
562                         break;
563                 }
564                 }
565
566                 if (yes) {
567                         len -= op->yes;
568                         bc += op->yes;
569                 } else {
570                         len -= op->no;
571                         bc += op->no;
572                 }
573         }
574         return len == 0;
575 }
576
577 /* This helper is available for all sockets (ESTABLISH, TIMEWAIT, SYN_RECV)
578  */
579 static void entry_fill_addrs(struct inet_diag_entry *entry,
580                              const struct sock *sk)
581 {
582 #if IS_ENABLED(CONFIG_IPV6)
583         if (sk->sk_family == AF_INET6) {
584                 entry->saddr = sk->sk_v6_rcv_saddr.s6_addr32;
585                 entry->daddr = sk->sk_v6_daddr.s6_addr32;
586         } else
587 #endif
588         {
589                 entry->saddr = &sk->sk_rcv_saddr;
590                 entry->daddr = &sk->sk_daddr;
591         }
592 }
593
594 int inet_diag_bc_sk(const struct nlattr *bc, struct sock *sk)
595 {
596         struct inet_sock *inet = inet_sk(sk);
597         struct inet_diag_entry entry;
598
599         if (!bc)
600                 return 1;
601
602         entry.family = sk->sk_family;
603         entry_fill_addrs(&entry, sk);
604         entry.sport = inet->inet_num;
605         entry.dport = ntohs(inet->inet_dport);
606         entry.ifindex = sk->sk_bound_dev_if;
607         entry.userlocks = sk_fullsock(sk) ? sk->sk_userlocks : 0;
608
609         return inet_diag_bc_run(bc, &entry);
610 }
611 EXPORT_SYMBOL_GPL(inet_diag_bc_sk);
612
613 static int valid_cc(const void *bc, int len, int cc)
614 {
615         while (len >= 0) {
616                 const struct inet_diag_bc_op *op = bc;
617
618                 if (cc > len)
619                         return 0;
620                 if (cc == len)
621                         return 1;
622                 if (op->yes < 4 || op->yes & 3)
623                         return 0;
624                 len -= op->yes;
625                 bc  += op->yes;
626         }
627         return 0;
628 }
629
630 /* data is u32 ifindex */
631 static bool valid_devcond(const struct inet_diag_bc_op *op, int len,
632                           int *min_len)
633 {
634         /* Check ifindex space. */
635         *min_len += sizeof(u32);
636         if (len < *min_len)
637                 return false;
638
639         return true;
640 }
641 /* Validate an inet_diag_hostcond. */
642 static bool valid_hostcond(const struct inet_diag_bc_op *op, int len,
643                            int *min_len)
644 {
645         struct inet_diag_hostcond *cond;
646         int addr_len;
647
648         /* Check hostcond space. */
649         *min_len += sizeof(struct inet_diag_hostcond);
650         if (len < *min_len)
651                 return false;
652         cond = (struct inet_diag_hostcond *)(op + 1);
653
654         /* Check address family and address length. */
655         switch (cond->family) {
656         case AF_UNSPEC:
657                 addr_len = 0;
658                 break;
659         case AF_INET:
660                 addr_len = sizeof(struct in_addr);
661                 break;
662         case AF_INET6:
663                 addr_len = sizeof(struct in6_addr);
664                 break;
665         default:
666                 return false;
667         }
668         *min_len += addr_len;
669         if (len < *min_len)
670                 return false;
671
672         /* Check prefix length (in bits) vs address length (in bytes). */
673         if (cond->prefix_len > 8 * addr_len)
674                 return false;
675
676         return true;
677 }
678
679 /* Validate a port comparison operator. */
680 static bool valid_port_comparison(const struct inet_diag_bc_op *op,
681                                   int len, int *min_len)
682 {
683         /* Port comparisons put the port in a follow-on inet_diag_bc_op. */
684         *min_len += sizeof(struct inet_diag_bc_op);
685         if (len < *min_len)
686                 return false;
687         return true;
688 }
689
690 static int inet_diag_bc_audit(const struct nlattr *attr)
691 {
692         const void *bytecode, *bc;
693         int bytecode_len, len;
694
695         if (!attr || nla_len(attr) < sizeof(struct inet_diag_bc_op))
696                 return -EINVAL;
697
698         bytecode = bc = nla_data(attr);
699         len = bytecode_len = nla_len(attr);
700
701         while (len > 0) {
702                 int min_len = sizeof(struct inet_diag_bc_op);
703                 const struct inet_diag_bc_op *op = bc;
704
705                 switch (op->code) {
706                 case INET_DIAG_BC_S_COND:
707                 case INET_DIAG_BC_D_COND:
708                         if (!valid_hostcond(bc, len, &min_len))
709                                 return -EINVAL;
710                         break;
711                 case INET_DIAG_BC_DEV_COND:
712                         if (!valid_devcond(bc, len, &min_len))
713                                 return -EINVAL;
714                         break;
715                 case INET_DIAG_BC_S_GE:
716                 case INET_DIAG_BC_S_LE:
717                 case INET_DIAG_BC_D_GE:
718                 case INET_DIAG_BC_D_LE:
719                         if (!valid_port_comparison(bc, len, &min_len))
720                                 return -EINVAL;
721                         break;
722                 case INET_DIAG_BC_AUTO:
723                 case INET_DIAG_BC_JMP:
724                 case INET_DIAG_BC_NOP:
725                         break;
726                 default:
727                         return -EINVAL;
728                 }
729
730                 if (op->code != INET_DIAG_BC_NOP) {
731                         if (op->no < min_len || op->no > len + 4 || op->no & 3)
732                                 return -EINVAL;
733                         if (op->no < len &&
734                             !valid_cc(bytecode, bytecode_len, len - op->no))
735                                 return -EINVAL;
736                 }
737
738                 if (op->yes < min_len || op->yes > len + 4 || op->yes & 3)
739                         return -EINVAL;
740                 bc  += op->yes;
741                 len -= op->yes;
742         }
743         return len == 0 ? 0 : -EINVAL;
744 }
745
746 static int inet_csk_diag_dump(struct sock *sk,
747                               struct sk_buff *skb,
748                               struct netlink_callback *cb,
749                               const struct inet_diag_req_v2 *r,
750                               const struct nlattr *bc)
751 {
752         if (!inet_diag_bc_sk(bc, sk))
753                 return 0;
754
755         return inet_csk_diag_fill(sk, skb, r,
756                                   sk_user_ns(NETLINK_CB(cb->skb).sk),
757                                   NETLINK_CB(cb->skb).portid,
758                                   cb->nlh->nlmsg_seq, NLM_F_MULTI, cb->nlh);
759 }
760
761 static void twsk_build_assert(void)
762 {
763         BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_family) !=
764                      offsetof(struct sock, sk_family));
765
766         BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_num) !=
767                      offsetof(struct inet_sock, inet_num));
768
769         BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_dport) !=
770                      offsetof(struct inet_sock, inet_dport));
771
772         BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_rcv_saddr) !=
773                      offsetof(struct inet_sock, inet_rcv_saddr));
774
775         BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_daddr) !=
776                      offsetof(struct inet_sock, inet_daddr));
777
778 #if IS_ENABLED(CONFIG_IPV6)
779         BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_v6_rcv_saddr) !=
780                      offsetof(struct sock, sk_v6_rcv_saddr));
781
782         BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_v6_daddr) !=
783                      offsetof(struct sock, sk_v6_daddr));
784 #endif
785 }
786
787 void inet_diag_dump_icsk(struct inet_hashinfo *hashinfo, struct sk_buff *skb,
788                          struct netlink_callback *cb,
789                          const struct inet_diag_req_v2 *r, struct nlattr *bc)
790 {
791         struct net *net = sock_net(skb->sk);
792         int i, num, s_i, s_num;
793         u32 idiag_states = r->idiag_states;
794
795         if (idiag_states & TCPF_SYN_RECV)
796                 idiag_states |= TCPF_NEW_SYN_RECV;
797         s_i = cb->args[1];
798         s_num = num = cb->args[2];
799
800         if (cb->args[0] == 0) {
801                 if (!(idiag_states & TCPF_LISTEN))
802                         goto skip_listen_ht;
803
804                 for (i = s_i; i < INET_LHTABLE_SIZE; i++) {
805                         struct inet_listen_hashbucket *ilb;
806                         struct hlist_nulls_node *node;
807                         struct sock *sk;
808
809                         num = 0;
810                         ilb = &hashinfo->listening_hash[i];
811                         spin_lock_bh(&ilb->lock);
812                         sk_nulls_for_each(sk, node, &ilb->head) {
813                                 struct inet_sock *inet = inet_sk(sk);
814
815                                 if (!net_eq(sock_net(sk), net))
816                                         continue;
817
818                                 if (num < s_num) {
819                                         num++;
820                                         continue;
821                                 }
822
823                                 if (r->sdiag_family != AF_UNSPEC &&
824                                     sk->sk_family != r->sdiag_family)
825                                         goto next_listen;
826
827                                 if (r->id.idiag_sport != inet->inet_sport &&
828                                     r->id.idiag_sport)
829                                         goto next_listen;
830
831                                 if (r->id.idiag_dport ||
832                                     cb->args[3] > 0)
833                                         goto next_listen;
834
835                                 if (inet_csk_diag_dump(sk, skb, cb, r, bc) < 0) {
836                                         spin_unlock_bh(&ilb->lock);
837                                         goto done;
838                                 }
839
840 next_listen:
841                                 cb->args[3] = 0;
842                                 cb->args[4] = 0;
843                                 ++num;
844                         }
845                         spin_unlock_bh(&ilb->lock);
846
847                         s_num = 0;
848                         cb->args[3] = 0;
849                         cb->args[4] = 0;
850                 }
851 skip_listen_ht:
852                 cb->args[0] = 1;
853                 s_i = num = s_num = 0;
854         }
855
856         if (!(idiag_states & ~TCPF_LISTEN))
857                 goto out;
858
859         for (i = s_i; i <= hashinfo->ehash_mask; i++) {
860                 struct inet_ehash_bucket *head = &hashinfo->ehash[i];
861                 spinlock_t *lock = inet_ehash_lockp(hashinfo, i);
862                 struct hlist_nulls_node *node;
863                 struct sock *sk;
864
865                 num = 0;
866
867                 if (hlist_nulls_empty(&head->chain))
868                         continue;
869
870                 if (i > s_i)
871                         s_num = 0;
872
873                 spin_lock_bh(lock);
874                 sk_nulls_for_each(sk, node, &head->chain) {
875                         int state, res;
876
877                         if (!net_eq(sock_net(sk), net))
878                                 continue;
879                         if (num < s_num)
880                                 goto next_normal;
881                         state = (sk->sk_state == TCP_TIME_WAIT) ?
882                                 inet_twsk(sk)->tw_substate : sk->sk_state;
883                         if (!(idiag_states & (1 << state)))
884                                 goto next_normal;
885                         if (r->sdiag_family != AF_UNSPEC &&
886                             sk->sk_family != r->sdiag_family)
887                                 goto next_normal;
888                         if (r->id.idiag_sport != htons(sk->sk_num) &&
889                             r->id.idiag_sport)
890                                 goto next_normal;
891                         if (r->id.idiag_dport != sk->sk_dport &&
892                             r->id.idiag_dport)
893                                 goto next_normal;
894                         twsk_build_assert();
895
896                         if (!inet_diag_bc_sk(bc, sk))
897                                 goto next_normal;
898
899                         res = sk_diag_fill(sk, skb, r,
900                                            sk_user_ns(NETLINK_CB(cb->skb).sk),
901                                            NETLINK_CB(cb->skb).portid,
902                                            cb->nlh->nlmsg_seq, NLM_F_MULTI,
903                                            cb->nlh);
904                         if (res < 0) {
905                                 spin_unlock_bh(lock);
906                                 goto done;
907                         }
908 next_normal:
909                         ++num;
910                 }
911
912                 spin_unlock_bh(lock);
913         }
914
915 done:
916         cb->args[1] = i;
917         cb->args[2] = num;
918 out:
919         ;
920 }
921 EXPORT_SYMBOL_GPL(inet_diag_dump_icsk);
922
923 static int __inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb,
924                             const struct inet_diag_req_v2 *r,
925                             struct nlattr *bc)
926 {
927         const struct inet_diag_handler *handler;
928         int err = 0;
929
930         handler = inet_diag_lock_handler(r->sdiag_protocol);
931         if (!IS_ERR(handler))
932                 handler->dump(skb, cb, r, bc);
933         else
934                 err = PTR_ERR(handler);
935         inet_diag_unlock_handler(handler);
936
937         return err ? : skb->len;
938 }
939
940 static int inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb)
941 {
942         int hdrlen = sizeof(struct inet_diag_req_v2);
943         struct nlattr *bc = NULL;
944
945         if (nlmsg_attrlen(cb->nlh, hdrlen))
946                 bc = nlmsg_find_attr(cb->nlh, hdrlen, INET_DIAG_REQ_BYTECODE);
947
948         return __inet_diag_dump(skb, cb, nlmsg_data(cb->nlh), bc);
949 }
950
951 static int inet_diag_type2proto(int type)
952 {
953         switch (type) {
954         case TCPDIAG_GETSOCK:
955                 return IPPROTO_TCP;
956         case DCCPDIAG_GETSOCK:
957                 return IPPROTO_DCCP;
958         default:
959                 return 0;
960         }
961 }
962
963 static int inet_diag_dump_compat(struct sk_buff *skb,
964                                  struct netlink_callback *cb)
965 {
966         struct inet_diag_req *rc = nlmsg_data(cb->nlh);
967         int hdrlen = sizeof(struct inet_diag_req);
968         struct inet_diag_req_v2 req;
969         struct nlattr *bc = NULL;
970
971         req.sdiag_family = AF_UNSPEC; /* compatibility */
972         req.sdiag_protocol = inet_diag_type2proto(cb->nlh->nlmsg_type);
973         req.idiag_ext = rc->idiag_ext;
974         req.idiag_states = rc->idiag_states;
975         req.id = rc->id;
976
977         if (nlmsg_attrlen(cb->nlh, hdrlen))
978                 bc = nlmsg_find_attr(cb->nlh, hdrlen, INET_DIAG_REQ_BYTECODE);
979
980         return __inet_diag_dump(skb, cb, &req, bc);
981 }
982
983 static int inet_diag_get_exact_compat(struct sk_buff *in_skb,
984                                       const struct nlmsghdr *nlh)
985 {
986         struct inet_diag_req *rc = nlmsg_data(nlh);
987         struct inet_diag_req_v2 req;
988
989         req.sdiag_family = rc->idiag_family;
990         req.sdiag_protocol = inet_diag_type2proto(nlh->nlmsg_type);
991         req.idiag_ext = rc->idiag_ext;
992         req.idiag_states = rc->idiag_states;
993         req.id = rc->id;
994
995         return inet_diag_cmd_exact(SOCK_DIAG_BY_FAMILY, in_skb, nlh, &req);
996 }
997
998 static int inet_diag_rcv_msg_compat(struct sk_buff *skb, struct nlmsghdr *nlh)
999 {
1000         int hdrlen = sizeof(struct inet_diag_req);
1001         struct net *net = sock_net(skb->sk);
1002
1003         if (nlh->nlmsg_type >= INET_DIAG_GETSOCK_MAX ||
1004             nlmsg_len(nlh) < hdrlen)
1005                 return -EINVAL;
1006
1007         if (nlh->nlmsg_flags & NLM_F_DUMP) {
1008                 if (nlmsg_attrlen(nlh, hdrlen)) {
1009                         struct nlattr *attr;
1010                         int err;
1011
1012                         attr = nlmsg_find_attr(nlh, hdrlen,
1013                                                INET_DIAG_REQ_BYTECODE);
1014                         err = inet_diag_bc_audit(attr);
1015                         if (err)
1016                                 return err;
1017                 }
1018                 {
1019                         struct netlink_dump_control c = {
1020                                 .dump = inet_diag_dump_compat,
1021                         };
1022                         return netlink_dump_start(net->diag_nlsk, skb, nlh, &c);
1023                 }
1024         }
1025
1026         return inet_diag_get_exact_compat(skb, nlh);
1027 }
1028
1029 static int inet_diag_handler_cmd(struct sk_buff *skb, struct nlmsghdr *h)
1030 {
1031         int hdrlen = sizeof(struct inet_diag_req_v2);
1032         struct net *net = sock_net(skb->sk);
1033
1034         if (nlmsg_len(h) < hdrlen)
1035                 return -EINVAL;
1036
1037         if (h->nlmsg_type == SOCK_DIAG_BY_FAMILY &&
1038             h->nlmsg_flags & NLM_F_DUMP) {
1039                 if (nlmsg_attrlen(h, hdrlen)) {
1040                         struct nlattr *attr;
1041                         int err;
1042
1043                         attr = nlmsg_find_attr(h, hdrlen,
1044                                                INET_DIAG_REQ_BYTECODE);
1045                         err = inet_diag_bc_audit(attr);
1046                         if (err)
1047                                 return err;
1048                 }
1049                 {
1050                         struct netlink_dump_control c = {
1051                                 .dump = inet_diag_dump,
1052                         };
1053                         return netlink_dump_start(net->diag_nlsk, skb, h, &c);
1054                 }
1055         }
1056
1057         return inet_diag_cmd_exact(h->nlmsg_type, skb, h, nlmsg_data(h));
1058 }
1059
1060 static
1061 int inet_diag_handler_get_info(struct sk_buff *skb, struct sock *sk)
1062 {
1063         const struct inet_diag_handler *handler;
1064         struct nlmsghdr *nlh;
1065         struct nlattr *attr;
1066         struct inet_diag_msg *r;
1067         void *info = NULL;
1068         int err = 0;
1069
1070         nlh = nlmsg_put(skb, 0, 0, SOCK_DIAG_BY_FAMILY, sizeof(*r), 0);
1071         if (!nlh)
1072                 return -ENOMEM;
1073
1074         r = nlmsg_data(nlh);
1075         memset(r, 0, sizeof(*r));
1076         inet_diag_msg_common_fill(r, sk);
1077         if (sk->sk_type == SOCK_DGRAM || sk->sk_type == SOCK_STREAM)
1078                 r->id.idiag_sport = inet_sk(sk)->inet_sport;
1079         r->idiag_state = sk->sk_state;
1080
1081         if ((err = nla_put_u8(skb, INET_DIAG_PROTOCOL, sk->sk_protocol))) {
1082                 nlmsg_cancel(skb, nlh);
1083                 return err;
1084         }
1085
1086         handler = inet_diag_lock_handler(sk->sk_protocol);
1087         if (IS_ERR(handler)) {
1088                 inet_diag_unlock_handler(handler);
1089                 nlmsg_cancel(skb, nlh);
1090                 return PTR_ERR(handler);
1091         }
1092
1093         attr = handler->idiag_info_size
1094                 ? nla_reserve(skb, INET_DIAG_INFO, handler->idiag_info_size)
1095                 : NULL;
1096         if (attr)
1097                 info = nla_data(attr);
1098
1099         handler->idiag_get_info(sk, r, info);
1100         inet_diag_unlock_handler(handler);
1101
1102         nlmsg_end(skb, nlh);
1103         return 0;
1104 }
1105
1106 static const struct sock_diag_handler inet_diag_handler = {
1107         .family = AF_INET,
1108         .dump = inet_diag_handler_cmd,
1109         .get_info = inet_diag_handler_get_info,
1110         .destroy = inet_diag_handler_cmd,
1111 };
1112
1113 static const struct sock_diag_handler inet6_diag_handler = {
1114         .family = AF_INET6,
1115         .dump = inet_diag_handler_cmd,
1116         .get_info = inet_diag_handler_get_info,
1117         .destroy = inet_diag_handler_cmd,
1118 };
1119
1120 int inet_diag_register(const struct inet_diag_handler *h)
1121 {
1122         const __u16 type = h->idiag_type;
1123         int err = -EINVAL;
1124
1125         if (type >= IPPROTO_MAX)
1126                 goto out;
1127
1128         mutex_lock(&inet_diag_table_mutex);
1129         err = -EEXIST;
1130         if (!inet_diag_table[type]) {
1131                 inet_diag_table[type] = h;
1132                 err = 0;
1133         }
1134         mutex_unlock(&inet_diag_table_mutex);
1135 out:
1136         return err;
1137 }
1138 EXPORT_SYMBOL_GPL(inet_diag_register);
1139
1140 void inet_diag_unregister(const struct inet_diag_handler *h)
1141 {
1142         const __u16 type = h->idiag_type;
1143
1144         if (type >= IPPROTO_MAX)
1145                 return;
1146
1147         mutex_lock(&inet_diag_table_mutex);
1148         inet_diag_table[type] = NULL;
1149         mutex_unlock(&inet_diag_table_mutex);
1150 }
1151 EXPORT_SYMBOL_GPL(inet_diag_unregister);
1152
1153 static int __init inet_diag_init(void)
1154 {
1155         const int inet_diag_table_size = (IPPROTO_MAX *
1156                                           sizeof(struct inet_diag_handler *));
1157         int err = -ENOMEM;
1158
1159         inet_diag_table = kzalloc(inet_diag_table_size, GFP_KERNEL);
1160         if (!inet_diag_table)
1161                 goto out;
1162
1163         err = sock_diag_register(&inet_diag_handler);
1164         if (err)
1165                 goto out_free_nl;
1166
1167         err = sock_diag_register(&inet6_diag_handler);
1168         if (err)
1169                 goto out_free_inet;
1170
1171         sock_diag_register_inet_compat(inet_diag_rcv_msg_compat);
1172 out:
1173         return err;
1174
1175 out_free_inet:
1176         sock_diag_unregister(&inet_diag_handler);
1177 out_free_nl:
1178         kfree(inet_diag_table);
1179         goto out;
1180 }
1181
1182 static void __exit inet_diag_exit(void)
1183 {
1184         sock_diag_unregister(&inet6_diag_handler);
1185         sock_diag_unregister(&inet_diag_handler);
1186         sock_diag_unregister_inet_compat(inet_diag_rcv_msg_compat);
1187         kfree(inet_diag_table);
1188 }
1189
1190 module_init(inet_diag_init);
1191 module_exit(inet_diag_exit);
1192 MODULE_LICENSE("GPL");
1193 MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 2 /* AF_INET */);
1194 MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 10 /* AF_INET6 */);