1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2019 Facebook */
4 /* WARNING: This implemenation is not necessarily the same
5 * as the tcp_dctcp.c. The purpose is mainly for testing
6 * the kernel BPF logic.
10 #include <linux/types.h>
11 #include <bpf/bpf_helpers.h>
12 #include "bpf_trace_helpers.h"
13 #include "bpf_tcp_helpers.h"
15 char _license
[] SEC("license") = "GPL";
17 #define DCTCP_MAX_ALPHA 1024U
21 __u32 old_delivered_ce
;
29 static unsigned int dctcp_shift_g
= 4; /* g = 1/2^4 */
30 static unsigned int dctcp_alpha_on_init
= DCTCP_MAX_ALPHA
;
32 static __always_inline
void dctcp_reset(const struct tcp_sock
*tp
,
35 ca
->next_seq
= tp
->snd_nxt
;
37 ca
->old_delivered
= tp
->delivered
;
38 ca
->old_delivered_ce
= tp
->delivered_ce
;
41 SEC("struct_ops/dctcp_init")
42 void BPF_PROG(dctcp_init
, struct sock
*sk
)
44 const struct tcp_sock
*tp
= tcp_sk(sk
);
45 struct dctcp
*ca
= inet_csk_ca(sk
);
47 ca
->prior_rcv_nxt
= tp
->rcv_nxt
;
48 ca
->dctcp_alpha
= min(dctcp_alpha_on_init
, DCTCP_MAX_ALPHA
);
55 SEC("struct_ops/dctcp_ssthresh")
56 __u32
BPF_PROG(dctcp_ssthresh
, struct sock
*sk
)
58 struct dctcp
*ca
= inet_csk_ca(sk
);
59 struct tcp_sock
*tp
= tcp_sk(sk
);
61 ca
->loss_cwnd
= tp
->snd_cwnd
;
62 return max(tp
->snd_cwnd
- ((tp
->snd_cwnd
* ca
->dctcp_alpha
) >> 11U), 2U);
65 SEC("struct_ops/dctcp_update_alpha")
66 void BPF_PROG(dctcp_update_alpha
, struct sock
*sk
, __u32 flags
)
68 const struct tcp_sock
*tp
= tcp_sk(sk
);
69 struct dctcp
*ca
= inet_csk_ca(sk
);
72 if (!before(tp
->snd_una
, ca
->next_seq
)) {
73 __u32 delivered_ce
= tp
->delivered_ce
- ca
->old_delivered_ce
;
74 __u32 alpha
= ca
->dctcp_alpha
;
76 /* alpha = (1 - g) * alpha + g * F */
78 alpha
-= min_not_zero(alpha
, alpha
>> dctcp_shift_g
);
80 __u32 delivered
= tp
->delivered
- ca
->old_delivered
;
82 /* If dctcp_shift_g == 1, a 32bit value would overflow
85 delivered_ce
<<= (10 - dctcp_shift_g
);
86 delivered_ce
/= max(1U, delivered
);
88 alpha
= min(alpha
+ delivered_ce
, DCTCP_MAX_ALPHA
);
90 ca
->dctcp_alpha
= alpha
;
95 static __always_inline
void dctcp_react_to_loss(struct sock
*sk
)
97 struct dctcp
*ca
= inet_csk_ca(sk
);
98 struct tcp_sock
*tp
= tcp_sk(sk
);
100 ca
->loss_cwnd
= tp
->snd_cwnd
;
101 tp
->snd_ssthresh
= max(tp
->snd_cwnd
>> 1U, 2U);
104 SEC("struct_ops/dctcp_state")
105 void BPF_PROG(dctcp_state
, struct sock
*sk
, __u8 new_state
)
107 if (new_state
== TCP_CA_Recovery
&&
108 new_state
!= BPF_CORE_READ_BITFIELD(inet_csk(sk
), icsk_ca_state
))
109 dctcp_react_to_loss(sk
);
110 /* We handle RTO in dctcp_cwnd_event to ensure that we perform only
111 * one loss-adjustment per RTT.
115 static __always_inline
void dctcp_ece_ack_cwr(struct sock
*sk
, __u32 ce_state
)
117 struct tcp_sock
*tp
= tcp_sk(sk
);
120 tp
->ecn_flags
|= TCP_ECN_DEMAND_CWR
;
122 tp
->ecn_flags
&= ~TCP_ECN_DEMAND_CWR
;
125 /* Minimal DCTP CE state machine:
127 * S: 0 <- last pkt was non-CE
128 * 1 <- last pkt was CE
130 static __always_inline
131 void dctcp_ece_ack_update(struct sock
*sk
, enum tcp_ca_event evt
,
132 __u32
*prior_rcv_nxt
, __u32
*ce_state
)
134 __u32 new_ce_state
= (evt
== CA_EVENT_ECN_IS_CE
) ? 1 : 0;
136 if (*ce_state
!= new_ce_state
) {
137 /* CE state has changed, force an immediate ACK to
138 * reflect the new CE state. If an ACK was delayed,
139 * send that first to reflect the prior CE state.
141 if (inet_csk(sk
)->icsk_ack
.pending
& ICSK_ACK_TIMER
) {
142 dctcp_ece_ack_cwr(sk
, *ce_state
);
143 bpf_tcp_send_ack(sk
, *prior_rcv_nxt
);
145 inet_csk(sk
)->icsk_ack
.pending
|= ICSK_ACK_NOW
;
147 *prior_rcv_nxt
= tcp_sk(sk
)->rcv_nxt
;
148 *ce_state
= new_ce_state
;
149 dctcp_ece_ack_cwr(sk
, new_ce_state
);
152 SEC("struct_ops/dctcp_cwnd_event")
153 void BPF_PROG(dctcp_cwnd_event
, struct sock
*sk
, enum tcp_ca_event ev
)
155 struct dctcp
*ca
= inet_csk_ca(sk
);
158 case CA_EVENT_ECN_IS_CE
:
159 case CA_EVENT_ECN_NO_CE
:
160 dctcp_ece_ack_update(sk
, ev
, &ca
->prior_rcv_nxt
, &ca
->ce_state
);
163 dctcp_react_to_loss(sk
);
166 /* Don't care for the rest. */
171 SEC("struct_ops/dctcp_cwnd_undo")
172 __u32
BPF_PROG(dctcp_cwnd_undo
, struct sock
*sk
)
174 const struct dctcp
*ca
= inet_csk_ca(sk
);
176 return max(tcp_sk(sk
)->snd_cwnd
, ca
->loss_cwnd
);
179 SEC("struct_ops/tcp_reno_cong_avoid")
180 void BPF_PROG(tcp_reno_cong_avoid
, struct sock
*sk
, __u32 ack
, __u32 acked
)
182 struct tcp_sock
*tp
= tcp_sk(sk
);
184 if (!tcp_is_cwnd_limited(sk
))
187 /* In "safe" area, increase. */
188 if (tcp_in_slow_start(tp
)) {
189 acked
= tcp_slow_start(tp
, acked
);
193 /* In dangerous area, increase slowly. */
194 tcp_cong_avoid_ai(tp
, tp
->snd_cwnd
, acked
);
198 struct tcp_congestion_ops dctcp_nouse
= {
199 .init
= (void *)dctcp_init
,
200 .set_state
= (void *)dctcp_state
,
201 .flags
= TCP_CONG_NEEDS_ECN
,
202 .name
= "bpf_dctcp_nouse",
206 struct tcp_congestion_ops dctcp
= {
207 .init
= (void *)dctcp_init
,
208 .in_ack_event
= (void *)dctcp_update_alpha
,
209 .cwnd_event
= (void *)dctcp_cwnd_event
,
210 .ssthresh
= (void *)dctcp_ssthresh
,
211 .cong_avoid
= (void *)tcp_reno_cong_avoid
,
212 .undo_cwnd
= (void *)dctcp_cwnd_undo
,
213 .set_state
= (void *)dctcp_state
,
214 .flags
= TCP_CONG_NEEDS_ECN
,