Merge tag 'trace-printf-v6.13' of git://git.kernel.org/pub/scm/linux/kernel/git/trace...
[drm/drm-misc.git] / tools / usb / p9_fwd.py
blob12c76cbb046b7277bdd0ec39b663a041c9a1fc89
1 #!/usr/bin/env python3
2 # SPDX-License-Identifier: GPL-2.0
4 import argparse
5 import errno
6 import logging
7 import socket
8 import struct
9 import time
11 import usb.core
12 import usb.util
15 def path_from_usb_dev(dev):
16 """Takes a pyUSB device as argument and returns a string.
17 The string is a Path representation of the position of the USB device on the USB bus tree.
19 This path is used to find a USB device on the bus or all devices connected to a HUB.
20 The path is made up of the number of the USB controller followed be the ports of the HUB tree."""
21 if dev.port_numbers:
22 dev_path = ".".join(str(i) for i in dev.port_numbers)
23 return f"{dev.bus}-{dev_path}"
24 return ""
27 HEXDUMP_FILTER = "".join(chr(x).isprintable() and chr(x) or "." for x in range(128)) + "." * 128
30 class Forwarder:
31 @staticmethod
32 def _log_hexdump(data):
33 if not logging.root.isEnabledFor(logging.TRACE):
34 return
35 L = 16
36 for c in range(0, len(data), L):
37 chars = data[c : c + L]
38 dump = " ".join(f"{x:02x}" for x in chars)
39 printable = "".join(HEXDUMP_FILTER[x] for x in chars)
40 line = f"{c:08x} {dump:{L*3}s} |{printable:{L}s}|"
41 logging.root.log(logging.TRACE, "%s", line)
43 def __init__(self, server, vid, pid, path):
44 self.stats = {
45 "c2s packets": 0,
46 "c2s bytes": 0,
47 "s2c packets": 0,
48 "s2c bytes": 0,
50 self.stats_logged = time.monotonic()
52 def find_filter(dev):
53 dev_path = path_from_usb_dev(dev)
54 if path is not None:
55 return dev_path == path
56 return True
58 dev = usb.core.find(idVendor=vid, idProduct=pid, custom_match=find_filter)
59 if dev is None:
60 raise ValueError("Device not found")
62 logging.info(f"found device: {dev.bus}/{dev.address} located at {path_from_usb_dev(dev)}")
64 # dev.set_configuration() is not necessary since g_multi has only one
65 usb9pfs = None
66 # g_multi adds 9pfs as last interface
67 cfg = dev.get_active_configuration()
68 for intf in cfg:
69 # we have to detach the usb-storage driver from multi gadget since
70 # stall option could be set, which will lead to spontaneous port
71 # resets and our transfers will run dead
72 if intf.bInterfaceClass == 0x08:
73 if dev.is_kernel_driver_active(intf.bInterfaceNumber):
74 dev.detach_kernel_driver(intf.bInterfaceNumber)
76 if intf.bInterfaceClass == 0xFF and intf.bInterfaceSubClass == 0xFF and intf.bInterfaceProtocol == 0x09:
77 usb9pfs = intf
78 if usb9pfs is None:
79 raise ValueError("Interface not found")
81 logging.info(f"claiming interface:\n{usb9pfs}")
82 usb.util.claim_interface(dev, usb9pfs.bInterfaceNumber)
83 ep_out = usb.util.find_descriptor(
84 usb9pfs,
85 custom_match=lambda e: usb.util.endpoint_direction(e.bEndpointAddress) == usb.util.ENDPOINT_OUT,
87 assert ep_out is not None
88 ep_in = usb.util.find_descriptor(
89 usb9pfs,
90 custom_match=lambda e: usb.util.endpoint_direction(e.bEndpointAddress) == usb.util.ENDPOINT_IN,
92 assert ep_in is not None
93 logging.info("interface claimed")
95 self.ep_out = ep_out
96 self.ep_in = ep_in
97 self.dev = dev
99 # create and connect socket
100 self.s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
101 self.s.connect(server)
103 logging.info("connected to server")
105 def c2s(self):
106 """forward a request from the USB client to the TCP server"""
107 data = None
108 while data is None:
109 try:
110 logging.log(logging.TRACE, "c2s: reading")
111 data = self.ep_in.read(self.ep_in.wMaxPacketSize)
112 except usb.core.USBTimeoutError:
113 logging.log(logging.TRACE, "c2s: reading timed out")
114 continue
115 except usb.core.USBError as e:
116 if e.errno == errno.EIO:
117 logging.debug("c2s: reading failed with %s, retrying", repr(e))
118 time.sleep(0.5)
119 continue
120 logging.error("c2s: reading failed with %s, aborting", repr(e))
121 raise
122 size = struct.unpack("<I", data[:4])[0]
123 while len(data) < size:
124 data += self.ep_in.read(size - len(data))
125 logging.log(logging.TRACE, "c2s: writing")
126 self._log_hexdump(data)
127 self.s.send(data)
128 logging.debug("c2s: forwarded %i bytes", size)
129 self.stats["c2s packets"] += 1
130 self.stats["c2s bytes"] += size
132 def s2c(self):
133 """forward a response from the TCP server to the USB client"""
134 logging.log(logging.TRACE, "s2c: reading")
135 data = self.s.recv(4)
136 size = struct.unpack("<I", data[:4])[0]
137 while len(data) < size:
138 data += self.s.recv(size - len(data))
139 logging.log(logging.TRACE, "s2c: writing")
140 self._log_hexdump(data)
141 while data:
142 written = self.ep_out.write(data)
143 assert written > 0
144 data = data[written:]
145 if size % self.ep_out.wMaxPacketSize == 0:
146 logging.log(logging.TRACE, "sending zero length packet")
147 self.ep_out.write(b"")
148 logging.debug("s2c: forwarded %i bytes", size)
149 self.stats["s2c packets"] += 1
150 self.stats["s2c bytes"] += size
152 def log_stats(self):
153 logging.info("statistics:")
154 for k, v in self.stats.items():
155 logging.info(f" {k+':':14s} {v}")
157 def log_stats_interval(self, interval=5):
158 if (time.monotonic() - self.stats_logged) < interval:
159 return
161 self.log_stats()
162 self.stats_logged = time.monotonic()
165 def try_get_usb_str(dev, name):
166 try:
167 with open(f"/sys/bus/usb/devices/{dev.bus}-{dev.address}/{name}") as f:
168 return f.read().strip()
169 except FileNotFoundError:
170 return None
173 def list_usb(args):
174 vid, pid = [int(x, 16) for x in args.id.split(":", 1)]
176 print("Bus | Addr | Manufacturer | Product | ID | Path")
177 print("--- | ---- | ---------------- | ---------------- | --------- | ----")
178 for dev in usb.core.find(find_all=True, idVendor=vid, idProduct=pid):
179 path = path_from_usb_dev(dev) or ""
180 manufacturer = try_get_usb_str(dev, "manufacturer") or "unknown"
181 product = try_get_usb_str(dev, "product") or "unknown"
182 print(
183 f"{dev.bus:3} | {dev.address:4} | {manufacturer:16} | {product:16} | {dev.idVendor:04x}:{dev.idProduct:04x} | {path:18}"
187 def connect(args):
188 vid, pid = [int(x, 16) for x in args.id.split(":", 1)]
190 f = Forwarder(server=(args.server, args.port), vid=vid, pid=pid, path=args.path)
192 try:
193 while True:
194 f.c2s()
195 f.s2c()
196 f.log_stats_interval()
197 finally:
198 f.log_stats()
201 def main():
202 parser = argparse.ArgumentParser(
203 description="Forward 9PFS requests from USB to TCP",
206 parser.add_argument("--id", type=str, default="1d6b:0109", help="vid:pid of target device")
207 parser.add_argument("--path", type=str, required=False, help="path of target device")
208 parser.add_argument("-v", "--verbose", action="count", default=0)
210 subparsers = parser.add_subparsers()
211 subparsers.required = True
212 subparsers.dest = "command"
214 parser_list = subparsers.add_parser("list", help="List all connected 9p gadgets")
215 parser_list.set_defaults(func=list_usb)
217 parser_connect = subparsers.add_parser(
218 "connect", help="Forward messages between the usb9pfs gadget and the 9p server"
220 parser_connect.set_defaults(func=connect)
221 connect_group = parser_connect.add_argument_group()
222 connect_group.required = True
223 parser_connect.add_argument("-s", "--server", type=str, default="127.0.0.1", help="server hostname")
224 parser_connect.add_argument("-p", "--port", type=int, default=564, help="server port")
226 args = parser.parse_args()
228 logging.TRACE = logging.DEBUG - 5
229 logging.addLevelName(logging.TRACE, "TRACE")
231 if args.verbose >= 2:
232 level = logging.TRACE
233 elif args.verbose:
234 level = logging.DEBUG
235 else:
236 level = logging.INFO
237 logging.basicConfig(level=level, format="%(asctime)-15s %(levelname)-8s %(message)s")
239 args.func(args)
242 if __name__ == "__main__":
243 main()