Merge pull request #307098 from r-ryantm/auto-update/cilium-cli
[NixPkgs.git] / pkgs / development / python-modules / diffusers / default.nix
blob39464efe47fdbf09d097ba832d5b77632a4a156a
1 { lib
2 , stdenv
3 , buildPythonPackage
4 , fetchFromGitHub
5 , pythonOlder
6 , writeText
7 , setuptools
8 , wheel
9 , filelock
10 , huggingface-hub
11 , importlib-metadata
12 , numpy
13 , pillow
14 , regex
15 , requests
16 , safetensors
17 # optional dependencies
18 , accelerate
19 , datasets
20 , flax
21 , jax
22 , jaxlib
23 , jinja2
24 , peft
25 , protobuf
26 , tensorboard
27 , torch
28 # test dependencies
29 , parameterized
30 , pytest-timeout
31 , pytest-xdist
32 , pytestCheckHook
33 , requests-mock
34 , scipy
35 , sentencepiece
36 , torchsde
37 , transformers
40 buildPythonPackage rec {
41   pname = "diffusers";
42   version = "0.27.2";
43   pyproject = true;
45   disabled = pythonOlder "3.8";
47   src = fetchFromGitHub {
48     owner = "huggingface";
49     repo = "diffusers";
50     rev = "refs/tags/v${version}";
51     hash = "sha256-aRnbU3jN40xaCsoMFyRt1XB+hyIYMJP2b/T1yZho90c=";
52   };
54   nativeBuildInputs = [
55     setuptools
56     wheel
57   ];
59   propagatedBuildInputs = [
60     filelock
61     huggingface-hub
62     importlib-metadata
63     numpy
64     pillow
65     regex
66     requests
67     safetensors
68   ];
70   passthru.optional-dependencies = {
71     flax = [
72       flax
73       jax
74       jaxlib
75     ];
76     torch = [
77       accelerate
78       torch
79     ];
80     training = [
81       accelerate
82       datasets
83       jinja2
84       peft
85       protobuf
86       tensorboard
87     ];
88   };
90   pythonImportsCheck = [
91     "diffusers"
92   ];
94   # tests crash due to torch segmentation fault
95   doCheck = !(stdenv.isLinux && stdenv.isAarch64);
97   nativeCheckInputs = [
98     parameterized
99     pytest-timeout
100     pytest-xdist
101     pytestCheckHook
102     requests-mock
103     scipy
104     sentencepiece
105     torchsde
106     transformers
107   ] ++ passthru.optional-dependencies.torch;
109   preCheck = let
110     # This pytest hook mocks and catches attempts at accessing the network
111     # tests that try to access the network will raise, get caught, be marked as skipped and tagged as xfailed.
112     # cf. python3Packages.shap
113     conftestSkipNetworkErrors = writeText "conftest.py" ''
114       from _pytest.runner import pytest_runtest_makereport as orig_pytest_runtest_makereport
115       import urllib3
117       class NetworkAccessDeniedError(RuntimeError): pass
118       def deny_network_access(*a, **kw):
119         raise NetworkAccessDeniedError
121       urllib3.connection.HTTPSConnection._new_conn = deny_network_access
123       def pytest_runtest_makereport(item, call):
124         tr = orig_pytest_runtest_makereport(item, call)
125         if call.excinfo is not None and call.excinfo.type is NetworkAccessDeniedError:
126             tr.outcome = 'skipped'
127             tr.wasxfail = "reason: Requires network access."
128         return tr
129     '';
130   in ''
131     export HOME=$TMPDIR
132     cat ${conftestSkipNetworkErrors} >> tests/conftest.py
133   '';
135   pytestFlagsArray = [
136     "tests/"
137   ];
139   disabledTests = [
140     # depends on current working directory
141     "test_deprecate_stacklevel"
142     # fails due to precision of floating point numbers
143     "test_model_cpu_offload_forward_pass"
144     # tries to run ruff which we have intentionally removed from nativeCheckInputs
145     "test_is_copy_consistent"
146   ];
148   meta = with lib; {
149     description = "State-of-the-art diffusion models for image and audio generation in PyTorch";
150     mainProgram = "diffusers-cli";
151     homepage = "https://github.com/huggingface/diffusers";
152     changelog = "https://github.com/huggingface/diffusers/releases/tag/${src.rev}";
153     license = licenses.asl20;
154     maintainers = with maintainers; [ natsukium ];
155   };