Merge pull request #307098 from r-ryantm/auto-update/cilium-cli
[NixPkgs.git] / pkgs / development / python-modules / flax / default.nix
blobce41f8e561394e524c0af96548053f4532f3bab5
1 { lib
2 , buildPythonPackage
3 , cloudpickle
4 , einops
5 , fetchFromGitHub
6 , jax
7 , jaxlib
8 , keras
9 , matplotlib
10 , msgpack
11 , numpy
12 , optax
13 , orbax-checkpoint
14 , pytest-xdist
15 , pytestCheckHook
16 , pythonOlder
17 , pythonRelaxDepsHook
18 , pyyaml
19 , rich
20 , setuptools-scm
21 , tensorflow
22 , tensorstore
23 , typing-extensions
26 buildPythonPackage rec {
27   pname = "flax";
28   version = "0.8.2";
29   pyproject = true;
31   disabled = pythonOlder "3.9";
33   src = fetchFromGitHub {
34     owner = "google";
35     repo = "flax";
36     rev = "refs/tags/v${version}";
37     hash = "sha256-UABgJGe1grUSkwOJpjeIoFqhXsqG//HlC1YyYPxXV+g=";
38   };
40   nativeBuildInputs = [
41     jaxlib
42     pythonRelaxDepsHook
43     setuptools-scm
44   ];
46   propagatedBuildInputs = [
47     jax
48     msgpack
49     numpy
50     optax
51     orbax-checkpoint
52     pyyaml
53     rich
54     tensorstore
55     typing-extensions
56   ];
58   passthru.optional-dependencies = {
59     all = [ matplotlib ];
60   };
62   pythonImportsCheck = [
63     "flax"
64   ];
66   nativeCheckInputs = [
67     cloudpickle
68     einops
69     keras
70     pytest-xdist
71     pytestCheckHook
72     tensorflow
73   ];
75   pytestFlagsArray = [
76     "-W ignore::FutureWarning"
77     "-W ignore::DeprecationWarning"
78   ];
80   disabledTestPaths = [
81     # Docs test, needs extra deps + we're not interested in it.
82     "docs/_ext/codediff_test.py"
83     # The tests in `examples` are not designed to be executed from a single test
84     # session and thus either have the modules that conflict with each other or
85     # wrong import paths, depending on how they're invoked. Many tests also have
86     # dependencies that are not packaged in `nixpkgs` (`clu`, `jgraph`,
87     # `tensorflow_datasets`, `vocabulary`) so the benefits of trying to run them
88     # would be limited anyway.
89     "examples/*"
90     "flax/experimental/nnx/examples/*"
91     # See https://github.com/google/flax/issues/3232.
92     "tests/jax_utils_test.py"
93     # Requires tree
94     "tests/tensorboard_test.py"
95   ];
97   disabledTests = [
98     # ValueError: Checkpoint path should be absolute
99     "test_overwrite_checkpoints0"
100   ];
102   meta = with lib; {
103     description = "Neural network library for JAX";
104     homepage = "https://github.com/google/flax";
105     changelog = "https://github.com/google/flax/releases/tag/v${version}";
106     license = licenses.asl20;
107     maintainers = with maintainers; [ ndl ];
108   };