btrbk: add mainProgram
[NixPkgs.git] / pkgs / by-name / az / azure-cli / extensions-tool.py
blob74e58723a099be46538a812ce9812959e49d6e9d
1 #!/usr/bin/env python
3 import argparse
4 import base64
5 import datetime
6 import json
7 import logging
8 import os
9 import sys
10 from dataclasses import asdict, dataclass, replace
11 from pathlib import Path
12 from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
13 from urllib.request import Request, urlopen
15 import git
16 from packaging.version import Version, parse
18 INDEX_URL = "https://azcliextensionsync.blob.core.windows.net/index1/index.json"
20 logger = logging.getLogger(__name__)
23 @dataclass(frozen=True)
24 class Ext:
25 pname: str
26 version: Version
27 url: str
28 hash: str
29 description: str
32 def _read_cached_index(path: Path) -> Tuple[datetime.datetime, Any]:
33 with open(path, "r") as f:
34 data = f.read()
36 j = json.loads(data)
37 cache_date_str = j["cache_date"]
38 if cache_date_str:
39 cache_date = datetime.datetime.fromisoformat(cache_date_str)
40 else:
41 cache_date = datetime.datetime.min
42 return cache_date, data
45 def _write_index_to_cache(data: Any, path: Path):
46 j = json.loads(data)
47 j["cache_date"] = datetime.datetime.now().isoformat()
48 with open(path, "w") as f:
49 json.dump(j, f, indent=2)
52 def _fetch_remote_index():
53 r = Request(INDEX_URL)
54 with urlopen(r) as resp:
55 return resp.read()
58 def get_extension_index(cache_dir: Path) -> Set[Ext]:
59 index_file = cache_dir / "index.json"
60 os.makedirs(cache_dir, exist_ok=True)
62 try:
63 index_cache_date, index_data = _read_cached_index(index_file)
64 except FileNotFoundError:
65 logger.info("index has not been cached, downloading from source")
66 logger.info("creating index cache in %s", index_file)
67 _write_index_to_cache(_fetch_remote_index(), index_file)
68 return get_extension_index(cache_dir)
70 if (
71 index_cache_date
72 and datetime.datetime.now() - index_cache_date > datetime.timedelta(days=1)
74 logger.info(
75 "cache is outdated (%s), refreshing",
76 datetime.datetime.now() - index_cache_date,
78 _write_index_to_cache(_fetch_remote_index(), index_file)
79 return get_extension_index(cache_dir)
81 logger.info("using index cache from %s", index_file)
82 return json.loads(index_data)
85 def _read_extension_set(extensions_generated: Path) -> Set[Ext]:
86 with open(extensions_generated, "r") as f:
87 data = f.read()
89 parsed_exts = {Ext(**json_ext) for _pname, json_ext in json.loads(data).items()}
90 parsed_exts_with_ver = set()
91 for ext in parsed_exts:
92 ext2 = replace(ext, version=parse(ext.version))
93 parsed_exts_with_ver.add(ext2)
95 return parsed_exts_with_ver
98 def _write_extension_set(extensions_generated: Path, extensions: Set[Ext]) -> None:
99 set_without_ver = {replace(ext, version=str(ext.version)) for ext in extensions}
100 ls = list(set_without_ver)
101 ls.sort(key=lambda e: e.pname)
102 with open(extensions_generated, "w") as f:
103 json.dump({ext.pname: asdict(ext) for ext in ls}, f, indent=2)
104 f.write("\n")
107 def _convert_hash_digest_from_hex_to_b64_sri(s: str) -> str:
108 try:
109 b = bytes.fromhex(s)
110 except ValueError as err:
111 logger.error("not a hex value: %s", str(err))
112 raise err
114 return f"sha256-{base64.b64encode(b).decode('utf-8')}"
117 def _commit(repo: git.Repo, message: str, files: List[Path], actor: git.Actor) -> None:
118 repo.index.add([str(f.resolve()) for f in files])
119 if repo.index.diff("HEAD"):
120 logger.info(f'committing to nixpkgs "{message}"')
121 repo.index.commit(message, author=actor, committer=actor)
122 else:
123 logger.warning("no changes in working tree to commit")
126 def _filter_invalid(o: Dict[str, Any]) -> bool:
127 if "metadata" not in o:
128 logger.warning("extension without metadata")
129 return False
130 metadata = o["metadata"]
131 if "name" not in metadata:
132 logger.warning("extension without name")
133 return False
134 if "version" not in metadata:
135 logger.warning(f"{metadata['name']} without version")
136 return False
137 if "azext.minCliCoreVersion" not in metadata:
138 logger.warning(
139 f"{metadata['name']} {metadata['version']} does not have azext.minCliCoreVersion"
141 return False
142 if "summary" not in metadata:
143 logger.info(f"{metadata['name']} {metadata['version']} without summary")
144 return False
145 if "downloadUrl" not in o:
146 logger.warning(f"{metadata['name']} {metadata['version']} without downloadUrl")
147 return False
148 if "sha256Digest" not in o:
149 logger.warning(f"{metadata['name']} {metadata['version']} without sha256Digest")
150 return False
152 return True
155 def _filter_compatible(o: Dict[str, Any], cli_version: Version) -> bool:
156 minCliVersion = parse(o["metadata"]["azext.minCliCoreVersion"])
157 return cli_version >= minCliVersion
160 def _transform_dict_to_obj(o: Dict[str, Any]) -> Ext:
161 m = o["metadata"]
162 return Ext(
163 pname=m["name"],
164 version=parse(m["version"]),
165 url=o["downloadUrl"],
166 hash=_convert_hash_digest_from_hex_to_b64_sri(o["sha256Digest"]),
167 description=m["summary"].rstrip("."),
171 def _get_latest_version(versions: dict) -> dict:
172 return max(versions, key=lambda e: parse(e["metadata"]["version"]), default=None)
175 def processExtension(
176 extVersions: dict,
177 cli_version: Version,
178 ext_name: Optional[str] = None,
179 requirements: bool = False,
180 ) -> Optional[Ext]:
181 versions = filter(_filter_invalid, extVersions)
182 versions = filter(lambda v: _filter_compatible(v, cli_version), versions)
183 latest = _get_latest_version(versions)
184 if not latest:
185 return None
186 if ext_name and latest["metadata"]["name"] != ext_name:
187 return None
188 if not requirements and "run_requires" in latest["metadata"]:
189 return None
191 return _transform_dict_to_obj(latest)
194 def _diff_sets(
195 set_local: Set[Ext], set_remote: Set[Ext]
196 ) -> Tuple[Set[Ext], Set[Ext], Set[Tuple[Ext, Ext]]]:
197 local_exts = {ext.pname: ext for ext in set_local}
198 remote_exts = {ext.pname: ext for ext in set_remote}
199 only_local = local_exts.keys() - remote_exts.keys()
200 only_remote = remote_exts.keys() - local_exts.keys()
201 both = remote_exts.keys() & local_exts.keys()
202 return (
203 {local_exts[pname] for pname in only_local},
204 {remote_exts[pname] for pname in only_remote},
205 {(local_exts[pname], remote_exts[pname]) for pname in both},
209 def _filter_updated(e: Tuple[Ext, Ext]) -> bool:
210 prev, new = e
211 return prev != new
214 def main() -> None:
215 sh = logging.StreamHandler(sys.stderr)
216 sh.setFormatter(
217 logging.Formatter(
218 "[%(asctime)s] [%(levelname)8s] --- %(message)s (%(filename)s:%(lineno)s)",
219 "%Y-%m-%d %H:%M:%S",
222 logging.basicConfig(level=logging.INFO, handlers=[sh])
224 parser = argparse.ArgumentParser(
225 prog="azure-cli.extensions-tool",
226 description="Script to handle Azure CLI extension updates",
228 parser.add_argument(
229 "--cli-version", type=str, help="version of azure-cli (required)"
231 parser.add_argument("--extension", type=str, help="name of extension to query")
232 parser.add_argument(
233 "--cache-dir",
234 type=Path,
235 help="path where to cache the extension index",
236 default=Path(os.getenv("XDG_CACHE_HOME", Path.home() / ".cache"))
237 / "azure-cli-extensions-tool",
239 parser.add_argument(
240 "--requirements",
241 action=argparse.BooleanOptionalAction,
242 help="whether to list extensions that have requirements",
244 parser.add_argument(
245 "--commit",
246 action=argparse.BooleanOptionalAction,
247 help="whether to commit changes to git",
249 args = parser.parse_args()
251 repo = git.Repo(Path(".").resolve(), search_parent_directories=True)
252 # Workaround for https://github.com/gitpython-developers/GitPython/issues/1923
253 author = repo.config_reader().get_value("user", "name").lstrip('"').rstrip('"')
254 email = repo.config_reader().get_value("user", "email").lstrip('"').rstrip('"')
255 actor = git.Actor(author, email)
257 index = get_extension_index(args.cache_dir)
258 assert index["formatVersion"] == "1" # only support formatVersion 1
259 extensions_remote = index["extensions"]
261 cli_version = parse(args.cli_version)
263 extensions_remote_filtered = set()
264 for _ext_name, extension in extensions_remote.items():
265 extension = processExtension(extension, cli_version, args.extension)
266 if extension:
267 extensions_remote_filtered.add(extension)
269 extension_file = (
270 Path(repo.working_dir) / "pkgs/by-name/az/azure-cli/extensions-generated.json"
272 extensions_local = _read_extension_set(extension_file)
273 extensions_local_filtered = set()
274 if args.extension:
275 extensions_local_filtered = filter(
276 lambda ext: args.extension == ext.pname, extensions_local
278 else:
279 extensions_local_filtered = extensions_local
281 removed, init, updated = _diff_sets(
282 extensions_local_filtered, extensions_remote_filtered
284 updated = set(filter(_filter_updated, updated))
286 logger.info("initialized extensions:")
287 for ext in init:
288 logger.info(f" {ext.pname} {ext.version}")
289 logger.info("removed extensions:")
290 for ext in removed:
291 logger.info(f" {ext.pname} {ext.version}")
292 logger.info("updated extensions:")
293 for prev, new in updated:
294 logger.info(f" {prev.pname} {prev.version} -> {new.version}")
296 for ext in init:
297 extensions_local.add(ext)
298 commit_msg = f"azure-cli-extensions.{ext.pname}: init at {ext.version}"
299 _write_extension_set(extension_file, extensions_local)
300 if args.commit:
301 _commit(repo, commit_msg, [extension_file], actor)
303 for prev, new in updated:
304 extensions_local.remove(prev)
305 extensions_local.add(new)
306 commit_msg = (
307 f"azure-cli-extensions.{prev.pname}: {prev.version} -> {new.version}"
309 _write_extension_set(extension_file, extensions_local)
310 if args.commit:
311 _commit(repo, commit_msg, [extension_file], actor)
313 for ext in removed:
314 extensions_local.remove(ext)
315 # TODO: Add additional check why this is removed
316 # TODO: Add an alias to extensions manual?
317 commit_msg = f"azure-cli-extensions.{ext.pname}: remove"
318 _write_extension_set(extension_file, extensions_local)
319 if args.commit:
320 _commit(repo, commit_msg, [extension_file], actor)
323 if __name__ == "__main__":
324 main()