python312Packages.fnllm: 0.0.11 -> 0.0.12 (#364582)
[NixPkgs.git] / pkgs / development / python-modules / distrax / default.nix
blobbae7ef5ffb93b1dbd478f25252280377560fc21c
2   lib,
3   buildPythonPackage,
4   pythonOlder,
5   fetchFromGitHub,
6   chex,
7   jaxlib,
8   numpy,
9   tensorflow-probability,
10   dm-haiku,
11   pytest-xdist,
12   pytestCheckHook,
15 buildPythonPackage rec {
16   pname = "distrax";
17   version = "0.1.5";
18   pyproject = true;
20   disabled = pythonOlder "3.9";
22   src = fetchFromGitHub {
23     owner = "google-deepmind";
24     repo = "distrax";
25     rev = "refs/tags/v${version}";
26     hash = "sha256-A1aCL/I89Blg9sNmIWQru4QJteUTN6+bhgrEJPmCrM0=";
27   };
29   buildInputs = [
30     chex
31     jaxlib
32     numpy
33     tensorflow-probability
34   ];
36   nativeCheckInputs = [
37     dm-haiku
38     pytest-xdist
39     pytestCheckHook
40   ];
42   pythonImportsCheck = [ "distrax" ];
44   disabledTests = [
45     # AssertionError on numerical values
46     # Reported upstream in https://github.com/google-deepmind/distrax/issues/267
47     "test_method_with_input_unnormalized_probs__with_device"
48     "test_method_with_input_unnormalized_probs__with_jit"
49     "test_method_with_input_unnormalized_probs__without_device"
50     "test_method_with_input_unnormalized_probs__without_jit"
51     "test_method_with_value_1d"
52     "test_nested_distributions__with_device"
53     "test_nested_distributions__without_device"
54     "test_nested_distributions__with_jit"
55     "test_nested_distributions__without_jit"
56     "test_stability__with_device"
57     "test_stability__with_jit"
58     "test_stability__without_device"
59     "test_stability__without_jit"
60     "test_von_mises_sample_gradient"
61     "test_von_mises_sample_moments"
62   ];
64   disabledTestPaths = [
65     # TypeErrors
66     "distrax/_src/bijectors/tfp_compatible_bijector_test.py"
67     "distrax/_src/distributions/distribution_from_tfp_test.py"
68     "distrax/_src/distributions/laplace_test.py"
69     "distrax/_src/distributions/multinomial_test.py"
70     "distrax/_src/distributions/mvn_diag_plus_low_rank_test.py"
71     "distrax/_src/distributions/mvn_kl_test.py"
72     "distrax/_src/distributions/straight_through_test.py"
73     "distrax/_src/distributions/tfp_compatible_distribution_test.py"
74     "distrax/_src/distributions/transformed_test.py"
75     "distrax/_src/distributions/uniform_test.py"
76     "distrax/_src/utils/transformations_test.py"
77     # https://github.com/google-deepmind/distrax/pull/270
78     "distrax/_src/distributions/deterministic_test.py"
79     "distrax/_src/distributions/epsilon_greedy_test.py"
80     "distrax/_src/distributions/gamma_test.py"
81     "distrax/_src/distributions/greedy_test.py"
82     "distrax/_src/distributions/gumbel_test.py"
83     "distrax/_src/distributions/logistic_test.py"
84     "distrax/_src/distributions/log_stddev_normal_test.py"
85     "distrax/_src/distributions/mvn_diag_test.py"
86     "distrax/_src/distributions/mvn_full_covariance_test.py"
87     "distrax/_src/distributions/mvn_tri_test.py"
88     "distrax/_src/distributions/one_hot_categorical_test.py"
89     "distrax/_src/distributions/softmax_test.py"
90     "distrax/_src/utils/hmm_test.py"
91   ];
93   meta = with lib; {
94     description = "Probability distributions in JAX";
95     homepage = "https://github.com/deepmind/distrax";
96     license = licenses.asl20;
97     maintainers = with maintainers; [ onny ];
98   };