Merge pull request #307098 from r-ryantm/auto-update/cilium-cli
[NixPkgs.git] / pkgs / development / python-modules / accelerate / default.nix
blobeecbd37a906345606416b19cd992c6484720b71c
1 { stdenv
2 , lib
3 , buildPythonPackage
4 , fetchFromGitHub
5 , pythonOlder
6 , pytest7CheckHook
7 , setuptools
8 , numpy
9 , packaging
10 , psutil
11 , pyyaml
12 , safetensors
13 , torch
14 , config
15 , cudatoolkit
16 , evaluate
17 , parameterized
18 , transformers
21 buildPythonPackage rec {
22   pname = "accelerate";
23   version = "0.27.0";
24   pyproject = true;
26   disabled = pythonOlder "3.7";
28   src = fetchFromGitHub {
29     owner = "huggingface";
30     repo = pname;
31     rev = "refs/tags/v${version}";
32     hash = "sha256-7rnI8UXyAql8fLMKoSRrWzVw5CnyYVE2o6dJOzSgWxw=";
33   };
35   nativeBuildInputs = [ setuptools ];
37   propagatedBuildInputs = [
38     numpy
39     packaging
40     psutil
41     pyyaml
42     safetensors
43     torch
44   ];
46   nativeCheckInputs = [
47     evaluate
48     parameterized
49     pytest7CheckHook
50     transformers
51   ];
52   preCheck = ''
53     export HOME=$(mktemp -d)
54     export PATH=$out/bin:$PATH
55   '' + lib.optionalString config.cudaSupport ''
56     export TRITON_PTXAS_PATH="${cudatoolkit}/bin/ptxas"
57   '';
58   pytestFlagsArray = [ "tests" ];
59   disabledTests = [
60     # try to download data:
61     "FeatureExamplesTests"
62     "test_infer_auto_device_map_on_t0pp"
64     # require socket communication
65     "test_explicit_dtypes"
66     "test_gated"
67     "test_invalid_model_name"
68     "test_invalid_model_name_transformers"
69     "test_no_metadata"
70     "test_no_split_modules"
71     "test_remote_code"
72     "test_transformers_model"
74     # set the environment variable, CC, which conflicts with standard environment
75     "test_patch_environment_key_exists"
76   ] ++ lib.optionals (stdenv.isLinux && stdenv.isAarch64) [
77     # usual aarch64-linux RuntimeError: DataLoader worker (pid(s) <...>) exited unexpectedly
78     "CheckpointTest"
79   ] ++ lib.optionals (!config.cudaSupport) [
80     # requires ptxas from cudatoolkit, which is unfree
81     "test_dynamo_extract_model"
82   ] ++ lib.optionals (stdenv.isDarwin && stdenv.isx86_64) [
83     # RuntimeError: torch_shm_manager: execl failed: Permission denied
84     "CheckpointTest"
85   ];
87   disabledTestPaths = lib.optionals (!(stdenv.isLinux && stdenv.isx86_64)) [
88     # numerous instances of torch.multiprocessing.spawn.ProcessRaisedException:
89     "tests/test_cpu.py"
90     "tests/test_grad_sync.py"
91     "tests/test_metrics.py"
92     "tests/test_scheduler.py"
93   ];
95   pythonImportsCheck = [
96     "accelerate"
97   ];
99   meta = with lib; {
100     homepage = "https://huggingface.co/docs/accelerate";
101     description = "A simple way to train and use PyTorch models with multi-GPU, TPU, mixed-precision";
102     changelog = "https://github.com/huggingface/accelerate/releases/tag/v${version}";
103     license = licenses.asl20;
104     maintainers = with maintainers; [ bcdarwin ];
105     mainProgram = "accelerate";
106   };