2 # SPDX-License-Identifier: GPL-2.0
15 def path_from_usb_dev(dev
):
16 """Takes a pyUSB device as argument and returns a string.
17 The string is a Path representation of the position of the USB device on the USB bus tree.
19 This path is used to find a USB device on the bus or all devices connected to a HUB.
20 The path is made up of the number of the USB controller followed be the ports of the HUB tree."""
22 dev_path
= ".".join(str(i
) for i
in dev
.port_numbers
)
23 return f
"{dev.bus}-{dev_path}"
27 HEXDUMP_FILTER
= "".join(chr(x
).isprintable() and chr(x
) or "." for x
in range(128)) + "." * 128
32 def _log_hexdump(data
):
33 if not logging
.root
.isEnabledFor(logging
.TRACE
):
36 for c
in range(0, len(data
), L
):
37 chars
= data
[c
: c
+ L
]
38 dump
= " ".join(f
"{x:02x}" for x
in chars
)
39 printable
= "".join(HEXDUMP_FILTER
[x
] for x
in chars
)
40 line
= f
"{c:08x} {dump:{L*3}s} |{printable:{L}s}|"
41 logging
.root
.log(logging
.TRACE
, "%s", line
)
43 def __init__(self
, server
, vid
, pid
, path
):
50 self
.stats_logged
= time
.monotonic()
53 dev_path
= path_from_usb_dev(dev
)
55 return dev_path
== path
58 dev
= usb
.core
.find(idVendor
=vid
, idProduct
=pid
, custom_match
=find_filter
)
60 raise ValueError("Device not found")
62 logging
.info(f
"found device: {dev.bus}/{dev.address} located at {path_from_usb_dev(dev)}")
64 # dev.set_configuration() is not necessary since g_multi has only one
66 # g_multi adds 9pfs as last interface
67 cfg
= dev
.get_active_configuration()
69 # we have to detach the usb-storage driver from multi gadget since
70 # stall option could be set, which will lead to spontaneous port
71 # resets and our transfers will run dead
72 if intf
.bInterfaceClass
== 0x08:
73 if dev
.is_kernel_driver_active(intf
.bInterfaceNumber
):
74 dev
.detach_kernel_driver(intf
.bInterfaceNumber
)
76 if intf
.bInterfaceClass
== 0xFF and intf
.bInterfaceSubClass
== 0xFF and intf
.bInterfaceProtocol
== 0x09:
79 raise ValueError("Interface not found")
81 logging
.info(f
"claiming interface:\n{usb9pfs}")
82 usb
.util
.claim_interface(dev
, usb9pfs
.bInterfaceNumber
)
83 ep_out
= usb
.util
.find_descriptor(
85 custom_match
=lambda e
: usb
.util
.endpoint_direction(e
.bEndpointAddress
) == usb
.util
.ENDPOINT_OUT
,
87 assert ep_out
is not None
88 ep_in
= usb
.util
.find_descriptor(
90 custom_match
=lambda e
: usb
.util
.endpoint_direction(e
.bEndpointAddress
) == usb
.util
.ENDPOINT_IN
,
92 assert ep_in
is not None
93 logging
.info("interface claimed")
99 # create and connect socket
100 self
.s
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
101 self
.s
.connect(server
)
103 logging
.info("connected to server")
106 """forward a request from the USB client to the TCP server"""
110 logging
.log(logging
.TRACE
, "c2s: reading")
111 data
= self
.ep_in
.read(self
.ep_in
.wMaxPacketSize
)
112 except usb
.core
.USBTimeoutError
:
113 logging
.log(logging
.TRACE
, "c2s: reading timed out")
115 except usb
.core
.USBError
as e
:
116 if e
.errno
== errno
.EIO
:
117 logging
.debug("c2s: reading failed with %s, retrying", repr(e
))
120 logging
.error("c2s: reading failed with %s, aborting", repr(e
))
122 size
= struct
.unpack("<I", data
[:4])[0]
123 while len(data
) < size
:
124 data
+= self
.ep_in
.read(size
- len(data
))
125 logging
.log(logging
.TRACE
, "c2s: writing")
126 self
._log
_hexdump
(data
)
128 logging
.debug("c2s: forwarded %i bytes", size
)
129 self
.stats
["c2s packets"] += 1
130 self
.stats
["c2s bytes"] += size
133 """forward a response from the TCP server to the USB client"""
134 logging
.log(logging
.TRACE
, "s2c: reading")
135 data
= self
.s
.recv(4)
136 size
= struct
.unpack("<I", data
[:4])[0]
137 while len(data
) < size
:
138 data
+= self
.s
.recv(size
- len(data
))
139 logging
.log(logging
.TRACE
, "s2c: writing")
140 self
._log
_hexdump
(data
)
142 written
= self
.ep_out
.write(data
)
144 data
= data
[written
:]
145 if size
% self
.ep_out
.wMaxPacketSize
== 0:
146 logging
.log(logging
.TRACE
, "sending zero length packet")
147 self
.ep_out
.write(b
"")
148 logging
.debug("s2c: forwarded %i bytes", size
)
149 self
.stats
["s2c packets"] += 1
150 self
.stats
["s2c bytes"] += size
153 logging
.info("statistics:")
154 for k
, v
in self
.stats
.items():
155 logging
.info(f
" {k+':':14s} {v}")
157 def log_stats_interval(self
, interval
=5):
158 if (time
.monotonic() - self
.stats_logged
) < interval
:
162 self
.stats_logged
= time
.monotonic()
165 def try_get_usb_str(dev
, name
):
167 with
open(f
"/sys/bus/usb/devices/{dev.bus}-{dev.address}/{name}") as f
:
168 return f
.read().strip()
169 except FileNotFoundError
:
174 vid
, pid
= [int(x
, 16) for x
in args
.id.split(":", 1)]
176 print("Bus | Addr | Manufacturer | Product | ID | Path")
177 print("--- | ---- | ---------------- | ---------------- | --------- | ----")
178 for dev
in usb
.core
.find(find_all
=True, idVendor
=vid
, idProduct
=pid
):
179 path
= path_from_usb_dev(dev
) or ""
180 manufacturer
= try_get_usb_str(dev
, "manufacturer") or "unknown"
181 product
= try_get_usb_str(dev
, "product") or "unknown"
183 f
"{dev.bus:3} | {dev.address:4} | {manufacturer:16} | {product:16} | {dev.idVendor:04x}:{dev.idProduct:04x} | {path:18}"
188 vid
, pid
= [int(x
, 16) for x
in args
.id.split(":", 1)]
190 f
= Forwarder(server
=(args
.server
, args
.port
), vid
=vid
, pid
=pid
, path
=args
.path
)
196 f
.log_stats_interval()
202 parser
= argparse
.ArgumentParser(
203 description
="Forward 9PFS requests from USB to TCP",
206 parser
.add_argument("--id", type=str, default
="1d6b:0109", help="vid:pid of target device")
207 parser
.add_argument("--path", type=str, required
=False, help="path of target device")
208 parser
.add_argument("-v", "--verbose", action
="count", default
=0)
210 subparsers
= parser
.add_subparsers()
211 subparsers
.required
= True
212 subparsers
.dest
= "command"
214 parser_list
= subparsers
.add_parser("list", help="List all connected 9p gadgets")
215 parser_list
.set_defaults(func
=list_usb
)
217 parser_connect
= subparsers
.add_parser(
218 "connect", help="Forward messages between the usb9pfs gadget and the 9p server"
220 parser_connect
.set_defaults(func
=connect
)
221 connect_group
= parser_connect
.add_argument_group()
222 connect_group
.required
= True
223 parser_connect
.add_argument("-s", "--server", type=str, default
="127.0.0.1", help="server hostname")
224 parser_connect
.add_argument("-p", "--port", type=int, default
=564, help="server port")
226 args
= parser
.parse_args()
228 logging
.TRACE
= logging
.DEBUG
- 5
229 logging
.addLevelName(logging
.TRACE
, "TRACE")
231 if args
.verbose
>= 2:
232 level
= logging
.TRACE
234 level
= logging
.DEBUG
237 logging
.basicConfig(level
=level
, format
="%(asctime)-15s %(levelname)-8s %(message)s")
242 if __name__
== "__main__":