aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--include/net/tcp.h2
-rw-r--r--net/ipv4/tcp_input.c4
2 files changed, 5 insertions, 1 deletions
diff --git a/include/net/tcp.h b/include/net/tcp.h
index 349204130d84..1aa9628ae608 100644
--- a/include/net/tcp.h
+++ b/include/net/tcp.h
@@ -917,6 +917,8 @@ struct tcp_congestion_ops {
917 void (*pkts_acked)(struct sock *sk, const struct ack_sample *sample); 917 void (*pkts_acked)(struct sock *sk, const struct ack_sample *sample);
918 /* suggest number of segments for each skb to transmit (optional) */ 918 /* suggest number of segments for each skb to transmit (optional) */
919 u32 (*tso_segs_goal)(struct sock *sk); 919 u32 (*tso_segs_goal)(struct sock *sk);
920 /* returns the multiplier used in tcp_sndbuf_expand (optional) */
921 u32 (*sndbuf_expand)(struct sock *sk);
920 /* get info for inet_diag (optional) */ 922 /* get info for inet_diag (optional) */
921 size_t (*get_info)(struct sock *sk, u32 ext, int *attr, 923 size_t (*get_info)(struct sock *sk, u32 ext, int *attr,
922 union tcp_cc_info *info); 924 union tcp_cc_info *info);
diff --git a/net/ipv4/tcp_input.c b/net/ipv4/tcp_input.c
index d9ed4bb96f74..13a2e70141f5 100644
--- a/net/ipv4/tcp_input.c
+++ b/net/ipv4/tcp_input.c
@@ -289,6 +289,7 @@ static bool tcp_ecn_rcv_ecn_echo(const struct tcp_sock *tp, const struct tcphdr
289static void tcp_sndbuf_expand(struct sock *sk) 289static void tcp_sndbuf_expand(struct sock *sk)
290{ 290{
291 const struct tcp_sock *tp = tcp_sk(sk); 291 const struct tcp_sock *tp = tcp_sk(sk);
292 const struct tcp_congestion_ops *ca_ops = inet_csk(sk)->icsk_ca_ops;
292 int sndmem, per_mss; 293 int sndmem, per_mss;
293 u32 nr_segs; 294 u32 nr_segs;
294 295
@@ -309,7 +310,8 @@ static void tcp_sndbuf_expand(struct sock *sk)
309 * Cubic needs 1.7 factor, rounded to 2 to include 310 * Cubic needs 1.7 factor, rounded to 2 to include
310 * extra cushion (application might react slowly to POLLOUT) 311 * extra cushion (application might react slowly to POLLOUT)
311 */ 312 */
312 sndmem = 2 * nr_segs * per_mss; 313 sndmem = ca_ops->sndbuf_expand ? ca_ops->sndbuf_expand(sk) : 2;
314 sndmem *= nr_segs * per_mss;
313 315
314 if (sk->sk_sndbuf < sndmem) 316 if (sk->sk_sndbuf < sndmem)
315 sk->sk_sndbuf = min(sndmem, sysctl_tcp_wmem[2]); 317 sk->sk_sndbuf = min(sndmem, sysctl_tcp_wmem[2]);