portfolio: 0.71.2 -> 0.72.2 (#360387)
[NixPkgs.git] / pkgs / development / python-modules / flax / default.nix
blob9ee281ed5dbaba32bb8436e3c12d0094c21f4821
2   lib,
3   buildPythonPackage,
4   fetchFromGitHub,
6   # build-system
7   jaxlib,
8   setuptools-scm,
10   # dependencies
11   jax,
12   msgpack,
13   numpy,
14   optax,
15   orbax-checkpoint,
16   pyyaml,
17   rich,
18   tensorstore,
19   typing-extensions,
21   # checks
22   cloudpickle,
23   einops,
24   flaxlib,
25   keras,
26   pytestCheckHook,
27   pytest-xdist,
28   sphinx,
29   tensorflow,
30   treescope,
32   # optional-dependencies
33   matplotlib,
35   writeScript,
36   tomlq,
39 buildPythonPackage rec {
40   pname = "flax";
41   version = "0.10.1";
42   pyproject = true;
44   src = fetchFromGitHub {
45     owner = "google";
46     repo = "flax";
47     rev = "refs/tags/v${version}";
48     hash = "sha256-+URbQGnmqmSNgucEyWvI5DMnzXjpmJzLA+Pho2lX+S4=";
49   };
51   build-system = [
52     jaxlib
53     setuptools-scm
54   ];
56   dependencies = [
57     jax
58     msgpack
59     numpy
60     optax
61     orbax-checkpoint
62     pyyaml
63     rich
64     tensorstore
65     typing-extensions
66   ];
68   optional-dependencies = {
69     all = [ matplotlib ];
70   };
72   pythonImportsCheck = [ "flax" ];
74   nativeCheckInputs = [
75     cloudpickle
76     einops
77     flaxlib
78     keras
79     pytestCheckHook
80     pytest-xdist
81     sphinx
82     tensorflow
83     treescope
84   ];
86   pytestFlagsArray = [
87     "-W ignore::FutureWarning"
88     "-W ignore::DeprecationWarning"
89   ];
91   disabledTestPaths = [
92     # Docs test, needs extra deps + we're not interested in it.
93     "docs/_ext/codediff_test.py"
94     # The tests in `examples` are not designed to be executed from a single test
95     # session and thus either have the modules that conflict with each other or
96     # wrong import paths, depending on how they're invoked. Many tests also have
97     # dependencies that are not packaged in `nixpkgs` (`clu`, `jgraph`,
98     # `tensorflow_datasets`, `vocabulary`) so the benefits of trying to run them
99     # would be limited anyway.
100     "examples/*"
101     "flax/nnx/examples/*"
102     # See https://github.com/google/flax/issues/3232.
103     "tests/jax_utils_test.py"
104     # Too old version of tensorflow:
105     # ModuleNotFoundError: No module named 'keras.api._v2'
106     "tests/tensorboard_test.py"
107   ];
109   disabledTests = [
110     # ValueError: Checkpoint path should be absolute
111     "test_overwrite_checkpoints0"
112     # Fixed in more recent versions of jax: https://github.com/google/flax/issues/4211
113     # TODO: Re-enable when jax>0.4.28 will be available in nixpkgs
114     "test_vmap_and_cond_passthrough" # ValueError: vmap has mapped output but out_axes is None
115     "test_vmap_and_cond_passthrough_error" # AssertionError: "at vmap.*'broadcast'.*got axis spec ...
116   ];
118   passthru = {
119     updateScript = writeScript "update.sh" ''
120       nix-update flax # does not --build by default
121       nix-build . -A flax.src # src is essentially a passthru
122       nix-update flaxlib --version="$(${lib.getExe tomlq} <result/Cargo.toml .something.version)" --commit
123     '';
124   };
126   meta = {
127     description = "Neural network library for JAX";
128     homepage = "https://github.com/google/flax";
129     changelog = "https://github.com/google/flax/releases/tag/v${version}";
130     license = lib.licenses.asl20;
131     maintainers = with lib.maintainers; [ ndl ];
132   };