1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2019 Facebook */
4 #include <linux/types.h>
5 #include <linux/bpf_verifier.h>
8 #include <linux/filter.h>
10 #include <net/bpf_sk_storage.h>
12 static u32 optional_ops
[] = {
13 offsetof(struct tcp_congestion_ops
, init
),
14 offsetof(struct tcp_congestion_ops
, release
),
15 offsetof(struct tcp_congestion_ops
, set_state
),
16 offsetof(struct tcp_congestion_ops
, cwnd_event
),
17 offsetof(struct tcp_congestion_ops
, in_ack_event
),
18 offsetof(struct tcp_congestion_ops
, pkts_acked
),
19 offsetof(struct tcp_congestion_ops
, min_tso_segs
),
20 offsetof(struct tcp_congestion_ops
, sndbuf_expand
),
21 offsetof(struct tcp_congestion_ops
, cong_control
),
24 static u32 unsupported_ops
[] = {
25 offsetof(struct tcp_congestion_ops
, get_info
),
28 static const struct btf_type
*tcp_sock_type
;
29 static u32 tcp_sock_id
, sock_id
;
31 static int btf_sk_storage_get_ids
[5];
32 static struct bpf_func_proto btf_sk_storage_get_proto __read_mostly
;
34 static int btf_sk_storage_delete_ids
[5];
35 static struct bpf_func_proto btf_sk_storage_delete_proto __read_mostly
;
37 static void convert_sk_func_proto(struct bpf_func_proto
*to
, int *to_btf_ids
,
38 const struct bpf_func_proto
*from
)
43 to
->btf_id
= to_btf_ids
;
44 for (i
= 0; i
< ARRAY_SIZE(to
->arg_type
); i
++) {
45 if (to
->arg_type
[i
] == ARG_PTR_TO_SOCKET
) {
46 to
->arg_type
[i
] = ARG_PTR_TO_BTF_ID
;
47 to
->btf_id
[i
] = tcp_sock_id
;
52 static int bpf_tcp_ca_init(struct btf
*btf
)
56 type_id
= btf_find_by_name_kind(btf
, "sock", BTF_KIND_STRUCT
);
61 type_id
= btf_find_by_name_kind(btf
, "tcp_sock", BTF_KIND_STRUCT
);
64 tcp_sock_id
= type_id
;
65 tcp_sock_type
= btf_type_by_id(btf
, tcp_sock_id
);
67 convert_sk_func_proto(&btf_sk_storage_get_proto
,
68 btf_sk_storage_get_ids
,
69 &bpf_sk_storage_get_proto
);
70 convert_sk_func_proto(&btf_sk_storage_delete_proto
,
71 btf_sk_storage_delete_ids
,
72 &bpf_sk_storage_delete_proto
);
77 static bool is_optional(u32 member_offset
)
81 for (i
= 0; i
< ARRAY_SIZE(optional_ops
); i
++) {
82 if (member_offset
== optional_ops
[i
])
89 static bool is_unsupported(u32 member_offset
)
93 for (i
= 0; i
< ARRAY_SIZE(unsupported_ops
); i
++) {
94 if (member_offset
== unsupported_ops
[i
])
101 extern struct btf
*btf_vmlinux
;
103 static bool bpf_tcp_ca_is_valid_access(int off
, int size
,
104 enum bpf_access_type type
,
105 const struct bpf_prog
*prog
,
106 struct bpf_insn_access_aux
*info
)
108 if (off
< 0 || off
>= sizeof(__u64
) * MAX_BPF_FUNC_ARGS
)
110 if (type
!= BPF_READ
)
115 if (!btf_ctx_access(off
, size
, type
, prog
, info
))
118 if (info
->reg_type
== PTR_TO_BTF_ID
&& info
->btf_id
== sock_id
)
119 /* promote it to tcp_sock */
120 info
->btf_id
= tcp_sock_id
;
125 static int bpf_tcp_ca_btf_struct_access(struct bpf_verifier_log
*log
,
126 const struct btf_type
*t
, int off
,
127 int size
, enum bpf_access_type atype
,
132 if (atype
== BPF_READ
)
133 return btf_struct_access(log
, t
, off
, size
, atype
, next_btf_id
);
135 if (t
!= tcp_sock_type
) {
136 bpf_log(log
, "only read is supported\n");
141 case bpf_ctx_range(struct inet_connection_sock
, icsk_ca_priv
):
142 end
= offsetofend(struct inet_connection_sock
, icsk_ca_priv
);
144 case offsetof(struct inet_connection_sock
, icsk_ack
.pending
):
145 end
= offsetofend(struct inet_connection_sock
,
148 case offsetof(struct tcp_sock
, snd_cwnd
):
149 end
= offsetofend(struct tcp_sock
, snd_cwnd
);
151 case offsetof(struct tcp_sock
, snd_cwnd_cnt
):
152 end
= offsetofend(struct tcp_sock
, snd_cwnd_cnt
);
154 case offsetof(struct tcp_sock
, snd_ssthresh
):
155 end
= offsetofend(struct tcp_sock
, snd_ssthresh
);
157 case offsetof(struct tcp_sock
, ecn_flags
):
158 end
= offsetofend(struct tcp_sock
, ecn_flags
);
161 bpf_log(log
, "no write support to tcp_sock at off %d\n", off
);
165 if (off
+ size
> end
) {
167 "write access at off %d with size %d beyond the member of tcp_sock ended at %zu\n",
175 BPF_CALL_2(bpf_tcp_send_ack
, struct tcp_sock
*, tp
, u32
, rcv_nxt
)
177 /* bpf_tcp_ca prog cannot have NULL tp */
178 __tcp_send_ack((struct sock
*)tp
, rcv_nxt
);
182 static const struct bpf_func_proto bpf_tcp_send_ack_proto
= {
183 .func
= bpf_tcp_send_ack
,
185 /* In case we want to report error later */
186 .ret_type
= RET_INTEGER
,
187 .arg1_type
= ARG_PTR_TO_BTF_ID
,
188 .arg2_type
= ARG_ANYTHING
,
189 .btf_id
= &tcp_sock_id
,
192 static const struct bpf_func_proto
*
193 bpf_tcp_ca_get_func_proto(enum bpf_func_id func_id
,
194 const struct bpf_prog
*prog
)
197 case BPF_FUNC_tcp_send_ack
:
198 return &bpf_tcp_send_ack_proto
;
199 case BPF_FUNC_sk_storage_get
:
200 return &btf_sk_storage_get_proto
;
201 case BPF_FUNC_sk_storage_delete
:
202 return &btf_sk_storage_delete_proto
;
204 return bpf_base_func_proto(func_id
);
208 static const struct bpf_verifier_ops bpf_tcp_ca_verifier_ops
= {
209 .get_func_proto
= bpf_tcp_ca_get_func_proto
,
210 .is_valid_access
= bpf_tcp_ca_is_valid_access
,
211 .btf_struct_access
= bpf_tcp_ca_btf_struct_access
,
214 static int bpf_tcp_ca_init_member(const struct btf_type
*t
,
215 const struct btf_member
*member
,
216 void *kdata
, const void *udata
)
218 const struct tcp_congestion_ops
*utcp_ca
;
219 struct tcp_congestion_ops
*tcp_ca
;
223 utcp_ca
= (const struct tcp_congestion_ops
*)udata
;
224 tcp_ca
= (struct tcp_congestion_ops
*)kdata
;
226 moff
= btf_member_bit_offset(t
, member
) / 8;
228 case offsetof(struct tcp_congestion_ops
, flags
):
229 if (utcp_ca
->flags
& ~TCP_CONG_MASK
)
231 tcp_ca
->flags
= utcp_ca
->flags
;
233 case offsetof(struct tcp_congestion_ops
, name
):
234 if (bpf_obj_name_cpy(tcp_ca
->name
, utcp_ca
->name
,
235 sizeof(tcp_ca
->name
)) <= 0)
237 if (tcp_ca_find(utcp_ca
->name
))
242 if (!btf_type_resolve_func_ptr(btf_vmlinux
, member
->type
, NULL
))
245 /* Ensure bpf_prog is provided for compulsory func ptr */
246 prog_fd
= (int)(*(unsigned long *)(udata
+ moff
));
247 if (!prog_fd
&& !is_optional(moff
) && !is_unsupported(moff
))
253 static int bpf_tcp_ca_check_member(const struct btf_type
*t
,
254 const struct btf_member
*member
)
256 if (is_unsupported(btf_member_bit_offset(t
, member
) / 8))
261 static int bpf_tcp_ca_reg(void *kdata
)
263 return tcp_register_congestion_control(kdata
);
266 static void bpf_tcp_ca_unreg(void *kdata
)
268 tcp_unregister_congestion_control(kdata
);
271 /* Avoid sparse warning. It is only used in bpf_struct_ops.c. */
272 extern struct bpf_struct_ops bpf_tcp_congestion_ops
;
274 struct bpf_struct_ops bpf_tcp_congestion_ops
= {
275 .verifier_ops
= &bpf_tcp_ca_verifier_ops
,
276 .reg
= bpf_tcp_ca_reg
,
277 .unreg
= bpf_tcp_ca_unreg
,
278 .check_member
= bpf_tcp_ca_check_member
,
279 .init_member
= bpf_tcp_ca_init_member
,
280 .init
= bpf_tcp_ca_init
,
281 .name
= "tcp_congestion_ops",