framework/replay: Update simpler integrity unit tests
[piglit.git] / unittests / framework / replay / test_download_utils.py
blob5652425042e6cdbbcc673eaaaeb515c964c27985
1 # coding=utf-8
3 # Copyright © 2020 Valve Corporation.
4 # Copyright © 2022 ‒ 2023 Collabora Ltd
6 # Permission is hereby granted, free of charge, to any person obtaining a
7 # copy of this software and associated documentation files (the "Software"),
8 # to deal in the Software without restriction, including without limitation
9 # the rights to use, copy, modify, merge, publish, distribute, sublicense,
10 # and/or sell copies of the Software, and to permit persons to whom the
11 # Software is furnished to do so, subject to the following conditions:
13 # The above copyright notice and this permission notice shall be included
14 # in all copies or substantial portions of the Software.
16 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
17 # OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
19 # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR
20 # OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
21 # ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
22 # OTHER DEALINGS IN THE SOFTWARE.
24 # SPDX-License-Identifier: MIT
27 """Tests for replayer's download_utils module."""
29 import os
30 from contextlib import contextmanager
31 from contextlib import nullcontext as does_not_raise
32 from dataclasses import dataclass
33 from hashlib import md5
34 from os import path
35 from pathlib import Path
36 from typing import Any
37 from urllib.parse import urlparse
39 import pytest
40 import requests
41 import requests_mock
43 from framework import exceptions
44 from framework.replay import download_utils
45 from framework.replay.options import OPTIONS
47 ASSUME_ROLE_RESPONSE = '''<?xml version="1.0" encoding="UTF-8"?>
48 <AssumeRoleWithWebIdentityResponse
49 xmlns="https://sts.amazonaws.com/doc/2011-06-15/">
50 <AssumeRoleWithWebIdentityResult>
51 <Credentials>
52 <AccessKeyId>Key</AccessKeyId>
53 <SecretAccessKey>Secret</SecretAccessKey>
54 <Expiration>2021-03-25T13:59:58Z</Expiration>
55 <SessionToken>token</SessionToken>
56 </Credentials>
57 </AssumeRoleWithWebIdentityResult>
58 </AssumeRoleWithWebIdentityResponse>
59 '''
61 class MockedResponseData:
62 binary_data: bytes = b"haxter"
65 @dataclass(frozen=True)
66 class MockedResponse:
67 @staticmethod
68 def header_scenarios():
69 binary_data_md5: str = md5(MockedResponseData.binary_data).hexdigest()
70 etag: dict[str, str] = {"etag": binary_data_md5}
71 length: dict[str, Any] = {
72 "Content-Length": str(len(MockedResponseData.binary_data))
74 return {
75 "With Content-Length": length,
76 "With etag": etag,
77 "With Content-Length and etag": {**length, **etag},
78 "Without integrity headers": {},
81 @staticmethod
82 def stored_file_scenarios():
83 return {
84 "nothing stored": None,
85 "already has file": MockedResponseData.binary_data,
86 "already has wrong file": b"obsolete/corrupted data",
89 @contextmanager
90 def create_local_file(trace_file, data):
91 data = data or MockedResponseData.binary_data
92 trace_path = Path(trace_file)
93 trace_path.write_bytes(data)
94 yield
95 if trace_path.exists():
96 trace_path.unlink()
97 class TestDownloadUtils(object):
98 """Tests for download_utils methods."""
100 @pytest.fixture(autouse=True)
101 def setup(self, requests_mock, tmpdir):
102 self.url = 'https://unittest.piglit.org/'
103 self.trace_path = 'KhronosGroup-Vulkan-Tools/amd/polaris10/vkcube.gfxr'
104 self.full_url = self.url + self.trace_path
105 self.trace_file = tmpdir.join(self.trace_path)
106 OPTIONS.set_download_url(self.url)
107 OPTIONS.download['force'] = False
108 OPTIONS.db_path = tmpdir.strpath
109 requests_mock.get(self.full_url, text='remote')
110 requests_mock.head(self.full_url, text='remote')
112 @staticmethod
113 def check_same_file(path_local, expected_content, expected_mtime=None):
114 assert path_local.read() == expected_content
115 if expected_mtime is not None:
116 m = path_local.mtime()
117 assert m == expected_mtime
119 @pytest.fixture
120 def prepare_trace_file(self):
121 # Make sure the temporary directory exists
122 os.makedirs(path.dirname(self.trace_file), exist_ok=True)
124 @pytest.fixture
125 def create_mock_response(self, requests_mock):
126 def inner(url, headers):
127 kwargs = {
128 "content": MockedResponseData.binary_data,
129 "headers": headers,
131 requests_mock.get(url, **kwargs)
132 requests_mock.head(url, **kwargs)
134 return inner
136 def test_ensure_file_exists(self,
137 prepare_trace_file):
138 """download_utils.ensure_file: Check an existing file doesn't get overwritten"""
140 self.trace_file.write("local")
141 m = self.trace_file.mtime()
142 download_utils.ensure_file(self.trace_path)
143 TestDownloadUtils.check_same_file(self.trace_file, "local", m)
145 def test_ensure_file_not_exists(self):
146 """download_utils.ensure_file: Check a non existing file gets downloaded"""
148 assert not self.trace_file.check()
149 download_utils.ensure_file(self.trace_path)
150 TestDownloadUtils.check_same_file(self.trace_file, "remote")
152 def test_ensure_file_exists_force_download(self,
153 prepare_trace_file):
154 """download_utils.ensure_file: Check an existing file gets overwritten when forced"""
156 OPTIONS.download['force'] = True
157 self.trace_file.write("local")
158 self.trace_file.mtime()
159 download_utils.ensure_file(self.trace_path)
160 TestDownloadUtils.check_same_file(self.trace_file, "remote")
162 @pytest.mark.raises(exception=exceptions.PiglitFatalError)
163 def test_ensure_file_not_exists_no_url(self):
164 """download_utils.ensure_file: Check an exception raises when not passing an URL for a non existing file"""
166 OPTIONS.set_download_url("")
167 assert not self.trace_file.check()
168 download_utils.ensure_file(self.trace_path)
170 @pytest.mark.raises(exception=requests.exceptions.HTTPError)
171 def test_ensure_file_not_exists_404(self, requests_mock):
172 """download_utils.ensure_file: Check an exception raises when an URL returns a 404"""
174 requests_mock.get(self.full_url, text='Not Found', status_code=404)
175 assert not self.trace_file.check()
176 download_utils.ensure_file(self.trace_path)
178 @pytest.mark.raises(exception=requests.exceptions.ConnectTimeout)
179 def test_ensure_file_not_exists_timeout(self, requests_mock):
180 """download_utils.ensure_file: Check an exception raises when an URL returns a Connect Timeout"""
182 requests_mock.get(self.full_url, exc=requests.exceptions.ConnectTimeout)
183 assert not self.trace_file.check()
184 download_utils.ensure_file(self.trace_path)
186 @contextmanager
187 def already_has_wrong_file(self):
188 self.trace_file.write(b"this_is_not_correct_file")
189 yield
190 Path(self.trace_file).unlink()
192 @pytest.mark.parametrize(
193 "stored_data",
194 MockedResponse.stored_file_scenarios().values(),
195 ids=MockedResponse.stored_file_scenarios().keys(),
197 @pytest.mark.parametrize(
198 "headers",
199 MockedResponse.header_scenarios().values(),
200 ids=MockedResponse.header_scenarios().keys(),
202 def test_ensure_file_checks_integrity(
203 self, prepare_trace_file, create_mock_response, headers, stored_data
205 create_mock_response(self.full_url, headers)
206 with MockedResponse.create_local_file(self.trace_file, stored_data):
207 stored_file_is_wrong: bool = (
208 self.trace_file.check()
209 and self.trace_file.read() != MockedResponseData.binary_data.decode()
211 expectation = (
212 pytest.raises(exceptions.PiglitFatalError)
213 if headers and stored_file_is_wrong
214 else does_not_raise()
216 with expectation:
217 download_utils.ensure_file(self.trace_path)
220 @pytest.mark.raises(exception=exceptions.PiglitFatalError)
221 def test_download_with_invalid_content_length(self,
222 mocker,
223 requests_mock,
224 prepare_trace_file):
225 """download_utils.download: Check if an exception raises
226 when filesize doesn't match"""
228 headers = {"Content-Length": "1"}
229 requests_mock.get(self.full_url,
230 headers=headers,
231 text="Binary file content")
233 assert not self.trace_file.check()
234 with mocker.patch('os.remove') as mock_remove:
235 download_utils.download(self.full_url, self.trace_file, None)
236 mock_remove.assert_called_once_with(self.trace_file)
239 def test_download_works_at_last_retry(self,
240 requests_mock,
241 prepare_trace_file):
242 """download_utils.download: Check download retry mechanism"""
244 bad_headers = {"Content-Length": "1"}
245 # Mock attempts - 1 bad requests and a working last one
246 attempts = 3
247 for _ in range(attempts - 1):
248 requests_mock.get(self.full_url,
249 headers=bad_headers,
250 text="Binary file content")
251 requests_mock.get(self.full_url,
252 text="Binary file content")
254 assert not self.trace_file.check()
255 download_utils.download(self.full_url, self.trace_file, None)
256 assert Path.exists(self.trace_file)
258 def test_download_without_content_length(self,
259 requests_mock,
260 prepare_trace_file):
261 """download_utils.download: Check an exception raises
262 when response does not have a Context-Length header"""
264 missing_headers = {}
265 requests_mock.get(self.full_url,
266 headers=missing_headers,
267 text="Binary file content")
269 assert not self.trace_file.check()
270 download_utils.download(self.full_url, self.trace_file, None)
271 assert Path.exists(self.trace_file)
273 def test_minio_authorization(self, requests_mock):
274 """download_utils.ensure_file: Check we send the authentication headers to MinIO"""
275 requests_mock.post(self.url, text=ASSUME_ROLE_RESPONSE)
276 OPTIONS.download['minio_host'] = urlparse(self.url).netloc
277 OPTIONS.download['minio_bucket'] = 'minio_bucket'
278 OPTIONS.download['role_session_name'] = 'role_session_name'
279 OPTIONS.download['jwt'] = 'jwt'
281 assert not self.trace_file.check()
282 download_utils.ensure_file(self.trace_path)
283 TestDownloadUtils.check_same_file(self.trace_file, "remote")
285 post_request = requests_mock.request_history[0]
286 assert(post_request.method == 'POST')
288 get_request = requests_mock.request_history[1]
289 assert(get_request.method == 'GET')
290 assert(requests_mock.request_history[1].headers['Authorization'].startswith('AWS Key'))
292 def test_jwt_authorization(self, requests_mock):
293 """download_utils.ensure_file: Check we send the authentication headers to the server"""
294 # reset minio_host from previous tests
295 OPTIONS.download['minio_host'] = ''
296 OPTIONS.download['jwt'] = 'jwt'
298 assert not self.trace_file.check()
299 download_utils.ensure_file(self.trace_path)
300 TestDownloadUtils.check_same_file(self.trace_file, "remote")
302 get_request = requests_mock.request_history[0]
303 assert(get_request.method == 'GET')
304 assert(requests_mock.request_history[0].headers['Authorization'].startswith('Bearer'))