[Hexagon] Handle all compares of i1 and vNi1
[llvm-project.git] / llvm / docs / CompileCudaWithLLVM.rst
blob631691ef9b472a1d010d7a483cc2d3e43b2573b8
1 =========================
2 Compiling CUDA with clang
3 =========================
5 .. contents::
6    :local:
8 Introduction
9 ============
11 This document describes how to compile CUDA code with clang, and gives some
12 details about LLVM and clang's CUDA implementations.
14 This document assumes a basic familiarity with CUDA. Information about CUDA
15 programming can be found in the
16 `CUDA programming guide
17 <http://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html>`_.
19 Compiling CUDA Code
20 ===================
22 Prerequisites
23 -------------
25 CUDA is supported since llvm 3.9. Clang currently supports CUDA 7.0 through
26 12.1. If clang detects a newer CUDA version, it will issue a warning and will
27 attempt to use detected CUDA SDK it as if it were CUDA 12.1.
29 Before you build CUDA code, you'll need to have installed the CUDA SDK.  See
30 `NVIDIA's CUDA installation guide
31 <https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html>`_ for
32 details.  Note that clang `maynot support
33 <https://bugs.llvm.org/show_bug.cgi?id=26966>`_ the CUDA toolkit as installed by
34 some Linux package managers. Clang does attempt to deal with specific details of
35 CUDA installation on a handful of common Linux distributions, but in general the
36 most reliable way to make it work is to install CUDA in a single directory from
37 NVIDIA's `.run` package and specify its location via `--cuda-path=...` argument.
39 CUDA compilation is supported on Linux. Compilation on MacOS and Windows may or
40 may not work and currently have no maintainers.
42 Invoking clang
43 --------------
45 Invoking clang for CUDA compilation works similarly to compiling regular C++.
46 You just need to be aware of a few additional flags.
48 You can use `this <https://gist.github.com/855e277884eb6b388cd2f00d956c2fd4>`_
49 program as a toy example.  Save it as ``axpy.cu``.  (Clang detects that you're
50 compiling CUDA code by noticing that your filename ends with ``.cu``.
51 Alternatively, you can pass ``-x cuda``.)
53 To build and run, run the following commands, filling in the parts in angle
54 brackets as described below:
56 .. code-block:: console
58   $ clang++ axpy.cu -o axpy --cuda-gpu-arch=<GPU arch> \
59       -L<CUDA install path>/<lib64 or lib>             \
60       -lcudart_static -ldl -lrt -pthread
61   $ ./axpy
62   y[0] = 2
63   y[1] = 4
64   y[2] = 6
65   y[3] = 8
67 On MacOS, replace `-lcudart_static` with `-lcudart`; otherwise, you may get
68 "CUDA driver version is insufficient for CUDA runtime version" errors when you
69 run your program.
71 * ``<CUDA install path>`` -- the directory where you installed CUDA SDK.
72   Typically, ``/usr/local/cuda``.
74   Pass e.g. ``-L/usr/local/cuda/lib64`` if compiling in 64-bit mode; otherwise,
75   pass e.g. ``-L/usr/local/cuda/lib``.  (In CUDA, the device code and host code
76   always have the same pointer widths, so if you're compiling 64-bit code for
77   the host, you're also compiling 64-bit code for the device.) Note that as of
78   v10.0 CUDA SDK `no longer supports compilation of 32-bit
79   applications <https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html#deprecated-features>`_.
81 * ``<GPU arch>`` -- the `compute capability
82   <https://developer.nvidia.com/cuda-gpus>`_ of your GPU. For example, if you
83   want to run your program on a GPU with compute capability of 3.5, specify
84   ``--cuda-gpu-arch=sm_35``.
86   Note: You cannot pass ``compute_XX`` as an argument to ``--cuda-gpu-arch``;
87   only ``sm_XX`` is currently supported.  However, clang always includes PTX in
88   its binaries, so e.g. a binary compiled with ``--cuda-gpu-arch=sm_30`` would be
89   forwards-compatible with e.g. ``sm_35`` GPUs.
91   You can pass ``--cuda-gpu-arch`` multiple times to compile for multiple archs.
93 The `-L` and `-l` flags only need to be passed when linking.  When compiling,
94 you may also need to pass ``--cuda-path=/path/to/cuda`` if you didn't install
95 the CUDA SDK into ``/usr/local/cuda`` or ``/usr/local/cuda-X.Y``.
97 Flags that control numerical code
98 ---------------------------------
100 If you're using GPUs, you probably care about making numerical code run fast.
101 GPU hardware allows for more control over numerical operations than most CPUs,
102 but this results in more compiler options for you to juggle.
104 Flags you may wish to tweak include:
106 * ``-ffp-contract={on,off,fast}`` (defaults to ``fast`` on host and device when
107   compiling CUDA) Controls whether the compiler emits fused multiply-add
108   operations.
110   * ``off``: never emit fma operations, and prevent ptxas from fusing multiply
111     and add instructions.
112   * ``on``: fuse multiplies and adds within a single statement, but never
113     across statements (C11 semantics).  Prevent ptxas from fusing other
114     multiplies and adds.
115   * ``fast``: fuse multiplies and adds wherever profitable, even across
116     statements.  Doesn't prevent ptxas from fusing additional multiplies and
117     adds.
119   Fused multiply-add instructions can be much faster than the unfused
120   equivalents, but because the intermediate result in an fma is not rounded,
121   this flag can affect numerical code.
123 * ``-fcuda-flush-denormals-to-zero`` (default: off) When this is enabled,
124   floating point operations may flush `denormal
125   <https://en.wikipedia.org/wiki/Denormal_number>`_ inputs and/or outputs to 0.
126   Operations on denormal numbers are often much slower than the same operations
127   on normal numbers.
129 * ``-fcuda-approx-transcendentals`` (default: off) When this is enabled, the
130   compiler may emit calls to faster, approximate versions of transcendental
131   functions, instead of using the slower, fully IEEE-compliant versions.  For
132   example, this flag allows clang to emit the ptx ``sin.approx.f32``
133   instruction.
135   This is implied by ``-ffast-math``.
137 Standard library support
138 ========================
140 In clang and nvcc, most of the C++ standard library is not supported on the
141 device side.
143 ``<math.h>`` and ``<cmath>``
144 ----------------------------
146 In clang, ``math.h`` and ``cmath`` are available and `pass
147 <https://github.com/llvm/llvm-test-suite/blob/main/External/CUDA/math_h.cu>`_
148 `tests
149 <https://github.com/llvm/llvm-test-suite/blob/main/External/CUDA/cmath.cu>`_
150 adapted from libc++'s test suite.
152 In nvcc ``math.h`` and ``cmath`` are mostly available.  Versions of ``::foof``
153 in namespace std (e.g. ``std::sinf``) are not available, and where the standard
154 calls for overloads that take integral arguments, these are usually not
155 available.
157 .. code-block:: c++
159   #include <math.h>
160   #include <cmath.h>
162   // clang is OK with everything in this function.
163   __device__ void test() {
164     std::sin(0.); // nvcc - ok
165     std::sin(0);  // nvcc - error, because no std::sin(int) override is available.
166     sin(0);       // nvcc - same as above.
168     sinf(0.);       // nvcc - ok
169     std::sinf(0.);  // nvcc - no such function
170   }
172 ``<std::complex>``
173 ------------------
175 nvcc does not officially support ``std::complex``.  It's an error to use
176 ``std::complex`` in ``__device__`` code, but it often works in ``__host__
177 __device__`` code due to nvcc's interpretation of the "wrong-side rule" (see
178 below).  However, we have heard from implementers that it's possible to get
179 into situations where nvcc will omit a call to an ``std::complex`` function,
180 especially when compiling without optimizations.
182 As of 2016-11-16, clang supports ``std::complex`` without these caveats.  It is
183 tested with libstdc++ 4.8.5 and newer, but is known to work only with libc++
184 newer than 2016-11-16.
186 ``<algorithm>``
187 ---------------
189 In C++14, many useful functions from ``<algorithm>`` (notably, ``std::min`` and
190 ``std::max``) become constexpr.  You can therefore use these in device code,
191 when compiling with clang.
193 Detecting clang vs NVCC from code
194 =================================
196 Although clang's CUDA implementation is largely compatible with NVCC's, you may
197 still want to detect when you're compiling CUDA code specifically with clang.
199 This is tricky, because NVCC may invoke clang as part of its own compilation
200 process!  For example, NVCC uses the host compiler's preprocessor when
201 compiling for device code, and that host compiler may in fact be clang.
203 When clang is actually compiling CUDA code -- rather than being used as a
204 subtool of NVCC's -- it defines the ``__CUDA__`` macro.  ``__CUDA_ARCH__`` is
205 defined only in device mode (but will be defined if NVCC is using clang as a
206 preprocessor).  So you can use the following incantations to detect clang CUDA
207 compilation, in host and device modes:
209 .. code-block:: c++
211   #if defined(__clang__) && defined(__CUDA__) && !defined(__CUDA_ARCH__)
212   // clang compiling CUDA code, host mode.
213   #endif
215   #if defined(__clang__) && defined(__CUDA__) && defined(__CUDA_ARCH__)
216   // clang compiling CUDA code, device mode.
217   #endif
219 Both clang and nvcc define ``__CUDACC__`` during CUDA compilation.  You can
220 detect NVCC specifically by looking for ``__NVCC__``.
222 Dialect Differences Between clang and nvcc
223 ==========================================
225 There is no formal CUDA spec, and clang and nvcc speak slightly different
226 dialects of the language.  Below, we describe some of the differences.
228 This section is painful; hopefully you can skip this section and live your life
229 blissfully unaware.
231 Compilation Models
232 ------------------
234 Most of the differences between clang and nvcc stem from the different
235 compilation models used by clang and nvcc.  nvcc uses *split compilation*,
236 which works roughly as follows:
238  * Run a preprocessor over the input ``.cu`` file to split it into two source
239    files: ``H``, containing source code for the host, and ``D``, containing
240    source code for the device.
242  * For each GPU architecture ``arch`` that we're compiling for, do:
244    * Compile ``D`` using nvcc proper.  The result of this is a ``ptx`` file for
245      ``P_arch``.
247    * Optionally, invoke ``ptxas``, the PTX assembler, to generate a file,
248      ``S_arch``, containing GPU machine code (SASS) for ``arch``.
250  * Invoke ``fatbin`` to combine all ``P_arch`` and ``S_arch`` files into a
251    single "fat binary" file, ``F``.
253  * Compile ``H`` using an external host compiler (gcc, clang, or whatever you
254    like).  ``F`` is packaged up into a header file which is force-included into
255    ``H``; nvcc generates code that calls into this header to e.g. launch
256    kernels.
258 clang uses *merged parsing*.  This is similar to split compilation, except all
259 of the host and device code is present and must be semantically-correct in both
260 compilation steps.
262   * For each GPU architecture ``arch`` that we're compiling for, do:
264     * Compile the input ``.cu`` file for device, using clang.  ``__host__`` code
265       is parsed and must be semantically correct, even though we're not
266       generating code for the host at this time.
268       The output of this step is a ``ptx`` file ``P_arch``.
270     * Invoke ``ptxas`` to generate a SASS file, ``S_arch``.  Note that, unlike
271       nvcc, clang always generates SASS code.
273   * Invoke ``fatbin`` to combine all ``P_arch`` and ``S_arch`` files into a
274     single fat binary file, ``F``.
276   * Compile ``H`` using clang.  ``__device__`` code is parsed and must be
277     semantically correct, even though we're not generating code for the device
278     at this time.
280     ``F`` is passed to this compilation, and clang includes it in a special ELF
281     section, where it can be found by tools like ``cuobjdump``.
283 (You may ask at this point, why does clang need to parse the input file
284 multiple times?  Why not parse it just once, and then use the AST to generate
285 code for the host and each device architecture?
287 Unfortunately this can't work because we have to define different macros during
288 host compilation and during device compilation for each GPU architecture.)
290 clang's approach allows it to be highly robust to C++ edge cases, as it doesn't
291 need to decide at an early stage which declarations to keep and which to throw
292 away.  But it has some consequences you should be aware of.
294 Overloading Based on ``__host__`` and ``__device__`` Attributes
295 ---------------------------------------------------------------
297 Let "H", "D", and "HD" stand for "``__host__`` functions", "``__device__``
298 functions", and "``__host__ __device__`` functions", respectively.  Functions
299 with no attributes behave the same as H.
301 nvcc does not allow you to create H and D functions with the same signature:
303 .. code-block:: c++
305   // nvcc: error - function "foo" has already been defined
306   __host__ void foo() {}
307   __device__ void foo() {}
309 However, nvcc allows you to "overload" H and D functions with different
310 signatures:
312 .. code-block:: c++
314   // nvcc: no error
315   __host__ void foo(int) {}
316   __device__ void foo() {}
318 In clang, the ``__host__`` and ``__device__`` attributes are part of a
319 function's signature, and so it's legal to have H and D functions with
320 (otherwise) the same signature:
322 .. code-block:: c++
324   // clang: no error
325   __host__ void foo() {}
326   __device__ void foo() {}
328 HD functions cannot be overloaded by H or D functions with the same signature:
330 .. code-block:: c++
332   // nvcc: error - function "foo" has already been defined
333   // clang: error - redefinition of 'foo'
334   __host__ __device__ void foo() {}
335   __device__ void foo() {}
337   // nvcc: no error
338   // clang: no error
339   __host__ __device__ void bar(int) {}
340   __device__ void bar() {}
342 When resolving an overloaded function, clang considers the host/device
343 attributes of the caller and callee.  These are used as a tiebreaker during
344 overload resolution.  See `IdentifyCUDAPreference
345 <https://clang.llvm.org/doxygen/SemaCUDA_8cpp.html>`_ for the full set of rules,
346 but at a high level they are:
348  * D functions prefer to call other Ds.  HDs are given lower priority.
350  * Similarly, H functions prefer to call other Hs, or ``__global__`` functions
351    (with equal priority).  HDs are given lower priority.
353  * HD functions prefer to call other HDs.
355    When compiling for device, HDs will call Ds with lower priority than HD, and
356    will call Hs with still lower priority.  If it's forced to call an H, the
357    program is malformed if we emit code for this HD function.  We call this the
358    "wrong-side rule", see example below.
360    The rules are symmetrical when compiling for host.
362 Some examples:
364 .. code-block:: c++
366    __host__ void foo();
367    __device__ void foo();
369    __host__ void bar();
370    __host__ __device__ void bar();
372    __host__ void test_host() {
373      foo();  // calls H overload
374      bar();  // calls H overload
375    }
377    __device__ void test_device() {
378      foo();  // calls D overload
379      bar();  // calls HD overload
380    }
382    __host__ __device__ void test_hd() {
383      foo();  // calls H overload when compiling for host, otherwise D overload
384      bar();  // always calls HD overload
385    }
387 Wrong-side rule example:
389 .. code-block:: c++
391   __host__ void host_only();
393   // We don't codegen inline functions unless they're referenced by a
394   // non-inline function.  inline_hd1() is called only from the host side, so
395   // does not generate an error.  inline_hd2() is called from the device side,
396   // so it generates an error.
397   inline __host__ __device__ void inline_hd1() { host_only(); }  // no error
398   inline __host__ __device__ void inline_hd2() { host_only(); }  // error
400   __host__ void host_fn() { inline_hd1(); }
401   __device__ void device_fn() { inline_hd2(); }
403   // This function is not inline, so it's always codegen'ed on both the host
404   // and the device.  Therefore, it generates an error.
405   __host__ __device__ void not_inline_hd() { host_only(); }
407 For the purposes of the wrong-side rule, templated functions also behave like
408 ``inline`` functions: They aren't codegen'ed unless they're instantiated
409 (usually as part of the process of invoking them).
411 clang's behavior with respect to the wrong-side rule matches nvcc's, except
412 nvcc only emits a warning for ``not_inline_hd``; device code is allowed to call
413 ``not_inline_hd``.  In its generated code, nvcc may omit ``not_inline_hd``'s
414 call to ``host_only`` entirely, or it may try to generate code for
415 ``host_only`` on the device.  What you get seems to depend on whether or not
416 the compiler chooses to inline ``host_only``.
418 Member functions, including constructors, may be overloaded using H and D
419 attributes.  However, destructors cannot be overloaded.
421 Using a Different Class on Host/Device
422 --------------------------------------
424 Occasionally you may want to have a class with different host/device versions.
426 If all of the class's members are the same on the host and device, you can just
427 provide overloads for the class's member functions.
429 However, if you want your class to have different members on host/device, you
430 won't be able to provide working H and D overloads in both classes. In this
431 case, clang is likely to be unhappy with you.
433 .. code-block:: c++
435   #ifdef __CUDA_ARCH__
436   struct S {
437     __device__ void foo() { /* use device_only */ }
438     int device_only;
439   };
440   #else
441   struct S {
442     __host__ void foo() { /* use host_only */ }
443     double host_only;
444   };
446   __device__ void test() {
447     S s;
448     // clang generates an error here, because during host compilation, we
449     // have ifdef'ed away the __device__ overload of S::foo().  The __device__
450     // overload must be present *even during host compilation*.
451     S.foo();
452   }
453   #endif
455 We posit that you don't really want to have classes with different members on H
456 and D.  For example, if you were to pass one of these as a parameter to a
457 kernel, it would have a different layout on H and D, so would not work
458 properly.
460 To make code like this compatible with clang, we recommend you separate it out
461 into two classes.  If you need to write code that works on both host and
462 device, consider writing an overloaded wrapper function that returns different
463 types on host and device.
465 .. code-block:: c++
467   struct HostS { ... };
468   struct DeviceS { ... };
470   __host__ HostS MakeStruct() { return HostS(); }
471   __device__ DeviceS MakeStruct() { return DeviceS(); }
473   // Now host and device code can call MakeStruct().
475 Unfortunately, this idiom isn't compatible with nvcc, because it doesn't allow
476 you to overload based on the H/D attributes.  Here's an idiom that works with
477 both clang and nvcc:
479 .. code-block:: c++
481   struct HostS { ... };
482   struct DeviceS { ... };
484   #ifdef __NVCC__
485     #ifndef __CUDA_ARCH__
486       __host__ HostS MakeStruct() { return HostS(); }
487     #else
488       __device__ DeviceS MakeStruct() { return DeviceS(); }
489     #endif
490   #else
491     __host__ HostS MakeStruct() { return HostS(); }
492     __device__ DeviceS MakeStruct() { return DeviceS(); }
493   #endif
495   // Now host and device code can call MakeStruct().
497 Hopefully you don't have to do this sort of thing often.
499 Optimizations
500 =============
502 Modern CPUs and GPUs are architecturally quite different, so code that's fast
503 on a CPU isn't necessarily fast on a GPU.  We've made a number of changes to
504 LLVM to make it generate good GPU code.  Among these changes are:
506 * `Straight-line scalar optimizations <https://goo.gl/4Rb9As>`_ -- These
507   reduce redundancy within straight-line code.
509 * `Aggressive speculative execution
510   <https://llvm.org/docs/doxygen/html/SpeculativeExecution_8cpp_source.html>`_
511   -- This is mainly for promoting straight-line scalar optimizations, which are
512   most effective on code along dominator paths.
514 * `Memory space inference
515   <https://llvm.org/doxygen/NVPTXInferAddressSpaces_8cpp_source.html>`_ --
516   In PTX, we can operate on pointers that are in a particular "address space"
517   (global, shared, constant, or local), or we can operate on pointers in the
518   "generic" address space, which can point to anything.  Operations in a
519   non-generic address space are faster, but pointers in CUDA are not explicitly
520   annotated with their address space, so it's up to LLVM to infer it where
521   possible.
523 * `Bypassing 64-bit divides
524   <https://llvm.org/docs/doxygen/html/BypassSlowDivision_8cpp_source.html>`_ --
525   This was an existing optimization that we enabled for the PTX backend.
527   64-bit integer divides are much slower than 32-bit ones on NVIDIA GPUs.
528   Many of the 64-bit divides in our benchmarks have a divisor and dividend
529   which fit in 32-bits at runtime. This optimization provides a fast path for
530   this common case.
532 * Aggressive loop unrolling and function inlining -- Loop unrolling and
533   function inlining need to be more aggressive for GPUs than for CPUs because
534   control flow transfer in GPU is more expensive. More aggressive unrolling and
535   inlining also promote other optimizations, such as constant propagation and
536   SROA, which sometimes speed up code by over 10x.
538   (Programmers can force unrolling and inline using clang's `loop unrolling pragmas
539   <https://clang.llvm.org/docs/AttributeReference.html#pragma-unroll-pragma-nounroll>`_
540   and ``__attribute__((always_inline))``.)
542 Publication
543 ===========
545 The team at Google published a paper in CGO 2016 detailing the optimizations
546 they'd made to clang/LLVM.  Note that "gpucc" is no longer a meaningful name:
547 The relevant tools are now just vanilla clang/LLVM.
549 | `gpucc: An Open-Source GPGPU Compiler <http://dl.acm.org/citation.cfm?id=2854041>`_
550 | Jingyue Wu, Artem Belevich, Eli Bendersky, Mark Heffernan, Chris Leary, Jacques Pienaar, Bjarke Roune, Rob Springer, Xuetian Weng, Robert Hundt
551 | *Proceedings of the 2016 International Symposium on Code Generation and Optimization (CGO 2016)*
553 | `Slides from the CGO talk <http://wujingyue.github.io/docs/gpucc-talk.pdf>`_
555 | `Tutorial given at CGO <http://wujingyue.github.io/docs/gpucc-tutorial.pdf>`_
557 Obtaining Help
558 ==============
560 To obtain help on LLVM in general and its CUDA support, see `the LLVM
561 community <https://llvm.org/docs/#mailing-lists>`_.