1 //=== AMDGPUPrintfRuntimeBinding.cpp - OpenCL printf implementation -------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
10 // The pass bind printfs to a kernel arg pointer that will be bound to a buffer
11 // later by the runtime.
13 // This pass traverses the functions in the module and converts
14 // each call to printf to a sequence of operations that
15 // store the following into the printf buffer:
16 // - format string (passed as a module's metadata unique ID)
17 // - bitwise copies of printf arguments
18 // The backend passes will need to store metadata in the kernel
19 //===----------------------------------------------------------------------===//
22 #include "llvm/ADT/SmallString.h"
23 #include "llvm/ADT/StringExtras.h"
24 #include "llvm/ADT/Triple.h"
25 #include "llvm/Analysis/InstructionSimplify.h"
26 #include "llvm/Analysis/TargetLibraryInfo.h"
27 #include "llvm/CodeGen/Passes.h"
28 #include "llvm/IR/Constants.h"
29 #include "llvm/IR/DataLayout.h"
30 #include "llvm/IR/Dominators.h"
31 #include "llvm/IR/GlobalVariable.h"
32 #include "llvm/IR/IRBuilder.h"
33 #include "llvm/IR/InstVisitor.h"
34 #include "llvm/IR/Instructions.h"
35 #include "llvm/IR/Module.h"
36 #include "llvm/IR/Type.h"
37 #include "llvm/Support/CommandLine.h"
38 #include "llvm/Support/Debug.h"
39 #include "llvm/Support/raw_ostream.h"
40 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
43 #define DEBUG_TYPE "printfToRuntime"
47 class LLVM_LIBRARY_VISIBILITY AMDGPUPrintfRuntimeBinding final
49 public InstVisitor
<AMDGPUPrintfRuntimeBinding
> {
54 explicit AMDGPUPrintfRuntimeBinding();
56 void visitCallSite(CallSite CS
) {
57 Function
*F
= CS
.getCalledFunction();
58 if (F
&& F
->hasName() && F
->getName() == "printf")
59 Printfs
.push_back(CS
.getInstruction());
63 bool runOnModule(Module
&M
) override
;
64 void getConversionSpecifiers(SmallVectorImpl
<char> &OpConvSpecifiers
,
65 StringRef fmt
, size_t num_ops
) const;
67 bool shouldPrintAsStr(char Specifier
, Type
*OpType
) const;
69 lowerPrintfForGpu(Module
&M
,
70 function_ref
<const TargetLibraryInfo
&(Function
&)> GetTLI
);
72 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
73 AU
.addRequired
<TargetLibraryInfoWrapperPass
>();
74 AU
.addRequired
<DominatorTreeWrapperPass
>();
77 Value
*simplify(Instruction
*I
, const TargetLibraryInfo
*TLI
) {
78 return SimplifyInstruction(I
, {*TD
, TLI
, DT
});
82 const DominatorTree
*DT
;
83 SmallVector
<Value
*, 32> Printfs
;
87 char AMDGPUPrintfRuntimeBinding::ID
= 0;
89 INITIALIZE_PASS_BEGIN(AMDGPUPrintfRuntimeBinding
,
90 "amdgpu-printf-runtime-binding", "AMDGPU Printf lowering",
92 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass
)
93 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass
)
94 INITIALIZE_PASS_END(AMDGPUPrintfRuntimeBinding
, "amdgpu-printf-runtime-binding",
95 "AMDGPU Printf lowering", false, false)
97 char &llvm::AMDGPUPrintfRuntimeBindingID
= AMDGPUPrintfRuntimeBinding::ID
;
100 ModulePass
*createAMDGPUPrintfRuntimeBinding() {
101 return new AMDGPUPrintfRuntimeBinding();
105 AMDGPUPrintfRuntimeBinding::AMDGPUPrintfRuntimeBinding()
106 : ModulePass(ID
), TD(nullptr), DT(nullptr) {
107 initializeAMDGPUPrintfRuntimeBindingPass(*PassRegistry::getPassRegistry());
110 void AMDGPUPrintfRuntimeBinding::getConversionSpecifiers(
111 SmallVectorImpl
<char> &OpConvSpecifiers
, StringRef Fmt
,
112 size_t NumOps
) const {
113 // not all format characters are collected.
114 // At this time the format characters of interest
115 // are %p and %s, which use to know if we
116 // are either storing a literal string or a
117 // pointer to the printf buffer.
118 static const char ConvSpecifiers
[] = "cdieEfgGaosuxXp";
119 size_t CurFmtSpecifierIdx
= 0;
120 size_t PrevFmtSpecifierIdx
= 0;
122 while ((CurFmtSpecifierIdx
= Fmt
.find_first_of(
123 ConvSpecifiers
, CurFmtSpecifierIdx
)) != StringRef::npos
) {
124 bool ArgDump
= false;
125 StringRef CurFmt
= Fmt
.substr(PrevFmtSpecifierIdx
,
126 CurFmtSpecifierIdx
- PrevFmtSpecifierIdx
);
127 size_t pTag
= CurFmt
.find_last_of("%");
128 if (pTag
!= StringRef::npos
) {
130 while (pTag
&& CurFmt
[--pTag
] == '%') {
136 OpConvSpecifiers
.push_back(Fmt
[CurFmtSpecifierIdx
]);
138 PrevFmtSpecifierIdx
= ++CurFmtSpecifierIdx
;
142 bool AMDGPUPrintfRuntimeBinding::shouldPrintAsStr(char Specifier
,
143 Type
*OpType
) const {
144 if (Specifier
!= 's')
146 const PointerType
*PT
= dyn_cast
<PointerType
>(OpType
);
147 if (!PT
|| PT
->getAddressSpace() != AMDGPUAS::CONSTANT_ADDRESS
)
149 Type
*ElemType
= PT
->getContainedType(0);
150 if (ElemType
->getTypeID() != Type::IntegerTyID
)
152 IntegerType
*ElemIType
= cast
<IntegerType
>(ElemType
);
153 return ElemIType
->getBitWidth() == 8;
156 bool AMDGPUPrintfRuntimeBinding::lowerPrintfForGpu(
157 Module
&M
, function_ref
<const TargetLibraryInfo
&(Function
&)> GetTLI
) {
158 LLVMContext
&Ctx
= M
.getContext();
159 IRBuilder
<> Builder(Ctx
);
160 Type
*I32Ty
= Type::getInt32Ty(Ctx
);
162 // NB: This is important for this string size to be divizable by 4
163 const char NonLiteralStr
[4] = "???";
165 for (auto P
: Printfs
) {
166 auto CI
= cast
<CallInst
>(P
);
167 unsigned NumOps
= CI
->getNumArgOperands();
169 SmallString
<16> OpConvSpecifiers
;
170 Value
*Op
= CI
->getArgOperand(0);
172 if (auto LI
= dyn_cast
<LoadInst
>(Op
)) {
173 Op
= LI
->getPointerOperand();
174 for (auto Use
: Op
->users()) {
175 if (auto SI
= dyn_cast
<StoreInst
>(Use
)) {
176 Op
= SI
->getValueOperand();
182 if (auto I
= dyn_cast
<Instruction
>(Op
)) {
183 Value
*Op_simplified
= simplify(I
, &GetTLI(*I
->getFunction()));
188 ConstantExpr
*ConstExpr
= dyn_cast
<ConstantExpr
>(Op
);
191 GlobalVariable
*GVar
= dyn_cast
<GlobalVariable
>(ConstExpr
->getOperand(0));
193 StringRef
Str("unknown");
194 if (GVar
&& GVar
->hasInitializer()) {
195 auto Init
= GVar
->getInitializer();
196 if (auto CA
= dyn_cast
<ConstantDataArray
>(Init
)) {
198 Str
= CA
->getAsCString();
199 } else if (isa
<ConstantAggregateZero
>(Init
)) {
203 // we need this call to ascertain
204 // that we are printing a string
205 // or a pointer. It takes out the
206 // specifiers and fills up the first
208 getConversionSpecifiers(OpConvSpecifiers
, Str
, NumOps
- 1);
210 // Add metadata for the string
211 std::string AStreamHolder
;
212 raw_string_ostream
Sizes(AStreamHolder
);
213 int Sum
= DWORD_ALIGN
;
214 Sizes
<< CI
->getNumArgOperands() - 1;
216 for (unsigned ArgCount
= 1; ArgCount
< CI
->getNumArgOperands() &&
217 ArgCount
<= OpConvSpecifiers
.size();
219 Value
*Arg
= CI
->getArgOperand(ArgCount
);
220 Type
*ArgType
= Arg
->getType();
221 unsigned ArgSize
= TD
->getTypeAllocSizeInBits(ArgType
);
222 ArgSize
= ArgSize
/ 8;
224 // ArgSize by design should be a multiple of DWORD_ALIGN,
225 // expand the arguments that do not follow this rule.
227 if (ArgSize
% DWORD_ALIGN
!= 0) {
228 llvm::Type
*ResType
= llvm::Type::getInt32Ty(Ctx
);
229 VectorType
*LLVMVecType
= llvm::dyn_cast
<llvm::VectorType
>(ArgType
);
230 int NumElem
= LLVMVecType
? LLVMVecType
->getNumElements() : 1;
231 if (LLVMVecType
&& NumElem
> 1)
232 ResType
= llvm::VectorType::get(ResType
, NumElem
);
233 Builder
.SetInsertPoint(CI
);
234 Builder
.SetCurrentDebugLocation(CI
->getDebugLoc());
235 if (OpConvSpecifiers
[ArgCount
- 1] == 'x' ||
236 OpConvSpecifiers
[ArgCount
- 1] == 'X' ||
237 OpConvSpecifiers
[ArgCount
- 1] == 'u' ||
238 OpConvSpecifiers
[ArgCount
- 1] == 'o')
239 Arg
= Builder
.CreateZExt(Arg
, ResType
);
241 Arg
= Builder
.CreateSExt(Arg
, ResType
);
242 ArgType
= Arg
->getType();
243 ArgSize
= TD
->getTypeAllocSizeInBits(ArgType
);
244 ArgSize
= ArgSize
/ 8;
245 CI
->setOperand(ArgCount
, Arg
);
247 if (OpConvSpecifiers
[ArgCount
- 1] == 'f') {
248 ConstantFP
*FpCons
= dyn_cast
<ConstantFP
>(Arg
);
252 FPExtInst
*FpExt
= dyn_cast
<FPExtInst
>(Arg
);
253 if (FpExt
&& FpExt
->getType()->isDoubleTy() &&
254 FpExt
->getOperand(0)->getType()->isFloatTy())
258 if (shouldPrintAsStr(OpConvSpecifiers
[ArgCount
- 1], ArgType
)) {
259 if (ConstantExpr
*ConstExpr
= dyn_cast
<ConstantExpr
>(Arg
)) {
261 dyn_cast
<GlobalVariable
>(ConstExpr
->getOperand(0));
262 if (GV
&& GV
->hasInitializer()) {
263 Constant
*Init
= GV
->getInitializer();
264 ConstantDataArray
*CA
= dyn_cast
<ConstantDataArray
>(Init
);
265 if (Init
->isZeroValue() || CA
->isString()) {
266 size_t SizeStr
= Init
->isZeroValue()
268 : (strlen(CA
->getAsCString().data()) + 1);
269 size_t Rem
= SizeStr
% DWORD_ALIGN
;
271 LLVM_DEBUG(dbgs() << "Printf string original size = " << SizeStr
274 NSizeStr
= SizeStr
+ (DWORD_ALIGN
- Rem
);
281 ArgSize
= sizeof(NonLiteralStr
);
284 ArgSize
= sizeof(NonLiteralStr
);
287 LLVM_DEBUG(dbgs() << "Printf ArgSize (in buffer) = " << ArgSize
288 << " for type: " << *ArgType
<< '\n');
289 Sizes
<< ArgSize
<< ':';
292 LLVM_DEBUG(dbgs() << "Printf format string in source = " << Str
.str()
294 for (size_t I
= 0; I
< Str
.size(); ++I
) {
295 // Rest of the C escape sequences (e.g. \') are handled correctly
317 // ':' cannot be scanned by Flex, as it is defined as a delimiter
318 // Replace it with it's octal representation \72
327 // Insert the printf_alloc call
328 Builder
.SetInsertPoint(CI
);
329 Builder
.SetCurrentDebugLocation(CI
->getDebugLoc());
331 AttributeList Attr
= AttributeList::get(Ctx
, AttributeList::FunctionIndex
,
332 Attribute::NoUnwind
);
334 Type
*SizetTy
= Type::getInt32Ty(Ctx
);
336 Type
*Tys_alloc
[1] = {SizetTy
};
337 Type
*I8Ptr
= PointerType::get(Type::getInt8Ty(Ctx
), 1);
338 FunctionType
*FTy_alloc
= FunctionType::get(I8Ptr
, Tys_alloc
, false);
339 FunctionCallee PrintfAllocFn
=
340 M
.getOrInsertFunction(StringRef("__printf_alloc"), FTy_alloc
, Attr
);
342 LLVM_DEBUG(dbgs() << "Printf metadata = " << Sizes
.str() << '\n');
343 std::string fmtstr
= itostr(++UniqID
) + ":" + Sizes
.str().c_str();
344 MDString
*fmtStrArray
= MDString::get(Ctx
, fmtstr
);
346 // Instead of creating global variables, the
347 // printf format strings are extracted
348 // and passed as metadata. This avoids
349 // polluting llvm's symbol tables in this module.
350 // Metadata is going to be extracted
351 // by the backend passes and inserted
352 // into the OpenCL binary as appropriate.
353 StringRef
amd("llvm.printf.fmts");
354 NamedMDNode
*metaD
= M
.getOrInsertNamedMetadata(amd
);
355 MDNode
*myMD
= MDNode::get(Ctx
, fmtStrArray
);
356 metaD
->addOperand(myMD
);
357 Value
*sumC
= ConstantInt::get(SizetTy
, Sum
, false);
358 SmallVector
<Value
*, 1> alloc_args
;
359 alloc_args
.push_back(sumC
);
361 CallInst::Create(PrintfAllocFn
, alloc_args
, "printf_alloc_fn", CI
);
364 // Insert code to split basicblock with a
365 // piece of hammock code.
366 // basicblock splits after buffer overflow check
368 ConstantPointerNull
*zeroIntPtr
=
369 ConstantPointerNull::get(PointerType::get(Type::getInt8Ty(Ctx
), 1));
371 dyn_cast
<ICmpInst
>(Builder
.CreateICmpNE(pcall
, zeroIntPtr
, ""));
372 if (!CI
->use_empty()) {
374 Builder
.CreateSExt(Builder
.CreateNot(cmp
), I32Ty
, "printf_res");
375 CI
->replaceAllUsesWith(result
);
377 SplitBlock(CI
->getParent(), cmp
);
379 SplitBlockAndInsertIfThen(cmp
, cmp
->getNextNode(), false);
381 Builder
.SetInsertPoint(Brnch
);
383 // store unique printf id in the buffer
385 SmallVector
<Value
*, 1> ZeroIdxList
;
386 ConstantInt
*zeroInt
=
387 ConstantInt::get(Ctx
, APInt(32, StringRef("0"), 10));
388 ZeroIdxList
.push_back(zeroInt
);
390 GetElementPtrInst
*BufferIdx
=
391 dyn_cast
<GetElementPtrInst
>(GetElementPtrInst::Create(
392 nullptr, pcall
, ZeroIdxList
, "PrintBuffID", Brnch
));
394 Type
*idPointer
= PointerType::get(I32Ty
, AMDGPUAS::GLOBAL_ADDRESS
);
396 new BitCastInst(BufferIdx
, idPointer
, "PrintBuffIdCast", Brnch
);
399 new StoreInst(ConstantInt::get(I32Ty
, UniqID
), id_gep_cast
);
400 stbuff
->insertBefore(Brnch
); // to Remove unused variable warning
402 SmallVector
<Value
*, 2> FourthIdxList
;
403 ConstantInt
*fourInt
=
404 ConstantInt::get(Ctx
, APInt(32, StringRef("4"), 10));
406 FourthIdxList
.push_back(fourInt
); // 1st 4 bytes hold the printf_id
407 // the following GEP is the buffer pointer
408 BufferIdx
= cast
<GetElementPtrInst
>(GetElementPtrInst::Create(
409 nullptr, pcall
, FourthIdxList
, "PrintBuffGep", Brnch
));
411 Type
*Int32Ty
= Type::getInt32Ty(Ctx
);
412 Type
*Int64Ty
= Type::getInt64Ty(Ctx
);
413 for (unsigned ArgCount
= 1; ArgCount
< CI
->getNumArgOperands() &&
414 ArgCount
<= OpConvSpecifiers
.size();
416 Value
*Arg
= CI
->getArgOperand(ArgCount
);
417 Type
*ArgType
= Arg
->getType();
418 SmallVector
<Value
*, 32> WhatToStore
;
419 if (ArgType
->isFPOrFPVectorTy() &&
420 (ArgType
->getTypeID() != Type::VectorTyID
)) {
421 Type
*IType
= (ArgType
->isFloatTy()) ? Int32Ty
: Int64Ty
;
422 if (OpConvSpecifiers
[ArgCount
- 1] == 'f') {
423 ConstantFP
*fpCons
= dyn_cast
<ConstantFP
>(Arg
);
425 APFloat
Val(fpCons
->getValueAPF());
427 Val
.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven
,
429 Arg
= ConstantFP::get(Ctx
, Val
);
432 FPExtInst
*FpExt
= dyn_cast
<FPExtInst
>(Arg
);
433 if (FpExt
&& FpExt
->getType()->isDoubleTy() &&
434 FpExt
->getOperand(0)->getType()->isFloatTy()) {
435 Arg
= FpExt
->getOperand(0);
440 Arg
= new BitCastInst(Arg
, IType
, "PrintArgFP", Brnch
);
441 WhatToStore
.push_back(Arg
);
442 } else if (ArgType
->getTypeID() == Type::PointerTyID
) {
443 if (shouldPrintAsStr(OpConvSpecifiers
[ArgCount
- 1], ArgType
)) {
444 const char *S
= NonLiteralStr
;
445 if (ConstantExpr
*ConstExpr
= dyn_cast
<ConstantExpr
>(Arg
)) {
447 dyn_cast
<GlobalVariable
>(ConstExpr
->getOperand(0));
448 if (GV
&& GV
->hasInitializer()) {
449 Constant
*Init
= GV
->getInitializer();
450 ConstantDataArray
*CA
= dyn_cast
<ConstantDataArray
>(Init
);
451 if (Init
->isZeroValue() || CA
->isString()) {
452 S
= Init
->isZeroValue() ? "" : CA
->getAsCString().data();
456 size_t SizeStr
= strlen(S
) + 1;
457 size_t Rem
= SizeStr
% DWORD_ALIGN
;
460 NSizeStr
= SizeStr
+ (DWORD_ALIGN
- Rem
);
465 char *MyNewStr
= new char[NSizeStr
]();
467 int NumInts
= NSizeStr
/ 4;
470 int ANum
= *(int *)(MyNewStr
+ CharC
);
473 Value
*ANumV
= ConstantInt::get(Int32Ty
, ANum
, false);
474 WhatToStore
.push_back(ANumV
);
478 // Empty string, give a hint to RT it is no NULL
479 Value
*ANumV
= ConstantInt::get(Int32Ty
, 0xFFFFFF00, false);
480 WhatToStore
.push_back(ANumV
);
483 uint64_t Size
= TD
->getTypeAllocSizeInBits(ArgType
);
484 assert((Size
== 32 || Size
== 64) && "unsupported size");
485 Type
*DstType
= (Size
== 32) ? Int32Ty
: Int64Ty
;
486 Arg
= new PtrToIntInst(Arg
, DstType
, "PrintArgPtr", Brnch
);
487 WhatToStore
.push_back(Arg
);
489 } else if (ArgType
->getTypeID() == Type::VectorTyID
) {
491 uint32_t EleCount
= cast
<VectorType
>(ArgType
)->getNumElements();
492 uint32_t EleSize
= ArgType
->getScalarSizeInBits();
493 uint32_t TotalSize
= EleCount
* EleSize
;
495 IntegerType
*Int32Ty
= Type::getInt32Ty(ArgType
->getContext());
496 Constant
*Indices
[4] = {
497 ConstantInt::get(Int32Ty
, 0), ConstantInt::get(Int32Ty
, 1),
498 ConstantInt::get(Int32Ty
, 2), ConstantInt::get(Int32Ty
, 2)};
499 Constant
*Mask
= ConstantVector::get(Indices
);
500 ShuffleVectorInst
*Shuffle
= new ShuffleVectorInst(Arg
, Arg
, Mask
);
501 Shuffle
->insertBefore(Brnch
);
503 ArgType
= Arg
->getType();
504 TotalSize
+= EleSize
;
508 EleCount
= TotalSize
/ 64;
509 IType
= dyn_cast
<Type
>(Type::getInt64Ty(ArgType
->getContext()));
513 EleCount
= TotalSize
/ 64;
514 IType
= dyn_cast
<Type
>(Type::getInt64Ty(ArgType
->getContext()));
515 } else if (EleCount
>= 3) {
517 IType
= dyn_cast
<Type
>(Type::getInt32Ty(ArgType
->getContext()));
520 IType
= dyn_cast
<Type
>(Type::getInt16Ty(ArgType
->getContext()));
525 EleCount
= TotalSize
/ 64;
526 IType
= dyn_cast
<Type
>(Type::getInt64Ty(ArgType
->getContext()));
529 IType
= dyn_cast
<Type
>(Type::getInt32Ty(ArgType
->getContext()));
534 IType
= dyn_cast
<Type
>(VectorType::get(IType
, EleCount
));
536 Arg
= new BitCastInst(Arg
, IType
, "PrintArgVect", Brnch
);
537 WhatToStore
.push_back(Arg
);
539 WhatToStore
.push_back(Arg
);
541 for (unsigned I
= 0, E
= WhatToStore
.size(); I
!= E
; ++I
) {
542 Value
*TheBtCast
= WhatToStore
[I
];
544 TD
->getTypeAllocSizeInBits(TheBtCast
->getType()) / 8;
545 SmallVector
<Value
*, 1> BuffOffset
;
546 BuffOffset
.push_back(ConstantInt::get(I32Ty
, ArgSize
));
548 Type
*ArgPointer
= PointerType::get(TheBtCast
->getType(), 1);
550 new BitCastInst(BufferIdx
, ArgPointer
, "PrintBuffPtrCast", Brnch
);
551 StoreInst
*StBuff
= new StoreInst(TheBtCast
, CastedGEP
, Brnch
);
552 LLVM_DEBUG(dbgs() << "inserting store to printf buffer:\n"
555 if (I
+ 1 == E
&& ArgCount
+ 1 == CI
->getNumArgOperands())
557 BufferIdx
= dyn_cast
<GetElementPtrInst
>(GetElementPtrInst::Create(
558 nullptr, BufferIdx
, BuffOffset
, "PrintBuffNextPtr", Brnch
));
559 LLVM_DEBUG(dbgs() << "inserting gep to the printf buffer:\n"
560 << *BufferIdx
<< '\n');
566 // erase the printf calls
567 for (auto P
: Printfs
) {
568 auto CI
= cast
<CallInst
>(P
);
569 CI
->eraseFromParent();
576 bool AMDGPUPrintfRuntimeBinding::runOnModule(Module
&M
) {
577 Triple
TT(M
.getTargetTriple());
578 if (TT
.getArch() == Triple::r600
)
586 TD
= &M
.getDataLayout();
587 auto DTWP
= getAnalysisIfAvailable
<DominatorTreeWrapperPass
>();
588 DT
= DTWP
? &DTWP
->getDomTree() : nullptr;
589 auto GetTLI
= [this](Function
&F
) -> TargetLibraryInfo
& {
590 return this->getAnalysis
<TargetLibraryInfoWrapperPass
>().getTLI(F
);
593 return lowerPrintfForGpu(M
, GetTLI
);