stripe-cli: 1.23.3 -> 1.23.5 (#375724)
[NixPkgs.git] / pkgs / development / python-modules / optax / default.nix
blobe8b265fe8512b444804a3e93061abbd40c8bfb99
2   lib,
3   buildPythonPackage,
4   fetchFromGitHub,
6   # build-system
7   flit-core,
9   # dependencies
10   absl-py,
11   chex,
12   jax,
13   jaxlib,
14   numpy,
15   etils,
17   # tests
18   callPackage,
21 buildPythonPackage rec {
22   pname = "optax";
23   version = "0.2.4";
24   pyproject = true;
26   src = fetchFromGitHub {
27     owner = "deepmind";
28     repo = "optax";
29     tag = "v${version}";
30     hash = "sha256-7UPWeo/Q9/tjewaM7HN8/e7U1U1QzAliuk95+9GOi0E=";
31   };
33   outputs = [
34     "out"
35     "testsout"
36   ];
38   build-system = [ flit-core ];
40   dependencies = [
41     absl-py
42     chex
43     etils
44     jax
45     jaxlib
46     numpy
47   ] ++ etils.optional-dependencies.epy;
49   postInstall = ''
50     mkdir $testsout
51     cp -R examples $testsout/examples
52   '';
54   pythonImportsCheck = [ "optax" ];
56   # check in passthru.tests.pytest to escape infinite recursion with flax
57   doCheck = false;
59   passthru.tests = {
60     pytest = callPackage ./tests.nix { };
61   };
63   meta = {
64     description = "Gradient processing and optimization library for JAX";
65     homepage = "https://github.com/deepmind/optax";
66     changelog = "https://github.com/deepmind/optax/releases/tag/v${version}";
67     license = lib.licenses.asl20;
68     maintainers = with lib.maintainers; [ ndl ];
69   };