3 # Allow direct execution
10 from test
.helper
import verify_address_availability
11 from yt_dlp
.networking
.common
import Features
, DEFAULT_TIMEOUT
13 sys
.path
.insert(0, os
.path
.dirname(os
.path
.dirname(os
.path
.abspath(__file__
))))
23 from yt_dlp
import socks
, traverse_obj
24 from yt_dlp
.cookies
import YoutubeDLCookieJar
25 from yt_dlp
.dependencies
import websockets
26 from yt_dlp
.networking
import Request
27 from yt_dlp
.networking
.exceptions
import (
28 CertificateVerifyError
,
35 from yt_dlp
.utils
.networking
import HTTPHeaderDict
37 TEST_DIR
= os
.path
.dirname(os
.path
.abspath(__file__
))
40 def websocket_handler(websocket
):
41 for message
in websocket
:
42 if isinstance(message
, bytes
):
43 if message
== b
'bytes':
44 return websocket
.send('2')
45 elif isinstance(message
, str):
46 if message
== 'headers':
47 return websocket
.send(json
.dumps(dict(websocket
.request
.headers
)))
48 elif message
== 'path':
49 return websocket
.send(websocket
.request
.path
)
50 elif message
== 'source_address':
51 return websocket
.send(websocket
.remote_address
[0])
52 elif message
== 'str':
53 return websocket
.send('1')
54 return websocket
.send(message
)
57 def process_request(self
, request
):
58 if request
.path
.startswith('/gen_'):
59 status
= http
.HTTPStatus(int(request
.path
[5:]))
60 if 300 <= status
.value
<= 300:
61 return websockets
.http11
.Response(
62 status
.value
, status
.phrase
, websockets
.datastructures
.Headers([('Location', '/')]), b
'')
63 return self
.protocol
.reject(status
.value
, status
.phrase
)
64 elif request
.path
.startswith('/get_cookie'):
65 response
= self
.protocol
.accept(request
)
66 response
.headers
['Set-Cookie'] = 'test=ytdlp'
68 return self
.protocol
.accept(request
)
71 def create_websocket_server(**ws_kwargs
):
72 import websockets
.sync
.server
73 wsd
= websockets
.sync
.server
.serve(
74 websocket_handler
, '127.0.0.1', 0,
75 process_request
=process_request
, open_timeout
=2, **ws_kwargs
)
76 ws_port
= wsd
.socket
.getsockname()[1]
77 ws_server_thread
= threading
.Thread(target
=wsd
.serve_forever
)
78 ws_server_thread
.daemon
= True
79 ws_server_thread
.start()
80 return ws_server_thread
, ws_port
83 def create_ws_websocket_server():
84 return create_websocket_server()
87 def create_wss_websocket_server():
88 certfn
= os
.path
.join(TEST_DIR
, 'testcert.pem')
89 sslctx
= ssl
.SSLContext(ssl
.PROTOCOL_TLS_SERVER
)
90 sslctx
.load_cert_chain(certfn
, None)
91 return create_websocket_server(ssl
=sslctx
)
94 MTLS_CERT_DIR
= os
.path
.join(TEST_DIR
, 'testdata', 'certificate')
97 def create_mtls_wss_websocket_server():
98 certfn
= os
.path
.join(TEST_DIR
, 'testcert.pem')
99 cacertfn
= os
.path
.join(MTLS_CERT_DIR
, 'ca.crt')
101 sslctx
= ssl
.SSLContext(ssl
.PROTOCOL_TLS_SERVER
)
102 sslctx
.verify_mode
= ssl
.CERT_REQUIRED
103 sslctx
.load_verify_locations(cafile
=cacertfn
)
104 sslctx
.load_cert_chain(certfn
, None)
106 return create_websocket_server(ssl
=sslctx
)
109 def create_legacy_wss_websocket_server():
110 certfn
= os
.path
.join(TEST_DIR
, 'testcert.pem')
111 sslctx
= ssl
.SSLContext(ssl
.PROTOCOL_TLS_SERVER
)
112 sslctx
.maximum_version
= ssl
.TLSVersion
.TLSv1_2
113 sslctx
.set_ciphers('SHA1:AESCCM:aDSS:eNULL:aNULL')
114 sslctx
.load_cert_chain(certfn
, None)
115 return create_websocket_server(ssl
=sslctx
)
118 def ws_validate_and_send(rh
, req
):
121 for i
in range(max_tries
):
124 except TransportError
as e
:
125 if i
< (max_tries
- 1) and 'connection closed during handshake' in str(e
):
126 # websockets server sometimes hangs on new connections
131 @pytest.mark
.skipif(not websockets
, reason
='websockets must be installed to test websocket request handlers')
132 @pytest.mark
.parametrize('handler', ['Websockets'], indirect
=True)
133 class TestWebsSocketRequestHandlerConformance
:
135 def setup_class(cls
):
136 cls
.ws_thread
, cls
.ws_port
= create_ws_websocket_server()
137 cls
.ws_base_url
= f
'ws://127.0.0.1:{cls.ws_port}'
139 cls
.wss_thread
, cls
.wss_port
= create_wss_websocket_server()
140 cls
.wss_base_url
= f
'wss://127.0.0.1:{cls.wss_port}'
142 cls
.bad_wss_thread
, cls
.bad_wss_port
= create_websocket_server(ssl
=ssl
.SSLContext(ssl
.PROTOCOL_TLS_SERVER
))
143 cls
.bad_wss_host
= f
'wss://127.0.0.1:{cls.bad_wss_port}'
145 cls
.mtls_wss_thread
, cls
.mtls_wss_port
= create_mtls_wss_websocket_server()
146 cls
.mtls_wss_base_url
= f
'wss://127.0.0.1:{cls.mtls_wss_port}'
148 cls
.legacy_wss_thread
, cls
.legacy_wss_port
= create_legacy_wss_websocket_server()
149 cls
.legacy_wss_host
= f
'wss://127.0.0.1:{cls.legacy_wss_port}'
151 def test_basic_websockets(self
, handler
):
152 with
handler() as rh
:
153 ws
= ws_validate_and_send(rh
, Request(self
.ws_base_url
))
154 assert 'upgrade' in ws
.headers
155 assert ws
.status
== 101
157 assert ws
.recv() == 'foo'
160 # https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6
161 @pytest.mark
.parametrize('msg,opcode', [('str', 1), (b
'bytes', 2)])
162 def test_send_types(self
, handler
, msg
, opcode
):
163 with
handler() as rh
:
164 ws
= ws_validate_and_send(rh
, Request(self
.ws_base_url
))
166 assert int(ws
.recv()) == opcode
169 def test_verify_cert(self
, handler
):
170 with
handler() as rh
:
171 with pytest
.raises(CertificateVerifyError
):
172 ws_validate_and_send(rh
, Request(self
.wss_base_url
))
174 with
handler(verify
=False) as rh
:
175 ws
= ws_validate_and_send(rh
, Request(self
.wss_base_url
))
176 assert ws
.status
== 101
179 def test_ssl_error(self
, handler
):
180 with
handler(verify
=False) as rh
:
181 with pytest
.raises(SSLError
, match
=r
'ssl(?:v3|/tls) alert handshake failure') as exc_info
:
182 ws_validate_and_send(rh
, Request(self
.bad_wss_host
))
183 assert not issubclass(exc_info
.type, CertificateVerifyError
)
185 def test_legacy_ssl_extension(self
, handler
):
186 with
handler(verify
=False) as rh
:
187 ws
= ws_validate_and_send(rh
, Request(self
.legacy_wss_host
, extensions
={'legacy_ssl': True}))
188 assert ws
.status
== 101
191 # Ensure only applies to request extension
192 with pytest
.raises(SSLError
):
193 ws_validate_and_send(rh
, Request(self
.legacy_wss_host
))
195 def test_legacy_ssl_support(self
, handler
):
196 with
handler(verify
=False, legacy_ssl_support
=True) as rh
:
197 ws
= ws_validate_and_send(rh
, Request(self
.legacy_wss_host
))
198 assert ws
.status
== 101
201 @pytest.mark
.parametrize('path,expected', [
202 # Unicode characters should be encoded with uppercase percent-encoding
203 ('/中文', '/%E4%B8%AD%E6%96%87'),
204 # don't normalize existing percent encodings
205 ('/%c7%9f', '/%c7%9f'),
207 def test_percent_encode(self
, handler
, path
, expected
):
208 with
handler() as rh
:
209 ws
= ws_validate_and_send(rh
, Request(f
'{self.ws_base_url}{path}'))
211 assert ws
.recv() == expected
212 assert ws
.status
== 101
215 def test_remove_dot_segments(self
, handler
):
216 with
handler() as rh
:
217 # This isn't a comprehensive test,
218 # but it should be enough to check whether the handler is removing dot segments
219 ws
= ws_validate_and_send(rh
, Request(f
'{self.ws_base_url}/a/b/./../../test'))
220 assert ws
.status
== 101
222 assert ws
.recv() == '/test'
225 # We are restricted to known HTTP status codes in http.HTTPStatus
226 # Redirects are not supported for websockets
227 @pytest.mark
.parametrize('status', (200, 204, 301, 302, 303, 400, 500, 511))
228 def test_raise_http_error(self
, handler
, status
):
229 with
handler() as rh
:
230 with pytest
.raises(HTTPError
) as exc_info
:
231 ws_validate_and_send(rh
, Request(f
'{self.ws_base_url}/gen_{status}'))
232 assert exc_info
.value
.status
== status
234 @pytest.mark
.parametrize('params,extensions', [
235 ({'timeout': sys
.float_info
.min}, {}),
236 ({}, {'timeout': sys
.float_info
.min}),
238 def test_read_timeout(self
, handler
, params
, extensions
):
239 with
handler(**params
) as rh
:
240 with pytest
.raises(TransportError
):
241 ws_validate_and_send(rh
, Request(self
.ws_base_url
, extensions
=extensions
))
243 def test_connect_timeout(self
, handler
):
244 # nothing should be listening on this port
245 connect_timeout_url
= 'ws://10.255.255.255'
246 with
handler(timeout
=0.01) as rh
, pytest
.raises(TransportError
):
248 ws_validate_and_send(rh
, Request(connect_timeout_url
))
249 assert time
.time() - now
< DEFAULT_TIMEOUT
251 # Per request timeout, should override handler timeout
252 request
= Request(connect_timeout_url
, extensions
={'timeout': 0.01})
253 with
handler() as rh
, pytest
.raises(TransportError
):
255 ws_validate_and_send(rh
, request
)
256 assert time
.time() - now
< DEFAULT_TIMEOUT
258 def test_cookies(self
, handler
):
259 cookiejar
= YoutubeDLCookieJar()
260 cookiejar
.set_cookie(http
.cookiejar
.Cookie(
261 version
=0, name
='test', value
='ytdlp', port
=None, port_specified
=False,
262 domain
='127.0.0.1', domain_specified
=True, domain_initial_dot
=False, path
='/',
263 path_specified
=True, secure
=False, expires
=None, discard
=False, comment
=None,
264 comment_url
=None, rest
={}))
266 with
handler(cookiejar
=cookiejar
) as rh
:
267 ws
= ws_validate_and_send(rh
, Request(self
.ws_base_url
))
269 assert json
.loads(ws
.recv())['cookie'] == 'test=ytdlp'
272 with
handler() as rh
:
273 ws
= ws_validate_and_send(rh
, Request(self
.ws_base_url
))
275 assert 'cookie' not in json
.loads(ws
.recv())
278 ws
= ws_validate_and_send(rh
, Request(self
.ws_base_url
, extensions
={'cookiejar': cookiejar
}))
280 assert json
.loads(ws
.recv())['cookie'] == 'test=ytdlp'
283 @pytest.mark
.skip_handler('Websockets', 'Set-Cookie not supported by websockets')
284 def test_cookie_sync_only_cookiejar(self
, handler
):
285 # Ensure that cookies are ONLY being handled by the cookiejar
286 with
handler() as rh
:
287 ws_validate_and_send(rh
, Request(f
'{self.ws_base_url}/get_cookie', extensions
={'cookiejar': YoutubeDLCookieJar()}))
288 ws
= ws_validate_and_send(rh
, Request(self
.ws_base_url
, extensions
={'cookiejar': YoutubeDLCookieJar()}))
290 assert 'cookie' not in json
.loads(ws
.recv())
293 @pytest.mark
.skip_handler('Websockets', 'Set-Cookie not supported by websockets')
294 def test_cookie_sync_delete_cookie(self
, handler
):
295 # Ensure that cookies are ONLY being handled by the cookiejar
296 cookiejar
= YoutubeDLCookieJar()
297 with
handler(verbose
=True, cookiejar
=cookiejar
) as rh
:
298 ws_validate_and_send(rh
, Request(f
'{self.ws_base_url}/get_cookie'))
299 ws
= ws_validate_and_send(rh
, Request(self
.ws_base_url
))
301 assert json
.loads(ws
.recv())['cookie'] == 'test=ytdlp'
303 cookiejar
.clear_session_cookies()
304 ws
= ws_validate_and_send(rh
, Request(self
.ws_base_url
))
306 assert 'cookie' not in json
.loads(ws
.recv())
309 def test_source_address(self
, handler
):
310 source_address
= f
'127.0.0.{random.randint(5, 255)}'
311 verify_address_availability(source_address
)
312 with
handler(source_address
=source_address
) as rh
:
313 ws
= ws_validate_and_send(rh
, Request(self
.ws_base_url
))
314 ws
.send('source_address')
315 assert source_address
== ws
.recv()
318 def test_response_url(self
, handler
):
319 with
handler() as rh
:
320 url
= f
'{self.ws_base_url}/something'
321 ws
= ws_validate_and_send(rh
, Request(url
))
325 def test_request_headers(self
, handler
):
326 with
handler(headers
=HTTPHeaderDict({'test1': 'test', 'test2': 'test2'})) as rh
:
328 ws
= ws_validate_and_send(rh
, Request(self
.ws_base_url
))
330 headers
= HTTPHeaderDict(json
.loads(ws
.recv()))
331 assert headers
['test1'] == 'test'
334 # Per request headers, merged with global
335 ws
= ws_validate_and_send(rh
, Request(
336 self
.ws_base_url
, headers
={'test2': 'changed', 'test3': 'test3'}))
338 headers
= HTTPHeaderDict(json
.loads(ws
.recv()))
339 assert headers
['test1'] == 'test'
340 assert headers
['test2'] == 'changed'
341 assert headers
['test3'] == 'test3'
344 @pytest.mark
.parametrize('client_cert', (
345 {'client_certificate': os
.path
.join(MTLS_CERT_DIR
, 'clientwithkey.crt')},
347 'client_certificate': os
.path
.join(MTLS_CERT_DIR
, 'client.crt'),
348 'client_certificate_key': os
.path
.join(MTLS_CERT_DIR
, 'client.key'),
351 'client_certificate': os
.path
.join(MTLS_CERT_DIR
, 'clientwithencryptedkey.crt'),
352 'client_certificate_password': 'foobar',
355 'client_certificate': os
.path
.join(MTLS_CERT_DIR
, 'client.crt'),
356 'client_certificate_key': os
.path
.join(MTLS_CERT_DIR
, 'clientencrypted.key'),
357 'client_certificate_password': 'foobar',
360 def test_mtls(self
, handler
, client_cert
):
362 # Disable client-side validation of unacceptable self-signed testcert.pem
363 # The test is of a check on the server side, so unaffected
365 client_cert
=client_cert
,
367 ws_validate_and_send(rh
, Request(self
.mtls_wss_base_url
)).close()
369 def test_request_disable_proxy(self
, handler
):
370 for proxy_proto
in handler
._SUPPORTED
_PROXY
_SCHEMES
or ['ws']:
371 # Given handler is configured with a proxy
372 with
handler(proxies
={'ws': f
'{proxy_proto}://10.255.255.255'}, timeout
=5) as rh
:
373 # When a proxy is explicitly set to None for the request
374 ws
= ws_validate_and_send(rh
, Request(self
.ws_base_url
, proxies
={'http': None}))
375 # Then no proxy should be used
376 assert ws
.status
== 101
379 @pytest.mark
.skip_handlers_if(
380 lambda _
, handler
: Features
.NO_PROXY
not in handler
._SUPPORTED
_FEATURES
, 'handler does not support NO_PROXY')
381 def test_noproxy(self
, handler
):
382 for proxy_proto
in handler
._SUPPORTED
_PROXY
_SCHEMES
or ['ws']:
383 # Given the handler is configured with a proxy
384 with
handler(proxies
={'ws': f
'{proxy_proto}://10.255.255.255'}, timeout
=5) as rh
:
385 for no_proxy
in (f
'127.0.0.1:{self.ws_port}', '127.0.0.1', 'localhost'):
386 # When request no proxy includes the request url host
387 ws
= ws_validate_and_send(rh
, Request(self
.ws_base_url
, proxies
={'no': no_proxy
}))
388 # Then the proxy should not be used
389 assert ws
.status
== 101
392 @pytest.mark
.skip_handlers_if(
393 lambda _
, handler
: Features
.ALL_PROXY
not in handler
._SUPPORTED
_FEATURES
, 'handler does not support ALL_PROXY')
394 def test_allproxy(self
, handler
):
395 supported_proto
= traverse_obj(handler
._SUPPORTED
_PROXY
_SCHEMES
, 0, default
='ws')
396 # This is a bit of a hacky test, but it should be enough to check whether the handler is using the proxy.
397 # 0.1s might not be enough of a timeout if proxy is not used in all cases, but should still get failures.
398 with
handler(proxies
={'all': f
'{supported_proto}://10.255.255.255'}, timeout
=0.1) as rh
:
399 with pytest
.raises(TransportError
):
400 ws_validate_and_send(rh
, Request(self
.ws_base_url
)).close()
402 with
handler(timeout
=0.1) as rh
:
403 with pytest
.raises(TransportError
):
404 ws_validate_and_send(
405 rh
, Request(self
.ws_base_url
, proxies
={'all': f
'{supported_proto}://10.255.255.255'})).close()
408 def create_fake_ws_connection(raised
):
409 import websockets
.sync
.client
411 class FakeWsConnection(websockets
.sync
.client
.ClientConnection
):
412 def __init__(self
, *args
, **kwargs
):
417 reason_phrase
= 'test'
419 self
.response
= FakeResponse()
421 def send(self
, *args
, **kwargs
):
424 def recv(self
, *args
, **kwargs
):
427 def close(self
, *args
, **kwargs
):
430 return FakeWsConnection()
433 @pytest.mark
.parametrize('handler', ['Websockets'], indirect
=True)
434 class TestWebsocketsRequestHandler
:
435 @pytest.mark
.parametrize('raised,expected', [
436 # https://websockets.readthedocs.io/en/stable/reference/exceptions.html
437 (lambda: websockets
.exceptions
.InvalidURI(msg
='test', uri
='test://'), RequestError
),
438 # Requires a response object. Should be covered by HTTP error tests.
439 # (lambda: websockets.exceptions.InvalidStatus(), TransportError),
440 (lambda: websockets
.exceptions
.InvalidHandshake(), TransportError
),
441 # These are subclasses of InvalidHandshake
442 (lambda: websockets
.exceptions
.InvalidHeader(name
='test'), TransportError
),
443 (lambda: websockets
.exceptions
.NegotiationError(), TransportError
),
445 (lambda: websockets
.exceptions
.WebSocketException(), TransportError
),
446 (lambda: TimeoutError(), TransportError
),
447 # These may be raised by our create_connection implementation, which should also be caught
448 (lambda: OSError(), TransportError
),
449 (lambda: ssl
.SSLError(), SSLError
),
450 (lambda: ssl
.SSLCertVerificationError(), CertificateVerifyError
),
451 (lambda: socks
.ProxyError(), ProxyError
),
453 def test_request_error_mapping(self
, handler
, monkeypatch
, raised
, expected
):
454 import websockets
.sync
.client
456 import yt_dlp
.networking
._websockets
457 with
handler() as rh
:
458 def fake_connect(*args
, **kwargs
):
460 monkeypatch
.setattr(yt_dlp
.networking
._websockets
, 'create_connection', lambda *args
, **kwargs
: None)
461 monkeypatch
.setattr(websockets
.sync
.client
, 'connect', fake_connect
)
462 with pytest
.raises(expected
) as exc_info
:
463 rh
.send(Request('ws://fake-url'))
464 assert exc_info
.type is expected
466 @pytest.mark
.parametrize('raised,expected,match', [
467 # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.send
468 (lambda: websockets
.exceptions
.ConnectionClosed(None, None), TransportError
, None),
469 (lambda: RuntimeError(), TransportError
, None),
470 (lambda: TimeoutError(), TransportError
, None),
471 (lambda: TypeError(), RequestError
, None),
472 (lambda: socks
.ProxyError(), ProxyError
, None),
474 (lambda: websockets
.exceptions
.WebSocketException(), TransportError
, None),
476 def test_ws_send_error_mapping(self
, handler
, monkeypatch
, raised
, expected
, match
):
477 from yt_dlp
.networking
._websockets
import WebsocketsResponseAdapter
478 ws
= WebsocketsResponseAdapter(create_fake_ws_connection(raised
), url
='ws://fake-url')
479 with pytest
.raises(expected
, match
=match
) as exc_info
:
481 assert exc_info
.type is expected
483 @pytest.mark
.parametrize('raised,expected,match', [
484 # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.recv
485 (lambda: websockets
.exceptions
.ConnectionClosed(None, None), TransportError
, None),
486 (lambda: RuntimeError(), TransportError
, None),
487 (lambda: TimeoutError(), TransportError
, None),
488 (lambda: socks
.ProxyError(), ProxyError
, None),
490 (lambda: websockets
.exceptions
.WebSocketException(), TransportError
, None),
492 def test_ws_recv_error_mapping(self
, handler
, monkeypatch
, raised
, expected
, match
):
493 from yt_dlp
.networking
._websockets
import WebsocketsResponseAdapter
494 ws
= WebsocketsResponseAdapter(create_fake_ws_connection(raised
), url
='ws://fake-url')
495 with pytest
.raises(expected
, match
=match
) as exc_info
:
497 assert exc_info
.type is expected