framework/replay: Fix Content-Length header name
[piglit.git] / framework / replay / download_utils.py
blob5c5c891284d4a23bef17fdf4f1bdbd55ecbe7a0f
1 # coding=utf-8
3 # Copyright © 2020, 2022 Collabora Ltd
4 # Copyright © 2020 Valve Corporation.
6 # Permission is hereby granted, free of charge, to any person obtaining a
7 # copy of this software and associated documentation files (the "Software"),
8 # to deal in the Software without restriction, including without limitation
9 # the rights to use, copy, modify, merge, publish, distribute, sublicense,
10 # and/or sell copies of the Software, and to permit persons to whom the
11 # Software is furnished to do so, subject to the following conditions:
13 # The above copyright notice and this permission notice shall be included
14 # in all copies or substantial portions of the Software.
16 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
17 # OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
19 # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR
20 # OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
21 # ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
22 # OTHER DEALINGS IN THE SOFTWARE.
24 # SPDX-License-Identifier: MIT
26 import base64
27 import hashlib
28 import hmac
29 import xml.etree.ElementTree as ET
30 from email.utils import formatdate
31 from os import path
32 from pathlib import Path
33 from typing import Any, Dict, List
34 from urllib.parse import urlparse
36 import requests
37 from requests.adapters import HTTPAdapter, Retry
38 from requests.utils import requote_uri
40 from framework import core, exceptions
41 from framework.replay.local_file_adapter import LocalFileAdapter
42 from framework.replay.options import OPTIONS
44 __all__ = ['ensure_file']
46 minio_credentials = None
49 def sign_with_hmac(key, message):
50 key = key.encode("UTF-8")
51 message = message.encode("UTF-8")
53 signature = hmac.new(key, message, hashlib.sha1).digest()
55 return base64.encodebytes(signature).strip().decode()
58 def get_minio_credentials(url):
59 global minio_credentials
61 if minio_credentials is not None:
62 return (minio_credentials['AccessKeyId'],
63 minio_credentials['SecretAccessKey'],
64 minio_credentials['SessionToken'])
66 minio_credentials = {}
68 params = {'Action': 'AssumeRoleWithWebIdentity',
69 'Version': '2011-06-15',
70 'RoleArn': 'arn:aws:iam::123456789012:role/FederatedWebIdentityRole',
71 'RoleSessionName': OPTIONS.download['role_session_name'],
72 'DurationSeconds': 3600,
73 'WebIdentityToken': OPTIONS.download['jwt']}
74 r = requests.post('https://%s' % OPTIONS.download['minio_host'], params=params)
75 if r.status_code >= 400:
76 print(r.text)
77 r.raise_for_status()
79 root = ET.fromstring(r.text)
80 for attr in root.iter():
81 if attr.tag == '{https://sts.amazonaws.com/doc/2011-06-15/}AccessKeyId':
82 minio_credentials['AccessKeyId'] = attr.text
83 elif attr.tag == '{https://sts.amazonaws.com/doc/2011-06-15/}SecretAccessKey':
84 minio_credentials['SecretAccessKey'] = attr.text
85 elif attr.tag == '{https://sts.amazonaws.com/doc/2011-06-15/}SessionToken':
86 minio_credentials['SessionToken'] = attr.text
88 return (minio_credentials['AccessKeyId'],
89 minio_credentials['SecretAccessKey'],
90 minio_credentials['SessionToken'])
93 def get_minio_authorization_headers(url, resource):
94 minio_key, minio_secret, minio_token = get_minio_credentials(url)
96 date = formatdate(timeval=None, localtime=False, usegmt=True)
97 to_sign = "GET\n\n\n%s\nx-amz-security-token:%s\n/%s/%s" % (date,
98 minio_token,
99 OPTIONS.download['minio_bucket'],
100 requote_uri(resource))
101 signature = sign_with_hmac(minio_secret, to_sign)
103 headers = {'Host': OPTIONS.download['minio_host'],
104 'Date': date,
105 'Authorization': 'AWS %s:%s' % (minio_key, signature),
106 'x-amz-security-token': minio_token}
107 return headers
110 def get_jwt_authorization_headers(url, resource):
111 date = formatdate(timeval=None, localtime=False, usegmt=True)
112 jwt = OPTIONS.download['jwt']
113 host = urlparse(url).netloc
115 headers = {'Host': host,
116 'Date': date,
117 'Authorization': 'Bearer %s' % (jwt)}
118 return headers
121 def calc_etags(inputfile: Path, partsize: int = 10 * 1024 * 1024) -> List[str]:
122 '''Calculate e-tag generated by FDO upload script (s3cp).'''
123 md5 = hashlib.md5()
124 md5_digests = []
125 with open(inputfile, 'rb') as file:
126 for chunk in iter(lambda: file.read(partsize), b''):
127 md5.update(chunk)
128 md5_digests.append(hashlib.md5(chunk).digest())
129 return [
130 hashlib.md5(b''.join(md5_digests)).hexdigest() + '-' + str(len(md5_digests)),
131 md5.hexdigest()
135 @core.timer_ms
136 def download(url: str, file_path: str, headers: Dict[str, str], attempts: int = 2) -> None:
137 """Downloads a URL content into a file
139 :param url: URL to download
140 :param file_path: Local file name to contain the data downloaded
141 :param attempts: Number of attempts
143 retries = Retry(
144 backoff_factor=30,
145 connect=attempts,
146 read=attempts,
147 redirect=attempts,
148 status_forcelist=[429, 500, 502, 503, 504],
149 raise_on_redirect=False
151 session = requests.Session()
152 for protocol in ["http://", "https://"]:
153 adapter = HTTPAdapter(max_retries=retries)
154 session.mount(protocol, adapter)
155 for protocol in ["file://"]:
156 file_adapter = LocalFileAdapter()
157 session.mount(protocol, file_adapter)
159 md5 = hashlib.md5()
160 local_file_checksums: List[Any] = []
161 md5_digests = []
162 with session.get(url,
163 allow_redirects=True,
164 stream=True,
165 headers=headers) as response:
166 with open(file_path, "wb") as file:
167 response.raise_for_status()
168 # chuck_size must be equal to s3cp upload chunk for md5 digest to match
169 for chunk in response.iter_content(chunk_size=10 * 1024 * 1024):
170 if chunk:
171 file.write(chunk)
172 md5.update(chunk)
173 md5_digests.append(hashlib.md5(chunk).digest())
174 local_file_checksums = [
175 hashlib.md5(b''.join(md5_digests)).hexdigest() + '-' + str(len(md5_digests)),
176 md5.hexdigest()
179 verify_file_integrity(file_path, response.headers, local_file_checksums)
182 def verify_file_integrity(file_path: str, headers: Any, local_file_checksums: Any) -> None:
184 :param file_path: path to the local file
185 :param headers: reference to the request
186 :param local_file_checksums: list of already generated MD5
188 try:
189 remote_file_checksum: str = headers["etag"].strip('\"').lower()
190 if remote_file_checksum not in local_file_checksums:
191 raise exceptions.PiglitFatalError(
192 f"MD5 checksum {local_file_checksums} "
193 f"doesn't match remote ETag MD5 {remote_file_checksum}"
195 except KeyError:
196 print("ETag is missing from the HTTPS header. "
197 "Fall back to Content-length verification.")
199 try:
200 remote_file_size = int(headers["Content-Length"])
201 except KeyError:
202 print("Error getting Content-Length from server. "
203 "Skipping file size check.")
204 return
206 if remote_file_size != path.getsize(file_path):
207 raise exceptions.PiglitFatalError(
208 f"Invalid filesize src {remote_file_size} "
209 f"doesn't match {path.getsize(file_path)}"
213 def verify_local_file_checksum(url, file_path, headers, destination_file_path):
214 @core.timer_ms
215 def check_md5():
216 print(
217 f"[check_image] Verifying already downloaded file {file_path}",
218 end=" ",
219 flush=True,
221 verify_file_integrity(
222 destination_file_path, remote_headers, calc_etags(destination_file_path)
225 print(f"[check_image] Requesting headers for {file_path}", end=" ", flush=True)
226 try:
227 response = requests.head(url + file_path, timeout=60, headers=headers)
228 except requests.exceptions.RequestException as err:
229 print(f"Not verified! HTTP request failed with {err}", flush=True)
230 return
231 print(
232 f"returned {response.status_code}.",
233 f"Took {response.elapsed.microseconds / 1000} ms",
234 flush=True,
236 remote_headers = response.headers
238 check_md5()
241 def ensure_file(file_path):
242 destination_file_path = path.join(OPTIONS.db_path, file_path)
243 if OPTIONS.download['url'] is None:
244 if not path.exists(destination_file_path):
245 raise exceptions.PiglitFatalError(
246 '{} missing'.format(destination_file_path))
247 return
249 url = OPTIONS.download['url'].geturl()
251 if OPTIONS.download['caching_proxy_url'] is not None:
252 url = OPTIONS.download['caching_proxy_url'].geturl() + url
254 core.check_dir(path.dirname(destination_file_path))
256 if OPTIONS.download['minio_host']:
257 assert OPTIONS.download['minio_bucket']
258 assert OPTIONS.download['role_session_name']
259 assert OPTIONS.download['jwt']
260 headers = get_minio_authorization_headers(url, file_path)
261 elif OPTIONS.download['jwt']:
262 headers = get_jwt_authorization_headers(url, file_path)
263 else:
264 headers = None
266 if not OPTIONS.download['force'] and path.exists(destination_file_path):
267 verify_local_file_checksum(url, file_path, headers, destination_file_path)
268 return
270 print(f"[check_image] Downloading file {file_path}", end=" ", flush=True)
272 download(url + file_path, destination_file_path, headers)