biome: 1.9.2 -> 1.9.3 (#349335)
[NixPkgs.git] / pkgs / development / cuda-modules / flags.nix
blob93952a66216b4d8804827efef6868419ca6ba92b
1 # Type aliases
2 # Gpu :: AttrSet
3 #   - See the documentation in ./gpus.nix.
5   config,
6   cudaCapabilities ? (config.cudaCapabilities or [ ]),
7   cudaForwardCompat ? (config.cudaForwardCompat or true),
8   lib,
9   cudaVersion,
10   stdenv,
11   # gpus :: List Gpu
12   gpus,
14 let
15   inherit (lib)
16     asserts
17     attrsets
18     lists
19     strings
20     trivial
21     ;
23   inherit (stdenv) hostPlatform;
25   # Flags are determined based on your CUDA toolkit by default.  You may benefit
26   # from improved performance, reduced file size, or greater hardware support by
27   # passing a configuration based on your specific GPU environment.
28   #
29   # cudaCapabilities :: List Capability
30   # List of hardware generations to build.
31   # E.g. [ "8.0" ]
32   # Currently, the last item is considered the optional forward-compatibility arch,
33   # but this may change in the future.
34   #
35   # cudaForwardCompat :: Bool
36   # Whether to include the forward compatibility gencode (+PTX)
37   # to support future GPU generations.
38   # E.g. true
39   #
40   # Please see the accompanying documentation or https://github.com/NixOS/nixpkgs/pull/205351
42   # isSupported :: Gpu -> Bool
43   isSupported =
44     gpu:
45     let
46       inherit (gpu) minCudaVersion maxCudaVersion;
47       lowerBoundSatisfied = strings.versionAtLeast cudaVersion minCudaVersion;
48       upperBoundSatisfied =
49         (maxCudaVersion == null) || !(strings.versionOlder maxCudaVersion cudaVersion);
50     in
51     lowerBoundSatisfied && upperBoundSatisfied;
53   # NOTE: Jetson is never built by default.
54   # isDefault :: Gpu -> Bool
55   isDefault =
56     gpu:
57     let
58       inherit (gpu) dontDefaultAfter isJetson;
59       newGpu = dontDefaultAfter == null;
60       recentGpu = newGpu || strings.versionAtLeast dontDefaultAfter cudaVersion;
61     in
62     recentGpu && !isJetson;
64   # supportedGpus :: List Gpu
65   # GPUs which are supported by the provided CUDA version.
66   supportedGpus = builtins.filter isSupported gpus;
68   # defaultGpus :: List Gpu
69   # GPUs which are supported by the provided CUDA version and we want to build for by default.
70   defaultGpus = builtins.filter isDefault supportedGpus;
72   # supportedCapabilities :: List Capability
73   supportedCapabilities = lists.map (gpu: gpu.computeCapability) supportedGpus;
75   # defaultCapabilities :: List Capability
76   # The default capabilities to target, if not overridden by the user.
77   defaultCapabilities = lists.map (gpu: gpu.computeCapability) defaultGpus;
79   # cudaArchNameToVersions :: AttrSet String (List String)
80   # Maps the name of a GPU architecture to different versions of that architecture.
81   # For example, "Ampere" maps to [ "8.0" "8.6" "8.7" ].
82   cudaArchNameToVersions = lists.groupBy' (versions: gpu: versions ++ [ gpu.computeCapability ]) [ ] (
83     gpu: gpu.archName
84   ) supportedGpus;
86   # cudaComputeCapabilityToName :: AttrSet String String
87   # Maps the version of a GPU architecture to the name of that architecture.
88   # For example, "8.0" maps to "Ampere".
89   cudaComputeCapabilityToName = builtins.listToAttrs (
90     lists.map (gpu: attrsets.nameValuePair gpu.computeCapability gpu.archName) supportedGpus
91   );
93   # cudaComputeCapabilityToIsJetson :: AttrSet String Boolean
94   cudaComputeCapabilityToIsJetson = builtins.listToAttrs (
95     lists.map (attrs: attrsets.nameValuePair attrs.computeCapability attrs.isJetson) supportedGpus
96   );
98   # jetsonComputeCapabilities :: List String
99   jetsonComputeCapabilities = trivial.pipe cudaComputeCapabilityToIsJetson [
100     (attrsets.filterAttrs (_: isJetson: isJetson))
101     builtins.attrNames
102   ];
104   # Find the intersection with the user-specified list of cudaCapabilities.
105   # NOTE: Jetson devices are never built by default because they cannot be targeted along with
106   # non-Jetson devices and require an aarch64 host platform. As such, if they're present anywhere,
107   # they must be in the user-specified cudaCapabilities.
108   # NOTE: We don't need to worry about mixes of Jetson and non-Jetson devices here -- there's
109   # sanity-checking for all that in below.
110   jetsonTargets = lists.intersectLists jetsonComputeCapabilities cudaCapabilities;
112   # dropDot :: String -> String
113   dropDot = ver: builtins.replaceStrings [ "." ] [ "" ] ver;
115   # archMapper :: String -> List String -> List String
116   # Maps a feature across a list of architecture versions to produce a list of architectures.
117   # For example, "sm" and [ "8.0" "8.6" "8.7" ] produces [ "sm_80" "sm_86" "sm_87" ].
118   archMapper = feat: lists.map (computeCapability: "${feat}_${dropDot computeCapability}");
120   # gencodeMapper :: String -> List String -> List String
121   # Maps a feature across a list of architecture versions to produce a list of gencode arguments.
122   # For example, "sm" and [ "8.0" "8.6" "8.7" ] produces [ "-gencode=arch=compute_80,code=sm_80"
123   # "-gencode=arch=compute_86,code=sm_86" "-gencode=arch=compute_87,code=sm_87" ].
124   gencodeMapper =
125     feat:
126     lists.map (
127       computeCapability:
128       "-gencode=arch=compute_${dropDot computeCapability},code=${feat}_${dropDot computeCapability}"
129     );
131   # Maps Nix system to NVIDIA redist arch.
132   # NOTE: We swap out the default `linux-sbsa` redist (for server-grade ARM chips) with the
133   # `linux-aarch64` redist (which is for Jetson devices) if we're building any Jetson devices.
134   # Since both are based on aarch64, we can only have one or the other, otherwise there's an
135   # ambiguity as to which should be used.
136   # NOTE: This function *will* be called by unsupported systems because `cudaPackages` is part of
137   # `all-packages.nix`, which is evaluated on all systems. As such, we need to handle unsupported
138   # systems gracefully.
139   # getRedistArch :: String -> String
140   getRedistArch =
141     nixSystem:
142     attrsets.attrByPath [ nixSystem ] "unsupported" {
143       aarch64-linux = if jetsonTargets != [ ] then "linux-aarch64" else "linux-sbsa";
144       x86_64-linux = "linux-x86_64";
145       ppc64le-linux = "linux-ppc64le";
146       x86_64-windows = "windows-x86_64";
147     };
149   # Maps NVIDIA redist arch to Nix system.
150   # NOTE: This function *will* be called by unsupported systems because `cudaPackages` is part of
151   # `all-packages.nix`, which is evaluated on all systems. As such, we need to handle unsupported
152   # systems gracefully.
153   # getNixSystem :: String -> String
154   getNixSystem =
155     redistArch:
156     attrsets.attrByPath [ redistArch ] "unsupported-${redistArch}" {
157       linux-sbsa = "aarch64-linux";
158       linux-aarch64 = "aarch64-linux";
159       linux-x86_64 = "x86_64-linux";
160       linux-ppc64le = "ppc64le-linux";
161       windows-x86_64 = "x86_64-windows";
162     };
164   formatCapabilities =
165     {
166       cudaCapabilities,
167       enableForwardCompat ? true,
168     }:
169     rec {
170       inherit cudaCapabilities enableForwardCompat;
172       # archNames :: List String
173       # E.g. [ "Turing" "Ampere" ]
174       #
175       # Unknown architectures are rendered as sm_XX gencode flags.
176       archNames = lists.unique (
177         lists.map (cap: cudaComputeCapabilityToName.${cap} or "sm_${dropDot cap}") cudaCapabilities
178       );
180       # realArches :: List String
181       # The real architectures are physical architectures supported by the CUDA version.
182       # E.g. [ "sm_75" "sm_86" ]
183       realArches = archMapper "sm" cudaCapabilities;
185       # virtualArches :: List String
186       # The virtual architectures are typically used for forward compatibility, when trying to support
187       # an architecture newer than the CUDA version allows.
188       # E.g. [ "compute_75" "compute_86" ]
189       virtualArches = archMapper "compute" cudaCapabilities;
191       # arches :: List String
192       # By default, build for all supported architectures and forward compatibility via a virtual
193       # architecture for the newest supported architecture.
194       # E.g. [ "sm_75" "sm_86" "compute_86" ]
195       arches = realArches ++ lists.optional enableForwardCompat (lists.last virtualArches);
197       # gencode :: List String
198       # A list of CUDA gencode arguments to pass to NVCC.
199       # E.g. [ "-gencode=arch=compute_75,code=sm_75" ... "-gencode=arch=compute_86,code=compute_86" ]
200       gencode =
201         let
202           base = gencodeMapper "sm" cudaCapabilities;
203           forward = gencodeMapper "compute" [ (lists.last cudaCapabilities) ];
204         in
205         base ++ lib.optionals enableForwardCompat forward;
207       # gencodeString :: String
208       # A space-separated string of CUDA gencode arguments to pass to NVCC.
209       # E.g. "-gencode=arch=compute_75,code=sm_75 ... -gencode=arch=compute_86,code=compute_86"
210       gencodeString = strings.concatStringsSep " " gencode;
212       # cmakeCudaArchitecturesString :: String
213       # A semicolon-separated string of CUDA capabilities without dots, suitable for passing to CMake.
214       # E.g. "75;86"
215       cmakeCudaArchitecturesString = strings.concatMapStringsSep ";" dropDot cudaCapabilities;
217       # Jetson devices cannot be targeted by the same binaries which target non-Jetson devices. While
218       # NVIDIA provides both `linux-aarch64` and `linux-sbsa` packages, which both target `aarch64`,
219       # they are built with different settings and cannot be mixed.
220       # isJetsonBuild :: Boolean
221       isJetsonBuild =
222         let
223           requestedJetsonDevices = lists.filter (
224             cap: cudaComputeCapabilityToIsJetson.${cap} or false
225           ) cudaCapabilities;
226           requestedNonJetsonDevices = lists.filter (
227             cap: !(builtins.elem cap requestedJetsonDevices)
228           ) cudaCapabilities;
229           jetsonBuildSufficientCondition = requestedJetsonDevices != [ ];
230           jetsonBuildNecessaryCondition = requestedNonJetsonDevices == [ ] && hostPlatform.isAarch64;
231         in
232         trivial.throwIf (jetsonBuildSufficientCondition && !jetsonBuildNecessaryCondition) ''
233           Jetson devices cannot be targeted with non-Jetson devices. Additionally, they require hostPlatform to be aarch64.
234           You requested ${builtins.toJSON cudaCapabilities} for host platform ${hostPlatform.system}.
235           Requested Jetson devices: ${builtins.toJSON requestedJetsonDevices}.
236           Requested non-Jetson devices: ${builtins.toJSON requestedNonJetsonDevices}.
237           Exactly one of the following must be true:
238           - All CUDA capabilities belong to Jetson devices and hostPlatform is aarch64.
239           - No CUDA capabilities belong to Jetson devices.
240           See ${./gpus.nix} for a list of architectures supported by this version of Nixpkgs.
241         '' jetsonBuildSufficientCondition
242         && jetsonBuildNecessaryCondition;
243     };
245 # When changing names or formats: pause, validate, and update the assert
246 assert
247   let
248     expected = {
249       cudaCapabilities = [
250         "7.5"
251         "8.6"
252       ];
253       enableForwardCompat = true;
255       archNames = [
256         "Turing"
257         "Ampere"
258       ];
259       realArches = [
260         "sm_75"
261         "sm_86"
262       ];
263       virtualArches = [
264         "compute_75"
265         "compute_86"
266       ];
267       arches = [
268         "sm_75"
269         "sm_86"
270         "compute_86"
271       ];
273       gencode = [
274         "-gencode=arch=compute_75,code=sm_75"
275         "-gencode=arch=compute_86,code=sm_86"
276         "-gencode=arch=compute_86,code=compute_86"
277       ];
278       gencodeString = "-gencode=arch=compute_75,code=sm_75 -gencode=arch=compute_86,code=sm_86 -gencode=arch=compute_86,code=compute_86";
280       cmakeCudaArchitecturesString = "75;86";
282       isJetsonBuild = false;
283     };
284     actual = formatCapabilities {
285       cudaCapabilities = [
286         "7.5"
287         "8.6"
288       ];
289     };
290     actualWrapped = (builtins.tryEval (builtins.deepSeq actual actual)).value;
291   in
292   asserts.assertMsg ((strings.versionAtLeast cudaVersion "11.2") -> (expected == actualWrapped)) ''
293     This test should only fail when using a version of CUDA older than 11.2, the first to support
294     8.6.
295     Expected: ${builtins.toJSON expected}
296     Actual: ${builtins.toJSON actualWrapped}
297   '';
298 # Check mixed Jetson and non-Jetson devices
299 assert
300   let
301     expected = false;
302     actual = formatCapabilities {
303       cudaCapabilities = [
304         "7.2"
305         "7.5"
306       ];
307     };
308     actualWrapped = (builtins.tryEval (builtins.deepSeq actual actual)).value;
309   in
310   asserts.assertMsg (expected == actualWrapped) ''
311     Jetson devices capabilities cannot be mixed with non-jetson devices.
312     Capability 7.5 is non-Jetson and should not be allowed with Jetson 7.2.
313     Expected: ${builtins.toJSON expected}
314     Actual: ${builtins.toJSON actualWrapped}
315   '';
316 # Check Jetson-only
317 assert
318   let
319     expected = {
320       cudaCapabilities = [
321         "6.2"
322         "7.2"
323       ];
324       enableForwardCompat = true;
326       archNames = [
327         "Pascal"
328         "Volta"
329       ];
330       realArches = [
331         "sm_62"
332         "sm_72"
333       ];
334       virtualArches = [
335         "compute_62"
336         "compute_72"
337       ];
338       arches = [
339         "sm_62"
340         "sm_72"
341         "compute_72"
342       ];
344       gencode = [
345         "-gencode=arch=compute_62,code=sm_62"
346         "-gencode=arch=compute_72,code=sm_72"
347         "-gencode=arch=compute_72,code=compute_72"
348       ];
349       gencodeString = "-gencode=arch=compute_62,code=sm_62 -gencode=arch=compute_72,code=sm_72 -gencode=arch=compute_72,code=compute_72";
351       cmakeCudaArchitecturesString = "62;72";
353       isJetsonBuild = true;
354     };
355     actual = formatCapabilities {
356       cudaCapabilities = [
357         "6.2"
358         "7.2"
359       ];
360     };
361     actualWrapped = (builtins.tryEval (builtins.deepSeq actual actual)).value;
362   in
363   asserts.assertMsg
364     # We can't do this test unless we're targeting aarch64
365     (hostPlatform.isAarch64 -> (expected == actualWrapped))
366     ''
367       Jetson devices can only be built with other Jetson devices.
368       Both 6.2 and 7.2 are Jetson devices.
369       Expected: ${builtins.toJSON expected}
370       Actual: ${builtins.toJSON actualWrapped}
371     '';
373   # formatCapabilities :: { cudaCapabilities: List Capability, enableForwardCompat: Boolean } ->  { ... }
374   inherit formatCapabilities;
376   # cudaArchNameToVersions :: String => String
377   inherit cudaArchNameToVersions;
379   # cudaComputeCapabilityToName :: String => String
380   inherit cudaComputeCapabilityToName;
382   # dropDot :: String -> String
383   inherit dropDot;
385   inherit
386     defaultCapabilities
387     supportedCapabilities
388     jetsonComputeCapabilities
389     jetsonTargets
390     getNixSystem
391     getRedistArch
392     ;
394 // formatCapabilities {
395   cudaCapabilities = if cudaCapabilities == [ ] then defaultCapabilities else cudaCapabilities;
396   enableForwardCompat = cudaForwardCompat;