ansible-later: 2.0.22 -> 2.0.23
[NixPkgs.git] / pkgs / development / python-modules / torchvision / default.nix
blobc44413ba4948d1d513dc9e8505e670680eea68a1
1 { lib
2 , symlinkJoin
3 , buildPythonPackage
4 , fetchFromGitHub
5 , ninja
6 , which
7 , libjpeg_turbo
8 , libpng
9 , numpy
10 , scipy
11 , pillow
12 , torch
13 , pytest
14 , cudaSupport ? torch.cudaSupport or false # by default uses the value from torch
17 let
18   inherit (torch.cudaPackages) cudatoolkit cudnn;
20   cudatoolkit_joined = symlinkJoin {
21     name = "${cudatoolkit.name}-unsplit";
22     paths = [ cudatoolkit.out cudatoolkit.lib ];
23   };
24   cudaArchStr = lib.optionalString cudaSupport lib.strings.concatStringsSep ";" torch.cudaArchList;
25 in buildPythonPackage rec {
26   pname = "torchvision";
27   version = "0.13.1";
29   src = fetchFromGitHub {
30     owner = "pytorch";
31     repo = "vision";
32     rev = "refs/tags/v${version}";
33     hash = "sha256-QlUAFAG6zEDCDSXR5n2CznspU3fT0kbqySzofGLPgK4=";
34   };
36   nativeBuildInputs = [ libpng ninja which ]
37     ++ lib.optionals cudaSupport [ cudatoolkit_joined ];
39   TORCHVISION_INCLUDE = "${libjpeg_turbo.dev}/include/";
40   TORCHVISION_LIBRARY = "${libjpeg_turbo}/lib/";
42   buildInputs = [ libjpeg_turbo libpng ]
43     ++ lib.optionals cudaSupport [ cudnn ];
45   propagatedBuildInputs = [ numpy pillow torch scipy ];
47   preBuild = lib.optionalString cudaSupport ''
48     export TORCH_CUDA_ARCH_LIST="${cudaArchStr}"
49     export FORCE_CUDA=1
50   '';
52   # tries to download many datasets for tests
53   doCheck = false;
55   checkPhase = ''
56     HOME=$TMPDIR py.test test --ignore=test/test_datasets_download.py
57   '';
59   checkInputs = [ pytest ];
61   meta = with lib; {
62     description = "PyTorch vision library";
63     homepage = "https://pytorch.org/";
64     license = licenses.bsd3;
65     platforms = with platforms; linux ++ lib.optionals (!cudaSupport) darwin;
66     maintainers = with maintainers; [ ericsagnes ];
67   };