Merge pull request #330634 from r-ryantm/auto-update/circumflex
[NixPkgs.git] / pkgs / servers / apache-airflow / update-providers.py
blob207c3811925863f5c15fcfc1e6985b30fe33bc54
1 #! /usr/bin/env python3
3 from itertools import chain
4 import json
5 import logging
6 from pathlib import Path
7 import os
8 import re
9 import subprocess
10 import sys
11 from typing import Dict, List, Optional, Set, TextIO
12 from urllib.request import urlopen
13 from urllib.error import HTTPError
14 import yaml
16 PKG_SET = "apache-airflow.pythonPackages"
18 # If some requirements are matched by multiple or no Python packages, the
19 # following can be used to choose the correct one
20 PKG_PREFERENCES = {
21 "dnspython": "dnspython",
22 "elasticsearch-dsl": "elasticsearch-dsl",
23 "google-api-python-client": "google-api-python-client",
24 "protobuf": "protobuf",
25 "psycopg2-binary": "psycopg2",
26 "requests_toolbelt": "requests-toolbelt",
29 # Requirements missing from the airflow provider metadata
30 EXTRA_REQS = {
31 "sftp": ["pysftp"],
35 def get_version():
36 with open(os.path.dirname(sys.argv[0]) + "/default.nix") as fh:
37 # A version consists of digits, dots, and possibly a "b" (for beta)
38 m = re.search('version = "([\\d\\.b]+)";', fh.read())
39 return m.group(1)
42 def get_file_from_github(version: str, path: str):
43 with urlopen(
44 f"https://raw.githubusercontent.com/apache/airflow/{version}/{path}"
45 ) as response:
46 return yaml.safe_load(response)
49 def repository_root() -> Path:
50 return Path(os.path.dirname(sys.argv[0])) / "../../.."
53 def dump_packages() -> Dict[str, Dict[str, str]]:
54 # Store a JSON dump of Nixpkgs' python3Packages
55 output = subprocess.check_output(
57 "nix-env",
58 "-f",
59 repository_root(),
60 "-qa",
61 "-A",
62 PKG_SET,
63 "--arg",
64 "config",
65 "{ allowAliases = false; }",
66 "--json",
69 return json.loads(output)
72 def remove_version_constraint(req: str) -> str:
73 return re.sub(r"[=><~].*$", "", req)
76 def name_to_attr_path(req: str, packages: Dict[str, Dict[str, str]]) -> Optional[str]:
77 if req in PKG_PREFERENCES:
78 return f"{PKG_SET}.{PKG_PREFERENCES[req]}"
79 attr_paths = []
80 names = [req]
81 # E.g. python-mpd2 is actually called python3.6-mpd2
82 # instead of python-3.6-python-mpd2 inside Nixpkgs
83 if req.startswith("python-") or req.startswith("python_"):
84 names.append(req[len("python-") :])
85 for name in names:
86 # treat "-" and "_" equally
87 name = re.sub("[-_]", "[-_]", name)
88 # python(minor).(major)-(pname)-(version or unstable-date)
89 # we need the version qualifier, or we'll have multiple matches
90 # (e.g. pyserial and pyserial-asyncio when looking for pyserial)
91 pattern = re.compile(
92 f"^python\\d+\\.\\d+-{name}-(?:\\d|unstable-.*)", re.I
94 for attr_path, package in packages.items():
95 # logging.debug("Checking match for %s with %s", name, package["name"])
96 if pattern.match(package["name"]):
97 attr_paths.append(attr_path)
98 # Let's hope there's only one derivation with a matching name
99 assert len(attr_paths) <= 1, f"{req} matches more than one derivation: {attr_paths}"
100 if attr_paths:
101 return attr_paths[0]
102 return None
105 def provider_reqs_to_attr_paths(reqs: List, packages: Dict) -> List:
106 no_version_reqs = map(remove_version_constraint, reqs)
107 filtered_reqs = [
108 req for req in no_version_reqs if not re.match(r"^apache-airflow", req)
110 attr_paths = []
111 for req in filtered_reqs:
112 attr_path = name_to_attr_path(req, packages)
113 if attr_path is not None:
114 # Add attribute path without "python3Packages." prefix
115 pname = attr_path[len(PKG_SET + ".") :]
116 attr_paths.append(pname)
117 else:
118 # If we can't find it, we just skip and warn the user
119 logging.warning("Could not find package attr for %s", req)
120 return attr_paths
123 def get_cross_provider_reqs(
124 provider: str, provider_reqs: Dict, cross_provider_deps: Dict, seen: List = None
125 ) -> Set:
126 # Unfortunately there are circular cross-provider dependencies, so keep a
127 # list of ones we've seen already
128 seen = seen or []
129 reqs = set(provider_reqs[provider])
130 if len(cross_provider_deps[provider]) > 0:
131 reqs.update(
132 chain.from_iterable(
133 get_cross_provider_reqs(
134 d, provider_reqs, cross_provider_deps, seen + [provider]
136 if d not in seen
137 else []
138 for d in cross_provider_deps[provider]
141 return reqs
144 def get_provider_reqs(version: str, packages: Dict) -> Dict:
145 provider_dependencies = get_file_from_github(
146 version, "generated/provider_dependencies.json"
148 provider_reqs = {}
149 cross_provider_deps = {}
150 for provider, provider_data in provider_dependencies.items():
151 provider_reqs[provider] = list(
152 provider_reqs_to_attr_paths(provider_data["deps"], packages)
153 ) + EXTRA_REQS.get(provider, [])
154 cross_provider_deps[provider] = [
155 d for d in provider_data["cross-providers-deps"] if d != "common.sql"
157 transitive_provider_reqs = {}
158 # Add transitive cross-provider reqs
159 for provider in provider_reqs:
160 transitive_provider_reqs[provider] = get_cross_provider_reqs(
161 provider, provider_reqs, cross_provider_deps
163 return transitive_provider_reqs
166 def get_provider_yaml(version: str, provider: str) -> Dict:
167 provider_dir = provider.replace(".", "/")
168 path = f"airflow/providers/{provider_dir}/provider.yaml"
169 try:
170 return get_file_from_github(version, path)
171 except HTTPError:
172 logging.warning("Couldn't get provider yaml for %s", provider)
173 return {}
176 def get_provider_imports(version: str, providers) -> Dict:
177 provider_imports = {}
178 for provider in providers:
179 provider_yaml = get_provider_yaml(version, provider)
180 imports: List[str] = []
181 if "hooks" in provider_yaml:
182 imports.extend(
183 chain.from_iterable(
184 hook["python-modules"] for hook in provider_yaml["hooks"]
187 if "operators" in provider_yaml:
188 imports.extend(
189 chain.from_iterable(
190 operator["python-modules"]
191 for operator in provider_yaml["operators"]
194 provider_imports[provider] = imports
195 return provider_imports
198 def to_nix_expr(provider_reqs: Dict, provider_imports: Dict, fh: TextIO) -> None:
199 fh.write("# Warning: generated by update-providers.py, do not update manually\n")
200 fh.write("{\n")
201 for provider, reqs in provider_reqs.items():
202 provider_name = provider.replace(".", "_")
203 fh.write(f" {provider_name} = {{\n")
204 fh.write(
205 " deps = [ " + " ".join(sorted(f'"{req}"' for req in reqs)) + " ];\n"
207 fh.write(
208 " imports = [ "
209 + " ".join(sorted(f'"{imp}"' for imp in provider_imports[provider]))
210 + " ];\n"
212 fh.write(" };\n")
213 fh.write("}\n")
216 def main() -> None:
217 logging.basicConfig(level=logging.INFO)
218 version = get_version()
219 packages = dump_packages()
220 logging.info("Generating providers.nix for version %s", version)
221 provider_reqs = get_provider_reqs(version, packages)
222 provider_imports = get_provider_imports(version, provider_reqs.keys())
223 with open("providers.nix", "w") as fh:
224 to_nix_expr(provider_reqs, provider_imports, fh)
227 if __name__ == "__main__":
228 main()