1 //===------ CGGPUBuiltin.cpp - Codegen for GPU builtins -------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // Generates code for built-in GPU calls which are not runtime-specific.
10 // (Runtime-specific codegen lives in programming model specific files.)
12 //===----------------------------------------------------------------------===//
14 #include "CodeGenFunction.h"
15 #include "clang/Basic/Builtins.h"
16 #include "llvm/IR/DataLayout.h"
17 #include "llvm/IR/Instruction.h"
18 #include "llvm/Support/MathExtras.h"
19 #include "llvm/Transforms/Utils/AMDGPUEmitPrintf.h"
21 using namespace clang
;
22 using namespace CodeGen
;
25 llvm::Function
*GetVprintfDeclaration(llvm::Module
&M
) {
26 llvm::Type
*ArgTypes
[] = {llvm::Type::getInt8PtrTy(M
.getContext()),
27 llvm::Type::getInt8PtrTy(M
.getContext())};
28 llvm::FunctionType
*VprintfFuncType
= llvm::FunctionType::get(
29 llvm::Type::getInt32Ty(M
.getContext()), ArgTypes
, false);
31 if (auto *F
= M
.getFunction("vprintf")) {
32 // Our CUDA system header declares vprintf with the right signature, so
33 // nobody else should have been able to declare vprintf with a bogus
35 assert(F
->getFunctionType() == VprintfFuncType
);
39 // vprintf doesn't already exist; create a declaration and insert it into the
41 return llvm::Function::Create(
42 VprintfFuncType
, llvm::GlobalVariable::ExternalLinkage
, "vprintf", &M
);
45 llvm::Function
*GetOpenMPVprintfDeclaration(CodeGenModule
&CGM
) {
46 const char *Name
= "__llvm_omp_vprintf";
47 llvm::Module
&M
= CGM
.getModule();
48 llvm::Type
*ArgTypes
[] = {llvm::Type::getInt8PtrTy(M
.getContext()),
49 llvm::Type::getInt8PtrTy(M
.getContext()),
50 llvm::Type::getInt32Ty(M
.getContext())};
51 llvm::FunctionType
*VprintfFuncType
= llvm::FunctionType::get(
52 llvm::Type::getInt32Ty(M
.getContext()), ArgTypes
, false);
54 if (auto *F
= M
.getFunction(Name
)) {
55 if (F
->getFunctionType() != VprintfFuncType
) {
56 CGM
.Error(SourceLocation(),
57 "Invalid type declaration for __llvm_omp_vprintf");
63 return llvm::Function::Create(
64 VprintfFuncType
, llvm::GlobalVariable::ExternalLinkage
, Name
, &M
);
67 // Transforms a call to printf into a call to the NVPTX vprintf syscall (which
68 // isn't particularly special; it's invoked just like a regular function).
69 // vprintf takes two args: A format string, and a pointer to a buffer containing
72 // For example, the call
74 // printf("format string", arg1, arg2, arg3);
76 // is converted into something resembling
83 // char* buf = alloca(sizeof(Tmp));
84 // *(Tmp*)buf = {a1, a2, a3};
85 // vprintf("format string", buf);
87 // buf is aligned to the max of {alignof(Arg1), ...}. Furthermore, each of the
88 // args is itself aligned to its preferred alignment.
90 // Note that by the time this function runs, E's args have already undergone the
91 // standard C vararg promotion (short -> int, float -> double, etc.).
93 std::pair
<llvm::Value
*, llvm::TypeSize
>
94 packArgsIntoNVPTXFormatBuffer(CodeGenFunction
*CGF
, const CallArgList
&Args
) {
95 const llvm::DataLayout
&DL
= CGF
->CGM
.getDataLayout();
96 llvm::LLVMContext
&Ctx
= CGF
->CGM
.getLLVMContext();
97 CGBuilderTy
&Builder
= CGF
->Builder
;
99 // Construct and fill the args buffer that we'll pass to vprintf.
100 if (Args
.size() <= 1) {
101 // If there are no args, pass a null pointer and size 0
102 llvm::Value
* BufferPtr
= llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(Ctx
));
103 return {BufferPtr
, llvm::TypeSize::Fixed(0)};
105 llvm::SmallVector
<llvm::Type
*, 8> ArgTypes
;
106 for (unsigned I
= 1, NumArgs
= Args
.size(); I
< NumArgs
; ++I
)
107 ArgTypes
.push_back(Args
[I
].getRValue(*CGF
).getScalarVal()->getType());
109 // Using llvm::StructType is correct only because printf doesn't accept
110 // aggregates. If we had to handle aggregates here, we'd have to manually
111 // compute the offsets within the alloca -- we wouldn't be able to assume
112 // that the alignment of the llvm type was the same as the alignment of the
114 llvm::Type
*AllocaTy
= llvm::StructType::create(ArgTypes
, "printf_args");
115 llvm::Value
*Alloca
= CGF
->CreateTempAlloca(AllocaTy
);
117 for (unsigned I
= 1, NumArgs
= Args
.size(); I
< NumArgs
; ++I
) {
118 llvm::Value
*P
= Builder
.CreateStructGEP(AllocaTy
, Alloca
, I
- 1);
119 llvm::Value
*Arg
= Args
[I
].getRValue(*CGF
).getScalarVal();
120 Builder
.CreateAlignedStore(Arg
, P
, DL
.getPrefTypeAlign(Arg
->getType()));
122 llvm::Value
*BufferPtr
=
123 Builder
.CreatePointerCast(Alloca
, llvm::Type::getInt8PtrTy(Ctx
));
124 return {BufferPtr
, DL
.getTypeAllocSize(AllocaTy
)};
128 bool containsNonScalarVarargs(CodeGenFunction
*CGF
, const CallArgList
&Args
) {
129 return llvm::any_of(llvm::drop_begin(Args
), [&](const CallArg
&A
) {
130 return !A
.getRValue(*CGF
).isScalar();
134 RValue
EmitDevicePrintfCallExpr(const CallExpr
*E
, CodeGenFunction
*CGF
,
135 llvm::Function
*Decl
, bool WithSizeArg
) {
136 CodeGenModule
&CGM
= CGF
->CGM
;
137 CGBuilderTy
&Builder
= CGF
->Builder
;
138 assert(E
->getBuiltinCallee() == Builtin::BIprintf
);
139 assert(E
->getNumArgs() >= 1); // printf always has at least one arg.
141 // Uses the same format as nvptx for the argument packing, but also passes
142 // an i32 for the total size of the passed pointer
144 CGF
->EmitCallArgs(Args
,
145 E
->getDirectCallee()->getType()->getAs
<FunctionProtoType
>(),
146 E
->arguments(), E
->getDirectCallee(),
147 /* ParamsToSkip = */ 0);
149 // We don't know how to emit non-scalar varargs.
150 if (containsNonScalarVarargs(CGF
, Args
)) {
151 CGM
.ErrorUnsupported(E
, "non-scalar arg to printf");
152 return RValue::get(llvm::ConstantInt::get(CGF
->IntTy
, 0));
155 auto r
= packArgsIntoNVPTXFormatBuffer(CGF
, Args
);
156 llvm::Value
*BufferPtr
= r
.first
;
158 llvm::SmallVector
<llvm::Value
*, 3> Vec
= {
159 Args
[0].getRValue(*CGF
).getScalarVal(), BufferPtr
};
161 // Passing > 32bit of data as a local alloca doesn't work for nvptx or
163 llvm::Constant
*Size
=
164 llvm::ConstantInt::get(llvm::Type::getInt32Ty(CGM
.getLLVMContext()),
165 static_cast<uint32_t>(r
.second
.getFixedValue()));
169 return RValue::get(Builder
.CreateCall(Decl
, Vec
));
173 RValue
CodeGenFunction::EmitNVPTXDevicePrintfCallExpr(const CallExpr
*E
) {
174 assert(getTarget().getTriple().isNVPTX());
175 return EmitDevicePrintfCallExpr(
176 E
, this, GetVprintfDeclaration(CGM
.getModule()), false);
179 RValue
CodeGenFunction::EmitAMDGPUDevicePrintfCallExpr(const CallExpr
*E
) {
180 assert(getTarget().getTriple().getArch() == llvm::Triple::amdgcn
);
181 assert(E
->getBuiltinCallee() == Builtin::BIprintf
||
182 E
->getBuiltinCallee() == Builtin::BI__builtin_printf
);
183 assert(E
->getNumArgs() >= 1); // printf always has at least one arg.
185 CallArgList CallArgs
;
186 EmitCallArgs(CallArgs
,
187 E
->getDirectCallee()->getType()->getAs
<FunctionProtoType
>(),
188 E
->arguments(), E
->getDirectCallee(),
189 /* ParamsToSkip = */ 0);
191 SmallVector
<llvm::Value
*, 8> Args
;
192 for (const auto &A
: CallArgs
) {
193 // We don't know how to emit non-scalar varargs.
194 if (!A
.getRValue(*this).isScalar()) {
195 CGM
.ErrorUnsupported(E
, "non-scalar arg to printf");
196 return RValue::get(llvm::ConstantInt::get(IntTy
, -1));
199 llvm::Value
*Arg
= A
.getRValue(*this).getScalarVal();
203 llvm::IRBuilder
<> IRB(Builder
.GetInsertBlock(), Builder
.GetInsertPoint());
204 IRB
.SetCurrentDebugLocation(Builder
.getCurrentDebugLocation());
206 bool isBuffered
= (CGM
.getTarget().getTargetOpts().AMDGPUPrintfKindVal
==
207 clang::TargetOptions::AMDGPUPrintfKind::Buffered
);
208 auto Printf
= llvm::emitAMDGPUPrintfCall(IRB
, Args
, isBuffered
);
209 Builder
.SetInsertPoint(IRB
.GetInsertBlock(), IRB
.GetInsertPoint());
210 return RValue::get(Printf
);
213 RValue
CodeGenFunction::EmitOpenMPDevicePrintfCallExpr(const CallExpr
*E
) {
214 assert(getTarget().getTriple().isNVPTX() ||
215 getTarget().getTriple().isAMDGCN());
216 return EmitDevicePrintfCallExpr(E
, this, GetOpenMPVprintfDeclaration(CGM
),