otadump: init at 0.1.2 (#329129)
[NixPkgs.git] / pkgs / development / python-modules / xformers / default.nix
blob8790b380b769930cfb1b387759fafd48d0bf1370
2   lib,
3   stdenv,
4   buildPythonPackage,
5   pythonOlder,
6   fetchFromGitHub,
7   which,
8   setuptools,
9   # runtime dependencies
10   numpy,
11   torch,
12   # check dependencies
13   pytestCheckHook,
14   pytest-cov-stub,
15   # , pytest-mpi
16   pytest-timeout,
17   # , pytorch-image-models
18   hydra-core,
19   fairscale,
20   scipy,
21   cmake,
22   ninja,
23   triton,
24   networkx,
25   #, apex
26   einops,
27   transformers,
28   timm,
29 #, flash-attn
31 let
32   inherit (torch) cudaCapabilities cudaPackages cudaSupport;
33   version = "0.0.28.post3";
35 buildPythonPackage {
36   pname = "xformers";
37   inherit version;
38   pyproject = true;
40   disabled = pythonOlder "3.9";
42   src = fetchFromGitHub {
43     owner = "facebookresearch";
44     repo = "xformers";
45     rev = "refs/tags/v${version}";
46     hash = "sha256-23tnhCHK+Z0No8fqZxkgDFp2VIgXZR4jpM+pkb/vvmw=";
47     fetchSubmodules = true;
48   };
50   patches = [ ./0001-fix-allow-building-without-git.patch ];
52   build-system = [ setuptools ];
54   preBuild = ''
55     cat << EOF > ./xformers/version.py
56     # noqa: C801
57     __version__ = "${version}"
58     EOF
60     export MAX_JOBS=$NIX_BUILD_CORES
61   '';
63   env = lib.attrsets.optionalAttrs cudaSupport {
64     TORCH_CUDA_ARCH_LIST = "${lib.concatStringsSep ";" torch.cudaCapabilities}";
65   };
67   stdenv = if cudaSupport then cudaPackages.backendStdenv else stdenv;
69   buildInputs = lib.optionals cudaSupport (
70     with cudaPackages;
71     [
72       # flash-attn build
73       cuda_cudart # cuda_runtime_api.h
74       libcusparse # cusparse.h
75       cuda_cccl # nv/target
76       libcublas # cublas_v2.h
77       libcusolver # cusolverDn.h
78       libcurand # curand_kernel.h
79     ]
80   );
82   nativeBuildInputs = [
83     ninja
84     which
85   ] ++ lib.optionals cudaSupport (with cudaPackages; [ cuda_nvcc ]);
87   dependencies = [
88     numpy
89     torch
90   ];
92   pythonImportsCheck = [ "xformers" ];
94   # Has broken 0.03 version:
95   # https://github.com/NixOS/nixpkgs/pull/285495#issuecomment-1920730720
96   passthru.skipBulkUpdate = true;
98   dontUseCmakeConfigure = true;
100   # see commented out missing packages
101   doCheck = false;
103   nativeCheckInputs = [
104     pytestCheckHook
105     pytest-cov-stub
106     pytest-timeout
107     hydra-core
108     fairscale
109     scipy
110     cmake
111     networkx
112     triton
113     # apex
114     einops
115     transformers
116     timm
117     # flash-attn
118   ];
120   meta = with lib; {
121     description = "Collection of composable Transformer building blocks";
122     homepage = "https://github.com/facebookresearch/xformers";
123     changelog = "https://github.com/facebookresearch/xformers/blob/${version}/CHANGELOG.md";
124     license = licenses.bsd3;
125     maintainers = with maintainers; [ happysalada ];
126   };