otadump: init at 0.1.2 (#329129)
[NixPkgs.git] / pkgs / development / python-modules / diffusers / default.nix
blob7b4c38096f23bd8bba76a4de4aefcb03dfcb836c
2   lib,
3   buildPythonPackage,
4   pythonOlder,
5   fetchFromGitHub,
6   writeText,
7   setuptools,
8   filelock,
9   huggingface-hub,
10   importlib-metadata,
11   numpy,
12   pillow,
13   regex,
14   requests,
15   safetensors,
16   # optional dependencies
17   accelerate,
18   datasets,
19   flax,
20   jax,
21   jaxlib,
22   jinja2,
23   peft,
24   protobuf,
25   tensorboard,
26   torch,
27   # test dependencies
28   parameterized,
29   pytest-timeout,
30   pytest-xdist,
31   pytestCheckHook,
32   requests-mock,
33   scipy,
34   sentencepiece,
35   torchsde,
36   transformers,
37   pythonAtLeast,
38   diffusers,
41 buildPythonPackage rec {
42   pname = "diffusers";
43   version = "0.30.3";
44   pyproject = true;
46   disabled = pythonOlder "3.8";
48   src = fetchFromGitHub {
49     owner = "huggingface";
50     repo = "diffusers";
51     rev = "refs/tags/v${version}";
52     hash = "sha256-/3lHJdsNblKb6xX03OluSCApMK3EXJbRLboBk8CjobE=";
53   };
55   build-system = [ setuptools ];
57   dependencies = [
58     filelock
59     huggingface-hub
60     importlib-metadata
61     numpy
62     pillow
63     regex
64     requests
65     safetensors
66   ];
68   optional-dependencies = {
69     flax = [
70       flax
71       jax
72       jaxlib
73     ];
74     torch = [
75       accelerate
76       torch
77     ];
78     training = [
79       accelerate
80       datasets
81       jinja2
82       peft
83       protobuf
84       tensorboard
85     ];
86   };
88   pythonImportsCheck = [ "diffusers" ];
90   # it takes a few hours
91   doCheck = false;
93   passthru.tests.pytest = diffusers.overridePythonAttrs { doCheck = true; };
95   nativeCheckInputs = [
96     parameterized
97     pytest-timeout
98     pytest-xdist
99     pytestCheckHook
100     requests-mock
101     scipy
102     sentencepiece
103     torchsde
104     transformers
105   ] ++ optional-dependencies.torch;
107   preCheck =
108     let
109       # This pytest hook mocks and catches attempts at accessing the network
110       # tests that try to access the network will raise, get caught, be marked as skipped and tagged as xfailed.
111       # cf. python3Packages.shap
112       conftestSkipNetworkErrors = writeText "conftest.py" ''
113         from _pytest.runner import pytest_runtest_makereport as orig_pytest_runtest_makereport
114         import urllib3
116         class NetworkAccessDeniedError(RuntimeError): pass
117         def deny_network_access(*a, **kw):
118           raise NetworkAccessDeniedError
120         urllib3.connection.HTTPSConnection._new_conn = deny_network_access
122         def pytest_runtest_makereport(item, call):
123           tr = orig_pytest_runtest_makereport(item, call)
124           if call.excinfo is not None and call.excinfo.type is NetworkAccessDeniedError:
125               tr.outcome = 'skipped'
126               tr.wasxfail = "reason: Requires network access."
127           return tr
128       '';
129     in
130     ''
131       export HOME=$TMPDIR
132       cat ${conftestSkipNetworkErrors} >> tests/conftest.py
133     '';
135   pytestFlagsArray = [ "tests/" ];
137   disabledTests =
138     [
139       # depends on current working directory
140       "test_deprecate_stacklevel"
141       # fails due to precision of floating point numbers
142       "test_model_cpu_offload_forward_pass"
143       # tries to run ruff which we have intentionally removed from nativeCheckInputs
144       "test_is_copy_consistent"
145     ]
146     ++ lib.optionals (pythonAtLeast "3.12") [
148       # RuntimeError: Dynamo is not supported on Python 3.12+
149       "test_from_save_pretrained_dynamo"
150     ];
152   meta = with lib; {
153     description = "State-of-the-art diffusion models for image and audio generation in PyTorch";
154     mainProgram = "diffusers-cli";
155     homepage = "https://github.com/huggingface/diffusers";
156     changelog = "https://github.com/huggingface/diffusers/releases/tag/${lib.removePrefix "refs/tags/" src.rev}";
157     license = licenses.asl20;
158     maintainers = with maintainers; [ natsukium ];
159   };