26 buildPythonPackage rec {
31 disabled = pythonOlder "3.9";
33 src = fetchFromGitHub {
36 rev = "refs/tags/v${version}";
37 hash = "sha256-UABgJGe1grUSkwOJpjeIoFqhXsqG//HlC1YyYPxXV+g=";
46 propagatedBuildInputs = [
58 passthru.optional-dependencies = {
62 pythonImportsCheck = [
76 "-W ignore::FutureWarning"
77 "-W ignore::DeprecationWarning"
81 # Docs test, needs extra deps + we're not interested in it.
82 "docs/_ext/codediff_test.py"
83 # The tests in `examples` are not designed to be executed from a single test
84 # session and thus either have the modules that conflict with each other or
85 # wrong import paths, depending on how they're invoked. Many tests also have
86 # dependencies that are not packaged in `nixpkgs` (`clu`, `jgraph`,
87 # `tensorflow_datasets`, `vocabulary`) so the benefits of trying to run them
88 # would be limited anyway.
90 "flax/experimental/nnx/examples/*"
91 # See https://github.com/google/flax/issues/3232.
92 "tests/jax_utils_test.py"
94 "tests/tensorboard_test.py"
98 # ValueError: Checkpoint path should be absolute
99 "test_overwrite_checkpoints0"
103 description = "Neural network library for JAX";
104 homepage = "https://github.com/google/flax";
105 changelog = "https://github.com/google/flax/releases/tag/v${version}";
106 license = licenses.asl20;
107 maintainers = with maintainers; [ ndl ];