pytrainer: unpin python 3.10
[NixPkgs.git] / pkgs / development / python-modules / jmp / default.nix
blobb0ec259f0723c502d1d81c370dc6c1a407e605ca
2   buildPythonPackage,
3   fetchFromGitHub,
4   jax,
5   jaxlib,
6   lib,
7   pytestCheckHook,
8 }:
10 buildPythonPackage rec {
11   pname = "jmp";
12   version = "0.0.4";
13   format = "setuptools";
15   src = fetchFromGitHub {
16     owner = "deepmind";
17     repo = pname;
18     rev = "refs/tags/v${version}";
19     hash = "sha256-+PefZU1209vvf1SfF8DXiTvKYEnZ4y8iiIr8yKikx9Y=";
20   };
22   # Wheel requires only `numpy`, but the import needs `jax`.
23   propagatedBuildInputs = [ jax ];
25   pythonImportsCheck = [ "jmp" ];
27   nativeCheckInputs = [
28     jaxlib
29     pytestCheckHook
30   ];
32   meta = with lib; {
33     description = "This library implements support for mixed precision training in JAX";
34     homepage = "https://github.com/deepmind/jmp";
35     license = licenses.asl20;
36     maintainers = with maintainers; [ ndl ];
37   };