Release 2024.12.06
[yt-dlp.git] / test / test_websockets.py
blob06112cc0b8fbf034c820d39c8beef21d48c5b0fe
1 #!/usr/bin/env python3
3 # Allow direct execution
4 import os
5 import sys
6 import time
8 import pytest
10 from test.helper import verify_address_availability
11 from yt_dlp.networking.common import Features, DEFAULT_TIMEOUT
13 sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
15 import http.client
16 import http.cookiejar
17 import http.server
18 import json
19 import random
20 import ssl
21 import threading
23 from yt_dlp import socks, traverse_obj
24 from yt_dlp.cookies import YoutubeDLCookieJar
25 from yt_dlp.dependencies import websockets
26 from yt_dlp.networking import Request
27 from yt_dlp.networking.exceptions import (
28 CertificateVerifyError,
29 HTTPError,
30 ProxyError,
31 RequestError,
32 SSLError,
33 TransportError,
35 from yt_dlp.utils.networking import HTTPHeaderDict
37 TEST_DIR = os.path.dirname(os.path.abspath(__file__))
40 def websocket_handler(websocket):
41 for message in websocket:
42 if isinstance(message, bytes):
43 if message == b'bytes':
44 return websocket.send('2')
45 elif isinstance(message, str):
46 if message == 'headers':
47 return websocket.send(json.dumps(dict(websocket.request.headers)))
48 elif message == 'path':
49 return websocket.send(websocket.request.path)
50 elif message == 'source_address':
51 return websocket.send(websocket.remote_address[0])
52 elif message == 'str':
53 return websocket.send('1')
54 return websocket.send(message)
57 def process_request(self, request):
58 if request.path.startswith('/gen_'):
59 status = http.HTTPStatus(int(request.path[5:]))
60 if 300 <= status.value <= 300:
61 return websockets.http11.Response(
62 status.value, status.phrase, websockets.datastructures.Headers([('Location', '/')]), b'')
63 return self.protocol.reject(status.value, status.phrase)
64 elif request.path.startswith('/get_cookie'):
65 response = self.protocol.accept(request)
66 response.headers['Set-Cookie'] = 'test=ytdlp'
67 return response
68 return self.protocol.accept(request)
71 def create_websocket_server(**ws_kwargs):
72 import websockets.sync.server
73 wsd = websockets.sync.server.serve(
74 websocket_handler, '127.0.0.1', 0,
75 process_request=process_request, open_timeout=2, **ws_kwargs)
76 ws_port = wsd.socket.getsockname()[1]
77 ws_server_thread = threading.Thread(target=wsd.serve_forever)
78 ws_server_thread.daemon = True
79 ws_server_thread.start()
80 return ws_server_thread, ws_port
83 def create_ws_websocket_server():
84 return create_websocket_server()
87 def create_wss_websocket_server():
88 certfn = os.path.join(TEST_DIR, 'testcert.pem')
89 sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
90 sslctx.load_cert_chain(certfn, None)
91 return create_websocket_server(ssl=sslctx)
94 MTLS_CERT_DIR = os.path.join(TEST_DIR, 'testdata', 'certificate')
97 def create_mtls_wss_websocket_server():
98 certfn = os.path.join(TEST_DIR, 'testcert.pem')
99 cacertfn = os.path.join(MTLS_CERT_DIR, 'ca.crt')
101 sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
102 sslctx.verify_mode = ssl.CERT_REQUIRED
103 sslctx.load_verify_locations(cafile=cacertfn)
104 sslctx.load_cert_chain(certfn, None)
106 return create_websocket_server(ssl=sslctx)
109 def create_legacy_wss_websocket_server():
110 certfn = os.path.join(TEST_DIR, 'testcert.pem')
111 sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
112 sslctx.maximum_version = ssl.TLSVersion.TLSv1_2
113 sslctx.set_ciphers('SHA1:AESCCM:aDSS:eNULL:aNULL')
114 sslctx.load_cert_chain(certfn, None)
115 return create_websocket_server(ssl=sslctx)
118 def ws_validate_and_send(rh, req):
119 rh.validate(req)
120 max_tries = 3
121 for i in range(max_tries):
122 try:
123 return rh.send(req)
124 except TransportError as e:
125 if i < (max_tries - 1) and 'connection closed during handshake' in str(e):
126 # websockets server sometimes hangs on new connections
127 continue
128 raise
131 @pytest.mark.skipif(not websockets, reason='websockets must be installed to test websocket request handlers')
132 @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
133 class TestWebsSocketRequestHandlerConformance:
134 @classmethod
135 def setup_class(cls):
136 cls.ws_thread, cls.ws_port = create_ws_websocket_server()
137 cls.ws_base_url = f'ws://127.0.0.1:{cls.ws_port}'
139 cls.wss_thread, cls.wss_port = create_wss_websocket_server()
140 cls.wss_base_url = f'wss://127.0.0.1:{cls.wss_port}'
142 cls.bad_wss_thread, cls.bad_wss_port = create_websocket_server(ssl=ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER))
143 cls.bad_wss_host = f'wss://127.0.0.1:{cls.bad_wss_port}'
145 cls.mtls_wss_thread, cls.mtls_wss_port = create_mtls_wss_websocket_server()
146 cls.mtls_wss_base_url = f'wss://127.0.0.1:{cls.mtls_wss_port}'
148 cls.legacy_wss_thread, cls.legacy_wss_port = create_legacy_wss_websocket_server()
149 cls.legacy_wss_host = f'wss://127.0.0.1:{cls.legacy_wss_port}'
151 def test_basic_websockets(self, handler):
152 with handler() as rh:
153 ws = ws_validate_and_send(rh, Request(self.ws_base_url))
154 assert 'upgrade' in ws.headers
155 assert ws.status == 101
156 ws.send('foo')
157 assert ws.recv() == 'foo'
158 ws.close()
160 # https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6
161 @pytest.mark.parametrize('msg,opcode', [('str', 1), (b'bytes', 2)])
162 def test_send_types(self, handler, msg, opcode):
163 with handler() as rh:
164 ws = ws_validate_and_send(rh, Request(self.ws_base_url))
165 ws.send(msg)
166 assert int(ws.recv()) == opcode
167 ws.close()
169 def test_verify_cert(self, handler):
170 with handler() as rh:
171 with pytest.raises(CertificateVerifyError):
172 ws_validate_and_send(rh, Request(self.wss_base_url))
174 with handler(verify=False) as rh:
175 ws = ws_validate_and_send(rh, Request(self.wss_base_url))
176 assert ws.status == 101
177 ws.close()
179 def test_ssl_error(self, handler):
180 with handler(verify=False) as rh:
181 with pytest.raises(SSLError, match=r'ssl(?:v3|/tls) alert handshake failure') as exc_info:
182 ws_validate_and_send(rh, Request(self.bad_wss_host))
183 assert not issubclass(exc_info.type, CertificateVerifyError)
185 def test_legacy_ssl_extension(self, handler):
186 with handler(verify=False) as rh:
187 ws = ws_validate_and_send(rh, Request(self.legacy_wss_host, extensions={'legacy_ssl': True}))
188 assert ws.status == 101
189 ws.close()
191 # Ensure only applies to request extension
192 with pytest.raises(SSLError):
193 ws_validate_and_send(rh, Request(self.legacy_wss_host))
195 def test_legacy_ssl_support(self, handler):
196 with handler(verify=False, legacy_ssl_support=True) as rh:
197 ws = ws_validate_and_send(rh, Request(self.legacy_wss_host))
198 assert ws.status == 101
199 ws.close()
201 @pytest.mark.parametrize('path,expected', [
202 # Unicode characters should be encoded with uppercase percent-encoding
203 ('/中文', '/%E4%B8%AD%E6%96%87'),
204 # don't normalize existing percent encodings
205 ('/%c7%9f', '/%c7%9f'),
207 def test_percent_encode(self, handler, path, expected):
208 with handler() as rh:
209 ws = ws_validate_and_send(rh, Request(f'{self.ws_base_url}{path}'))
210 ws.send('path')
211 assert ws.recv() == expected
212 assert ws.status == 101
213 ws.close()
215 def test_remove_dot_segments(self, handler):
216 with handler() as rh:
217 # This isn't a comprehensive test,
218 # but it should be enough to check whether the handler is removing dot segments
219 ws = ws_validate_and_send(rh, Request(f'{self.ws_base_url}/a/b/./../../test'))
220 assert ws.status == 101
221 ws.send('path')
222 assert ws.recv() == '/test'
223 ws.close()
225 # We are restricted to known HTTP status codes in http.HTTPStatus
226 # Redirects are not supported for websockets
227 @pytest.mark.parametrize('status', (200, 204, 301, 302, 303, 400, 500, 511))
228 def test_raise_http_error(self, handler, status):
229 with handler() as rh:
230 with pytest.raises(HTTPError) as exc_info:
231 ws_validate_and_send(rh, Request(f'{self.ws_base_url}/gen_{status}'))
232 assert exc_info.value.status == status
234 @pytest.mark.parametrize('params,extensions', [
235 ({'timeout': sys.float_info.min}, {}),
236 ({}, {'timeout': sys.float_info.min}),
238 def test_read_timeout(self, handler, params, extensions):
239 with handler(**params) as rh:
240 with pytest.raises(TransportError):
241 ws_validate_and_send(rh, Request(self.ws_base_url, extensions=extensions))
243 def test_connect_timeout(self, handler):
244 # nothing should be listening on this port
245 connect_timeout_url = 'ws://10.255.255.255'
246 with handler(timeout=0.01) as rh, pytest.raises(TransportError):
247 now = time.time()
248 ws_validate_and_send(rh, Request(connect_timeout_url))
249 assert time.time() - now < DEFAULT_TIMEOUT
251 # Per request timeout, should override handler timeout
252 request = Request(connect_timeout_url, extensions={'timeout': 0.01})
253 with handler() as rh, pytest.raises(TransportError):
254 now = time.time()
255 ws_validate_and_send(rh, request)
256 assert time.time() - now < DEFAULT_TIMEOUT
258 def test_cookies(self, handler):
259 cookiejar = YoutubeDLCookieJar()
260 cookiejar.set_cookie(http.cookiejar.Cookie(
261 version=0, name='test', value='ytdlp', port=None, port_specified=False,
262 domain='127.0.0.1', domain_specified=True, domain_initial_dot=False, path='/',
263 path_specified=True, secure=False, expires=None, discard=False, comment=None,
264 comment_url=None, rest={}))
266 with handler(cookiejar=cookiejar) as rh:
267 ws = ws_validate_and_send(rh, Request(self.ws_base_url))
268 ws.send('headers')
269 assert json.loads(ws.recv())['cookie'] == 'test=ytdlp'
270 ws.close()
272 with handler() as rh:
273 ws = ws_validate_and_send(rh, Request(self.ws_base_url))
274 ws.send('headers')
275 assert 'cookie' not in json.loads(ws.recv())
276 ws.close()
278 ws = ws_validate_and_send(rh, Request(self.ws_base_url, extensions={'cookiejar': cookiejar}))
279 ws.send('headers')
280 assert json.loads(ws.recv())['cookie'] == 'test=ytdlp'
281 ws.close()
283 @pytest.mark.skip_handler('Websockets', 'Set-Cookie not supported by websockets')
284 def test_cookie_sync_only_cookiejar(self, handler):
285 # Ensure that cookies are ONLY being handled by the cookiejar
286 with handler() as rh:
287 ws_validate_and_send(rh, Request(f'{self.ws_base_url}/get_cookie', extensions={'cookiejar': YoutubeDLCookieJar()}))
288 ws = ws_validate_and_send(rh, Request(self.ws_base_url, extensions={'cookiejar': YoutubeDLCookieJar()}))
289 ws.send('headers')
290 assert 'cookie' not in json.loads(ws.recv())
291 ws.close()
293 @pytest.mark.skip_handler('Websockets', 'Set-Cookie not supported by websockets')
294 def test_cookie_sync_delete_cookie(self, handler):
295 # Ensure that cookies are ONLY being handled by the cookiejar
296 cookiejar = YoutubeDLCookieJar()
297 with handler(verbose=True, cookiejar=cookiejar) as rh:
298 ws_validate_and_send(rh, Request(f'{self.ws_base_url}/get_cookie'))
299 ws = ws_validate_and_send(rh, Request(self.ws_base_url))
300 ws.send('headers')
301 assert json.loads(ws.recv())['cookie'] == 'test=ytdlp'
302 ws.close()
303 cookiejar.clear_session_cookies()
304 ws = ws_validate_and_send(rh, Request(self.ws_base_url))
305 ws.send('headers')
306 assert 'cookie' not in json.loads(ws.recv())
307 ws.close()
309 def test_source_address(self, handler):
310 source_address = f'127.0.0.{random.randint(5, 255)}'
311 verify_address_availability(source_address)
312 with handler(source_address=source_address) as rh:
313 ws = ws_validate_and_send(rh, Request(self.ws_base_url))
314 ws.send('source_address')
315 assert source_address == ws.recv()
316 ws.close()
318 def test_response_url(self, handler):
319 with handler() as rh:
320 url = f'{self.ws_base_url}/something'
321 ws = ws_validate_and_send(rh, Request(url))
322 assert ws.url == url
323 ws.close()
325 def test_request_headers(self, handler):
326 with handler(headers=HTTPHeaderDict({'test1': 'test', 'test2': 'test2'})) as rh:
327 # Global Headers
328 ws = ws_validate_and_send(rh, Request(self.ws_base_url))
329 ws.send('headers')
330 headers = HTTPHeaderDict(json.loads(ws.recv()))
331 assert headers['test1'] == 'test'
332 ws.close()
334 # Per request headers, merged with global
335 ws = ws_validate_and_send(rh, Request(
336 self.ws_base_url, headers={'test2': 'changed', 'test3': 'test3'}))
337 ws.send('headers')
338 headers = HTTPHeaderDict(json.loads(ws.recv()))
339 assert headers['test1'] == 'test'
340 assert headers['test2'] == 'changed'
341 assert headers['test3'] == 'test3'
342 ws.close()
344 @pytest.mark.parametrize('client_cert', (
345 {'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithkey.crt')},
347 'client_certificate': os.path.join(MTLS_CERT_DIR, 'client.crt'),
348 'client_certificate_key': os.path.join(MTLS_CERT_DIR, 'client.key'),
351 'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithencryptedkey.crt'),
352 'client_certificate_password': 'foobar',
355 'client_certificate': os.path.join(MTLS_CERT_DIR, 'client.crt'),
356 'client_certificate_key': os.path.join(MTLS_CERT_DIR, 'clientencrypted.key'),
357 'client_certificate_password': 'foobar',
360 def test_mtls(self, handler, client_cert):
361 with handler(
362 # Disable client-side validation of unacceptable self-signed testcert.pem
363 # The test is of a check on the server side, so unaffected
364 verify=False,
365 client_cert=client_cert,
366 ) as rh:
367 ws_validate_and_send(rh, Request(self.mtls_wss_base_url)).close()
369 def test_request_disable_proxy(self, handler):
370 for proxy_proto in handler._SUPPORTED_PROXY_SCHEMES or ['ws']:
371 # Given handler is configured with a proxy
372 with handler(proxies={'ws': f'{proxy_proto}://10.255.255.255'}, timeout=5) as rh:
373 # When a proxy is explicitly set to None for the request
374 ws = ws_validate_and_send(rh, Request(self.ws_base_url, proxies={'http': None}))
375 # Then no proxy should be used
376 assert ws.status == 101
377 ws.close()
379 @pytest.mark.skip_handlers_if(
380 lambda _, handler: Features.NO_PROXY not in handler._SUPPORTED_FEATURES, 'handler does not support NO_PROXY')
381 def test_noproxy(self, handler):
382 for proxy_proto in handler._SUPPORTED_PROXY_SCHEMES or ['ws']:
383 # Given the handler is configured with a proxy
384 with handler(proxies={'ws': f'{proxy_proto}://10.255.255.255'}, timeout=5) as rh:
385 for no_proxy in (f'127.0.0.1:{self.ws_port}', '127.0.0.1', 'localhost'):
386 # When request no proxy includes the request url host
387 ws = ws_validate_and_send(rh, Request(self.ws_base_url, proxies={'no': no_proxy}))
388 # Then the proxy should not be used
389 assert ws.status == 101
390 ws.close()
392 @pytest.mark.skip_handlers_if(
393 lambda _, handler: Features.ALL_PROXY not in handler._SUPPORTED_FEATURES, 'handler does not support ALL_PROXY')
394 def test_allproxy(self, handler):
395 supported_proto = traverse_obj(handler._SUPPORTED_PROXY_SCHEMES, 0, default='ws')
396 # This is a bit of a hacky test, but it should be enough to check whether the handler is using the proxy.
397 # 0.1s might not be enough of a timeout if proxy is not used in all cases, but should still get failures.
398 with handler(proxies={'all': f'{supported_proto}://10.255.255.255'}, timeout=0.1) as rh:
399 with pytest.raises(TransportError):
400 ws_validate_and_send(rh, Request(self.ws_base_url)).close()
402 with handler(timeout=0.1) as rh:
403 with pytest.raises(TransportError):
404 ws_validate_and_send(
405 rh, Request(self.ws_base_url, proxies={'all': f'{supported_proto}://10.255.255.255'})).close()
408 def create_fake_ws_connection(raised):
409 import websockets.sync.client
411 class FakeWsConnection(websockets.sync.client.ClientConnection):
412 def __init__(self, *args, **kwargs):
413 class FakeResponse:
414 body = b''
415 headers = {}
416 status_code = 101
417 reason_phrase = 'test'
419 self.response = FakeResponse()
421 def send(self, *args, **kwargs):
422 raise raised()
424 def recv(self, *args, **kwargs):
425 raise raised()
427 def close(self, *args, **kwargs):
428 return
430 return FakeWsConnection()
433 @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
434 class TestWebsocketsRequestHandler:
435 @pytest.mark.parametrize('raised,expected', [
436 # https://websockets.readthedocs.io/en/stable/reference/exceptions.html
437 (lambda: websockets.exceptions.InvalidURI(msg='test', uri='test://'), RequestError),
438 # Requires a response object. Should be covered by HTTP error tests.
439 # (lambda: websockets.exceptions.InvalidStatus(), TransportError),
440 (lambda: websockets.exceptions.InvalidHandshake(), TransportError),
441 # These are subclasses of InvalidHandshake
442 (lambda: websockets.exceptions.InvalidHeader(name='test'), TransportError),
443 (lambda: websockets.exceptions.NegotiationError(), TransportError),
444 # Catch-all
445 (lambda: websockets.exceptions.WebSocketException(), TransportError),
446 (lambda: TimeoutError(), TransportError),
447 # These may be raised by our create_connection implementation, which should also be caught
448 (lambda: OSError(), TransportError),
449 (lambda: ssl.SSLError(), SSLError),
450 (lambda: ssl.SSLCertVerificationError(), CertificateVerifyError),
451 (lambda: socks.ProxyError(), ProxyError),
453 def test_request_error_mapping(self, handler, monkeypatch, raised, expected):
454 import websockets.sync.client
456 import yt_dlp.networking._websockets
457 with handler() as rh:
458 def fake_connect(*args, **kwargs):
459 raise raised()
460 monkeypatch.setattr(yt_dlp.networking._websockets, 'create_connection', lambda *args, **kwargs: None)
461 monkeypatch.setattr(websockets.sync.client, 'connect', fake_connect)
462 with pytest.raises(expected) as exc_info:
463 rh.send(Request('ws://fake-url'))
464 assert exc_info.type is expected
466 @pytest.mark.parametrize('raised,expected,match', [
467 # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.send
468 (lambda: websockets.exceptions.ConnectionClosed(None, None), TransportError, None),
469 (lambda: RuntimeError(), TransportError, None),
470 (lambda: TimeoutError(), TransportError, None),
471 (lambda: TypeError(), RequestError, None),
472 (lambda: socks.ProxyError(), ProxyError, None),
473 # Catch-all
474 (lambda: websockets.exceptions.WebSocketException(), TransportError, None),
476 def test_ws_send_error_mapping(self, handler, monkeypatch, raised, expected, match):
477 from yt_dlp.networking._websockets import WebsocketsResponseAdapter
478 ws = WebsocketsResponseAdapter(create_fake_ws_connection(raised), url='ws://fake-url')
479 with pytest.raises(expected, match=match) as exc_info:
480 ws.send('test')
481 assert exc_info.type is expected
483 @pytest.mark.parametrize('raised,expected,match', [
484 # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.recv
485 (lambda: websockets.exceptions.ConnectionClosed(None, None), TransportError, None),
486 (lambda: RuntimeError(), TransportError, None),
487 (lambda: TimeoutError(), TransportError, None),
488 (lambda: socks.ProxyError(), ProxyError, None),
489 # Catch-all
490 (lambda: websockets.exceptions.WebSocketException(), TransportError, None),
492 def test_ws_recv_error_mapping(self, handler, monkeypatch, raised, expected, match):
493 from yt_dlp.networking._websockets import WebsocketsResponseAdapter
494 ws = WebsocketsResponseAdapter(create_fake_ws_connection(raised), url='ws://fake-url')
495 with pytest.raises(expected, match=match) as exc_info:
496 ws.recv()
497 assert exc_info.type is expected