1 // SPDX-License-Identifier: GPL-2.0
10 #include <linux/filter.h>
11 #include <linux/bpf.h>
12 #include <linux/if_packet.h>
13 #include <linux/if_vlan.h>
14 #include <linux/virtio_net.h>
16 #include <net/ethernet.h>
17 #include <netinet/ip.h>
18 #include <netinet/udp.h>
27 #include <sys/socket.h>
29 #include <sys/types.h>
32 #include "psock_lib.h"
34 static bool cfg_use_bind
;
35 static bool cfg_use_csum_off
;
36 static bool cfg_use_csum_off_bad
;
37 static bool cfg_use_dgram
;
38 static bool cfg_use_gso
;
39 static bool cfg_use_qdisc_bypass
;
40 static bool cfg_use_vlan
;
41 static bool cfg_use_vnet
;
43 static char *cfg_ifname
= "lo";
44 static int cfg_mtu
= 1500;
45 static int cfg_payload_len
= DATA_LEN
;
46 static int cfg_truncate_len
= INT_MAX
;
47 static uint16_t cfg_port
= 8000;
49 /* test sending up to max mtu + 1 */
50 #define TEST_SZ (sizeof(struct virtio_net_hdr) + ETH_HLEN + ETH_MAX_MTU + 1)
52 static char tbuf
[TEST_SZ
], rbuf
[TEST_SZ
];
54 static unsigned long add_csum_hword(const uint16_t *start
, int num_u16
)
56 unsigned long sum
= 0;
59 for (i
= 0; i
< num_u16
; i
++)
65 static uint16_t build_ip_csum(const uint16_t *start
, int num_u16
,
68 sum
+= add_csum_hword(start
, num_u16
);
71 sum
= (sum
& 0xffff) + (sum
>> 16);
76 static int build_vnet_header(void *header
)
78 struct virtio_net_hdr
*vh
= header
;
80 vh
->hdr_len
= ETH_HLEN
+ sizeof(struct iphdr
) + sizeof(struct udphdr
);
82 if (cfg_use_csum_off
) {
83 vh
->flags
|= VIRTIO_NET_HDR_F_NEEDS_CSUM
;
84 vh
->csum_start
= ETH_HLEN
+ sizeof(struct iphdr
);
85 vh
->csum_offset
= __builtin_offsetof(struct udphdr
, check
);
87 /* position check field exactly one byte beyond end of packet */
88 if (cfg_use_csum_off_bad
)
89 vh
->csum_start
+= sizeof(struct udphdr
) + cfg_payload_len
-
94 vh
->gso_type
= VIRTIO_NET_HDR_GSO_UDP
;
95 vh
->gso_size
= cfg_mtu
- sizeof(struct iphdr
);
101 static int build_eth_header(void *header
)
103 struct ethhdr
*eth
= header
;
106 uint16_t *tag
= header
+ ETH_HLEN
;
108 eth
->h_proto
= htons(ETH_P_8021Q
);
109 tag
[1] = htons(ETH_P_IP
);
113 eth
->h_proto
= htons(ETH_P_IP
);
117 static int build_ipv4_header(void *header
, int payload_len
)
119 struct iphdr
*iph
= header
;
124 iph
->tot_len
= htons(sizeof(*iph
) + sizeof(struct udphdr
) + payload_len
);
125 iph
->id
= htons(1337);
126 iph
->protocol
= IPPROTO_UDP
;
127 iph
->saddr
= htonl((172 << 24) | (17 << 16) | 2);
128 iph
->daddr
= htonl((172 << 24) | (17 << 16) | 1);
129 iph
->check
= build_ip_csum((void *) iph
, iph
->ihl
<< 1, 0);
131 return iph
->ihl
<< 2;
134 static int build_udp_header(void *header
, int payload_len
)
136 const int alen
= sizeof(uint32_t);
137 struct udphdr
*udph
= header
;
138 int len
= sizeof(*udph
) + payload_len
;
140 udph
->source
= htons(9);
141 udph
->dest
= htons(cfg_port
);
142 udph
->len
= htons(len
);
144 if (cfg_use_csum_off
)
145 udph
->check
= build_ip_csum(header
- (2 * alen
), alen
,
146 htons(IPPROTO_UDP
) + udph
->len
);
150 return sizeof(*udph
);
153 static int build_packet(int payload_len
)
157 off
+= build_vnet_header(tbuf
);
158 off
+= build_eth_header(tbuf
+ off
);
159 off
+= build_ipv4_header(tbuf
+ off
, payload_len
);
160 off
+= build_udp_header(tbuf
+ off
, payload_len
);
162 if (off
+ payload_len
> sizeof(tbuf
))
163 error(1, 0, "payload length exceeds max");
165 memset(tbuf
+ off
, DATA_CHAR
, payload_len
);
167 return off
+ payload_len
;
170 static void do_bind(int fd
)
172 struct sockaddr_ll laddr
= {0};
174 laddr
.sll_family
= AF_PACKET
;
175 laddr
.sll_protocol
= htons(ETH_P_IP
);
176 laddr
.sll_ifindex
= if_nametoindex(cfg_ifname
);
177 if (!laddr
.sll_ifindex
)
178 error(1, errno
, "if_nametoindex");
180 if (bind(fd
, (void *)&laddr
, sizeof(laddr
)))
181 error(1, errno
, "bind");
184 static void do_send(int fd
, char *buf
, int len
)
189 buf
+= sizeof(struct virtio_net_hdr
);
190 len
-= sizeof(struct virtio_net_hdr
);
198 ret
= write(fd
, buf
, len
);
200 struct sockaddr_ll laddr
= {0};
202 laddr
.sll_protocol
= htons(ETH_P_IP
);
203 laddr
.sll_ifindex
= if_nametoindex(cfg_ifname
);
204 if (!laddr
.sll_ifindex
)
205 error(1, errno
, "if_nametoindex");
207 ret
= sendto(fd
, buf
, len
, 0, (void *)&laddr
, sizeof(laddr
));
211 error(1, errno
, "write");
213 error(1, 0, "write: %u %u", ret
, len
);
215 fprintf(stderr
, "tx: %u\n", ret
);
218 static int do_tx(void)
223 fd
= socket(PF_PACKET
, cfg_use_dgram
? SOCK_DGRAM
: SOCK_RAW
, 0);
225 error(1, errno
, "socket t");
230 if (cfg_use_qdisc_bypass
&&
231 setsockopt(fd
, SOL_PACKET
, PACKET_QDISC_BYPASS
, &one
, sizeof(one
)))
232 error(1, errno
, "setsockopt qdisc bypass");
235 setsockopt(fd
, SOL_PACKET
, PACKET_VNET_HDR
, &one
, sizeof(one
)))
236 error(1, errno
, "setsockopt vnet");
238 len
= build_packet(cfg_payload_len
);
240 if (cfg_truncate_len
< len
)
241 len
= cfg_truncate_len
;
243 do_send(fd
, tbuf
, len
);
246 error(1, errno
, "close t");
251 static int setup_rx(void)
253 struct timeval tv
= { .tv_usec
= 100 * 1000 };
254 struct sockaddr_in raddr
= {0};
257 fd
= socket(PF_INET
, SOCK_DGRAM
, 0);
259 error(1, errno
, "socket r");
261 if (setsockopt(fd
, SOL_SOCKET
, SO_RCVTIMEO
, &tv
, sizeof(tv
)))
262 error(1, errno
, "setsockopt rcv timeout");
264 raddr
.sin_family
= AF_INET
;
265 raddr
.sin_port
= htons(cfg_port
);
266 raddr
.sin_addr
.s_addr
= htonl(INADDR_ANY
);
268 if (bind(fd
, (void *)&raddr
, sizeof(raddr
)))
269 error(1, errno
, "bind r");
274 static void do_rx(int fd
, int expected_len
, char *expected
)
278 ret
= recv(fd
, rbuf
, sizeof(rbuf
), 0);
280 error(1, errno
, "recv");
281 if (ret
!= expected_len
)
282 error(1, 0, "recv: %u != %u", ret
, expected_len
);
284 if (memcmp(rbuf
, expected
, ret
))
285 error(1, 0, "recv: data mismatch");
287 fprintf(stderr
, "rx: %u\n", ret
);
290 static int setup_sniffer(void)
292 struct timeval tv
= { .tv_usec
= 100 * 1000 };
295 fd
= socket(PF_PACKET
, SOCK_RAW
, 0);
297 error(1, errno
, "socket p");
299 if (setsockopt(fd
, SOL_SOCKET
, SO_RCVTIMEO
, &tv
, sizeof(tv
)))
300 error(1, errno
, "setsockopt rcv timeout");
302 pair_udp_setfilter(fd
);
308 static void parse_opts(int argc
, char **argv
)
312 while ((c
= getopt(argc
, argv
, "bcCdgl:qt:vV")) != -1) {
318 cfg_use_csum_off
= true;
321 cfg_use_csum_off_bad
= true;
324 cfg_use_dgram
= true;
330 cfg_payload_len
= strtoul(optarg
, NULL
, 0);
333 cfg_use_qdisc_bypass
= true;
336 cfg_truncate_len
= strtoul(optarg
, NULL
, 0);
345 error(1, 0, "%s: parse error", argv
[0]);
349 if (cfg_use_vlan
&& cfg_use_dgram
)
350 error(1, 0, "option vlan (-V) conflicts with dgram (-d)");
352 if (cfg_use_csum_off
&& !cfg_use_vnet
)
353 error(1, 0, "option csum offload (-c) requires vnet (-v)");
355 if (cfg_use_csum_off_bad
&& !cfg_use_csum_off
)
356 error(1, 0, "option csum bad (-C) requires csum offload (-c)");
358 if (cfg_use_gso
&& !cfg_use_csum_off
)
359 error(1, 0, "option gso (-g) requires csum offload (-c)");
362 static void run_test(void)
364 int fdr
, fds
, total_len
;
367 fds
= setup_sniffer();
371 /* BPF filter accepts only this length, vlan changes MAC */
372 if (cfg_payload_len
== DATA_LEN
&& !cfg_use_vlan
)
373 do_rx(fds
, total_len
- sizeof(struct virtio_net_hdr
),
374 tbuf
+ sizeof(struct virtio_net_hdr
));
376 do_rx(fdr
, cfg_payload_len
, tbuf
+ total_len
- cfg_payload_len
);
379 error(1, errno
, "close s");
381 error(1, errno
, "close r");
384 int main(int argc
, char **argv
)
386 parse_opts(argc
, argv
);
388 if (system("ip link set dev lo mtu 1500"))
389 error(1, errno
, "ip link set mtu");
390 if (system("ip addr add dev lo 172.17.0.1/24"))
391 error(1, errno
, "ip addr add");
392 if (system("sysctl -w net.ipv4.conf.lo.accept_local=1"))
393 error(1, errno
, "sysctl lo.accept_local");
397 fprintf(stderr
, "OK\n\n");