biome: 1.9.2 -> 1.9.3 (#349335)
[NixPkgs.git] / pkgs / development / python-modules / xformers / default.nix
blob69c7583c9eac6fdf84be3ff3cf9fa2a96afdc42f
2   lib,
3   stdenv,
4   buildPythonPackage,
5   pythonOlder,
6   fetchFromGitHub,
7   which,
8   # runtime dependencies
9   numpy,
10   torch,
11   # check dependencies
12   pytestCheckHook,
13   pytest-cov,
14   # , pytest-mpi
15   pytest-timeout,
16   # , pytorch-image-models
17   hydra-core,
18   fairscale,
19   scipy,
20   cmake,
21   ninja,
22   triton,
23   networkx,
24   #, apex
25   einops,
26   transformers,
27   timm,
28 #, flash-attn
30 let
31   inherit (torch) cudaCapabilities cudaPackages cudaSupport;
32   version = "0.0.23.post1";
34 buildPythonPackage {
35   pname = "xformers";
36   inherit version;
37   format = "setuptools";
39   disabled = pythonOlder "3.7";
41   src = fetchFromGitHub {
42     owner = "facebookresearch";
43     repo = "xformers";
44     rev = "refs/tags/v${version}";
45     hash = "sha256-AJXow8MmX4GxtEE2jJJ/ZIBr+3i+uS4cA6vofb390rY=";
46     fetchSubmodules = true;
47   };
49   patches = [ ./0001-fix-allow-building-without-git.patch ];
51   preBuild = ''
52     cat << EOF > ./xformers/version.py
53     # noqa: C801
54     __version__ = "${version}"
55     EOF
57     export MAX_JOBS=$NIX_BUILD_CORES
58   '';
60   env = lib.attrsets.optionalAttrs cudaSupport {
61     TORCH_CUDA_ARCH_LIST = "${lib.concatStringsSep ";" torch.cudaCapabilities}";
62   };
64   stdenv = if cudaSupport then cudaPackages.backendStdenv else stdenv;
66   buildInputs = lib.optionals cudaSupport (
67     with cudaPackages;
68     [
69       # flash-attn build
70       cuda_cudart # cuda_runtime_api.h
71       libcusparse # cusparse.h
72       cuda_cccl # nv/target
73       libcublas # cublas_v2.h
74       libcusolver # cusolverDn.h
75       libcurand # curand_kernel.h
76     ]
77   );
79   nativeBuildInputs = [
80     ninja
81     which
82   ] ++ lib.optionals cudaSupport (with cudaPackages; [ cuda_nvcc ]);
84   propagatedBuildInputs = [
85     numpy
86     torch
87   ];
89   pythonImportsCheck = [ "xformers" ];
91   # Has broken 0.03 version:
92   # https://github.com/NixOS/nixpkgs/pull/285495#issuecomment-1920730720
93   passthru.skipBulkUpdate = true;
95   dontUseCmakeConfigure = true;
97   # see commented out missing packages
98   doCheck = false;
100   nativeCheckInputs = [
101     pytestCheckHook
102     pytest-cov
103     pytest-timeout
104     hydra-core
105     fairscale
106     scipy
107     cmake
108     networkx
109     triton
110     # apex
111     einops
112     transformers
113     timm
114     # flash-attn
115   ];
117   meta = with lib; {
118     description = "XFormers: A collection of composable Transformer building blocks";
119     homepage = "https://github.com/facebookresearch/xformers";
120     changelog = "https://github.com/facebookresearch/xformers/blob/${version}/CHANGELOG.md";
121     license = licenses.bsd3;
122     maintainers = with maintainers; [ happysalada ];
123   };