30 # optional-dependencies
34 buildPythonPackage rec {
39 src = fetchFromGitHub {
42 rev = "refs/tags/v${version}";
43 hash = "sha256-iDWuUJKO7V4QrbVsS4ALgy6fbllOC43o7W4mhjtZ9xc=";
63 optional-dependencies = {
67 pythonImportsCheck = [ "flax" ];
80 "-W ignore::FutureWarning"
81 "-W ignore::DeprecationWarning"
85 # Docs test, needs extra deps + we're not interested in it.
86 "docs/_ext/codediff_test.py"
87 # The tests in `examples` are not designed to be executed from a single test
88 # session and thus either have the modules that conflict with each other or
89 # wrong import paths, depending on how they're invoked. Many tests also have
90 # dependencies that are not packaged in `nixpkgs` (`clu`, `jgraph`,
91 # `tensorflow_datasets`, `vocabulary`) so the benefits of trying to run them
92 # would be limited anyway.
95 # See https://github.com/google/flax/issues/3232.
96 "tests/jax_utils_test.py"
97 # Too old version of tensorflow:
98 # ModuleNotFoundError: No module named 'keras.api._v2'
99 "tests/tensorboard_test.py"
103 # ValueError: Checkpoint path should be absolute
104 "test_overwrite_checkpoints0"
105 # Fixed in more recent versions of jax: https://github.com/google/flax/issues/4211
106 # TODO: Re-enable when jax>0.4.28 will be available in nixpkgs
107 "test_vmap_and_cond_passthrough" # ValueError: vmap has mapped output but out_axes is None
108 "test_vmap_and_cond_passthrough_error" # AssertionError: "at vmap.*'broadcast'.*got axis spec ...
112 description = "Neural network library for JAX";
113 homepage = "https://github.com/google/flax";
114 changelog = "https://github.com/google/flax/releases/tag/v${version}";
115 license = lib.licenses.asl20;
116 maintainers = with lib.maintainers; [ ndl ];