32 # optional-dependencies
39 buildPythonPackage rec {
44 src = fetchFromGitHub {
47 rev = "refs/tags/v${version}";
48 hash = "sha256-+URbQGnmqmSNgucEyWvI5DMnzXjpmJzLA+Pho2lX+S4=";
68 optional-dependencies = {
72 pythonImportsCheck = [ "flax" ];
87 "-W ignore::FutureWarning"
88 "-W ignore::DeprecationWarning"
92 # Docs test, needs extra deps + we're not interested in it.
93 "docs/_ext/codediff_test.py"
94 # The tests in `examples` are not designed to be executed from a single test
95 # session and thus either have the modules that conflict with each other or
96 # wrong import paths, depending on how they're invoked. Many tests also have
97 # dependencies that are not packaged in `nixpkgs` (`clu`, `jgraph`,
98 # `tensorflow_datasets`, `vocabulary`) so the benefits of trying to run them
99 # would be limited anyway.
101 "flax/nnx/examples/*"
102 # See https://github.com/google/flax/issues/3232.
103 "tests/jax_utils_test.py"
104 # Too old version of tensorflow:
105 # ModuleNotFoundError: No module named 'keras.api._v2'
106 "tests/tensorboard_test.py"
110 # ValueError: Checkpoint path should be absolute
111 "test_overwrite_checkpoints0"
112 # Fixed in more recent versions of jax: https://github.com/google/flax/issues/4211
113 # TODO: Re-enable when jax>0.4.28 will be available in nixpkgs
114 "test_vmap_and_cond_passthrough" # ValueError: vmap has mapped output but out_axes is None
115 "test_vmap_and_cond_passthrough_error" # AssertionError: "at vmap.*'broadcast'.*got axis spec ...
119 updateScript = writeScript "update.sh" ''
120 nix-update flax # does not --build by default
121 nix-build . -A flax.src # src is essentially a passthru
122 nix-update flaxlib --version="$(${lib.getExe tomlq} <result/Cargo.toml .something.version)" --commit
127 description = "Neural network library for JAX";
128 homepage = "https://github.com/google/flax";
129 changelog = "https://github.com/google/flax/releases/tag/v${version}";
130 license = lib.licenses.asl20;
131 maintainers = with lib.maintainers; [ ndl ];