biome: 1.9.2 -> 1.9.3 (#349335)
[NixPkgs.git] / pkgs / development / python-modules / flax / default.nix
blob4f93bd4f8ea50bfebf03f26daa8eef074cb826a7
2   lib,
3   buildPythonPackage,
4   fetchFromGitHub,
6   # build-system
7   jaxlib,
8   setuptools-scm,
10   # dependencies
11   jax,
12   msgpack,
13   numpy,
14   optax,
15   orbax-checkpoint,
16   pyyaml,
17   rich,
18   tensorstore,
19   typing-extensions,
21   # checks
22   cloudpickle,
23   einops,
24   keras,
25   pytest-xdist,
26   pytestCheckHook,
27   tensorflow,
28   treescope,
30   # optional-dependencies
31   matplotlib,
34 buildPythonPackage rec {
35   pname = "flax";
36   version = "0.9.0";
37   pyproject = true;
39   src = fetchFromGitHub {
40     owner = "google";
41     repo = "flax";
42     rev = "refs/tags/v${version}";
43     hash = "sha256-iDWuUJKO7V4QrbVsS4ALgy6fbllOC43o7W4mhjtZ9xc=";
44   };
46   build-system = [
47     jaxlib
48     setuptools-scm
49   ];
51   dependencies = [
52     jax
53     msgpack
54     numpy
55     optax
56     orbax-checkpoint
57     pyyaml
58     rich
59     tensorstore
60     typing-extensions
61   ];
63   optional-dependencies = {
64     all = [ matplotlib ];
65   };
67   pythonImportsCheck = [ "flax" ];
69   nativeCheckInputs = [
70     cloudpickle
71     einops
72     keras
73     pytest-xdist
74     pytestCheckHook
75     tensorflow
76     treescope
77   ];
79   pytestFlagsArray = [
80     "-W ignore::FutureWarning"
81     "-W ignore::DeprecationWarning"
82   ];
84   disabledTestPaths = [
85     # Docs test, needs extra deps + we're not interested in it.
86     "docs/_ext/codediff_test.py"
87     # The tests in `examples` are not designed to be executed from a single test
88     # session and thus either have the modules that conflict with each other or
89     # wrong import paths, depending on how they're invoked. Many tests also have
90     # dependencies that are not packaged in `nixpkgs` (`clu`, `jgraph`,
91     # `tensorflow_datasets`, `vocabulary`) so the benefits of trying to run them
92     # would be limited anyway.
93     "examples/*"
94     "flax/nnx/examples/*"
95     # See https://github.com/google/flax/issues/3232.
96     "tests/jax_utils_test.py"
97     # Too old version of tensorflow:
98     # ModuleNotFoundError: No module named 'keras.api._v2'
99     "tests/tensorboard_test.py"
100   ];
102   disabledTests = [
103     # ValueError: Checkpoint path should be absolute
104     "test_overwrite_checkpoints0"
105     # Fixed in more recent versions of jax: https://github.com/google/flax/issues/4211
106     # TODO: Re-enable when jax>0.4.28 will be available in nixpkgs
107     "test_vmap_and_cond_passthrough" # ValueError: vmap has mapped output but out_axes is None
108     "test_vmap_and_cond_passthrough_error" # AssertionError: "at vmap.*'broadcast'.*got axis spec ...
109   ];
111   meta = {
112     description = "Neural network library for JAX";
113     homepage = "https://github.com/google/flax";
114     changelog = "https://github.com/google/flax/releases/tag/v${version}";
115     license = lib.licenses.asl20;
116     maintainers = with lib.maintainers; [ ndl ];
117   };