1 #! /usr/bin/env python3
3 from itertools
import chain
6 from pathlib
import Path
11 from typing
import Dict
, List
, Optional
, Set
, TextIO
12 from urllib
.request
import urlopen
13 from urllib
.error
import HTTPError
16 PKG_SET
= "apache-airflow.pythonPackages"
18 # If some requirements are matched by multiple or no Python packages, the
19 # following can be used to choose the correct one
21 "dnspython": "dnspython",
22 "elasticsearch-dsl": "elasticsearch-dsl",
23 "google-api-python-client": "google-api-python-client",
24 "protobuf": "protobuf",
25 "psycopg2-binary": "psycopg2",
26 "requests_toolbelt": "requests-toolbelt",
29 # Requirements missing from the airflow provider metadata
36 with
open(os
.path
.dirname(sys
.argv
[0]) + "/default.nix") as fh
:
37 # A version consists of digits, dots, and possibly a "b" (for beta)
38 m
= re
.search('version = "([\\d\\.b]+)";', fh
.read())
42 def get_file_from_github(version
: str, path
: str):
44 f
"https://raw.githubusercontent.com/apache/airflow/{version}/{path}"
46 return yaml
.safe_load(response
)
49 def repository_root() -> Path
:
50 return Path(os
.path
.dirname(sys
.argv
[0])) / "../../.."
53 def dump_packages() -> Dict
[str, Dict
[str, str]]:
54 # Store a JSON dump of Nixpkgs' python3Packages
55 output
= subprocess
.check_output(
65 "{ allowAliases = false; }",
69 return json
.loads(output
)
72 def remove_version_constraint(req
: str) -> str:
73 return re
.sub(r
"[=><~].*$", "", req
)
76 def name_to_attr_path(req
: str, packages
: Dict
[str, Dict
[str, str]]) -> Optional
[str]:
77 if req
in PKG_PREFERENCES
:
78 return f
"{PKG_SET}.{PKG_PREFERENCES[req]}"
81 # E.g. python-mpd2 is actually called python3.6-mpd2
82 # instead of python-3.6-python-mpd2 inside Nixpkgs
83 if req
.startswith("python-") or req
.startswith("python_"):
84 names
.append(req
[len("python-") :])
86 # treat "-" and "_" equally
87 name
= re
.sub("[-_]", "[-_]", name
)
88 # python(minor).(major)-(pname)-(version or unstable-date)
89 # we need the version qualifier, or we'll have multiple matches
90 # (e.g. pyserial and pyserial-asyncio when looking for pyserial)
92 f
"^python\\d+\\.\\d+-{name}-(?:\\d|unstable-.*)", re
.I
94 for attr_path
, package
in packages
.items():
95 # logging.debug("Checking match for %s with %s", name, package["name"])
96 if pattern
.match(package
["name"]):
97 attr_paths
.append(attr_path
)
98 # Let's hope there's only one derivation with a matching name
99 assert len(attr_paths
) <= 1, f
"{req} matches more than one derivation: {attr_paths}"
105 def provider_reqs_to_attr_paths(reqs
: List
, packages
: Dict
) -> List
:
106 no_version_reqs
= map(remove_version_constraint
, reqs
)
108 req
for req
in no_version_reqs
if not re
.match(r
"^apache-airflow", req
)
111 for req
in filtered_reqs
:
112 attr_path
= name_to_attr_path(req
, packages
)
113 if attr_path
is not None:
114 # Add attribute path without "python3Packages." prefix
115 pname
= attr_path
[len(PKG_SET
+ ".") :]
116 attr_paths
.append(pname
)
118 # If we can't find it, we just skip and warn the user
119 logging
.warning("Could not find package attr for %s", req
)
123 def get_cross_provider_reqs(
124 provider
: str, provider_reqs
: Dict
, cross_provider_deps
: Dict
, seen
: List
= None
126 # Unfortunately there are circular cross-provider dependencies, so keep a
127 # list of ones we've seen already
129 reqs
= set(provider_reqs
[provider
])
130 if len(cross_provider_deps
[provider
]) > 0:
133 get_cross_provider_reqs(
134 d
, provider_reqs
, cross_provider_deps
, seen
+ [provider
]
138 for d
in cross_provider_deps
[provider
]
144 def get_provider_reqs(version
: str, packages
: Dict
) -> Dict
:
145 provider_dependencies
= get_file_from_github(
146 version
, "generated/provider_dependencies.json"
149 cross_provider_deps
= {}
150 for provider
, provider_data
in provider_dependencies
.items():
151 provider_reqs
[provider
] = list(
152 provider_reqs_to_attr_paths(provider_data
["deps"], packages
)
153 ) + EXTRA_REQS
.get(provider
, [])
154 cross_provider_deps
[provider
] = [
155 d
for d
in provider_data
["cross-providers-deps"] if d
!= "common.sql"
157 transitive_provider_reqs
= {}
158 # Add transitive cross-provider reqs
159 for provider
in provider_reqs
:
160 transitive_provider_reqs
[provider
] = get_cross_provider_reqs(
161 provider
, provider_reqs
, cross_provider_deps
163 return transitive_provider_reqs
166 def get_provider_yaml(version
: str, provider
: str) -> Dict
:
167 provider_dir
= provider
.replace(".", "/")
168 path
= f
"airflow/providers/{provider_dir}/provider.yaml"
170 return get_file_from_github(version
, path
)
172 logging
.warning("Couldn't get provider yaml for %s", provider
)
176 def get_provider_imports(version
: str, providers
) -> Dict
:
177 provider_imports
= {}
178 for provider
in providers
:
179 provider_yaml
= get_provider_yaml(version
, provider
)
180 imports
: List
[str] = []
181 if "hooks" in provider_yaml
:
184 hook
["python-modules"] for hook
in provider_yaml
["hooks"]
187 if "operators" in provider_yaml
:
190 operator
["python-modules"]
191 for operator
in provider_yaml
["operators"]
194 provider_imports
[provider
] = imports
195 return provider_imports
198 def to_nix_expr(provider_reqs
: Dict
, provider_imports
: Dict
, fh
: TextIO
) -> None:
199 fh
.write("# Warning: generated by update-providers.py, do not update manually\n")
201 for provider
, reqs
in provider_reqs
.items():
202 provider_name
= provider
.replace(".", "_")
203 fh
.write(f
" {provider_name} = {{\n")
205 " deps = [ " + " ".join(sorted(f
'"{req}"' for req
in reqs
)) + " ];\n"
209 + " ".join(sorted(f
'"{imp}"' for imp
in provider_imports
[provider
]))
217 logging
.basicConfig(level
=logging
.INFO
)
218 version
= get_version()
219 packages
= dump_packages()
220 logging
.info("Generating providers.nix for version %s", version
)
221 provider_reqs
= get_provider_reqs(version
, packages
)
222 provider_imports
= get_provider_imports(version
, provider_reqs
.keys())
223 with
open("providers.nix", "w") as fh
:
224 to_nix_expr(provider_reqs
, provider_imports
, fh
)
227 if __name__
== "__main__":