1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2019 Facebook */
4 /* WARNING: This implemenation is not necessarily the same
5 * as the tcp_dctcp.c. The purpose is mainly for testing
6 * the kernel BPF logic.
10 #include <linux/bpf.h>
11 #include <linux/types.h>
12 #include <linux/stddef.h>
13 #include <linux/tcp.h>
14 #include <bpf/bpf_helpers.h>
15 #include <bpf/bpf_tracing.h>
16 #include "bpf_tcp_helpers.h"
18 char _license
[] SEC("license") = "GPL";
23 __uint(type
, BPF_MAP_TYPE_SK_STORAGE
);
24 __uint(map_flags
, BPF_F_NO_PREALLOC
);
27 } sk_stg_map
SEC(".maps");
29 #define DCTCP_MAX_ALPHA 1024U
33 __u32 old_delivered_ce
;
41 static unsigned int dctcp_shift_g
= 4; /* g = 1/2^4 */
42 static unsigned int dctcp_alpha_on_init
= DCTCP_MAX_ALPHA
;
44 static __always_inline
void dctcp_reset(const struct tcp_sock
*tp
,
47 ca
->next_seq
= tp
->snd_nxt
;
49 ca
->old_delivered
= tp
->delivered
;
50 ca
->old_delivered_ce
= tp
->delivered_ce
;
53 SEC("struct_ops/dctcp_init")
54 void BPF_PROG(dctcp_init
, struct sock
*sk
)
56 const struct tcp_sock
*tp
= tcp_sk(sk
);
57 struct dctcp
*ca
= inet_csk_ca(sk
);
60 ca
->prior_rcv_nxt
= tp
->rcv_nxt
;
61 ca
->dctcp_alpha
= min(dctcp_alpha_on_init
, DCTCP_MAX_ALPHA
);
65 stg
= bpf_sk_storage_get(&sk_stg_map
, (void *)tp
, NULL
, 0);
68 bpf_sk_storage_delete(&sk_stg_map
, (void *)tp
);
73 SEC("struct_ops/dctcp_ssthresh")
74 __u32
BPF_PROG(dctcp_ssthresh
, struct sock
*sk
)
76 struct dctcp
*ca
= inet_csk_ca(sk
);
77 struct tcp_sock
*tp
= tcp_sk(sk
);
79 ca
->loss_cwnd
= tp
->snd_cwnd
;
80 return max(tp
->snd_cwnd
- ((tp
->snd_cwnd
* ca
->dctcp_alpha
) >> 11U), 2U);
83 SEC("struct_ops/dctcp_update_alpha")
84 void BPF_PROG(dctcp_update_alpha
, struct sock
*sk
, __u32 flags
)
86 const struct tcp_sock
*tp
= tcp_sk(sk
);
87 struct dctcp
*ca
= inet_csk_ca(sk
);
90 if (!before(tp
->snd_una
, ca
->next_seq
)) {
91 __u32 delivered_ce
= tp
->delivered_ce
- ca
->old_delivered_ce
;
92 __u32 alpha
= ca
->dctcp_alpha
;
94 /* alpha = (1 - g) * alpha + g * F */
96 alpha
-= min_not_zero(alpha
, alpha
>> dctcp_shift_g
);
98 __u32 delivered
= tp
->delivered
- ca
->old_delivered
;
100 /* If dctcp_shift_g == 1, a 32bit value would overflow
103 delivered_ce
<<= (10 - dctcp_shift_g
);
104 delivered_ce
/= max(1U, delivered
);
106 alpha
= min(alpha
+ delivered_ce
, DCTCP_MAX_ALPHA
);
108 ca
->dctcp_alpha
= alpha
;
113 static __always_inline
void dctcp_react_to_loss(struct sock
*sk
)
115 struct dctcp
*ca
= inet_csk_ca(sk
);
116 struct tcp_sock
*tp
= tcp_sk(sk
);
118 ca
->loss_cwnd
= tp
->snd_cwnd
;
119 tp
->snd_ssthresh
= max(tp
->snd_cwnd
>> 1U, 2U);
122 SEC("struct_ops/dctcp_state")
123 void BPF_PROG(dctcp_state
, struct sock
*sk
, __u8 new_state
)
125 if (new_state
== TCP_CA_Recovery
&&
126 new_state
!= BPF_CORE_READ_BITFIELD(inet_csk(sk
), icsk_ca_state
))
127 dctcp_react_to_loss(sk
);
128 /* We handle RTO in dctcp_cwnd_event to ensure that we perform only
129 * one loss-adjustment per RTT.
133 static __always_inline
void dctcp_ece_ack_cwr(struct sock
*sk
, __u32 ce_state
)
135 struct tcp_sock
*tp
= tcp_sk(sk
);
138 tp
->ecn_flags
|= TCP_ECN_DEMAND_CWR
;
140 tp
->ecn_flags
&= ~TCP_ECN_DEMAND_CWR
;
143 /* Minimal DCTP CE state machine:
145 * S: 0 <- last pkt was non-CE
146 * 1 <- last pkt was CE
148 static __always_inline
149 void dctcp_ece_ack_update(struct sock
*sk
, enum tcp_ca_event evt
,
150 __u32
*prior_rcv_nxt
, __u32
*ce_state
)
152 __u32 new_ce_state
= (evt
== CA_EVENT_ECN_IS_CE
) ? 1 : 0;
154 if (*ce_state
!= new_ce_state
) {
155 /* CE state has changed, force an immediate ACK to
156 * reflect the new CE state. If an ACK was delayed,
157 * send that first to reflect the prior CE state.
159 if (inet_csk(sk
)->icsk_ack
.pending
& ICSK_ACK_TIMER
) {
160 dctcp_ece_ack_cwr(sk
, *ce_state
);
161 bpf_tcp_send_ack(sk
, *prior_rcv_nxt
);
163 inet_csk(sk
)->icsk_ack
.pending
|= ICSK_ACK_NOW
;
165 *prior_rcv_nxt
= tcp_sk(sk
)->rcv_nxt
;
166 *ce_state
= new_ce_state
;
167 dctcp_ece_ack_cwr(sk
, new_ce_state
);
170 SEC("struct_ops/dctcp_cwnd_event")
171 void BPF_PROG(dctcp_cwnd_event
, struct sock
*sk
, enum tcp_ca_event ev
)
173 struct dctcp
*ca
= inet_csk_ca(sk
);
176 case CA_EVENT_ECN_IS_CE
:
177 case CA_EVENT_ECN_NO_CE
:
178 dctcp_ece_ack_update(sk
, ev
, &ca
->prior_rcv_nxt
, &ca
->ce_state
);
181 dctcp_react_to_loss(sk
);
184 /* Don't care for the rest. */
189 SEC("struct_ops/dctcp_cwnd_undo")
190 __u32
BPF_PROG(dctcp_cwnd_undo
, struct sock
*sk
)
192 const struct dctcp
*ca
= inet_csk_ca(sk
);
194 return max(tcp_sk(sk
)->snd_cwnd
, ca
->loss_cwnd
);
197 SEC("struct_ops/tcp_reno_cong_avoid")
198 void BPF_PROG(tcp_reno_cong_avoid
, struct sock
*sk
, __u32 ack
, __u32 acked
)
200 struct tcp_sock
*tp
= tcp_sk(sk
);
202 if (!tcp_is_cwnd_limited(sk
))
205 /* In "safe" area, increase. */
206 if (tcp_in_slow_start(tp
)) {
207 acked
= tcp_slow_start(tp
, acked
);
211 /* In dangerous area, increase slowly. */
212 tcp_cong_avoid_ai(tp
, tp
->snd_cwnd
, acked
);
216 struct tcp_congestion_ops dctcp_nouse
= {
217 .init
= (void *)dctcp_init
,
218 .set_state
= (void *)dctcp_state
,
219 .flags
= TCP_CONG_NEEDS_ECN
,
220 .name
= "bpf_dctcp_nouse",
224 struct tcp_congestion_ops dctcp
= {
225 .init
= (void *)dctcp_init
,
226 .in_ack_event
= (void *)dctcp_update_alpha
,
227 .cwnd_event
= (void *)dctcp_cwnd_event
,
228 .ssthresh
= (void *)dctcp_ssthresh
,
229 .cong_avoid
= (void *)tcp_reno_cong_avoid
,
230 .undo_cwnd
= (void *)dctcp_cwnd_undo
,
231 .set_state
= (void *)dctcp_state
,
232 .flags
= TCP_CONG_NEEDS_ECN
,