Revert r371023 "[lib/ObjectYAML] - Stop calling error(1) when mapping the st_other...
[llvm-complete.git] / docs / CompileCudaWithLLVM.rst
blob6e181c84e6881b9c956bcf06f9cc515ef4753c31
1 =========================
2 Compiling CUDA with clang
3 =========================
5 .. contents::
6    :local:
8 Introduction
9 ============
11 This document describes how to compile CUDA code with clang, and gives some
12 details about LLVM and clang's CUDA implementations.
14 This document assumes a basic familiarity with CUDA. Information about CUDA
15 programming can be found in the
16 `CUDA programming guide
17 <http://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html>`_.
19 Compiling CUDA Code
20 ===================
22 Prerequisites
23 -------------
25 CUDA is supported since llvm 3.9. Current release of clang (7.0.0) supports CUDA
26 7.0 through 9.2. If you need support for CUDA 10, you will need to use clang
27 built from r342924 or newer.
29 Before you build CUDA code, you'll need to have installed the appropriate driver
30 for your nvidia GPU and the CUDA SDK.  See `NVIDIA's CUDA installation guide
31 <https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html>`_ for
32 details.  Note that clang `does not support
33 <https://llvm.org/bugs/show_bug.cgi?id=26966>`_ the CUDA toolkit as installed by
34 many Linux package managers; you probably need to install CUDA in a single
35 directory from NVIDIA's package.
37 CUDA compilation is supported on Linux. Compilation on MacOS and Windows may or
38 may not work and currently have no maintainers. Compilation with CUDA-9.x is
39 `currently broken on Windows <https://bugs.llvm.org/show_bug.cgi?id=38811>`_.
41 Invoking clang
42 --------------
44 Invoking clang for CUDA compilation works similarly to compiling regular C++.
45 You just need to be aware of a few additional flags.
47 You can use `this <https://gist.github.com/855e277884eb6b388cd2f00d956c2fd4>`_
48 program as a toy example.  Save it as ``axpy.cu``.  (Clang detects that you're
49 compiling CUDA code by noticing that your filename ends with ``.cu``.
50 Alternatively, you can pass ``-x cuda``.)
52 To build and run, run the following commands, filling in the parts in angle
53 brackets as described below:
55 .. code-block:: console
57   $ clang++ axpy.cu -o axpy --cuda-gpu-arch=<GPU arch> \
58       -L<CUDA install path>/<lib64 or lib>             \
59       -lcudart_static -ldl -lrt -pthread
60   $ ./axpy
61   y[0] = 2
62   y[1] = 4
63   y[2] = 6
64   y[3] = 8
66 On MacOS, replace `-lcudart_static` with `-lcudart`; otherwise, you may get
67 "CUDA driver version is insufficient for CUDA runtime version" errors when you
68 run your program.
70 * ``<CUDA install path>`` -- the directory where you installed CUDA SDK.
71   Typically, ``/usr/local/cuda``.
73   Pass e.g. ``-L/usr/local/cuda/lib64`` if compiling in 64-bit mode; otherwise,
74   pass e.g. ``-L/usr/local/cuda/lib``.  (In CUDA, the device code and host code
75   always have the same pointer widths, so if you're compiling 64-bit code for
76   the host, you're also compiling 64-bit code for the device.) Note that as of
77   v10.0 CUDA SDK `no longer supports compilation of 32-bit
78   applications <https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html#deprecated-features>`_.
80 * ``<GPU arch>`` -- the `compute capability
81   <https://developer.nvidia.com/cuda-gpus>`_ of your GPU. For example, if you
82   want to run your program on a GPU with compute capability of 3.5, specify
83   ``--cuda-gpu-arch=sm_35``.
85   Note: You cannot pass ``compute_XX`` as an argument to ``--cuda-gpu-arch``;
86   only ``sm_XX`` is currently supported.  However, clang always includes PTX in
87   its binaries, so e.g. a binary compiled with ``--cuda-gpu-arch=sm_30`` would be
88   forwards-compatible with e.g. ``sm_35`` GPUs.
90   You can pass ``--cuda-gpu-arch`` multiple times to compile for multiple archs.
92 The `-L` and `-l` flags only need to be passed when linking.  When compiling,
93 you may also need to pass ``--cuda-path=/path/to/cuda`` if you didn't install
94 the CUDA SDK into ``/usr/local/cuda`` or ``/usr/local/cuda-X.Y``.
96 Flags that control numerical code
97 ---------------------------------
99 If you're using GPUs, you probably care about making numerical code run fast.
100 GPU hardware allows for more control over numerical operations than most CPUs,
101 but this results in more compiler options for you to juggle.
103 Flags you may wish to tweak include:
105 * ``-ffp-contract={on,off,fast}`` (defaults to ``fast`` on host and device when
106   compiling CUDA) Controls whether the compiler emits fused multiply-add
107   operations.
109   * ``off``: never emit fma operations, and prevent ptxas from fusing multiply
110     and add instructions.
111   * ``on``: fuse multiplies and adds within a single statement, but never
112     across statements (C11 semantics).  Prevent ptxas from fusing other
113     multiplies and adds.
114   * ``fast``: fuse multiplies and adds wherever profitable, even across
115     statements.  Doesn't prevent ptxas from fusing additional multiplies and
116     adds.
118   Fused multiply-add instructions can be much faster than the unfused
119   equivalents, but because the intermediate result in an fma is not rounded,
120   this flag can affect numerical code.
122 * ``-fcuda-flush-denormals-to-zero`` (default: off) When this is enabled,
123   floating point operations may flush `denormal
124   <https://en.wikipedia.org/wiki/Denormal_number>`_ inputs and/or outputs to 0.
125   Operations on denormal numbers are often much slower than the same operations
126   on normal numbers.
128 * ``-fcuda-approx-transcendentals`` (default: off) When this is enabled, the
129   compiler may emit calls to faster, approximate versions of transcendental
130   functions, instead of using the slower, fully IEEE-compliant versions.  For
131   example, this flag allows clang to emit the ptx ``sin.approx.f32``
132   instruction.
134   This is implied by ``-ffast-math``.
136 Standard library support
137 ========================
139 In clang and nvcc, most of the C++ standard library is not supported on the
140 device side.
142 ``<math.h>`` and ``<cmath>``
143 ----------------------------
145 In clang, ``math.h`` and ``cmath`` are available and `pass
146 <https://github.com/llvm/llvm-test-suite/blob/master/External/CUDA/math_h.cu>`_
147 `tests
148 <https://github.com/llvm/llvm-test-suite/blob/master/External/CUDA/cmath.cu>`_
149 adapted from libc++'s test suite.
151 In nvcc ``math.h`` and ``cmath`` are mostly available.  Versions of ``::foof``
152 in namespace std (e.g. ``std::sinf``) are not available, and where the standard
153 calls for overloads that take integral arguments, these are usually not
154 available.
156 .. code-block:: c++
158   #include <math.h>
159   #include <cmath.h>
161   // clang is OK with everything in this function.
162   __device__ void test() {
163     std::sin(0.); // nvcc - ok
164     std::sin(0);  // nvcc - error, because no std::sin(int) override is available.
165     sin(0);       // nvcc - same as above.
167     sinf(0.);       // nvcc - ok
168     std::sinf(0.);  // nvcc - no such function
169   }
171 ``<std::complex>``
172 ------------------
174 nvcc does not officially support ``std::complex``.  It's an error to use
175 ``std::complex`` in ``__device__`` code, but it often works in ``__host__
176 __device__`` code due to nvcc's interpretation of the "wrong-side rule" (see
177 below).  However, we have heard from implementers that it's possible to get
178 into situations where nvcc will omit a call to an ``std::complex`` function,
179 especially when compiling without optimizations.
181 As of 2016-11-16, clang supports ``std::complex`` without these caveats.  It is
182 tested with libstdc++ 4.8.5 and newer, but is known to work only with libc++
183 newer than 2016-11-16.
185 ``<algorithm>``
186 ---------------
188 In C++14, many useful functions from ``<algorithm>`` (notably, ``std::min`` and
189 ``std::max``) become constexpr.  You can therefore use these in device code,
190 when compiling with clang.
192 Detecting clang vs NVCC from code
193 =================================
195 Although clang's CUDA implementation is largely compatible with NVCC's, you may
196 still want to detect when you're compiling CUDA code specifically with clang.
198 This is tricky, because NVCC may invoke clang as part of its own compilation
199 process!  For example, NVCC uses the host compiler's preprocessor when
200 compiling for device code, and that host compiler may in fact be clang.
202 When clang is actually compiling CUDA code -- rather than being used as a
203 subtool of NVCC's -- it defines the ``__CUDA__`` macro.  ``__CUDA_ARCH__`` is
204 defined only in device mode (but will be defined if NVCC is using clang as a
205 preprocessor).  So you can use the following incantations to detect clang CUDA
206 compilation, in host and device modes:
208 .. code-block:: c++
210   #if defined(__clang__) && defined(__CUDA__) && !defined(__CUDA_ARCH__)
211   // clang compiling CUDA code, host mode.
212   #endif
214   #if defined(__clang__) && defined(__CUDA__) && defined(__CUDA_ARCH__)
215   // clang compiling CUDA code, device mode.
216   #endif
218 Both clang and nvcc define ``__CUDACC__`` during CUDA compilation.  You can
219 detect NVCC specifically by looking for ``__NVCC__``.
221 Dialect Differences Between clang and nvcc
222 ==========================================
224 There is no formal CUDA spec, and clang and nvcc speak slightly different
225 dialects of the language.  Below, we describe some of the differences.
227 This section is painful; hopefully you can skip this section and live your life
228 blissfully unaware.
230 Compilation Models
231 ------------------
233 Most of the differences between clang and nvcc stem from the different
234 compilation models used by clang and nvcc.  nvcc uses *split compilation*,
235 which works roughly as follows:
237  * Run a preprocessor over the input ``.cu`` file to split it into two source
238    files: ``H``, containing source code for the host, and ``D``, containing
239    source code for the device.
241  * For each GPU architecture ``arch`` that we're compiling for, do:
243    * Compile ``D`` using nvcc proper.  The result of this is a ``ptx`` file for
244      ``P_arch``.
246    * Optionally, invoke ``ptxas``, the PTX assembler, to generate a file,
247      ``S_arch``, containing GPU machine code (SASS) for ``arch``.
249  * Invoke ``fatbin`` to combine all ``P_arch`` and ``S_arch`` files into a
250    single "fat binary" file, ``F``.
252  * Compile ``H`` using an external host compiler (gcc, clang, or whatever you
253    like).  ``F`` is packaged up into a header file which is force-included into
254    ``H``; nvcc generates code that calls into this header to e.g. launch
255    kernels.
257 clang uses *merged parsing*.  This is similar to split compilation, except all
258 of the host and device code is present and must be semantically-correct in both
259 compilation steps.
261   * For each GPU architecture ``arch`` that we're compiling for, do:
263     * Compile the input ``.cu`` file for device, using clang.  ``__host__`` code
264       is parsed and must be semantically correct, even though we're not
265       generating code for the host at this time.
267       The output of this step is a ``ptx`` file ``P_arch``.
269     * Invoke ``ptxas`` to generate a SASS file, ``S_arch``.  Note that, unlike
270       nvcc, clang always generates SASS code.
272   * Invoke ``fatbin`` to combine all ``P_arch`` and ``S_arch`` files into a
273     single fat binary file, ``F``.
275   * Compile ``H`` using clang.  ``__device__`` code is parsed and must be
276     semantically correct, even though we're not generating code for the device
277     at this time.
279     ``F`` is passed to this compilation, and clang includes it in a special ELF
280     section, where it can be found by tools like ``cuobjdump``.
282 (You may ask at this point, why does clang need to parse the input file
283 multiple times?  Why not parse it just once, and then use the AST to generate
284 code for the host and each device architecture?
286 Unfortunately this can't work because we have to define different macros during
287 host compilation and during device compilation for each GPU architecture.)
289 clang's approach allows it to be highly robust to C++ edge cases, as it doesn't
290 need to decide at an early stage which declarations to keep and which to throw
291 away.  But it has some consequences you should be aware of.
293 Overloading Based on ``__host__`` and ``__device__`` Attributes
294 ---------------------------------------------------------------
296 Let "H", "D", and "HD" stand for "``__host__`` functions", "``__device__``
297 functions", and "``__host__ __device__`` functions", respectively.  Functions
298 with no attributes behave the same as H.
300 nvcc does not allow you to create H and D functions with the same signature:
302 .. code-block:: c++
304   // nvcc: error - function "foo" has already been defined
305   __host__ void foo() {}
306   __device__ void foo() {}
308 However, nvcc allows you to "overload" H and D functions with different
309 signatures:
311 .. code-block:: c++
313   // nvcc: no error
314   __host__ void foo(int) {}
315   __device__ void foo() {}
317 In clang, the ``__host__`` and ``__device__`` attributes are part of a
318 function's signature, and so it's legal to have H and D functions with
319 (otherwise) the same signature:
321 .. code-block:: c++
323   // clang: no error
324   __host__ void foo() {}
325   __device__ void foo() {}
327 HD functions cannot be overloaded by H or D functions with the same signature:
329 .. code-block:: c++
331   // nvcc: error - function "foo" has already been defined
332   // clang: error - redefinition of 'foo'
333   __host__ __device__ void foo() {}
334   __device__ void foo() {}
336   // nvcc: no error
337   // clang: no error
338   __host__ __device__ void bar(int) {}
339   __device__ void bar() {}
341 When resolving an overloaded function, clang considers the host/device
342 attributes of the caller and callee.  These are used as a tiebreaker during
343 overload resolution.  See `IdentifyCUDAPreference
344 <http://clang.llvm.org/doxygen/SemaCUDA_8cpp.html>`_ for the full set of rules,
345 but at a high level they are:
347  * D functions prefer to call other Ds.  HDs are given lower priority.
349  * Similarly, H functions prefer to call other Hs, or ``__global__`` functions
350    (with equal priority).  HDs are given lower priority.
352  * HD functions prefer to call other HDs.
354    When compiling for device, HDs will call Ds with lower priority than HD, and
355    will call Hs with still lower priority.  If it's forced to call an H, the
356    program is malformed if we emit code for this HD function.  We call this the
357    "wrong-side rule", see example below.
359    The rules are symmetrical when compiling for host.
361 Some examples:
363 .. code-block:: c++
365    __host__ void foo();
366    __device__ void foo();
368    __host__ void bar();
369    __host__ __device__ void bar();
371    __host__ void test_host() {
372      foo();  // calls H overload
373      bar();  // calls H overload
374    }
376    __device__ void test_device() {
377      foo();  // calls D overload
378      bar();  // calls HD overload
379    }
381    __host__ __device__ void test_hd() {
382      foo();  // calls H overload when compiling for host, otherwise D overload
383      bar();  // always calls HD overload
384    }
386 Wrong-side rule example:
388 .. code-block:: c++
390   __host__ void host_only();
392   // We don't codegen inline functions unless they're referenced by a
393   // non-inline function.  inline_hd1() is called only from the host side, so
394   // does not generate an error.  inline_hd2() is called from the device side,
395   // so it generates an error.
396   inline __host__ __device__ void inline_hd1() { host_only(); }  // no error
397   inline __host__ __device__ void inline_hd2() { host_only(); }  // error
399   __host__ void host_fn() { inline_hd1(); }
400   __device__ void device_fn() { inline_hd2(); }
402   // This function is not inline, so it's always codegen'ed on both the host
403   // and the device.  Therefore, it generates an error.
404   __host__ __device__ void not_inline_hd() { host_only(); }
406 For the purposes of the wrong-side rule, templated functions also behave like
407 ``inline`` functions: They aren't codegen'ed unless they're instantiated
408 (usually as part of the process of invoking them).
410 clang's behavior with respect to the wrong-side rule matches nvcc's, except
411 nvcc only emits a warning for ``not_inline_hd``; device code is allowed to call
412 ``not_inline_hd``.  In its generated code, nvcc may omit ``not_inline_hd``'s
413 call to ``host_only`` entirely, or it may try to generate code for
414 ``host_only`` on the device.  What you get seems to depend on whether or not
415 the compiler chooses to inline ``host_only``.
417 Member functions, including constructors, may be overloaded using H and D
418 attributes.  However, destructors cannot be overloaded.
420 Using a Different Class on Host/Device
421 --------------------------------------
423 Occasionally you may want to have a class with different host/device versions.
425 If all of the class's members are the same on the host and device, you can just
426 provide overloads for the class's member functions.
428 However, if you want your class to have different members on host/device, you
429 won't be able to provide working H and D overloads in both classes. In this
430 case, clang is likely to be unhappy with you.
432 .. code-block:: c++
434   #ifdef __CUDA_ARCH__
435   struct S {
436     __device__ void foo() { /* use device_only */ }
437     int device_only;
438   };
439   #else
440   struct S {
441     __host__ void foo() { /* use host_only */ }
442     double host_only;
443   };
445   __device__ void test() {
446     S s;
447     // clang generates an error here, because during host compilation, we
448     // have ifdef'ed away the __device__ overload of S::foo().  The __device__
449     // overload must be present *even during host compilation*.
450     S.foo();
451   }
452   #endif
454 We posit that you don't really want to have classes with different members on H
455 and D.  For example, if you were to pass one of these as a parameter to a
456 kernel, it would have a different layout on H and D, so would not work
457 properly.
459 To make code like this compatible with clang, we recommend you separate it out
460 into two classes.  If you need to write code that works on both host and
461 device, consider writing an overloaded wrapper function that returns different
462 types on host and device.
464 .. code-block:: c++
466   struct HostS { ... };
467   struct DeviceS { ... };
469   __host__ HostS MakeStruct() { return HostS(); }
470   __device__ DeviceS MakeStruct() { return DeviceS(); }
472   // Now host and device code can call MakeStruct().
474 Unfortunately, this idiom isn't compatible with nvcc, because it doesn't allow
475 you to overload based on the H/D attributes.  Here's an idiom that works with
476 both clang and nvcc:
478 .. code-block:: c++
480   struct HostS { ... };
481   struct DeviceS { ... };
483   #ifdef __NVCC__
484     #ifndef __CUDA_ARCH__
485       __host__ HostS MakeStruct() { return HostS(); }
486     #else
487       __device__ DeviceS MakeStruct() { return DeviceS(); }
488     #endif
489   #else
490     __host__ HostS MakeStruct() { return HostS(); }
491     __device__ DeviceS MakeStruct() { return DeviceS(); }
492   #endif
494   // Now host and device code can call MakeStruct().
496 Hopefully you don't have to do this sort of thing often.
498 Optimizations
499 =============
501 Modern CPUs and GPUs are architecturally quite different, so code that's fast
502 on a CPU isn't necessarily fast on a GPU.  We've made a number of changes to
503 LLVM to make it generate good GPU code.  Among these changes are:
505 * `Straight-line scalar optimizations <https://goo.gl/4Rb9As>`_ -- These
506   reduce redundancy within straight-line code.
508 * `Aggressive speculative execution
509   <http://llvm.org/docs/doxygen/html/SpeculativeExecution_8cpp_source.html>`_
510   -- This is mainly for promoting straight-line scalar optimizations, which are
511   most effective on code along dominator paths.
513 * `Memory space inference
514   <http://llvm.org/doxygen/NVPTXInferAddressSpaces_8cpp_source.html>`_ --
515   In PTX, we can operate on pointers that are in a paricular "address space"
516   (global, shared, constant, or local), or we can operate on pointers in the
517   "generic" address space, which can point to anything.  Operations in a
518   non-generic address space are faster, but pointers in CUDA are not explicitly
519   annotated with their address space, so it's up to LLVM to infer it where
520   possible.
522 * `Bypassing 64-bit divides
523   <http://llvm.org/docs/doxygen/html/BypassSlowDivision_8cpp_source.html>`_ --
524   This was an existing optimization that we enabled for the PTX backend.
526   64-bit integer divides are much slower than 32-bit ones on NVIDIA GPUs.
527   Many of the 64-bit divides in our benchmarks have a divisor and dividend
528   which fit in 32-bits at runtime. This optimization provides a fast path for
529   this common case.
531 * Aggressive loop unrooling and function inlining -- Loop unrolling and
532   function inlining need to be more aggressive for GPUs than for CPUs because
533   control flow transfer in GPU is more expensive. More aggressive unrolling and
534   inlining also promote other optimizations, such as constant propagation and
535   SROA, which sometimes speed up code by over 10x.
537   (Programmers can force unrolling and inline using clang's `loop unrolling pragmas
538   <http://clang.llvm.org/docs/AttributeReference.html#pragma-unroll-pragma-nounroll>`_
539   and ``__attribute__((always_inline))``.)
541 Publication
542 ===========
544 The team at Google published a paper in CGO 2016 detailing the optimizations
545 they'd made to clang/LLVM.  Note that "gpucc" is no longer a meaningful name:
546 The relevant tools are now just vanilla clang/LLVM.
548 | `gpucc: An Open-Source GPGPU Compiler <http://dl.acm.org/citation.cfm?id=2854041>`_
549 | Jingyue Wu, Artem Belevich, Eli Bendersky, Mark Heffernan, Chris Leary, Jacques Pienaar, Bjarke Roune, Rob Springer, Xuetian Weng, Robert Hundt
550 | *Proceedings of the 2016 International Symposium on Code Generation and Optimization (CGO 2016)*
552 | `Slides from the CGO talk <http://wujingyue.github.io/docs/gpucc-talk.pdf>`_
554 | `Tutorial given at CGO <http://wujingyue.github.io/docs/gpucc-tutorial.pdf>`_
556 Obtaining Help
557 ==============
559 To obtain help on LLVM in general and its CUDA support, see `the LLVM
560 community <http://llvm.org/docs/#mailing-lists>`_.