nixos/doh-server: init
[NixPkgs.git] / pkgs / development / python-modules / dm-haiku / default.nix
blob96fbd446ecb4f6cf2be0293935f0a30417c80d79
2   lib,
3   buildPythonPackage,
4   fetchFromGitHub,
5   fetchpatch,
6   setuptools,
7   absl-py,
8   flax,
9   jax,
10   jaxlib,
11   jmp,
12   numpy,
13   tabulate,
14   pytest-xdist,
15   pytestCheckHook,
16   bsuite,
17   chex,
18   cloudpickle,
19   dill,
20   dm-env,
21   dm-tree,
22   optax,
23   rlax,
24   tensorflow,
27 let
28   dm-haiku = buildPythonPackage rec {
29     pname = "dm-haiku";
30     version = "0.0.13";
31     pyproject = true;
33     src = fetchFromGitHub {
34       owner = "deepmind";
35       repo = "dm-haiku";
36       tag = "v${version}";
37       hash = "sha256-RJpQ9BzlbQ4X31XoJFnsZASiaC9fP2AdyuTAGINhMxs=";
38     };
40     patches = [
41       # https://github.com/deepmind/dm-haiku/pull/672
42       (fetchpatch {
43         name = "fix-find-namespace-packages.patch";
44         url = "https://github.com/deepmind/dm-haiku/commit/728031721f77d9aaa260bba0eddd9200d107ba5d.patch";
45         hash = "sha256-qV94TdJnphlnpbq+B0G3KTx5CFGPno+8FvHyu/aZeQE=";
46       })
47     ];
49     build-system = [ setuptools ];
51     dependencies = [
52       absl-py
53       jaxlib # implicit runtime dependency
54       jmp
55       numpy
56       tabulate
57     ];
59     optional-dependencies = {
60       jax = [
61         jax
62         jaxlib
63       ];
64       flax = [ flax ];
65     };
67     pythonImportsCheck = [ "haiku" ];
69     nativeCheckInputs = [
70       bsuite
71       chex
72       cloudpickle
73       dill
74       dm-env
75       dm-haiku
76       dm-tree
77       jaxlib
78       optax
79       pytest-xdist
80       pytestCheckHook
81       rlax
82       tensorflow
83     ];
85     disabledTests = [
86       # See https://github.com/deepmind/dm-haiku/issues/366.
87       "test_jit_Recurrent"
89       # Assertion errors
90       "testShapeChecking0"
91       "testShapeChecking1"
93       # This test requires a more recent version of tensorflow. The current one (2.13) is not enough.
94       "test_reshape_convert"
96       # This test requires JAX support for double precision (64bit), but enabling this causes several
97       # other tests to fail.
98       # https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision
99       "test_doctest_haiku.experimental"
100     ];
102     disabledTestPaths = [
103       # Those tests requires a more recent version of tensorflow. The current one (2.13) is not enough.
104       "haiku/_src/integration/jax2tf_test.py"
105     ];
107     doCheck = false;
109     # check in passthru.tests.pytest to escape infinite recursion with bsuite
110     passthru.tests.pytest = dm-haiku.overridePythonAttrs (_: {
111       pname = "${pname}-tests";
112       doCheck = true;
114       # We don't have to install because the only purpose
115       # of this passthru test is to, well, test.
116       # This fixes having to set `catchConflicts` to false.
117       dontInstall = true;
118     });
120     meta = with lib; {
121       description = "Haiku is a simple neural network library for JAX developed by some of the authors of Sonnet";
122       homepage = "https://github.com/deepmind/dm-haiku";
123       license = licenses.asl20;
124       maintainers = with maintainers; [ ndl ];
125     };
126   };
128 dm-haiku