3 # Copyright © 2020 Valve Corporation.
4 # Copyright © 2022 ‒ 2023 Collabora Ltd
6 # Permission is hereby granted, free of charge, to any person obtaining a
7 # copy of this software and associated documentation files (the "Software"),
8 # to deal in the Software without restriction, including without limitation
9 # the rights to use, copy, modify, merge, publish, distribute, sublicense,
10 # and/or sell copies of the Software, and to permit persons to whom the
11 # Software is furnished to do so, subject to the following conditions:
13 # The above copyright notice and this permission notice shall be included
14 # in all copies or substantial portions of the Software.
16 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
17 # OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
19 # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR
20 # OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
21 # ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
22 # OTHER DEALINGS IN THE SOFTWARE.
24 # SPDX-License-Identifier: MIT
27 """Tests for replayer's download_utils module."""
30 from contextlib
import contextmanager
31 from contextlib
import nullcontext
as does_not_raise
32 from dataclasses
import dataclass
33 from hashlib
import md5
35 from pathlib
import Path
36 from typing
import Any
37 from urllib
.parse
import urlparse
43 from framework
import exceptions
44 from framework
.replay
import download_utils
45 from framework
.replay
.options
import OPTIONS
47 ASSUME_ROLE_RESPONSE
= '''<?xml version="1.0" encoding="UTF-8"?>
48 <AssumeRoleWithWebIdentityResponse
49 xmlns="https://sts.amazonaws.com/doc/2011-06-15/">
50 <AssumeRoleWithWebIdentityResult>
52 <AccessKeyId>Key</AccessKeyId>
53 <SecretAccessKey>Secret</SecretAccessKey>
54 <Expiration>2021-03-25T13:59:58Z</Expiration>
55 <SessionToken>token</SessionToken>
57 </AssumeRoleWithWebIdentityResult>
58 </AssumeRoleWithWebIdentityResponse>
61 class MockedResponseData
:
62 binary_data
: bytes
= b
"haxter"
65 @dataclass(frozen
=True)
68 def header_scenarios():
69 binary_data_md5
: str = md5(MockedResponseData
.binary_data
).hexdigest()
70 etag
: dict[str, str] = {"etag": binary_data_md5
}
71 length
: dict[str, Any
] = {
72 "Content-Length": str(len(MockedResponseData
.binary_data
))
75 "With Content-Length": length
,
77 "With Content-Length and etag": {**length
, **etag
},
78 "Without integrity headers": {},
82 def stored_file_scenarios():
84 "nothing stored": None,
85 "already has file": MockedResponseData
.binary_data
,
86 "already has wrong file": b
"obsolete/corrupted data",
90 def create_local_file(trace_file
, data
):
91 data
= data
or MockedResponseData
.binary_data
92 trace_path
= Path(trace_file
)
93 trace_path
.write_bytes(data
)
95 if trace_path
.exists():
97 class TestDownloadUtils(object):
98 """Tests for download_utils methods."""
100 @pytest.fixture(autouse
=True)
101 def setup(self
, requests_mock
, tmpdir
):
102 self
.url
= 'https://unittest.piglit.org/'
103 self
.trace_path
= 'KhronosGroup-Vulkan-Tools/amd/polaris10/vkcube.gfxr'
104 self
.full_url
= self
.url
+ self
.trace_path
105 self
.trace_file
= tmpdir
.join(self
.trace_path
)
106 OPTIONS
.set_download_url(self
.url
)
107 OPTIONS
.download
['force'] = False
108 OPTIONS
.db_path
= tmpdir
.strpath
109 requests_mock
.get(self
.full_url
, text
='remote')
110 requests_mock
.head(self
.full_url
, text
='remote')
113 def check_same_file(path_local
, expected_content
, expected_mtime
=None):
114 assert path_local
.read() == expected_content
115 if expected_mtime
is not None:
116 m
= path_local
.mtime()
117 assert m
== expected_mtime
120 def prepare_trace_file(self
):
121 # Make sure the temporary directory exists
122 os
.makedirs(path
.dirname(self
.trace_file
), exist_ok
=True)
125 def create_mock_response(self
, requests_mock
):
126 def inner(url
, headers
):
128 "content": MockedResponseData
.binary_data
,
131 requests_mock
.get(url
, **kwargs
)
132 requests_mock
.head(url
, **kwargs
)
136 def test_ensure_file_exists(self
,
138 """download_utils.ensure_file: Check an existing file doesn't get overwritten"""
140 self
.trace_file
.write("local")
141 m
= self
.trace_file
.mtime()
142 download_utils
.ensure_file(self
.trace_path
)
143 TestDownloadUtils
.check_same_file(self
.trace_file
, "local", m
)
145 def test_ensure_file_not_exists(self
):
146 """download_utils.ensure_file: Check a non existing file gets downloaded"""
148 assert not self
.trace_file
.check()
149 download_utils
.ensure_file(self
.trace_path
)
150 TestDownloadUtils
.check_same_file(self
.trace_file
, "remote")
152 def test_ensure_file_exists_force_download(self
,
154 """download_utils.ensure_file: Check an existing file gets overwritten when forced"""
156 OPTIONS
.download
['force'] = True
157 self
.trace_file
.write("local")
158 self
.trace_file
.mtime()
159 download_utils
.ensure_file(self
.trace_path
)
160 TestDownloadUtils
.check_same_file(self
.trace_file
, "remote")
162 @pytest.mark
.raises(exception
=exceptions
.PiglitFatalError
)
163 def test_ensure_file_not_exists_no_url(self
):
164 """download_utils.ensure_file: Check an exception raises when not passing an URL for a non existing file"""
166 OPTIONS
.set_download_url("")
167 assert not self
.trace_file
.check()
168 download_utils
.ensure_file(self
.trace_path
)
170 @pytest.mark
.raises(exception
=requests
.exceptions
.HTTPError
)
171 def test_ensure_file_not_exists_404(self
, requests_mock
):
172 """download_utils.ensure_file: Check an exception raises when an URL returns a 404"""
174 requests_mock
.get(self
.full_url
, text
='Not Found', status_code
=404)
175 assert not self
.trace_file
.check()
176 download_utils
.ensure_file(self
.trace_path
)
178 @pytest.mark
.raises(exception
=requests
.exceptions
.ConnectTimeout
)
179 def test_ensure_file_not_exists_timeout(self
, requests_mock
):
180 """download_utils.ensure_file: Check an exception raises when an URL returns a Connect Timeout"""
182 requests_mock
.get(self
.full_url
, exc
=requests
.exceptions
.ConnectTimeout
)
183 assert not self
.trace_file
.check()
184 download_utils
.ensure_file(self
.trace_path
)
187 def already_has_wrong_file(self
):
188 self
.trace_file
.write(b
"this_is_not_correct_file")
190 Path(self
.trace_file
).unlink()
192 @pytest.mark
.parametrize(
194 MockedResponse
.stored_file_scenarios().values(),
195 ids
=MockedResponse
.stored_file_scenarios().keys(),
197 @pytest.mark
.parametrize(
199 MockedResponse
.header_scenarios().values(),
200 ids
=MockedResponse
.header_scenarios().keys(),
202 def test_ensure_file_checks_integrity(
203 self
, prepare_trace_file
, create_mock_response
, headers
, stored_data
205 create_mock_response(self
.full_url
, headers
)
206 with MockedResponse
.create_local_file(self
.trace_file
, stored_data
):
207 stored_file_is_wrong
: bool = (
208 self
.trace_file
.check()
209 and self
.trace_file
.read() != MockedResponseData
.binary_data
.decode()
212 pytest
.raises(exceptions
.PiglitFatalError
)
213 if headers
and stored_file_is_wrong
214 else does_not_raise()
217 download_utils
.ensure_file(self
.trace_path
)
220 @pytest.mark
.raises(exception
=exceptions
.PiglitFatalError
)
221 def test_download_with_invalid_content_length(self
,
225 """download_utils.download: Check if an exception raises
226 when filesize doesn't match"""
228 headers
= {"Content-Length": "1"}
229 requests_mock
.get(self
.full_url
,
231 text
="Binary file content")
233 assert not self
.trace_file
.check()
234 with mocker
.patch('os.remove') as mock_remove
:
235 download_utils
.download(self
.full_url
, self
.trace_file
, None)
236 mock_remove
.assert_called_once_with(self
.trace_file
)
239 def test_download_works_at_last_retry(self
,
242 """download_utils.download: Check download retry mechanism"""
244 bad_headers
= {"Content-Length": "1"}
245 # Mock attempts - 1 bad requests and a working last one
247 for _
in range(attempts
- 1):
248 requests_mock
.get(self
.full_url
,
250 text
="Binary file content")
251 requests_mock
.get(self
.full_url
,
252 text
="Binary file content")
254 assert not self
.trace_file
.check()
255 download_utils
.download(self
.full_url
, self
.trace_file
, None)
256 assert Path
.exists(self
.trace_file
)
258 def test_download_without_content_length(self
,
261 """download_utils.download: Check an exception raises
262 when response does not have a Context-Length header"""
265 requests_mock
.get(self
.full_url
,
266 headers
=missing_headers
,
267 text
="Binary file content")
269 assert not self
.trace_file
.check()
270 download_utils
.download(self
.full_url
, self
.trace_file
, None)
271 assert Path
.exists(self
.trace_file
)
273 def test_minio_authorization(self
, requests_mock
):
274 """download_utils.ensure_file: Check we send the authentication headers to MinIO"""
275 requests_mock
.post(self
.url
, text
=ASSUME_ROLE_RESPONSE
)
276 OPTIONS
.download
['minio_host'] = urlparse(self
.url
).netloc
277 OPTIONS
.download
['minio_bucket'] = 'minio_bucket'
278 OPTIONS
.download
['role_session_name'] = 'role_session_name'
279 OPTIONS
.download
['jwt'] = 'jwt'
281 assert not self
.trace_file
.check()
282 download_utils
.ensure_file(self
.trace_path
)
283 TestDownloadUtils
.check_same_file(self
.trace_file
, "remote")
285 post_request
= requests_mock
.request_history
[0]
286 assert(post_request
.method
== 'POST')
288 get_request
= requests_mock
.request_history
[1]
289 assert(get_request
.method
== 'GET')
290 assert(requests_mock
.request_history
[1].headers
['Authorization'].startswith('AWS Key'))
292 def test_jwt_authorization(self
, requests_mock
):
293 """download_utils.ensure_file: Check we send the authentication headers to the server"""
294 # reset minio_host from previous tests
295 OPTIONS
.download
['minio_host'] = ''
296 OPTIONS
.download
['jwt'] = 'jwt'
298 assert not self
.trace_file
.check()
299 download_utils
.ensure_file(self
.trace_path
)
300 TestDownloadUtils
.check_same_file(self
.trace_file
, "remote")
302 get_request
= requests_mock
.request_history
[0]
303 assert(get_request
.method
== 'GET')
304 assert(requests_mock
.request_history
[0].headers
['Authorization'].startswith('Bearer'))