1 #include <linux/rcupdate.h>
2 #include <linux/spinlock.h>
3 #include <linux/jiffies.h>
4 #include <linux/bootmem.h>
5 #include <linux/module.h>
6 #include <linux/cache.h>
7 #include <linux/slab.h>
8 #include <linux/init.h>
11 #include <net/inet_connection_sock.h>
12 #include <net/net_namespace.h>
13 #include <net/request_sock.h>
14 #include <net/inetpeer.h>
20 int sysctl_tcp_nometrics_save __read_mostly;
22 enum tcp_metric_index {
27 TCP_METRIC_REORDERING,
33 struct tcp_metrics_block {
34 struct tcp_metrics_block __rcu *tcpm_next;
35 struct inetpeer_addr tcpm_addr;
36 unsigned long tcpm_stamp;
38 u32 tcpm_vals[TCP_METRIC_MAX];
41 static bool tcp_metric_locked(struct tcp_metrics_block *tm,
42 enum tcp_metric_index idx)
44 return tm->tcpm_lock & (1 << idx);
47 static u32 tcp_metric_get(struct tcp_metrics_block *tm,
48 enum tcp_metric_index idx)
50 return tm->tcpm_vals[idx];
53 static u32 tcp_metric_get_jiffies(struct tcp_metrics_block *tm,
54 enum tcp_metric_index idx)
56 return msecs_to_jiffies(tm->tcpm_vals[idx]);
59 static void tcp_metric_set(struct tcp_metrics_block *tm,
60 enum tcp_metric_index idx,
63 tm->tcpm_vals[idx] = val;
66 static void tcp_metric_set_msecs(struct tcp_metrics_block *tm,
67 enum tcp_metric_index idx,
70 tm->tcpm_vals[idx] = jiffies_to_msecs(val);
73 static bool addr_same(const struct inetpeer_addr *a,
74 const struct inetpeer_addr *b)
76 const struct in6_addr *a6, *b6;
78 if (a->family != b->family)
80 if (a->family == AF_INET)
81 return a->addr.a4 == b->addr.a4;
83 a6 = (const struct in6_addr *) &a->addr.a6[0];
84 b6 = (const struct in6_addr *) &b->addr.a6[0];
86 return ipv6_addr_equal(a6, b6);
89 struct tcpm_hash_bucket {
90 struct tcp_metrics_block __rcu *chain;
93 static DEFINE_SPINLOCK(tcp_metrics_lock);
95 static void tcpm_suck_dst(struct tcp_metrics_block *tm, struct dst_entry *dst)
100 if (dst_metric_locked(dst, RTAX_RTT))
101 val |= 1 << TCP_METRIC_RTT;
102 if (dst_metric_locked(dst, RTAX_RTTVAR))
103 val |= 1 << TCP_METRIC_RTTVAR;
104 if (dst_metric_locked(dst, RTAX_SSTHRESH))
105 val |= 1 << TCP_METRIC_SSTHRESH;
106 if (dst_metric_locked(dst, RTAX_CWND))
107 val |= 1 << TCP_METRIC_CWND;
108 if (dst_metric_locked(dst, RTAX_REORDERING))
109 val |= 1 << TCP_METRIC_REORDERING;
112 tm->tcpm_vals[TCP_METRIC_RTT] = dst_metric_raw(dst, RTAX_RTT);
113 tm->tcpm_vals[TCP_METRIC_RTTVAR] = dst_metric_raw(dst, RTAX_RTTVAR);
114 tm->tcpm_vals[TCP_METRIC_SSTHRESH] = dst_metric_raw(dst, RTAX_SSTHRESH);
115 tm->tcpm_vals[TCP_METRIC_CWND] = dst_metric_raw(dst, RTAX_CWND);
116 tm->tcpm_vals[TCP_METRIC_REORDERING] = dst_metric_raw(dst, RTAX_REORDERING);
119 static struct tcp_metrics_block *tcpm_new(struct dst_entry *dst,
120 struct inetpeer_addr *addr,
124 struct tcp_metrics_block *tm;
127 spin_lock_bh(&tcp_metrics_lock);
128 net = dev_net(dst->dev);
129 if (unlikely(reclaim)) {
130 struct tcp_metrics_block *oldest;
132 oldest = rcu_dereference(net->ipv4.tcp_metrics_hash[hash].chain);
133 for (tm = rcu_dereference(oldest->tcpm_next); tm;
134 tm = rcu_dereference(tm->tcpm_next)) {
135 if (time_before(tm->tcpm_stamp, oldest->tcpm_stamp))
140 tm = kmalloc(sizeof(*tm), GFP_ATOMIC);
144 tm->tcpm_addr = *addr;
145 tm->tcpm_stamp = jiffies;
147 tcpm_suck_dst(tm, dst);
149 if (likely(!reclaim)) {
150 tm->tcpm_next = net->ipv4.tcp_metrics_hash[hash].chain;
151 rcu_assign_pointer(net->ipv4.tcp_metrics_hash[hash].chain, tm);
155 spin_unlock_bh(&tcp_metrics_lock);
159 #define TCP_METRICS_TIMEOUT (60 * 60 * HZ)
161 static void tcpm_check_stamp(struct tcp_metrics_block *tm, struct dst_entry *dst)
163 if (tm && unlikely(time_after(jiffies, tm->tcpm_stamp + TCP_METRICS_TIMEOUT)))
164 tcpm_suck_dst(tm, dst);
167 #define TCP_METRICS_RECLAIM_DEPTH 5
168 #define TCP_METRICS_RECLAIM_PTR (struct tcp_metrics_block *) 0x1UL
170 static struct tcp_metrics_block *tcp_get_encode(struct tcp_metrics_block *tm, int depth)
174 if (depth > TCP_METRICS_RECLAIM_DEPTH)
175 return TCP_METRICS_RECLAIM_PTR;
179 static struct tcp_metrics_block *__tcp_get_metrics(const struct inetpeer_addr *addr,
180 struct net *net, unsigned int hash)
182 struct tcp_metrics_block *tm;
185 for (tm = rcu_dereference(net->ipv4.tcp_metrics_hash[hash].chain); tm;
186 tm = rcu_dereference(tm->tcpm_next)) {
187 if (addr_same(&tm->tcpm_addr, addr))
191 return tcp_get_encode(tm, depth);
194 static struct tcp_metrics_block *__tcp_get_metrics_req(struct request_sock *req,
195 struct dst_entry *dst)
197 struct tcp_metrics_block *tm;
198 struct inetpeer_addr addr;
202 addr.family = req->rsk_ops->family;
203 switch (addr.family) {
205 addr.addr.a4 = inet_rsk(req)->rmt_addr;
206 hash = (__force unsigned int) addr.addr.a4;
209 *(struct in6_addr *)addr.addr.a6 = inet6_rsk(req)->rmt_addr;
210 hash = ((__force unsigned int) addr.addr.a6[0] ^
211 (__force unsigned int) addr.addr.a6[1] ^
212 (__force unsigned int) addr.addr.a6[2] ^
213 (__force unsigned int) addr.addr.a6[3]);
219 hash ^= (hash >> 24) ^ (hash >> 16) ^ (hash >> 8);
221 net = dev_net(dst->dev);
222 hash &= net->ipv4.tcp_metrics_hash_mask;
224 for (tm = rcu_dereference(net->ipv4.tcp_metrics_hash[hash].chain); tm;
225 tm = rcu_dereference(tm->tcpm_next)) {
226 if (addr_same(&tm->tcpm_addr, &addr))
229 tcpm_check_stamp(tm, dst);
233 static struct tcp_metrics_block *tcp_get_metrics(struct sock *sk,
234 struct dst_entry *dst,
237 struct tcp_metrics_block *tm;
238 struct inetpeer_addr addr;
243 addr.family = sk->sk_family;
244 switch (addr.family) {
246 addr.addr.a4 = inet_sk(sk)->inet_daddr;
247 hash = (__force unsigned int) addr.addr.a4;
250 *(struct in6_addr *)addr.addr.a6 = inet6_sk(sk)->daddr;
251 hash = ((__force unsigned int) addr.addr.a6[0] ^
252 (__force unsigned int) addr.addr.a6[1] ^
253 (__force unsigned int) addr.addr.a6[2] ^
254 (__force unsigned int) addr.addr.a6[3]);
260 hash ^= (hash >> 24) ^ (hash >> 16) ^ (hash >> 8);
262 net = dev_net(dst->dev);
263 hash &= net->ipv4.tcp_metrics_hash_mask;
265 tm = __tcp_get_metrics(&addr, net, hash);
267 if (tm == TCP_METRICS_RECLAIM_PTR) {
272 tm = tcpm_new(dst, &addr, hash, reclaim);
274 tcpm_check_stamp(tm, dst);
279 /* Save metrics learned by this TCP session. This function is called
280 * only, when TCP finishes successfully i.e. when it enters TIME-WAIT
281 * or goes from LAST-ACK to CLOSE.
283 void tcp_update_metrics(struct sock *sk)
285 const struct inet_connection_sock *icsk = inet_csk(sk);
286 struct dst_entry *dst = __sk_dst_get(sk);
287 struct tcp_sock *tp = tcp_sk(sk);
288 struct tcp_metrics_block *tm;
293 if (sysctl_tcp_nometrics_save || !dst)
296 if (dst->flags & DST_HOST)
300 if (icsk->icsk_backoff || !tp->srtt) {
301 /* This session failed to estimate rtt. Why?
302 * Probably, no packets returned in time. Reset our
305 tm = tcp_get_metrics(sk, dst, false);
306 if (tm && !tcp_metric_locked(tm, TCP_METRIC_RTT))
307 tcp_metric_set(tm, TCP_METRIC_RTT, 0);
310 tm = tcp_get_metrics(sk, dst, true);
315 rtt = tcp_metric_get_jiffies(tm, TCP_METRIC_RTT);
318 /* If newly calculated rtt larger than stored one, store new
319 * one. Otherwise, use EWMA. Remember, rtt overestimation is
320 * always better than underestimation.
322 if (!tcp_metric_locked(tm, TCP_METRIC_RTT)) {
327 tcp_metric_set_msecs(tm, TCP_METRIC_RTT, rtt);
330 if (!tcp_metric_locked(tm, TCP_METRIC_RTTVAR)) {
336 /* Scale deviation to rttvar fixed point */
341 var = tcp_metric_get_jiffies(tm, TCP_METRIC_RTTVAR);
345 var -= (var - m) >> 2;
347 tcp_metric_set_msecs(tm, TCP_METRIC_RTTVAR, var);
350 if (tcp_in_initial_slowstart(tp)) {
351 /* Slow start still did not finish. */
352 if (!tcp_metric_locked(tm, TCP_METRIC_SSTHRESH)) {
353 val = tcp_metric_get(tm, TCP_METRIC_SSTHRESH);
354 if (val && (tp->snd_cwnd >> 1) > val)
355 tcp_metric_set(tm, TCP_METRIC_SSTHRESH,
358 if (!tcp_metric_locked(tm, TCP_METRIC_CWND)) {
359 val = tcp_metric_get(tm, TCP_METRIC_CWND);
360 if (tp->snd_cwnd > val)
361 tcp_metric_set(tm, TCP_METRIC_CWND,
364 } else if (tp->snd_cwnd > tp->snd_ssthresh &&
365 icsk->icsk_ca_state == TCP_CA_Open) {
366 /* Cong. avoidance phase, cwnd is reliable. */
367 if (!tcp_metric_locked(tm, TCP_METRIC_SSTHRESH))
368 tcp_metric_set(tm, TCP_METRIC_SSTHRESH,
369 max(tp->snd_cwnd >> 1, tp->snd_ssthresh));
370 if (!tcp_metric_locked(tm, TCP_METRIC_CWND)) {
371 val = tcp_metric_get(tm, TCP_METRIC_CWND);
372 tcp_metric_set(tm, RTAX_CWND, (val + tp->snd_cwnd) >> 1);
375 /* Else slow start did not finish, cwnd is non-sense,
376 * ssthresh may be also invalid.
378 if (!tcp_metric_locked(tm, TCP_METRIC_CWND)) {
379 val = tcp_metric_get(tm, TCP_METRIC_CWND);
380 tcp_metric_set(tm, TCP_METRIC_CWND,
381 (val + tp->snd_ssthresh) >> 1);
383 if (!tcp_metric_locked(tm, TCP_METRIC_SSTHRESH)) {
384 val = tcp_metric_get(tm, TCP_METRIC_SSTHRESH);
385 if (val && tp->snd_ssthresh > val)
386 tcp_metric_set(tm, TCP_METRIC_SSTHRESH,
389 if (!tcp_metric_locked(tm, TCP_METRIC_REORDERING)) {
390 val = tcp_metric_get(tm, TCP_METRIC_REORDERING);
391 if (val < tp->reordering &&
392 tp->reordering != sysctl_tcp_reordering)
393 tcp_metric_set(tm, TCP_METRIC_REORDERING,
397 tm->tcpm_stamp = jiffies;
402 /* Initialize metrics on socket. */
404 void tcp_init_metrics(struct sock *sk)
406 struct dst_entry *dst = __sk_dst_get(sk);
407 struct tcp_sock *tp = tcp_sk(sk);
408 struct tcp_metrics_block *tm;
417 tm = tcp_get_metrics(sk, dst, true);
423 if (tcp_metric_locked(tm, TCP_METRIC_CWND))
424 tp->snd_cwnd_clamp = tcp_metric_get(tm, TCP_METRIC_CWND);
426 val = tcp_metric_get(tm, TCP_METRIC_SSTHRESH);
428 tp->snd_ssthresh = val;
429 if (tp->snd_ssthresh > tp->snd_cwnd_clamp)
430 tp->snd_ssthresh = tp->snd_cwnd_clamp;
432 /* ssthresh may have been reduced unnecessarily during.
433 * 3WHS. Restore it back to its initial default.
435 tp->snd_ssthresh = TCP_INFINITE_SSTHRESH;
437 val = tcp_metric_get(tm, TCP_METRIC_REORDERING);
438 if (val && tp->reordering != val) {
439 tcp_disable_fack(tp);
440 tcp_disable_early_retrans(tp);
441 tp->reordering = val;
444 val = tcp_metric_get(tm, TCP_METRIC_RTT);
445 if (val == 0 || tp->srtt == 0) {
449 /* Initial rtt is determined from SYN,SYN-ACK.
450 * The segment is small and rtt may appear much
451 * less than real one. Use per-dst memory
452 * to make it more realistic.
454 * A bit of theory. RTT is time passed after "normal" sized packet
455 * is sent until it is ACKed. In normal circumstances sending small
456 * packets force peer to delay ACKs and calculation is correct too.
457 * The algorithm is adaptive and, provided we follow specs, it
458 * NEVER underestimate RTT. BUT! If peer tries to make some clever
459 * tricks sort of "quick acks" for time long enough to decrease RTT
460 * to low value, and then abruptly stops to do it and starts to delay
461 * ACKs, wait for troubles.
463 val = msecs_to_jiffies(val);
464 if (val > tp->srtt) {
466 tp->rtt_seq = tp->snd_nxt;
468 val = tcp_metric_get_jiffies(tm, TCP_METRIC_RTTVAR);
469 if (val > tp->mdev) {
471 tp->mdev_max = tp->rttvar = max(tp->mdev, tcp_rto_min(sk));
478 /* RFC6298: 5.7 We've failed to get a valid RTT sample from
479 * 3WHS. This is most likely due to retransmission,
480 * including spurious one. Reset the RTO back to 3secs
481 * from the more aggressive 1sec to avoid more spurious
484 tp->mdev = tp->mdev_max = tp->rttvar = TCP_TIMEOUT_FALLBACK;
485 inet_csk(sk)->icsk_rto = TCP_TIMEOUT_FALLBACK;
487 /* Cut cwnd down to 1 per RFC5681 if SYN or SYN-ACK has been
488 * retransmitted. In light of RFC6298 more aggressive 1sec
489 * initRTO, we only reset cwnd when more than 1 SYN/SYN-ACK
490 * retransmission has occurred.
492 if (tp->total_retrans > 1)
495 tp->snd_cwnd = tcp_init_cwnd(tp, dst);
496 tp->snd_cwnd_stamp = tcp_time_stamp;
499 bool tcp_peer_is_proven(struct request_sock *req, struct dst_entry *dst)
501 struct tcp_metrics_block *tm;
508 tm = __tcp_get_metrics_req(req, dst);
509 if (tm && tcp_metric_get(tm, TCP_METRIC_RTT))
517 EXPORT_SYMBOL_GPL(tcp_peer_is_proven);
519 static unsigned long tcpmhash_entries;
520 static int __init set_tcpmhash_entries(char *str)
527 ret = kstrtoul(str, 0, &tcpmhash_entries);
533 __setup("tcpmhash_entries=", set_tcpmhash_entries);
535 static int __net_init tcp_net_metrics_init(struct net *net)
539 slots = tcpmhash_entries;
541 if (totalram_pages >= 128 * 1024)
547 size = slots * sizeof(struct tcpm_hash_bucket);
549 net->ipv4.tcp_metrics_hash = kzalloc(size, GFP_KERNEL);
550 if (!net->ipv4.tcp_metrics_hash)
553 net->ipv4.tcp_metrics_hash_mask = (slots - 1);
558 static void __net_exit tcp_net_metrics_exit(struct net *net)
560 kfree(net->ipv4.tcp_metrics_hash);
563 static __net_initdata struct pernet_operations tcp_net_metrics_ops = {
564 .init = tcp_net_metrics_init,
565 .exit = tcp_net_metrics_exit,
568 void __init tcp_metrics_init(void)
570 register_pernet_subsys(&tcp_net_metrics_ops);