Release 2024.03.10
[yt-dlp.git] / yt_dlp / networking / _websockets.py
blob159793204b126480271e040d12f501b4a3c67f94
1 from __future__ import annotations
3 import io
4 import logging
5 import ssl
6 import sys
8 from ._helper import (
9 create_connection,
10 create_socks_proxy_socket,
11 make_socks_proxy_opts,
12 select_proxy,
14 from .common import Features, Response, register_rh
15 from .exceptions import (
16 CertificateVerifyError,
17 HTTPError,
18 ProxyError,
19 RequestError,
20 SSLError,
21 TransportError,
23 from .websocket import WebSocketRequestHandler, WebSocketResponse
24 from ..compat import functools
25 from ..dependencies import websockets
26 from ..socks import ProxyError as SocksProxyError
27 from ..utils import int_or_none
29 if not websockets:
30 raise ImportError('websockets is not installed')
32 import websockets.version
34 websockets_version = tuple(map(int_or_none, websockets.version.version.split('.')))
35 if websockets_version < (12, 0):
36 raise ImportError('Only websockets>=12.0 is supported')
38 import websockets.sync.client
39 from websockets.uri import parse_uri
42 class WebsocketsResponseAdapter(WebSocketResponse):
44 def __init__(self, wsw: websockets.sync.client.ClientConnection, url):
45 super().__init__(
46 fp=io.BytesIO(wsw.response.body or b''),
47 url=url,
48 headers=wsw.response.headers,
49 status=wsw.response.status_code,
50 reason=wsw.response.reason_phrase,
52 self.wsw = wsw
54 def close(self):
55 self.wsw.close()
56 super().close()
58 def send(self, message):
59 # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.send
60 try:
61 return self.wsw.send(message)
62 except (websockets.exceptions.WebSocketException, RuntimeError, TimeoutError) as e:
63 raise TransportError(cause=e) from e
64 except SocksProxyError as e:
65 raise ProxyError(cause=e) from e
66 except TypeError as e:
67 raise RequestError(cause=e) from e
69 def recv(self):
70 # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.recv
71 try:
72 return self.wsw.recv()
73 except SocksProxyError as e:
74 raise ProxyError(cause=e) from e
75 except (websockets.exceptions.WebSocketException, RuntimeError, TimeoutError) as e:
76 raise TransportError(cause=e) from e
79 @register_rh
80 class WebsocketsRH(WebSocketRequestHandler):
81 """
82 Websockets request handler
83 https://websockets.readthedocs.io
84 https://github.com/python-websockets/websockets
85 """
86 _SUPPORTED_URL_SCHEMES = ('wss', 'ws')
87 _SUPPORTED_PROXY_SCHEMES = ('socks4', 'socks4a', 'socks5', 'socks5h')
88 _SUPPORTED_FEATURES = (Features.ALL_PROXY, Features.NO_PROXY)
89 RH_NAME = 'websockets'
91 def __init__(self, *args, **kwargs):
92 super().__init__(*args, **kwargs)
93 self.__logging_handlers = {}
94 for name in ('websockets.client', 'websockets.server'):
95 logger = logging.getLogger(name)
96 handler = logging.StreamHandler(stream=sys.stdout)
97 handler.setFormatter(logging.Formatter(f'{self.RH_NAME}: %(message)s'))
98 self.__logging_handlers[name] = handler
99 logger.addHandler(handler)
100 if self.verbose:
101 logger.setLevel(logging.DEBUG)
103 def _check_extensions(self, extensions):
104 super()._check_extensions(extensions)
105 extensions.pop('timeout', None)
106 extensions.pop('cookiejar', None)
108 def close(self):
109 # Remove the logging handler that contains a reference to our logger
110 # See: https://github.com/yt-dlp/yt-dlp/issues/8922
111 for name, handler in self.__logging_handlers.items():
112 logging.getLogger(name).removeHandler(handler)
114 def _send(self, request):
115 timeout = float(request.extensions.get('timeout') or self.timeout)
116 headers = self._merge_headers(request.headers)
117 if 'cookie' not in headers:
118 cookiejar = request.extensions.get('cookiejar') or self.cookiejar
119 cookie_header = cookiejar.get_cookie_header(request.url)
120 if cookie_header:
121 headers['cookie'] = cookie_header
123 wsuri = parse_uri(request.url)
124 create_conn_kwargs = {
125 'source_address': (self.source_address, 0) if self.source_address else None,
126 'timeout': timeout
128 proxy = select_proxy(request.url, request.proxies or self.proxies or {})
129 try:
130 if proxy:
131 socks_proxy_options = make_socks_proxy_opts(proxy)
132 sock = create_connection(
133 address=(socks_proxy_options['addr'], socks_proxy_options['port']),
134 _create_socket_func=functools.partial(
135 create_socks_proxy_socket, (wsuri.host, wsuri.port), socks_proxy_options),
136 **create_conn_kwargs
138 else:
139 sock = create_connection(
140 address=(wsuri.host, wsuri.port),
141 **create_conn_kwargs
143 conn = websockets.sync.client.connect(
144 sock=sock,
145 uri=request.url,
146 additional_headers=headers,
147 open_timeout=timeout,
148 user_agent_header=None,
149 ssl_context=self._make_sslcontext() if wsuri.secure else None,
150 close_timeout=0, # not ideal, but prevents yt-dlp hanging
152 return WebsocketsResponseAdapter(conn, url=request.url)
154 # Exceptions as per https://websockets.readthedocs.io/en/stable/reference/sync/client.html
155 except SocksProxyError as e:
156 raise ProxyError(cause=e) from e
157 except websockets.exceptions.InvalidURI as e:
158 raise RequestError(cause=e) from e
159 except ssl.SSLCertVerificationError as e:
160 raise CertificateVerifyError(cause=e) from e
161 except ssl.SSLError as e:
162 raise SSLError(cause=e) from e
163 except websockets.exceptions.InvalidStatus as e:
164 raise HTTPError(
165 Response(
166 fp=io.BytesIO(e.response.body),
167 url=request.url,
168 headers=e.response.headers,
169 status=e.response.status_code,
170 reason=e.response.reason_phrase),
171 ) from e
172 except (OSError, TimeoutError, websockets.exceptions.WebSocketException) as e:
173 raise TransportError(cause=e) from e