1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2019 Facebook */
4 #include <linux/init.h>
5 #include <linux/types.h>
6 #include <linux/bpf_verifier.h>
9 #include <linux/btf_ids.h>
10 #include <linux/filter.h>
12 #include <net/bpf_sk_storage.h>
14 /* "extern" is to avoid sparse warning. It is only used in bpf_struct_ops.c. */
15 static struct bpf_struct_ops bpf_tcp_congestion_ops
;
17 static const struct btf_type
*tcp_sock_type
;
18 static u32 tcp_sock_id
, sock_id
;
19 static const struct btf_type
*tcp_congestion_ops_type
;
21 static int bpf_tcp_ca_init(struct btf
*btf
)
25 type_id
= btf_find_by_name_kind(btf
, "sock", BTF_KIND_STRUCT
);
30 type_id
= btf_find_by_name_kind(btf
, "tcp_sock", BTF_KIND_STRUCT
);
33 tcp_sock_id
= type_id
;
34 tcp_sock_type
= btf_type_by_id(btf
, tcp_sock_id
);
36 type_id
= btf_find_by_name_kind(btf
, "tcp_congestion_ops", BTF_KIND_STRUCT
);
39 tcp_congestion_ops_type
= btf_type_by_id(btf
, type_id
);
44 static bool bpf_tcp_ca_is_valid_access(int off
, int size
,
45 enum bpf_access_type type
,
46 const struct bpf_prog
*prog
,
47 struct bpf_insn_access_aux
*info
)
49 if (!bpf_tracing_btf_ctx_access(off
, size
, type
, prog
, info
))
52 if (base_type(info
->reg_type
) == PTR_TO_BTF_ID
&&
53 !bpf_type_has_unsafe_modifiers(info
->reg_type
) &&
54 info
->btf_id
== sock_id
)
55 /* promote it to tcp_sock */
56 info
->btf_id
= tcp_sock_id
;
61 static int bpf_tcp_ca_btf_struct_access(struct bpf_verifier_log
*log
,
62 const struct bpf_reg_state
*reg
,
65 const struct btf_type
*t
;
68 t
= btf_type_by_id(reg
->btf
, reg
->btf_id
);
69 if (t
!= tcp_sock_type
) {
70 bpf_log(log
, "only read is supported\n");
75 case offsetof(struct sock
, sk_pacing_rate
):
76 end
= offsetofend(struct sock
, sk_pacing_rate
);
78 case offsetof(struct sock
, sk_pacing_status
):
79 end
= offsetofend(struct sock
, sk_pacing_status
);
81 case bpf_ctx_range(struct inet_connection_sock
, icsk_ca_priv
):
82 end
= offsetofend(struct inet_connection_sock
, icsk_ca_priv
);
84 case offsetof(struct inet_connection_sock
, icsk_ack
.pending
):
85 end
= offsetofend(struct inet_connection_sock
,
88 case offsetof(struct tcp_sock
, snd_cwnd
):
89 end
= offsetofend(struct tcp_sock
, snd_cwnd
);
91 case offsetof(struct tcp_sock
, snd_cwnd_cnt
):
92 end
= offsetofend(struct tcp_sock
, snd_cwnd_cnt
);
94 case offsetof(struct tcp_sock
, snd_cwnd_stamp
):
95 end
= offsetofend(struct tcp_sock
, snd_cwnd_stamp
);
97 case offsetof(struct tcp_sock
, snd_ssthresh
):
98 end
= offsetofend(struct tcp_sock
, snd_ssthresh
);
100 case offsetof(struct tcp_sock
, ecn_flags
):
101 end
= offsetofend(struct tcp_sock
, ecn_flags
);
103 case offsetof(struct tcp_sock
, app_limited
):
104 end
= offsetofend(struct tcp_sock
, app_limited
);
107 bpf_log(log
, "no write support to tcp_sock at off %d\n", off
);
111 if (off
+ size
> end
) {
113 "write access at off %d with size %d beyond the member of tcp_sock ended at %zu\n",
121 BPF_CALL_2(bpf_tcp_send_ack
, struct tcp_sock
*, tp
, u32
, rcv_nxt
)
123 /* bpf_tcp_ca prog cannot have NULL tp */
124 __tcp_send_ack((struct sock
*)tp
, rcv_nxt
);
128 static const struct bpf_func_proto bpf_tcp_send_ack_proto
= {
129 .func
= bpf_tcp_send_ack
,
131 /* In case we want to report error later */
132 .ret_type
= RET_INTEGER
,
133 .arg1_type
= ARG_PTR_TO_BTF_ID
,
134 .arg1_btf_id
= &tcp_sock_id
,
135 .arg2_type
= ARG_ANYTHING
,
138 static u32
prog_ops_moff(const struct bpf_prog
*prog
)
140 const struct btf_member
*m
;
141 const struct btf_type
*t
;
144 midx
= prog
->expected_attach_type
;
145 t
= tcp_congestion_ops_type
;
146 m
= &btf_type_member(t
)[midx
];
148 return __btf_member_bit_offset(t
, m
) / 8;
151 static const struct bpf_func_proto
*
152 bpf_tcp_ca_get_func_proto(enum bpf_func_id func_id
,
153 const struct bpf_prog
*prog
)
156 case BPF_FUNC_tcp_send_ack
:
157 return &bpf_tcp_send_ack_proto
;
158 case BPF_FUNC_sk_storage_get
:
159 return &bpf_sk_storage_get_proto
;
160 case BPF_FUNC_sk_storage_delete
:
161 return &bpf_sk_storage_delete_proto
;
162 case BPF_FUNC_setsockopt
:
163 /* Does not allow release() to call setsockopt.
164 * release() is called when the current bpf-tcp-cc
165 * is retiring. It is not allowed to call
166 * setsockopt() to make further changes which
167 * may potentially allocate new resources.
169 if (prog_ops_moff(prog
) !=
170 offsetof(struct tcp_congestion_ops
, release
))
171 return &bpf_sk_setsockopt_proto
;
173 case BPF_FUNC_getsockopt
:
174 /* Since get/setsockopt is usually expected to
175 * be available together, disable getsockopt for
176 * release also to avoid usage surprise.
177 * The bpf-tcp-cc already has a more powerful way
178 * to read tcp_sock from the PTR_TO_BTF_ID.
180 if (prog_ops_moff(prog
) !=
181 offsetof(struct tcp_congestion_ops
, release
))
182 return &bpf_sk_getsockopt_proto
;
184 case BPF_FUNC_ktime_get_coarse_ns
:
185 return &bpf_ktime_get_coarse_ns_proto
;
187 return bpf_base_func_proto(func_id
, prog
);
191 BTF_KFUNCS_START(bpf_tcp_ca_check_kfunc_ids
)
192 BTF_ID_FLAGS(func
, tcp_reno_ssthresh
)
193 BTF_ID_FLAGS(func
, tcp_reno_cong_avoid
)
194 BTF_ID_FLAGS(func
, tcp_reno_undo_cwnd
)
195 BTF_ID_FLAGS(func
, tcp_slow_start
)
196 BTF_ID_FLAGS(func
, tcp_cong_avoid_ai
)
197 BTF_KFUNCS_END(bpf_tcp_ca_check_kfunc_ids
)
199 static const struct btf_kfunc_id_set bpf_tcp_ca_kfunc_set
= {
200 .owner
= THIS_MODULE
,
201 .set
= &bpf_tcp_ca_check_kfunc_ids
,
204 static const struct bpf_verifier_ops bpf_tcp_ca_verifier_ops
= {
205 .get_func_proto
= bpf_tcp_ca_get_func_proto
,
206 .is_valid_access
= bpf_tcp_ca_is_valid_access
,
207 .btf_struct_access
= bpf_tcp_ca_btf_struct_access
,
210 static int bpf_tcp_ca_init_member(const struct btf_type
*t
,
211 const struct btf_member
*member
,
212 void *kdata
, const void *udata
)
214 const struct tcp_congestion_ops
*utcp_ca
;
215 struct tcp_congestion_ops
*tcp_ca
;
218 utcp_ca
= (const struct tcp_congestion_ops
*)udata
;
219 tcp_ca
= (struct tcp_congestion_ops
*)kdata
;
221 moff
= __btf_member_bit_offset(t
, member
) / 8;
223 case offsetof(struct tcp_congestion_ops
, flags
):
224 if (utcp_ca
->flags
& ~TCP_CONG_MASK
)
226 tcp_ca
->flags
= utcp_ca
->flags
;
228 case offsetof(struct tcp_congestion_ops
, name
):
229 if (bpf_obj_name_cpy(tcp_ca
->name
, utcp_ca
->name
,
230 sizeof(tcp_ca
->name
)) <= 0)
238 static int bpf_tcp_ca_reg(void *kdata
, struct bpf_link
*link
)
240 return tcp_register_congestion_control(kdata
);
243 static void bpf_tcp_ca_unreg(void *kdata
, struct bpf_link
*link
)
245 tcp_unregister_congestion_control(kdata
);
248 static int bpf_tcp_ca_update(void *kdata
, void *old_kdata
, struct bpf_link
*link
)
250 return tcp_update_congestion_control(kdata
, old_kdata
);
253 static int bpf_tcp_ca_validate(void *kdata
)
255 return tcp_validate_congestion_control(kdata
);
258 static u32
bpf_tcp_ca_ssthresh(struct sock
*sk
)
263 static void bpf_tcp_ca_cong_avoid(struct sock
*sk
, u32 ack
, u32 acked
)
267 static void bpf_tcp_ca_set_state(struct sock
*sk
, u8 new_state
)
271 static void bpf_tcp_ca_cwnd_event(struct sock
*sk
, enum tcp_ca_event ev
)
275 static void bpf_tcp_ca_in_ack_event(struct sock
*sk
, u32 flags
)
279 static void bpf_tcp_ca_pkts_acked(struct sock
*sk
, const struct ack_sample
*sample
)
283 static u32
bpf_tcp_ca_min_tso_segs(struct sock
*sk
)
288 static void bpf_tcp_ca_cong_control(struct sock
*sk
, u32 ack
, int flag
,
289 const struct rate_sample
*rs
)
293 static u32
bpf_tcp_ca_undo_cwnd(struct sock
*sk
)
298 static u32
bpf_tcp_ca_sndbuf_expand(struct sock
*sk
)
303 static void __bpf_tcp_ca_init(struct sock
*sk
)
307 static void __bpf_tcp_ca_release(struct sock
*sk
)
311 static struct tcp_congestion_ops __bpf_ops_tcp_congestion_ops
= {
312 .ssthresh
= bpf_tcp_ca_ssthresh
,
313 .cong_avoid
= bpf_tcp_ca_cong_avoid
,
314 .set_state
= bpf_tcp_ca_set_state
,
315 .cwnd_event
= bpf_tcp_ca_cwnd_event
,
316 .in_ack_event
= bpf_tcp_ca_in_ack_event
,
317 .pkts_acked
= bpf_tcp_ca_pkts_acked
,
318 .min_tso_segs
= bpf_tcp_ca_min_tso_segs
,
319 .cong_control
= bpf_tcp_ca_cong_control
,
320 .undo_cwnd
= bpf_tcp_ca_undo_cwnd
,
321 .sndbuf_expand
= bpf_tcp_ca_sndbuf_expand
,
323 .init
= __bpf_tcp_ca_init
,
324 .release
= __bpf_tcp_ca_release
,
327 static struct bpf_struct_ops bpf_tcp_congestion_ops
= {
328 .verifier_ops
= &bpf_tcp_ca_verifier_ops
,
329 .reg
= bpf_tcp_ca_reg
,
330 .unreg
= bpf_tcp_ca_unreg
,
331 .update
= bpf_tcp_ca_update
,
332 .init_member
= bpf_tcp_ca_init_member
,
333 .init
= bpf_tcp_ca_init
,
334 .validate
= bpf_tcp_ca_validate
,
335 .name
= "tcp_congestion_ops",
336 .cfi_stubs
= &__bpf_ops_tcp_congestion_ops
,
337 .owner
= THIS_MODULE
,
340 static int __init
bpf_tcp_ca_kfunc_init(void)
344 ret
= register_btf_kfunc_id_set(BPF_PROG_TYPE_STRUCT_OPS
, &bpf_tcp_ca_kfunc_set
);
345 ret
= ret
?: register_bpf_struct_ops(&bpf_tcp_congestion_ops
, tcp_congestion_ops
);
349 late_initcall(bpf_tcp_ca_kfunc_init
);