1 #! /usr/bin/env python3
2 # SPDX-License-Identifier: GPL-2.0
15 from pwd
import getpwuid
18 # Allow utils module to be imported from different directory
19 this_dir
= os
.path
.dirname(os
.path
.realpath(__file__
))
20 sys
.path
.append(os
.path
.join(this_dir
, "../"))
21 from lib
.py
.utils
import ip
23 libc
= ctypes
.cdll
.LoadLibrary('libc.so.6')
32 # Helper function for creating a socket inside a network namespace.
33 # We need this because otherwise RDS will detect that the two TCP
34 # sockets are on the same interface and use the loop transport instead
35 # of the TCP transport.
36 def netns_socket(netns
, *args
):
37 u0
, u1
= socket
.socketpair(socket
.AF_UNIX
, socket
.SOCK_SEQPACKET
)
41 # change network namespace
42 with
open(f
'/var/run/netns/{netns}') as f
:
44 ret
= setns(f
.fileno(), 0)
49 # create socket in target namespace
50 s
= socket
.socket(*args
)
52 # send resulting socket to parent
53 socket
.send_fds(u0
, [], [s
.fileno()])
57 # receive socket from child
58 _
, s
, _
, _
= socket
.recv_fds(u1
, 0, 1)
62 return socket
.fromfd(s
[0], *args
)
64 def signal_handler(sig
, frame
):
65 print('Test timed out')
68 #Parse out command line arguments. We take an optional
69 # timeout parameter and an optional log output folder
70 parser
= argparse
.ArgumentParser(description
="init script args",
71 formatter_class
=argparse
.ArgumentDefaultsHelpFormatter
)
72 parser
.add_argument("-d", "--logdir", action
="store",
73 help="directory to store logs", default
="/tmp")
74 parser
.add_argument('--timeout', help="timeout to terminate hung test",
76 parser
.add_argument('-l', '--loss', help="Simulate tcp packet loss",
78 parser
.add_argument('-c', '--corruption', help="Simulate tcp packet corruption",
80 parser
.add_argument('-u', '--duplicate', help="Simulate tcp packet duplication",
82 args
= parser
.parse_args()
84 packet_loss
=str(args
.loss
)+'%'
85 packet_corruption
=str(args
.corruption
)+'%'
86 packet_duplicate
=str(args
.duplicate
)+'%'
88 ip(f
"netns add {net0}")
89 ip(f
"netns add {net1}")
90 ip(f
"link add type veth")
93 # we technically don't need different port numbers, but this will
94 # help identify traffic in the network analyzer
99 # move interfaces to separate namespaces so they can no longer be
100 # bound directly; this prevents rds from switching over from the tcp
101 # transport to the loop transport.
102 ip(f
"link set {veth0} netns {net0} up")
103 ip(f
"link set {veth1} netns {net1} up")
108 ip(f
"-n {net0} addr add {addrs[0][0]}/32 dev {veth0}")
109 ip(f
"-n {net1} addr add {addrs[1][0]}/32 dev {veth1}")
112 ip(f
"-n {net0} route add {addrs[1][0]}/32 dev {veth0}")
113 ip(f
"-n {net1} route add {addrs[0][0]}/32 dev {veth1}")
115 # sanity check that our two interfaces/addresses are correctly set up
116 # and communicating by doing a single ping
117 ip(f
"netns exec {net0} ping -c 1 {addrs[1][0]}")
119 # Start a packet capture on each network
120 for net
in [net0
, net1
]:
121 tcpdump_pid
= os
.fork()
123 pcap
= logdir
+'/'+net
+'.pcap'
124 subprocess
.check_call(['touch', pcap
])
125 user
= getpwuid(stat(pcap
).st_uid
).pw_name
126 ip(f
"netns exec {net} /usr/sbin/tcpdump -Z {user} -i any -w {pcap}")
129 # simulate packet loss, duplication and corruption
130 for net
, iface
in [(net0
, veth0
), (net1
, veth1
)]:
131 ip(f
"netns exec {net} /usr/sbin/tc qdisc add dev {iface} root netem \
132 corrupt {packet_corruption} loss {packet_loss} duplicate \
137 signal
.alarm(args
.timeout
)
138 signal
.signal(signal
.SIGALRM
, signal_handler
)
141 netns_socket(net0
, socket
.AF_RDS
, socket
.SOCK_SEQPACKET
),
142 netns_socket(net1
, socket
.AF_RDS
, socket
.SOCK_SEQPACKET
),
145 for s
, addr
in zip(sockets
, addrs
):
150 s
.fileno(): s
for s
in sockets
154 addr
: s
for addr
, s
in zip(addrs
, sockets
)
158 s
: addr
for addr
, s
in zip(addrs
, sockets
)
167 ep
.register(s
, select
.EPOLLRDNORM
)
174 # Send as much as we can without blocking
175 print("sending...", nr_send
, nr_recv
)
177 send_data
= hashlib
.sha256(
178 f
'packet {nr_send}'.encode('utf-8')).hexdigest().encode('utf-8')
180 # pseudo-random send/receive pattern
181 sender
= sockets
[nr_send
% 2]
182 receiver
= sockets
[1 - (nr_send
% 3) % 2]
185 sender
.sendto(send_data
, socket_to_addr
[receiver
])
186 send_hashes
.setdefault((sender
.fileno(), receiver
.fileno()),
187 hashlib
.sha256()).update(f
'<{send_data}>'.encode('utf-8'))
188 nr_send
= nr_send
+ 1
189 except BlockingIOError
as e
:
192 if e
.errno
in [errno
.ENOBUFS
, errno
.ECONNRESET
, errno
.EPIPE
]:
196 # Receive as much as we can without blocking
197 print("receiving...", nr_send
, nr_recv
)
198 while nr_recv
< nr_send
:
199 for fileno
, eventmask
in ep
.poll():
200 receiver
= fileno_to_socket
[fileno
]
202 if eventmask
& select
.EPOLLRDNORM
:
205 recv_data
, address
= receiver
.recvfrom(1024)
206 sender
= addr_to_socket
[address
]
207 recv_hashes
.setdefault((sender
.fileno(),
208 receiver
.fileno()), hashlib
.sha256()).update(
209 f
'<{recv_data}>'.encode('utf-8'))
210 nr_recv
= nr_recv
+ 1
211 except BlockingIOError
as e
:
214 # exercise net/rds/tcp.c:rds_tcp_sysctl_reset()
215 for net
in [net0
, net1
]:
216 ip(f
"netns exec {net} /usr/sbin/sysctl net.rds.tcp.rds_tcp_rcvbuf=10000")
217 ip(f
"netns exec {net} /usr/sbin/sysctl net.rds.tcp.rds_tcp_sndbuf=10000")
219 print("done", nr_send
, nr_recv
)
221 # the Python socket module doesn't know these
222 RDS_INFO_FIRST
= 10000
223 RDS_INFO_LAST
= 10017
229 for optname
in range(RDS_INFO_FIRST
, RDS_INFO_LAST
+ 1):
230 # Sigh, the Python socket module doesn't allow us to pass
231 # buffer lengths greater than 1024 for some reason. RDS
232 # wants multiple pages.
234 s
.getsockopt(socket
.SOL_RDS
, optname
, 1024)
235 nr_success
= nr_success
+ 1
237 nr_error
= nr_error
+ 1
238 if e
.errno
== errno
.ENOSPC
:
242 print(f
"getsockopt(): {nr_success}/{nr_error}")
244 print("Stopping network packet captures")
245 subprocess
.check_call(['killall', '-q', 'tcpdump'])
247 # We're done sending and receiving stuff, now let's check if what
248 # we received is what we sent.
249 for (sender
, receiver
), send_hash
in send_hashes
.items():
250 recv_hash
= recv_hashes
.get((sender
, receiver
))
252 if recv_hash
is None:
253 print("FAIL: No data received")
256 if send_hash
.hexdigest() != recv_hash
.hexdigest():
257 print("FAIL: Send/recv mismatch")
258 print("hash expected:", send_hash
.hexdigest())
259 print("hash received:", recv_hash
.hexdigest())
262 print(f
"{sender}/{receiver}: ok")