mir,mir_2_15: Fix builds, modernise & fix VM tests (#374873)
[NixPkgs.git] / pkgs / development / python-modules / jmp / default.nix
blob8ca03df1fe0f19a938c225dcb3f2c255decacb2a
2   buildPythonPackage,
3   fetchFromGitHub,
4   jax,
5   jaxlib,
6   lib,
7   pytestCheckHook,
8 }:
10 buildPythonPackage rec {
11   pname = "jmp";
12   version = "0.0.4";
13   format = "setuptools";
15   src = fetchFromGitHub {
16     owner = "deepmind";
17     repo = pname;
18     tag = "v${version}";
19     hash = "sha256-+PefZU1209vvf1SfF8DXiTvKYEnZ4y8iiIr8yKikx9Y=";
20   };
22   # Wheel requires only `numpy`, but the import needs `jax`.
23   propagatedBuildInputs = [ jax ];
25   pythonImportsCheck = [ "jmp" ];
27   nativeCheckInputs = [
28     jaxlib
29     pytestCheckHook
30   ];
32   meta = with lib; {
33     description = "This library implements support for mixed precision training in JAX";
34     homepage = "https://github.com/deepmind/jmp";
35     license = licenses.asl20;
36     maintainers = with maintainers; [ ndl ];
37   };