Merge pull request #307098 from r-ryantm/auto-update/cilium-cli
[NixPkgs.git] / pkgs / development / python-modules / openai-triton / default.nix
blob2bdb8d918af3f9dddb2db8c41e6b9356af8b04d0
1 { lib
2 , config
3 , buildPythonPackage
4 , fetchFromGitHub
5 , fetchpatch
6 , addOpenGLRunpath
7 , setuptools
8 , pytestCheckHook
9 , pythonRelaxDepsHook
10 , cmake
11 , ninja
12 , pybind11
13 , gtest
14 , zlib
15 , ncurses
16 , libxml2
17 , lit
18 , llvm
19 , filelock
20 , torchWithRocm
21 , python
23 , runCommand
25 , cudaPackages
26 , cudaSupport ? config.cudaSupport
29 let
30   ptxas = "${cudaPackages.cuda_nvcc}/bin/ptxas"; # Make sure cudaPackages is the right version each update (See python/setup.py)
32 buildPythonPackage rec {
33   pname = "triton";
34   version = "2.1.0";
35   pyproject = true;
37   src = fetchFromGitHub {
38     owner = "openai";
39     repo = pname;
40     rev = "v${version}";
41     hash = "sha256-8UTUwLH+SriiJnpejdrzz9qIquP2zBp1/uwLdHmv0XQ=";
42   };
44   patches = [
45     # fix overflow error
46     (fetchpatch {
47       url = "https://github.com/openai/triton/commit/52c146f66b79b6079bcd28c55312fc6ea1852519.patch";
48       hash = "sha256-098/TCQrzvrBAbQiaVGCMaF3o5Yc3yWDxzwSkzIuAtY=";
49     })
50   ] ++ lib.optionals (!cudaSupport) [
51     ./0000-dont-download-ptxas.patch
52     # openai-triton wants to get ptxas version even if ptxas is not
53     # used, resulting in ptxas not found error.
54     ./0001-ptxas-disable-version-key-for-non-cuda-targets.patch
55   ];
57   nativeBuildInputs = [
58     setuptools
59     pythonRelaxDepsHook
60     # pytestCheckHook # Requires torch (circular dependency) and probably needs GPUs:
61     cmake
62     ninja
64     # Note for future:
65     # These *probably* should go in depsTargetTarget
66     # ...but we cannot test cross right now anyway
67     # because we only support cudaPackages on x86_64-linux atm
68     lit
69     llvm
70   ];
72   buildInputs = [
73     gtest
74     libxml2.dev
75     ncurses
76     pybind11
77     zlib
78   ];
80   propagatedBuildInputs = [
81     filelock
82     # openai-triton uses setuptools at runtime:
83     # https://github.com/NixOS/nixpkgs/pull/286763/#discussion_r1480392652
84     setuptools
85   ];
87   postPatch = let
88     # Bash was getting weird without linting,
89     # but basically upstream contains [cc, ..., "-lcuda", ...]
90     # and we replace it with [..., "-lcuda", "-L/run/opengl-driver/lib", "-L$stubs", ...]
91     old = [ "-lcuda" ];
92     new = [ "-lcuda" "-L${addOpenGLRunpath.driverLink}" "-L${cudaPackages.cuda_cudart}/lib/stubs/" ];
94     quote = x: ''"${x}"'';
95     oldStr = lib.concatMapStringsSep ", " quote old;
96     newStr = lib.concatMapStringsSep ", " quote new;
97   in ''
98     # Use our `cmakeFlags` instead and avoid downloading dependencies
99     substituteInPlace python/setup.py \
100       --replace "= get_thirdparty_packages(triton_cache_path)" "= os.environ[\"cmakeFlags\"].split()"
102     # Already defined in llvm, when built with -DLLVM_INSTALL_UTILS
103     substituteInPlace bin/CMakeLists.txt \
104       --replace "add_subdirectory(FileCheck)" ""
106     # Don't fetch googletest
107     substituteInPlace unittest/CMakeLists.txt \
108       --replace "include (\''${CMAKE_CURRENT_SOURCE_DIR}/googletest.cmake)" ""\
109       --replace "include(GoogleTest)" "find_package(GTest REQUIRED)"
110   '' + lib.optionalString cudaSupport ''
111     # Use our linker flags
112     substituteInPlace python/triton/common/build.py \
113       --replace '${oldStr}' '${newStr}'
114   '';
116   # Avoid GLIBCXX mismatch with other cuda-enabled python packages
117   preConfigure = ''
118     # Ensure that the build process uses the requested number of cores
119     export MAX_JOBS="$NIX_BUILD_CORES"
121     # Upstream's setup.py tries to write cache somewhere in ~/
122     export HOME=$(mktemp -d)
124     # Upstream's github actions patch setup.cfg to write base-dir. May be redundant
125     echo "
126     [build_ext]
127     base-dir=$PWD" >> python/setup.cfg
129     # The rest (including buildPhase) is relative to ./python/
130     cd python
131   '' + lib.optionalString cudaSupport ''
132     export CC=${cudaPackages.backendStdenv.cc}/bin/cc;
133     export CXX=${cudaPackages.backendStdenv.cc}/bin/c++;
135     # Work around download_and_copy_ptxas()
136     mkdir -p $PWD/triton/third_party/cuda/bin
137     ln -s ${ptxas} $PWD/triton/third_party/cuda/bin
138   '';
140   # CMake is run by setup.py instead
141   dontUseCmakeConfigure = true;
143   # Setuptools (?) strips runpath and +x flags. Let's just restore the symlink
144   postFixup = lib.optionalString cudaSupport ''
145     rm -f $out/${python.sitePackages}/triton/third_party/cuda/bin/ptxas
146     ln -s ${ptxas} $out/${python.sitePackages}/triton/third_party/cuda/bin/ptxas
147   '';
149   checkInputs = [ cmake ]; # ctest
150   dontUseSetuptoolsCheck = true;
152   preCheck = ''
153     # build/temp* refers to build_ext.build_temp (looked up in the build logs)
154     (cd /build/source/python/build/temp* ; ctest)
156     # For pytestCheckHook
157     cd test/unit
158   '';
160   # Circular dependency on torch
161   # pythonImportsCheck = [
162   #   "triton"
163   #   "triton.language"
164   # ];
166   # Ultimately, torch is our test suite:
167   passthru.tests = {
168     inherit torchWithRocm;
169     # Implemented as alternative to pythonImportsCheck, in case if circular dependency on torch occurs again,
170     # and pythonImportsCheck is commented back.
171     import-triton = runCommand "import-triton" { nativeBuildInputs = [(python.withPackages (ps: [ps.openai-triton]))]; } ''
172       python << \EOF
173       import triton
174       import triton.language
175       EOF
176       touch "$out"
177     '';
178   };
180   pythonRemoveDeps = [
181     # Circular dependency, cf. https://github.com/openai/triton/issues/1374
182     "torch"
184     # CLI tools without dist-info
185     "cmake"
186     "lit"
187   ];
189   meta = with lib; {
190     description = "Language and compiler for writing highly efficient custom Deep-Learning primitives";
191     homepage = "https://github.com/openai/triton";
192     platforms = lib.platforms.unix;
193     license = licenses.mit;
194     maintainers = with maintainers; [ SomeoneSerge Madouura ];
195   };