3 # - See the documentation in ./gpus.nix.
6 cudaCapabilities ? (config.cudaCapabilities or [ ]),
7 cudaForwardCompat ? (config.cudaForwardCompat or true),
23 inherit (stdenv) hostPlatform;
25 # Flags are determined based on your CUDA toolkit by default. You may benefit
26 # from improved performance, reduced file size, or greater hardware support by
27 # passing a configuration based on your specific GPU environment.
29 # cudaCapabilities :: List Capability
30 # List of hardware generations to build.
32 # Currently, the last item is considered the optional forward-compatibility arch,
33 # but this may change in the future.
35 # cudaForwardCompat :: Bool
36 # Whether to include the forward compatibility gencode (+PTX)
37 # to support future GPU generations.
40 # Please see the accompanying documentation or https://github.com/NixOS/nixpkgs/pull/205351
42 # isSupported :: Gpu -> Bool
46 inherit (gpu) minCudaVersion maxCudaVersion;
47 lowerBoundSatisfied = strings.versionAtLeast cudaVersion minCudaVersion;
49 (maxCudaVersion == null) || !(strings.versionOlder maxCudaVersion cudaVersion);
51 lowerBoundSatisfied && upperBoundSatisfied;
53 # NOTE: Jetson is never built by default.
54 # isDefault :: Gpu -> Bool
58 inherit (gpu) dontDefaultAfter isJetson;
59 newGpu = dontDefaultAfter == null;
60 recentGpu = newGpu || strings.versionAtLeast dontDefaultAfter cudaVersion;
62 recentGpu && !isJetson;
64 # supportedGpus :: List Gpu
65 # GPUs which are supported by the provided CUDA version.
66 supportedGpus = builtins.filter isSupported gpus;
68 # defaultGpus :: List Gpu
69 # GPUs which are supported by the provided CUDA version and we want to build for by default.
70 defaultGpus = builtins.filter isDefault supportedGpus;
72 # supportedCapabilities :: List Capability
73 supportedCapabilities = lists.map (gpu: gpu.computeCapability) supportedGpus;
75 # defaultCapabilities :: List Capability
76 # The default capabilities to target, if not overridden by the user.
77 defaultCapabilities = lists.map (gpu: gpu.computeCapability) defaultGpus;
79 # cudaArchNameToVersions :: AttrSet String (List String)
80 # Maps the name of a GPU architecture to different versions of that architecture.
81 # For example, "Ampere" maps to [ "8.0" "8.6" "8.7" ].
82 cudaArchNameToVersions = lists.groupBy' (versions: gpu: versions ++ [ gpu.computeCapability ]) [ ] (
86 # cudaComputeCapabilityToName :: AttrSet String String
87 # Maps the version of a GPU architecture to the name of that architecture.
88 # For example, "8.0" maps to "Ampere".
89 cudaComputeCapabilityToName = builtins.listToAttrs (
90 lists.map (gpu: attrsets.nameValuePair gpu.computeCapability gpu.archName) supportedGpus
93 # cudaComputeCapabilityToIsJetson :: AttrSet String Boolean
94 cudaComputeCapabilityToIsJetson = builtins.listToAttrs (
95 lists.map (attrs: attrsets.nameValuePair attrs.computeCapability attrs.isJetson) supportedGpus
98 # jetsonComputeCapabilities :: List String
99 jetsonComputeCapabilities = trivial.pipe cudaComputeCapabilityToIsJetson [
100 (attrsets.filterAttrs (_: isJetson: isJetson))
104 # Find the intersection with the user-specified list of cudaCapabilities.
105 # NOTE: Jetson devices are never built by default because they cannot be targeted along with
106 # non-Jetson devices and require an aarch64 host platform. As such, if they're present anywhere,
107 # they must be in the user-specified cudaCapabilities.
108 # NOTE: We don't need to worry about mixes of Jetson and non-Jetson devices here -- there's
109 # sanity-checking for all that in below.
110 jetsonTargets = lists.intersectLists jetsonComputeCapabilities cudaCapabilities;
112 # dropDot :: String -> String
113 dropDot = ver: builtins.replaceStrings [ "." ] [ "" ] ver;
115 # archMapper :: String -> List String -> List String
116 # Maps a feature across a list of architecture versions to produce a list of architectures.
117 # For example, "sm" and [ "8.0" "8.6" "8.7" ] produces [ "sm_80" "sm_86" "sm_87" ].
118 archMapper = feat: lists.map (computeCapability: "${feat}_${dropDot computeCapability}");
120 # gencodeMapper :: String -> List String -> List String
121 # Maps a feature across a list of architecture versions to produce a list of gencode arguments.
122 # For example, "sm" and [ "8.0" "8.6" "8.7" ] produces [ "-gencode=arch=compute_80,code=sm_80"
123 # "-gencode=arch=compute_86,code=sm_86" "-gencode=arch=compute_87,code=sm_87" ].
128 "-gencode=arch=compute_${dropDot computeCapability},code=${feat}_${dropDot computeCapability}"
131 # Maps Nix system to NVIDIA redist arch.
132 # NOTE: We swap out the default `linux-sbsa` redist (for server-grade ARM chips) with the
133 # `linux-aarch64` redist (which is for Jetson devices) if we're building any Jetson devices.
134 # Since both are based on aarch64, we can only have one or the other, otherwise there's an
135 # ambiguity as to which should be used.
136 # NOTE: This function *will* be called by unsupported systems because `cudaPackages` is part of
137 # `all-packages.nix`, which is evaluated on all systems. As such, we need to handle unsupported
138 # systems gracefully.
139 # getRedistArch :: String -> String
142 attrsets.attrByPath [ nixSystem ] "unsupported" {
143 aarch64-linux = if jetsonTargets != [ ] then "linux-aarch64" else "linux-sbsa";
144 x86_64-linux = "linux-x86_64";
145 ppc64le-linux = "linux-ppc64le";
146 x86_64-windows = "windows-x86_64";
149 # Maps NVIDIA redist arch to Nix system.
150 # NOTE: This function *will* be called by unsupported systems because `cudaPackages` is part of
151 # `all-packages.nix`, which is evaluated on all systems. As such, we need to handle unsupported
152 # systems gracefully.
153 # getNixSystem :: String -> String
156 attrsets.attrByPath [ redistArch ] "unsupported-${redistArch}" {
157 linux-sbsa = "aarch64-linux";
158 linux-aarch64 = "aarch64-linux";
159 linux-x86_64 = "x86_64-linux";
160 linux-ppc64le = "ppc64le-linux";
161 windows-x86_64 = "x86_64-windows";
167 enableForwardCompat ? true,
170 inherit cudaCapabilities enableForwardCompat;
172 # archNames :: List String
173 # E.g. [ "Turing" "Ampere" ]
175 # Unknown architectures are rendered as sm_XX gencode flags.
176 archNames = lists.unique (
177 lists.map (cap: cudaComputeCapabilityToName.${cap} or "sm_${dropDot cap}") cudaCapabilities
180 # realArches :: List String
181 # The real architectures are physical architectures supported by the CUDA version.
182 # E.g. [ "sm_75" "sm_86" ]
183 realArches = archMapper "sm" cudaCapabilities;
185 # virtualArches :: List String
186 # The virtual architectures are typically used for forward compatibility, when trying to support
187 # an architecture newer than the CUDA version allows.
188 # E.g. [ "compute_75" "compute_86" ]
189 virtualArches = archMapper "compute" cudaCapabilities;
191 # arches :: List String
192 # By default, build for all supported architectures and forward compatibility via a virtual
193 # architecture for the newest supported architecture.
194 # E.g. [ "sm_75" "sm_86" "compute_86" ]
195 arches = realArches ++ lists.optional enableForwardCompat (lists.last virtualArches);
197 # gencode :: List String
198 # A list of CUDA gencode arguments to pass to NVCC.
199 # E.g. [ "-gencode=arch=compute_75,code=sm_75" ... "-gencode=arch=compute_86,code=compute_86" ]
202 base = gencodeMapper "sm" cudaCapabilities;
203 forward = gencodeMapper "compute" [ (lists.last cudaCapabilities) ];
205 base ++ lib.optionals enableForwardCompat forward;
207 # gencodeString :: String
208 # A space-separated string of CUDA gencode arguments to pass to NVCC.
209 # E.g. "-gencode=arch=compute_75,code=sm_75 ... -gencode=arch=compute_86,code=compute_86"
210 gencodeString = strings.concatStringsSep " " gencode;
212 # cmakeCudaArchitecturesString :: String
213 # A semicolon-separated string of CUDA capabilities without dots, suitable for passing to CMake.
215 cmakeCudaArchitecturesString = strings.concatMapStringsSep ";" dropDot cudaCapabilities;
217 # Jetson devices cannot be targeted by the same binaries which target non-Jetson devices. While
218 # NVIDIA provides both `linux-aarch64` and `linux-sbsa` packages, which both target `aarch64`,
219 # they are built with different settings and cannot be mixed.
220 # isJetsonBuild :: Boolean
223 requestedJetsonDevices = lists.filter (
224 cap: cudaComputeCapabilityToIsJetson.${cap} or false
226 requestedNonJetsonDevices = lists.filter (
227 cap: !(builtins.elem cap requestedJetsonDevices)
229 jetsonBuildSufficientCondition = requestedJetsonDevices != [ ];
230 jetsonBuildNecessaryCondition = requestedNonJetsonDevices == [ ] && hostPlatform.isAarch64;
232 trivial.throwIf (jetsonBuildSufficientCondition && !jetsonBuildNecessaryCondition) ''
233 Jetson devices cannot be targeted with non-Jetson devices. Additionally, they require hostPlatform to be aarch64.
234 You requested ${builtins.toJSON cudaCapabilities} for host platform ${hostPlatform.system}.
235 Requested Jetson devices: ${builtins.toJSON requestedJetsonDevices}.
236 Requested non-Jetson devices: ${builtins.toJSON requestedNonJetsonDevices}.
237 Exactly one of the following must be true:
238 - All CUDA capabilities belong to Jetson devices and hostPlatform is aarch64.
239 - No CUDA capabilities belong to Jetson devices.
240 See ${./gpus.nix} for a list of architectures supported by this version of Nixpkgs.
241 '' jetsonBuildSufficientCondition
242 && jetsonBuildNecessaryCondition;
245 # When changing names or formats: pause, validate, and update the assert
253 enableForwardCompat = true;
274 "-gencode=arch=compute_75,code=sm_75"
275 "-gencode=arch=compute_86,code=sm_86"
276 "-gencode=arch=compute_86,code=compute_86"
278 gencodeString = "-gencode=arch=compute_75,code=sm_75 -gencode=arch=compute_86,code=sm_86 -gencode=arch=compute_86,code=compute_86";
280 cmakeCudaArchitecturesString = "75;86";
282 isJetsonBuild = false;
284 actual = formatCapabilities {
290 actualWrapped = (builtins.tryEval (builtins.deepSeq actual actual)).value;
292 asserts.assertMsg ((strings.versionAtLeast cudaVersion "11.2") -> (expected == actualWrapped)) ''
293 This test should only fail when using a version of CUDA older than 11.2, the first to support
295 Expected: ${builtins.toJSON expected}
296 Actual: ${builtins.toJSON actualWrapped}
298 # Check mixed Jetson and non-Jetson devices
302 actual = formatCapabilities {
308 actualWrapped = (builtins.tryEval (builtins.deepSeq actual actual)).value;
310 asserts.assertMsg (expected == actualWrapped) ''
311 Jetson devices capabilities cannot be mixed with non-jetson devices.
312 Capability 7.5 is non-Jetson and should not be allowed with Jetson 7.2.
313 Expected: ${builtins.toJSON expected}
314 Actual: ${builtins.toJSON actualWrapped}
324 enableForwardCompat = true;
345 "-gencode=arch=compute_62,code=sm_62"
346 "-gencode=arch=compute_72,code=sm_72"
347 "-gencode=arch=compute_72,code=compute_72"
349 gencodeString = "-gencode=arch=compute_62,code=sm_62 -gencode=arch=compute_72,code=sm_72 -gencode=arch=compute_72,code=compute_72";
351 cmakeCudaArchitecturesString = "62;72";
353 isJetsonBuild = true;
355 actual = formatCapabilities {
361 actualWrapped = (builtins.tryEval (builtins.deepSeq actual actual)).value;
364 # We can't do this test unless we're targeting aarch64
365 (hostPlatform.isAarch64 -> (expected == actualWrapped))
367 Jetson devices can only be built with other Jetson devices.
368 Both 6.2 and 7.2 are Jetson devices.
369 Expected: ${builtins.toJSON expected}
370 Actual: ${builtins.toJSON actualWrapped}
373 # formatCapabilities :: { cudaCapabilities: List Capability, enableForwardCompat: Boolean } -> { ... }
374 inherit formatCapabilities;
376 # cudaArchNameToVersions :: String => String
377 inherit cudaArchNameToVersions;
379 # cudaComputeCapabilityToName :: String => String
380 inherit cudaComputeCapabilityToName;
382 # dropDot :: String -> String
387 supportedCapabilities
388 jetsonComputeCapabilities
394 // formatCapabilities {
395 cudaCapabilities = if cudaCapabilities == [ ] then defaultCapabilities else cudaCapabilities;
396 enableForwardCompat = cudaForwardCompat;