1 { stdenv, lib, fetchFromGitHub, fetchpatch, buildPythonPackage, python,
2 config, cudaSupport ? config.cudaSupport, cudaPackages, magma,
4 MPISupport ? false, mpi,
8 cmake, linkFarm, symlinkJoin, which, pybind11, removeReferencesTo,
13 Accelerate, CoreServices, libobjc,
15 # Propagated build inputs
20 numpy, pyyaml, cffi, click, typing-extensions,
21 # ROCm build and `torch.compile` requires `openai-triton`
22 tritonSupport ? (!stdenv.isDarwin), openai-triton,
27 # Disable MKLDNN on aarch64-darwin, it negatively impacts performance,
28 # this is also what official pytorch build does
29 mklDnnSupport ? !(stdenv.isDarwin && stdenv.isAarch64),
31 # virtual pkg that consistently instantiates blas across nixpkgs
32 # See https://github.com/NixOS/nixpkgs/pull/83888
35 # ninja (https://ninja-build.org) must be available to run C++ extensions tests,
38 # dependencies for torch.utils.tensorboard
39 pillow, six, future, tensorboard, protobuf,
44 rocmSupport ? config.rocmSupport,
50 inherit (lib) attrsets lists strings trivial;
51 inherit (cudaPackages) cudaFlags cudnn;
53 # Some packages are not available on all platforms
54 nccl = cudaPackages.nccl or null;
56 setBool = v: if v then "1" else "0";
58 # https://github.com/pytorch/pytorch/blob/v2.0.1/torch/utils/cpp_extension.py#L1744
59 supportedTorchCudaCapabilities =
61 real = ["3.5" "3.7" "5.0" "5.2" "5.3" "6.0" "6.1" "6.2" "7.0" "7.2" "7.5" "8.0" "8.6" "8.7" "8.9" "9.0"];
62 ptx = lists.map (x: "${x}+PTX") real;
66 # NOTE: The lists.subtractLists function is perhaps a bit unintuitive. It subtracts the elements
67 # of the first list *from* the second list. That means:
68 # lists.subtractLists a b = b - a
71 supportedCudaCapabilities = lists.intersectLists cudaFlags.cudaCapabilities supportedTorchCudaCapabilities;
72 unsupportedCudaCapabilities = lists.subtractLists supportedCudaCapabilities cudaFlags.cudaCapabilities;
74 # Use trivial.warnIf to print a warning if any unsupported GPU targets are specified.
75 gpuArchWarner = supported: unsupported:
76 trivial.throwIf (supported == [ ])
78 "No supported GPU targets specified. Requested GPU targets: "
79 + strings.concatStringsSep ", " unsupported
83 # Create the gpuTargetString.
84 gpuTargetString = strings.concatStringsSep ";" (
85 if gpuTargets != [ ] then
86 # If gpuTargets is specified, it always takes priority.
88 else if cudaSupport then
89 gpuArchWarner supportedCudaCapabilities unsupportedCudaCapabilities
90 else if rocmSupport then
91 rocmPackages.clr.gpuTargets
93 throw "No GPU targets specified"
96 rocmtoolkit_joined = symlinkJoin {
99 paths = with rocmPackages; [
100 rocm-core clr rccl miopen miopengemm rocrand rocblas
101 rocsparse hipsparse rocthrust rocprim hipcub roctracer
102 rocfft rocsolver hipfft hipsolver hipblas
103 rocminfo rocm-thunk rocm-comgr rocm-device-libs
104 rocm-runtime clr.icd hipify
107 # Fix `setuptools` not being found
109 rm -rf $out/nix-support
113 brokenConditions = attrsets.filterAttrs (_: cond: cond) {
114 "CUDA and ROCm are not mutually exclusive" = cudaSupport && rocmSupport;
115 "CUDA is not targeting Linux" = cudaSupport && !stdenv.isLinux;
116 "Unsupported CUDA version" = cudaSupport && !(builtins.elem cudaPackages.cudaMajorVersion [ "11" "12" ]);
117 "MPI cudatoolkit does not match cudaPackages.cudatoolkit" = MPISupport && cudaSupport && (mpi.cudatoolkit != cudaPackages.cudatoolkit);
118 "Magma cudaPackages does not match cudaPackages" = cudaSupport && (magma.cudaPackages != cudaPackages);
120 in buildPythonPackage rec {
122 # Don't forget to update torch-bin to the same version.
124 format = "setuptools";
126 disabled = pythonOlder "3.8.0";
129 "out" # output standard python package
130 "dev" # output libtorch headers
131 "lib" # output libtorch libraries
134 src = fetchFromGitHub {
137 rev = "refs/tags/v${version}";
138 fetchSubmodules = true;
139 hash = "sha256-xUj77yKz3IQ3gd/G32pI4OhL3LoN1zS7eFg0/0nZp5I=";
142 patches = lib.optionals (stdenv.isDarwin && stdenv.isx86_64) [
143 # pthreadpool added support for Grand Central Dispatch in April
144 # 2020. However, this relies on functionality (DISPATCH_APPLY_AUTO)
145 # that is available starting with macOS 10.13. However, our current
146 # base is 10.12. Until we upgrade, we can fall back on the older
148 ./pthreadpool-disable-gcd.diff
149 ] ++ lib.optionals stdenv.isLinux [
150 # Propagate CUPTI to Kineto by overriding the search path with environment variables.
151 # https://github.com/pytorch/pytorch/pull/108847
152 ./pytorch-pr-108847.patch
155 postPatch = lib.optionalString rocmSupport ''
156 # https://github.com/facebookincubator/gloo/pull/297
157 substituteInPlace third_party/gloo/cmake/Hipify.cmake \
158 --replace "\''${HIPIFY_COMMAND}" "python \''${HIPIFY_COMMAND}"
160 # Replace hard-coded rocm paths
161 substituteInPlace caffe2/CMakeLists.txt \
162 --replace "/opt/rocm" "${rocmtoolkit_joined}" \
163 --replace "hcc/include" "hip/include" \
164 --replace "rocblas/include" "include/rocblas" \
165 --replace "hipsparse/include" "include/hipsparse"
167 # Doesn't pick up the environment variable?
168 substituteInPlace third_party/kineto/libkineto/CMakeLists.txt \
169 --replace "\''$ENV{ROCM_SOURCE_DIR}" "${rocmtoolkit_joined}" \
170 --replace "/opt/rocm" "${rocmtoolkit_joined}"
172 # Strangely, this is never set in cmake
173 substituteInPlace cmake/public/LoadHIP.cmake \
174 --replace "set(ROCM_PATH \$ENV{ROCM_PATH})" \
175 "set(ROCM_PATH \$ENV{ROCM_PATH})''\nset(ROCM_VERSION ${lib.concatStrings (lib.intersperse "0" (lib.splitString "." rocmPackages.clr.version))})"
177 # Detection of NCCL version doesn't work particularly well when using the static binary.
178 + lib.optionalString cudaSupport ''
179 substituteInPlace cmake/Modules/FindNCCL.cmake \
181 'message(FATAL_ERROR "Found NCCL header version and library version' \
182 'message(WARNING "Found NCCL header version and library version'
184 # TODO(@connorbaker): Remove this patch after 2.1.0 lands.
185 + lib.optionalString cudaSupport ''
186 substituteInPlace torch/utils/cpp_extension.py \
189 "'8.6', '8.7', '8.9'"
191 # error: no member named 'aligned_alloc' in the global namespace; did you mean simply 'aligned_alloc'
192 # This lib overrided aligned_alloc hence the error message. Tltr: his function is linkable but not in header.
193 + lib.optionalString (stdenv.isDarwin && lib.versionOlder stdenv.hostPlatform.darwinSdkVersion "11.0") ''
194 substituteInPlace third_party/pocketfft/pocketfft_hdronly.h --replace '#if __cplusplus >= 201703L
195 inline void *aligned_alloc(size_t align, size_t size)' '#if __cplusplus >= 201703L && 0
196 inline void *aligned_alloc(size_t align, size_t size)'
199 # NOTE(@connorbaker): Though we do not disable Gloo or MPI when building with CUDA support, caution should be taken
200 # when using the different backends. Gloo's GPU support isn't great, and MPI and CUDA can't be used at the same time
201 # without extreme care to ensure they don't lock each other out of shared resources.
202 # For more, see https://github.com/open-mpi/ompi/issues/7733#issuecomment-629806195.
203 preConfigure = lib.optionalString cudaSupport ''
204 export TORCH_CUDA_ARCH_LIST="${gpuTargetString}"
205 export CUDNN_INCLUDE_DIR=${cudnn.dev}/include
206 export CUDNN_LIB_DIR=${cudnn.lib}/lib
207 export CUPTI_INCLUDE_DIR=${cudaPackages.cuda_cupti.dev}/include
208 export CUPTI_LIBRARY_DIR=${cudaPackages.cuda_cupti.lib}/lib
209 '' + lib.optionalString rocmSupport ''
210 export ROCM_PATH=${rocmtoolkit_joined}
211 export ROCM_SOURCE_DIR=${rocmtoolkit_joined}
212 export PYTORCH_ROCM_ARCH="${gpuTargetString}"
213 export CMAKE_CXX_FLAGS="-I${rocmtoolkit_joined}/include -I${rocmtoolkit_joined}/include/rocblas"
214 python tools/amd_build/build_amd.py
217 # Use pytorch's custom configurations
218 dontUseCmakeConfigure = true;
220 # causes possible redefinition of _FORTIFY_SOURCE
221 hardeningDisable = [ "fortify3" ];
223 BUILD_NAMEDTENSOR = setBool true;
224 BUILD_DOCS = setBool buildDocs;
226 # We only do an imports check, so do not build tests either.
227 BUILD_TEST = setBool false;
229 # Unlike MKL, oneDNN (née MKLDNN) is FOSS, so we enable support for
230 # it by default. PyTorch currently uses its own vendored version
231 # of oneDNN through Intel iDeep.
232 USE_MKLDNN = setBool mklDnnSupport;
233 USE_MKLDNN_CBLAS = setBool mklDnnSupport;
235 # Avoid using pybind11 from git submodule
236 # Also avoids pytorch exporting the headers of pybind11
237 USE_SYSTEM_PYBIND11 = true;
240 export MAX_JOBS=$NIX_BUILD_CORES
241 ${python.pythonOnBuildForHost.interpreter} setup.py build --cmake-only
242 ${cmake}/bin/cmake build
246 function join_by { local IFS="$1"; shift; echo "$*"; }
249 read -ra RP <<< $(patchelf --print-rpath $1)
251 RP_NEW=$(join_by : ''${RP[@]:2})
252 patchelf --set-rpath \$ORIGIN:''${RP_NEW} "$1"
254 for f in $(find ''${out} -name 'libcaffe2*.so')
260 # Override the (weirdly) wrong version set by default. See
261 # https://github.com/NixOS/nixpkgs/pull/52437#issuecomment-449718038
262 # https://github.com/pytorch/pytorch/blob/v1.0.0/setup.py#L267
263 PYTORCH_BUILD_VERSION = version;
264 PYTORCH_BUILD_NUMBER = 0;
266 USE_NCCL = setBool (nccl != null);
267 USE_SYSTEM_NCCL = setBool useSystemNccl; # don't build pytorch's third_party NCCL
268 USE_STATIC_NCCL = setBool useSystemNccl;
270 # Suppress a weird warning in mkl-dnn, part of ideep in pytorch
271 # (upstream seems to have fixed this in the wrong place?)
272 # https://github.com/intel/mkl-dnn/commit/8134d346cdb7fe1695a2aa55771071d455fae0bc
273 # https://github.com/pytorch/pytorch/issues/22346
275 # Also of interest: pytorch ignores CXXFLAGS uses CFLAGS for both C and C++:
276 # https://github.com/pytorch/pytorch/blob/v1.11.0/setup.py#L17
277 env.NIX_CFLAGS_COMPILE = toString ((lib.optionals (blas.implementation == "mkl") [ "-Wno-error=array-bounds" ]
278 # Suppress gcc regression: avx512 math function raises uninitialized variable warning
279 # https://gcc.gnu.org/bugzilla/show_bug.cgi?id=105593
280 # See also: Fails to compile with GCC 12.1.0 https://github.com/pytorch/pytorch/issues/77939
281 ++ lib.optionals (stdenv.cc.isGNU && lib.versionAtLeast stdenv.cc.version "12.0.0") [
282 "-Wno-error=maybe-uninitialized"
283 "-Wno-error=uninitialized"
286 # gcc-12.2.0/include/c++/12.2.0/bits/new_allocator.h:158:33: error: ‘void operator delete(void*, std::size_t)’
287 # ... called on pointer ‘<unknown>’ with nonzero offset [1, 9223372036854775800] [-Werror=free-nonheap-object]
288 ++ lib.optionals (stdenv.cc.isGNU && lib.versions.major stdenv.cc.version == "12" ) [
289 "-Wno-error=free-nonheap-object"
291 # .../source/torch/csrc/autograd/generated/python_functions_0.cpp:85:3:
292 # error: cast from ... to ... converts to incompatible function type [-Werror,-Wcast-function-type-strict]
293 ++ lib.optionals (stdenv.cc.isClang && lib.versionAtLeast stdenv.cc.version "16") [
294 "-Wno-error=cast-function-type-strict"
295 # Suppresses the most spammy warnings.
296 # This is mainly to fix https://github.com/NixOS/nixpkgs/issues/266895.
297 ] ++ lib.optionals rocmSupport [
300 "-Wno-unknown-warning-option"
301 "-Wno-ignored-attributes"
302 "-Wno-deprecated-declarations"
303 "-Wno-defaulted-function-deleted"
306 "-Wno-unused-command-line-argument"
309 "-Wno-free-nonheap-object"
311 ] ++ lib.optionals stdenv.cc.isGNU [
312 "-Wno-maybe-uninitialized"
313 "-Wno-stringop-overflow"
316 nativeBuildInputs = [
323 ] ++ lib.optionals cudaSupport (with cudaPackages; [
324 autoAddOpenGLRunpathHook
327 ++ lib.optionals rocmSupport [ rocmtoolkit_joined ];
329 buildInputs = [ blas blas.provider ]
330 ++ lib.optionals cudaSupport (with cudaPackages; [
331 cuda_cccl.dev # <thrust/*>
332 cuda_cudart # cuda_runtime.h and libraries
333 cuda_cupti.dev # For kineto
334 cuda_cupti.lib # For kineto
335 cuda_nvcc.dev # crt/host_config.h; even though we include this in nativeBuildinputs, it's needed here too
336 cuda_nvml_dev.dev # <nvml.h>
340 cuda_nvtx.lib # -llibNVToolsExt
353 ] ++ lists.optionals (nccl != null) [
354 # Some platforms do not support NCCL (i.e., Jetson)
355 nccl.dev # Provides nccl.h AND a static copy of NCCL!
356 ] ++ lists.optionals (strings.versionOlder cudaVersion "11.8") [
357 cuda_nvprof.dev # <cuda_profiler_api.h>
358 ] ++ lists.optionals (strings.versionAtLeast cudaVersion "11.8") [
359 cuda_profiler_api.dev # <cuda_profiler_api.h>
361 ++ lib.optionals rocmSupport [ rocmPackages.llvm.openmp ]
362 ++ lib.optionals (cudaSupport || rocmSupport) [ magma ]
363 ++ lib.optionals stdenv.isLinux [ numactl ]
364 ++ lib.optionals stdenv.isDarwin [ Accelerate CoreServices libobjc ];
366 propagatedBuildInputs = [
372 # From install_requires:
379 # the following are required for tensorboard support
380 pillow six future tensorboard protobuf
382 # torch/csrc requires `pybind11` at runtime
385 ++ lib.optionals tritonSupport [ openai-triton ]
386 ++ lib.optionals MPISupport [ mpi ]
387 ++ lib.optionals rocmSupport [ rocmtoolkit_joined ];
389 # Tests take a long time and may be flaky, so just sanity-check imports
392 pythonImportsCheck = [
396 nativeCheckInputs = [ hypothesis ninja psutil ];
398 checkPhase = with lib.versions; with lib.strings; concatStringsSep " " [
400 "${python.interpreter} test/run_test.py"
402 (concatStringsSep " " [
403 "utils" # utils requires git, which is not allowed in the check phase
405 # "dataloader" # psutils correctly finds and triggers multiprocessing, but is too sandboxed to run -- resulting in numerous errors
406 # ^^^^^^^^^^^^ NOTE: while test_dataloader does return errors, these are acceptable errors and do not interfere with the build
408 # tensorboard has acceptable failures for pytorch 1.3.x due to dependencies on tensorboard-plugins
409 (optionalString (majorMinor version == "1.3" ) "tensorboard")
415 # In our dist-info the name is just "triton"
416 "pytorch-triton-rocm"
420 find "$out/${python.sitePackages}/torch/include" "$out/${python.sitePackages}/torch/lib" -type f -exec remove-references-to -t ${stdenv.cc} '{}' +
423 cp -r $out/${python.sitePackages}/torch/include $dev/include
424 cp -r $out/${python.sitePackages}/torch/share $dev/share
426 # Fix up library paths for split outputs
428 $dev/share/cmake/Torch/TorchConfig.cmake \
429 --replace \''${TORCH_INSTALL_PREFIX}/lib "$lib/lib"
432 $dev/share/cmake/Caffe2/Caffe2Targets-release.cmake \
433 --replace \''${_IMPORT_PREFIX}/lib "$lib/lib"
436 mv $out/${python.sitePackages}/torch/lib $lib/lib
437 ln -s $lib/lib $out/${python.sitePackages}/torch/lib
438 '' + lib.optionalString rocmSupport ''
439 substituteInPlace $dev/share/cmake/Tensorpipe/TensorpipeTargets-release.cmake \
440 --replace "\''${_IMPORT_PREFIX}/lib64" "$lib/lib"
442 substituteInPlace $dev/share/cmake/ATen/ATenConfig.cmake \
443 --replace "/build/source/torch/include" "$dev/include"
446 postFixup = lib.optionalString stdenv.isDarwin ''
447 for f in $(ls $lib/lib/*.dylib); do
448 install_name_tool -id $lib/lib/$(basename $f) $f || true
451 install_name_tool -change @rpath/libshm.dylib $lib/lib/libshm.dylib $lib/lib/libtorch_python.dylib
452 install_name_tool -change @rpath/libtorch.dylib $lib/lib/libtorch.dylib $lib/lib/libtorch_python.dylib
453 install_name_tool -change @rpath/libc10.dylib $lib/lib/libc10.dylib $lib/lib/libtorch_python.dylib
455 install_name_tool -change @rpath/libc10.dylib $lib/lib/libc10.dylib $lib/lib/libtorch.dylib
457 install_name_tool -change @rpath/libtorch.dylib $lib/lib/libtorch.dylib $lib/lib/libshm.dylib
458 install_name_tool -change @rpath/libc10.dylib $lib/lib/libc10.dylib $lib/lib/libshm.dylib
461 # Builds in 2+h with 2 cores, and ~15m with a big-parallel builder.
462 requiredSystemFeatures = [ "big-parallel" ];
465 inherit cudaSupport cudaPackages;
466 # At least for 1.10.2 `torch.fft` is unavailable unless BLAS provider is MKL. This attribute allows for easy detection of its availability.
467 blasProvider = blas.provider;
468 # To help debug when a package is broken due to CUDA support
469 inherit brokenConditions;
470 cudaCapabilities = if cudaSupport then supportedCudaCapabilities else [ ];
474 changelog = "https://github.com/pytorch/pytorch/releases/tag/v${version}";
475 # keep PyTorch in the description so the package can be found under that name on search.nixos.org
476 description = "PyTorch: Tensors and Dynamic neural networks in Python with strong GPU acceleration";
477 homepage = "https://pytorch.org/";
478 license = licenses.bsd3;
479 maintainers = with maintainers; [ teh thoughtpolice tscholak ]; # tscholak esp. for darwin-related builds
480 platforms = with platforms; linux ++ lib.optionals (!cudaSupport && !rocmSupport) darwin;
481 broken = builtins.any trivial.id (builtins.attrValues brokenConditions);