cmake: respect indentation
[piglit.git] / framework / replay / download_utils.py
blob261a41e91b6cdb6431d919740b13316b930cb6d3
1 # coding=utf-8
3 # Copyright © 2020, 2022 Collabora Ltd
4 # Copyright © 2020 Valve Corporation.
6 # Permission is hereby granted, free of charge, to any person obtaining a
7 # copy of this software and associated documentation files (the "Software"),
8 # to deal in the Software without restriction, including without limitation
9 # the rights to use, copy, modify, merge, publish, distribute, sublicense,
10 # and/or sell copies of the Software, and to permit persons to whom the
11 # Software is furnished to do so, subject to the following conditions:
13 # The above copyright notice and this permission notice shall be included
14 # in all copies or substantial portions of the Software.
16 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
17 # OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
19 # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR
20 # OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
21 # ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
22 # OTHER DEALINGS IN THE SOFTWARE.
24 # SPDX-License-Identifier: MIT
26 import base64
27 import hashlib
28 import hmac
29 import xml.etree.ElementTree as ET
30 from email.utils import formatdate
31 from os import path, remove
32 from pathlib import Path
33 from typing import Any, Dict, List
34 from urllib.parse import urlparse
36 import requests
37 from requests.adapters import HTTPAdapter, Retry
38 from requests.utils import requote_uri
40 from framework import core, exceptions
41 from framework.replay.local_file_adapter import LocalFileAdapter
42 from framework.replay.options import OPTIONS
44 __all__ = ['ensure_file']
46 minio_credentials = None
49 def sign_with_hmac(key, message):
50 key = key.encode("UTF-8")
51 message = message.encode("UTF-8")
53 signature = hmac.new(key, message, hashlib.sha1).digest()
55 return base64.encodebytes(signature).strip().decode()
58 def get_minio_credentials(url):
59 global minio_credentials
61 if minio_credentials is not None:
62 return (minio_credentials['AccessKeyId'],
63 minio_credentials['SecretAccessKey'],
64 minio_credentials['SessionToken'])
66 minio_credentials = {}
68 params = {'Action': 'AssumeRoleWithWebIdentity',
69 'Version': '2011-06-15',
70 'RoleArn': 'arn:aws:iam::123456789012:role/FederatedWebIdentityRole',
71 'RoleSessionName': OPTIONS.download['role_session_name'],
72 'DurationSeconds': 3600,
73 'WebIdentityToken': OPTIONS.download['jwt']}
74 r = requests.post('https://%s' % OPTIONS.download['minio_host'], params=params)
75 if r.status_code >= 400:
76 print(r.text)
77 r.raise_for_status()
79 root = ET.fromstring(r.text)
80 for attr in root.iter():
81 if attr.tag == '{https://sts.amazonaws.com/doc/2011-06-15/}AccessKeyId':
82 minio_credentials['AccessKeyId'] = attr.text
83 elif attr.tag == '{https://sts.amazonaws.com/doc/2011-06-15/}SecretAccessKey':
84 minio_credentials['SecretAccessKey'] = attr.text
85 elif attr.tag == '{https://sts.amazonaws.com/doc/2011-06-15/}SessionToken':
86 minio_credentials['SessionToken'] = attr.text
88 return (minio_credentials['AccessKeyId'],
89 minio_credentials['SecretAccessKey'],
90 minio_credentials['SessionToken'])
93 def get_minio_authorization_headers(url, resource):
94 minio_key, minio_secret, minio_token = get_minio_credentials(url)
96 date = formatdate(timeval=None, localtime=False, usegmt=True)
97 to_sign = "GET\n\n\n%s\nx-amz-security-token:%s\n/%s/%s" % (date,
98 minio_token,
99 OPTIONS.download['minio_bucket'],
100 requote_uri(resource))
101 signature = sign_with_hmac(minio_secret, to_sign)
103 headers = {'Host': OPTIONS.download['minio_host'],
104 'Date': date,
105 'Authorization': 'AWS %s:%s' % (minio_key, signature),
106 'x-amz-security-token': minio_token}
107 return headers
110 def get_jwt_authorization_headers(url, resource):
111 date = formatdate(timeval=None, localtime=False, usegmt=True)
112 jwt = OPTIONS.download['jwt']
113 host = urlparse(url).netloc
115 headers = {'Host': host,
116 'Date': date,
117 'Authorization': 'Bearer %s' % (jwt)}
118 return headers
121 def calc_etags(inputfile: Path, partsize: int = 10 * 1024 * 1024) -> List[str]:
122 '''Calculate e-tag generated by FDO upload script (s3cp).'''
123 md5 = hashlib.md5()
124 md5_digests = []
125 with open(inputfile, 'rb') as file:
126 for chunk in iter(lambda: file.read(partsize), b''):
127 md5.update(chunk)
128 md5_digests.append(hashlib.md5(chunk).digest())
129 return [
130 hashlib.md5(b''.join(md5_digests)).hexdigest() + '-' + str(len(md5_digests)),
131 md5.hexdigest()
135 @core.timer_ms
136 def download(url: str, file_path: str, headers: Dict[str, str], attempts: int = 2) -> None:
137 """Downloads a URL content into a file
139 :param url: URL to download
140 :param file_path: Local file name to contain the data downloaded
141 :param attempts: Number of attempts
143 retries = Retry(
144 backoff_factor=30,
145 connect=attempts,
146 read=attempts,
147 redirect=attempts,
148 status_forcelist=[429, 500, 502, 503, 504],
149 raise_on_redirect=False
151 session = requests.Session()
152 for protocol in ["http://", "https://"]:
153 adapter = HTTPAdapter(max_retries=retries)
154 session.mount(protocol, adapter)
155 for protocol in ["file://"]:
156 file_adapter = LocalFileAdapter()
157 session.mount(protocol, file_adapter)
159 md5 = hashlib.md5()
160 local_file_checksums: List[Any] = []
161 md5_digests = []
162 with session.get(url,
163 allow_redirects=True,
164 stream=True,
165 headers=headers) as response:
166 with open(file_path, "wb") as file:
167 response.raise_for_status()
168 # chuck_size must be equal to s3cp upload chunk for md5 digest to match
169 for chunk in response.iter_content(chunk_size=10 * 1024 * 1024):
170 if chunk:
171 file.write(chunk)
172 md5.update(chunk)
173 md5_digests.append(hashlib.md5(chunk).digest())
174 local_file_checksums = [
175 hashlib.md5(b''.join(md5_digests)).hexdigest() + '-' + str(len(md5_digests)),
176 md5.hexdigest()
179 verify_file_integrity(file_path, response.headers, local_file_checksums)
182 def verify_file_integrity(file_path: str, headers: Any, local_file_checksums: Any) -> None:
184 :param file_path: path to the local file
185 :param headers: reference to the request
186 :param local_file_checksums: list of already generated MD5
188 try:
189 remote_file_checksum: str = headers["etag"].strip('\"').lower()
190 if remote_file_checksum not in local_file_checksums:
191 remove(file_path)
192 raise exceptions.PiglitFatalError(
193 f"MD5 checksum {local_file_checksums} "
194 f"doesn't match remote ETag MD5 {remote_file_checksum}, removing file..."
196 except KeyError:
197 print("ETag is missing from the HTTPS header. "
198 "Fall back to Content-length verification.")
200 try:
201 remote_file_size = int(headers["Content-Length"])
202 except KeyError:
203 print("Error getting Content-Length from server. "
204 "Skipping file size check.")
205 return
207 local_file_size = path.getsize(file_path)
208 if remote_file_size != local_file_size:
209 remove(file_path)
210 raise exceptions.PiglitFatalError(
211 f"Invalid filesize src {remote_file_size} "
212 f"doesn't match {local_file_size}, removing file..."
216 def verify_local_file_checksum(url, file_path, headers, destination_file_path):
217 @core.timer_ms
218 def check_md5():
219 print(
220 f"[check_image] Verifying already downloaded file {file_path}",
221 end=" ",
222 flush=True,
224 verify_file_integrity(
225 destination_file_path, remote_headers, calc_etags(destination_file_path)
228 print(f"[check_image] Requesting headers for {file_path}", end=" ", flush=True)
229 try:
230 response = requests.head(url + file_path, timeout=60, headers=headers)
231 except requests.exceptions.RequestException as err:
232 print(f"Not verified! HTTP request failed with {err}", flush=True)
233 return
234 print(
235 f"returned {response.status_code}.",
236 f"Took {response.elapsed.microseconds / 1000} ms",
237 flush=True,
239 remote_headers = response.headers
241 check_md5()
244 def ensure_file(file_path):
245 destination_file_path = path.join(OPTIONS.db_path, file_path)
246 if OPTIONS.download['url'] is None:
247 if not path.exists(destination_file_path):
248 raise exceptions.PiglitFatalError(
249 '{} missing'.format(destination_file_path))
250 return
252 url = OPTIONS.download['url'].geturl()
254 if OPTIONS.download['caching_proxy_url'] is not None:
255 url = OPTIONS.download['caching_proxy_url'].geturl() + url
257 core.check_dir(path.dirname(destination_file_path))
259 if OPTIONS.download['minio_host']:
260 assert OPTIONS.download['minio_bucket']
261 assert OPTIONS.download['role_session_name']
262 assert OPTIONS.download['jwt']
263 headers = get_minio_authorization_headers(url, file_path)
264 elif OPTIONS.download['jwt']:
265 headers = get_jwt_authorization_headers(url, file_path)
266 else:
267 headers = None
269 if not OPTIONS.download['force'] and path.exists(destination_file_path):
270 verify_local_file_checksum(url, file_path, headers, destination_file_path)
271 return
273 print(f"[check_image] Downloading file {file_path}", end=" ", flush=True)
275 download(url + file_path, destination_file_path, headers)