1 //=== AMDGPUPrintfRuntimeBinding.cpp - OpenCL printf implementation -------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
10 // The pass bind printfs to a kernel arg pointer that will be bound to a buffer
11 // later by the runtime.
13 // This pass traverses the functions in the module and converts
14 // each call to printf to a sequence of operations that
15 // store the following into the printf buffer:
16 // - format string (passed as a module's metadata unique ID)
17 // - bitwise copies of printf arguments
18 // The backend passes will need to store metadata in the kernel
19 //===----------------------------------------------------------------------===//
22 #include "llvm/ADT/SmallString.h"
23 #include "llvm/ADT/StringExtras.h"
24 #include "llvm/ADT/Triple.h"
25 #include "llvm/Analysis/InstructionSimplify.h"
26 #include "llvm/Analysis/TargetLibraryInfo.h"
27 #include "llvm/CodeGen/Passes.h"
28 #include "llvm/IR/Constants.h"
29 #include "llvm/IR/DataLayout.h"
30 #include "llvm/IR/Dominators.h"
31 #include "llvm/IR/GlobalVariable.h"
32 #include "llvm/IR/IRBuilder.h"
33 #include "llvm/IR/InstVisitor.h"
34 #include "llvm/IR/Instructions.h"
35 #include "llvm/IR/Module.h"
36 #include "llvm/IR/Type.h"
37 #include "llvm/Support/CommandLine.h"
38 #include "llvm/Support/Debug.h"
39 #include "llvm/Support/raw_ostream.h"
40 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
43 #define DEBUG_TYPE "printfToRuntime"
47 class LLVM_LIBRARY_VISIBILITY AMDGPUPrintfRuntimeBinding final
49 public InstVisitor
<AMDGPUPrintfRuntimeBinding
> {
54 explicit AMDGPUPrintfRuntimeBinding();
56 void visitCallSite(CallSite CS
) {
57 Function
*F
= CS
.getCalledFunction();
58 if (F
&& F
->hasName() && F
->getName() == "printf")
59 Printfs
.push_back(CS
.getInstruction());
63 bool runOnModule(Module
&M
) override
;
64 void getConversionSpecifiers(SmallVectorImpl
<char> &OpConvSpecifiers
,
65 StringRef fmt
, size_t num_ops
) const;
67 bool shouldPrintAsStr(char Specifier
, Type
*OpType
) const;
68 bool lowerPrintfForGpu(Module
&M
);
70 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
71 AU
.addRequired
<TargetLibraryInfoWrapperPass
>();
72 AU
.addRequired
<DominatorTreeWrapperPass
>();
75 Value
*simplify(Instruction
*I
) {
76 return SimplifyInstruction(I
, {*TD
, TLI
, DT
});
80 const DominatorTree
*DT
;
81 const TargetLibraryInfo
*TLI
;
82 SmallVector
<Value
*, 32> Printfs
;
86 char AMDGPUPrintfRuntimeBinding::ID
= 0;
88 INITIALIZE_PASS_BEGIN(AMDGPUPrintfRuntimeBinding
,
89 "amdgpu-printf-runtime-binding", "AMDGPU Printf lowering",
91 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass
)
92 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass
)
93 INITIALIZE_PASS_END(AMDGPUPrintfRuntimeBinding
, "amdgpu-printf-runtime-binding",
94 "AMDGPU Printf lowering", false, false)
96 char &llvm::AMDGPUPrintfRuntimeBindingID
= AMDGPUPrintfRuntimeBinding::ID
;
99 ModulePass
*createAMDGPUPrintfRuntimeBinding() {
100 return new AMDGPUPrintfRuntimeBinding();
104 AMDGPUPrintfRuntimeBinding::AMDGPUPrintfRuntimeBinding()
105 : ModulePass(ID
), TD(nullptr), DT(nullptr), TLI(nullptr) {
106 initializeAMDGPUPrintfRuntimeBindingPass(*PassRegistry::getPassRegistry());
109 void AMDGPUPrintfRuntimeBinding::getConversionSpecifiers(
110 SmallVectorImpl
<char> &OpConvSpecifiers
, StringRef Fmt
,
111 size_t NumOps
) const {
112 // not all format characters are collected.
113 // At this time the format characters of interest
114 // are %p and %s, which use to know if we
115 // are either storing a literal string or a
116 // pointer to the printf buffer.
117 static const char ConvSpecifiers
[] = "cdieEfgGaosuxXp";
118 size_t CurFmtSpecifierIdx
= 0;
119 size_t PrevFmtSpecifierIdx
= 0;
121 while ((CurFmtSpecifierIdx
= Fmt
.find_first_of(
122 ConvSpecifiers
, CurFmtSpecifierIdx
)) != StringRef::npos
) {
123 bool ArgDump
= false;
124 StringRef CurFmt
= Fmt
.substr(PrevFmtSpecifierIdx
,
125 CurFmtSpecifierIdx
- PrevFmtSpecifierIdx
);
126 size_t pTag
= CurFmt
.find_last_of("%");
127 if (pTag
!= StringRef::npos
) {
129 while (pTag
&& CurFmt
[--pTag
] == '%') {
135 OpConvSpecifiers
.push_back(Fmt
[CurFmtSpecifierIdx
]);
137 PrevFmtSpecifierIdx
= ++CurFmtSpecifierIdx
;
141 bool AMDGPUPrintfRuntimeBinding::shouldPrintAsStr(char Specifier
,
142 Type
*OpType
) const {
143 if (Specifier
!= 's')
145 const PointerType
*PT
= dyn_cast
<PointerType
>(OpType
);
146 if (!PT
|| PT
->getAddressSpace() != AMDGPUAS::CONSTANT_ADDRESS
)
148 Type
*ElemType
= PT
->getContainedType(0);
149 if (ElemType
->getTypeID() != Type::IntegerTyID
)
151 IntegerType
*ElemIType
= cast
<IntegerType
>(ElemType
);
152 return ElemIType
->getBitWidth() == 8;
155 bool AMDGPUPrintfRuntimeBinding::lowerPrintfForGpu(Module
&M
) {
156 LLVMContext
&Ctx
= M
.getContext();
157 IRBuilder
<> Builder(Ctx
);
158 Type
*I32Ty
= Type::getInt32Ty(Ctx
);
160 // NB: This is important for this string size to be divizable by 4
161 const char NonLiteralStr
[4] = "???";
163 for (auto P
: Printfs
) {
164 CallInst
*CI
= dyn_cast
<CallInst
>(P
);
166 unsigned NumOps
= CI
->getNumArgOperands();
168 SmallString
<16> OpConvSpecifiers
;
169 Value
*Op
= CI
->getArgOperand(0);
171 if (auto LI
= dyn_cast
<LoadInst
>(Op
)) {
172 Op
= LI
->getPointerOperand();
173 for (auto Use
: Op
->users()) {
174 if (auto SI
= dyn_cast
<StoreInst
>(Use
)) {
175 Op
= SI
->getValueOperand();
181 if (auto I
= dyn_cast
<Instruction
>(Op
)) {
182 Value
*Op_simplified
= simplify(I
);
187 ConstantExpr
*ConstExpr
= dyn_cast
<ConstantExpr
>(Op
);
190 GlobalVariable
*GVar
= dyn_cast
<GlobalVariable
>(ConstExpr
->getOperand(0));
192 StringRef
Str("unknown");
193 if (GVar
&& GVar
->hasInitializer()) {
194 auto Init
= GVar
->getInitializer();
195 if (auto CA
= dyn_cast
<ConstantDataArray
>(Init
)) {
197 Str
= CA
->getAsCString();
198 } else if (isa
<ConstantAggregateZero
>(Init
)) {
202 // we need this call to ascertain
203 // that we are printing a string
204 // or a pointer. It takes out the
205 // specifiers and fills up the first
207 getConversionSpecifiers(OpConvSpecifiers
, Str
, NumOps
- 1);
209 // Add metadata for the string
210 std::string AStreamHolder
;
211 raw_string_ostream
Sizes(AStreamHolder
);
212 int Sum
= DWORD_ALIGN
;
213 Sizes
<< CI
->getNumArgOperands() - 1;
215 for (unsigned ArgCount
= 1; ArgCount
< CI
->getNumArgOperands() &&
216 ArgCount
<= OpConvSpecifiers
.size();
218 Value
*Arg
= CI
->getArgOperand(ArgCount
);
219 Type
*ArgType
= Arg
->getType();
220 unsigned ArgSize
= TD
->getTypeAllocSizeInBits(ArgType
);
221 ArgSize
= ArgSize
/ 8;
223 // ArgSize by design should be a multiple of DWORD_ALIGN,
224 // expand the arguments that do not follow this rule.
226 if (ArgSize
% DWORD_ALIGN
!= 0) {
227 llvm::Type
*ResType
= llvm::Type::getInt32Ty(Ctx
);
228 VectorType
*LLVMVecType
= llvm::dyn_cast
<llvm::VectorType
>(ArgType
);
229 int NumElem
= LLVMVecType
? LLVMVecType
->getNumElements() : 1;
230 if (LLVMVecType
&& NumElem
> 1)
231 ResType
= llvm::VectorType::get(ResType
, NumElem
);
232 Builder
.SetInsertPoint(CI
);
233 Builder
.SetCurrentDebugLocation(CI
->getDebugLoc());
234 if (OpConvSpecifiers
[ArgCount
- 1] == 'x' ||
235 OpConvSpecifiers
[ArgCount
- 1] == 'X' ||
236 OpConvSpecifiers
[ArgCount
- 1] == 'u' ||
237 OpConvSpecifiers
[ArgCount
- 1] == 'o')
238 Arg
= Builder
.CreateZExt(Arg
, ResType
);
240 Arg
= Builder
.CreateSExt(Arg
, ResType
);
241 ArgType
= Arg
->getType();
242 ArgSize
= TD
->getTypeAllocSizeInBits(ArgType
);
243 ArgSize
= ArgSize
/ 8;
244 CI
->setOperand(ArgCount
, Arg
);
246 if (OpConvSpecifiers
[ArgCount
- 1] == 'f') {
247 ConstantFP
*FpCons
= dyn_cast
<ConstantFP
>(Arg
);
251 FPExtInst
*FpExt
= dyn_cast
<FPExtInst
>(Arg
);
252 if (FpExt
&& FpExt
->getType()->isDoubleTy() &&
253 FpExt
->getOperand(0)->getType()->isFloatTy())
257 if (shouldPrintAsStr(OpConvSpecifiers
[ArgCount
- 1], ArgType
)) {
258 if (ConstantExpr
*ConstExpr
= dyn_cast
<ConstantExpr
>(Arg
)) {
260 dyn_cast
<GlobalVariable
>(ConstExpr
->getOperand(0));
261 if (GV
&& GV
->hasInitializer()) {
262 Constant
*Init
= GV
->getInitializer();
263 ConstantDataArray
*CA
= dyn_cast
<ConstantDataArray
>(Init
);
264 if (Init
->isZeroValue() || CA
->isString()) {
265 size_t SizeStr
= Init
->isZeroValue()
267 : (strlen(CA
->getAsCString().data()) + 1);
268 size_t Rem
= SizeStr
% DWORD_ALIGN
;
270 LLVM_DEBUG(dbgs() << "Printf string original size = " << SizeStr
273 NSizeStr
= SizeStr
+ (DWORD_ALIGN
- Rem
);
280 ArgSize
= sizeof(NonLiteralStr
);
283 ArgSize
= sizeof(NonLiteralStr
);
286 LLVM_DEBUG(dbgs() << "Printf ArgSize (in buffer) = " << ArgSize
287 << " for type: " << *ArgType
<< '\n');
288 Sizes
<< ArgSize
<< ':';
291 LLVM_DEBUG(dbgs() << "Printf format string in source = " << Str
.str()
293 for (size_t I
= 0; I
< Str
.size(); ++I
) {
294 // Rest of the C escape sequences (e.g. \') are handled correctly
316 // ':' cannot be scanned by Flex, as it is defined as a delimiter
317 // Replace it with it's octal representation \72
326 // Insert the printf_alloc call
327 Builder
.SetInsertPoint(CI
);
328 Builder
.SetCurrentDebugLocation(CI
->getDebugLoc());
330 AttributeList Attr
= AttributeList::get(Ctx
, AttributeList::FunctionIndex
,
331 Attribute::NoUnwind
);
333 Type
*SizetTy
= Type::getInt32Ty(Ctx
);
335 Type
*Tys_alloc
[1] = {SizetTy
};
336 Type
*I8Ptr
= PointerType::get(Type::getInt8Ty(Ctx
), 1);
337 FunctionType
*FTy_alloc
= FunctionType::get(I8Ptr
, Tys_alloc
, false);
338 FunctionCallee PrintfAllocFn
=
339 M
.getOrInsertFunction(StringRef("__printf_alloc"), FTy_alloc
, Attr
);
341 LLVM_DEBUG(dbgs() << "Printf metadata = " << Sizes
.str() << '\n');
342 std::string fmtstr
= itostr(++UniqID
) + ":" + Sizes
.str().c_str();
343 MDString
*fmtStrArray
= MDString::get(Ctx
, fmtstr
);
345 // Instead of creating global variables, the
346 // printf format strings are extracted
347 // and passed as metadata. This avoids
348 // polluting llvm's symbol tables in this module.
349 // Metadata is going to be extracted
350 // by the backend passes and inserted
351 // into the OpenCL binary as appropriate.
352 StringRef
amd("llvm.printf.fmts");
353 NamedMDNode
*metaD
= M
.getOrInsertNamedMetadata(amd
);
354 MDNode
*myMD
= MDNode::get(Ctx
, fmtStrArray
);
355 metaD
->addOperand(myMD
);
356 Value
*sumC
= ConstantInt::get(SizetTy
, Sum
, false);
357 SmallVector
<Value
*, 1> alloc_args
;
358 alloc_args
.push_back(sumC
);
360 CallInst::Create(PrintfAllocFn
, alloc_args
, "printf_alloc_fn", CI
);
363 // Insert code to split basicblock with a
364 // piece of hammock code.
365 // basicblock splits after buffer overflow check
367 ConstantPointerNull
*zeroIntPtr
=
368 ConstantPointerNull::get(PointerType::get(Type::getInt8Ty(Ctx
), 1));
370 dyn_cast
<ICmpInst
>(Builder
.CreateICmpNE(pcall
, zeroIntPtr
, ""));
371 if (!CI
->use_empty()) {
373 Builder
.CreateSExt(Builder
.CreateNot(cmp
), I32Ty
, "printf_res");
374 CI
->replaceAllUsesWith(result
);
376 SplitBlock(CI
->getParent(), cmp
);
378 SplitBlockAndInsertIfThen(cmp
, cmp
->getNextNode(), false);
380 Builder
.SetInsertPoint(Brnch
);
382 // store unique printf id in the buffer
384 SmallVector
<Value
*, 1> ZeroIdxList
;
385 ConstantInt
*zeroInt
=
386 ConstantInt::get(Ctx
, APInt(32, StringRef("0"), 10));
387 ZeroIdxList
.push_back(zeroInt
);
389 GetElementPtrInst
*BufferIdx
=
390 dyn_cast
<GetElementPtrInst
>(GetElementPtrInst::Create(
391 nullptr, pcall
, ZeroIdxList
, "PrintBuffID", Brnch
));
393 Type
*idPointer
= PointerType::get(I32Ty
, AMDGPUAS::GLOBAL_ADDRESS
);
395 new BitCastInst(BufferIdx
, idPointer
, "PrintBuffIdCast", Brnch
);
398 new StoreInst(ConstantInt::get(I32Ty
, UniqID
), id_gep_cast
);
399 stbuff
->insertBefore(Brnch
); // to Remove unused variable warning
401 SmallVector
<Value
*, 2> FourthIdxList
;
402 ConstantInt
*fourInt
=
403 ConstantInt::get(Ctx
, APInt(32, StringRef("4"), 10));
405 FourthIdxList
.push_back(fourInt
); // 1st 4 bytes hold the printf_id
406 // the following GEP is the buffer pointer
407 BufferIdx
= cast
<GetElementPtrInst
>(GetElementPtrInst::Create(
408 nullptr, pcall
, FourthIdxList
, "PrintBuffGep", Brnch
));
410 Type
*Int32Ty
= Type::getInt32Ty(Ctx
);
411 Type
*Int64Ty
= Type::getInt64Ty(Ctx
);
412 for (unsigned ArgCount
= 1; ArgCount
< CI
->getNumArgOperands() &&
413 ArgCount
<= OpConvSpecifiers
.size();
415 Value
*Arg
= CI
->getArgOperand(ArgCount
);
416 Type
*ArgType
= Arg
->getType();
417 SmallVector
<Value
*, 32> WhatToStore
;
418 if (ArgType
->isFPOrFPVectorTy() &&
419 (ArgType
->getTypeID() != Type::VectorTyID
)) {
420 Type
*IType
= (ArgType
->isFloatTy()) ? Int32Ty
: Int64Ty
;
421 if (OpConvSpecifiers
[ArgCount
- 1] == 'f') {
422 ConstantFP
*fpCons
= dyn_cast
<ConstantFP
>(Arg
);
424 APFloat
Val(fpCons
->getValueAPF());
426 Val
.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven
,
428 Arg
= ConstantFP::get(Ctx
, Val
);
431 FPExtInst
*FpExt
= dyn_cast
<FPExtInst
>(Arg
);
432 if (FpExt
&& FpExt
->getType()->isDoubleTy() &&
433 FpExt
->getOperand(0)->getType()->isFloatTy()) {
434 Arg
= FpExt
->getOperand(0);
439 Arg
= new BitCastInst(Arg
, IType
, "PrintArgFP", Brnch
);
440 WhatToStore
.push_back(Arg
);
441 } else if (ArgType
->getTypeID() == Type::PointerTyID
) {
442 if (shouldPrintAsStr(OpConvSpecifiers
[ArgCount
- 1], ArgType
)) {
443 const char *S
= NonLiteralStr
;
444 if (ConstantExpr
*ConstExpr
= dyn_cast
<ConstantExpr
>(Arg
)) {
446 dyn_cast
<GlobalVariable
>(ConstExpr
->getOperand(0));
447 if (GV
&& GV
->hasInitializer()) {
448 Constant
*Init
= GV
->getInitializer();
449 ConstantDataArray
*CA
= dyn_cast
<ConstantDataArray
>(Init
);
450 if (Init
->isZeroValue() || CA
->isString()) {
451 S
= Init
->isZeroValue() ? "" : CA
->getAsCString().data();
455 size_t SizeStr
= strlen(S
) + 1;
456 size_t Rem
= SizeStr
% DWORD_ALIGN
;
459 NSizeStr
= SizeStr
+ (DWORD_ALIGN
- Rem
);
464 char *MyNewStr
= new char[NSizeStr
]();
466 int NumInts
= NSizeStr
/ 4;
469 int ANum
= *(int *)(MyNewStr
+ CharC
);
472 Value
*ANumV
= ConstantInt::get(Int32Ty
, ANum
, false);
473 WhatToStore
.push_back(ANumV
);
477 // Empty string, give a hint to RT it is no NULL
478 Value
*ANumV
= ConstantInt::get(Int32Ty
, 0xFFFFFF00, false);
479 WhatToStore
.push_back(ANumV
);
482 uint64_t Size
= TD
->getTypeAllocSizeInBits(ArgType
);
483 assert((Size
== 32 || Size
== 64) && "unsupported size");
484 Type
*DstType
= (Size
== 32) ? Int32Ty
: Int64Ty
;
485 Arg
= new PtrToIntInst(Arg
, DstType
, "PrintArgPtr", Brnch
);
486 WhatToStore
.push_back(Arg
);
488 } else if (ArgType
->getTypeID() == Type::VectorTyID
) {
490 uint32_t EleCount
= cast
<VectorType
>(ArgType
)->getNumElements();
491 uint32_t EleSize
= ArgType
->getScalarSizeInBits();
492 uint32_t TotalSize
= EleCount
* EleSize
;
494 IntegerType
*Int32Ty
= Type::getInt32Ty(ArgType
->getContext());
495 Constant
*Indices
[4] = {
496 ConstantInt::get(Int32Ty
, 0), ConstantInt::get(Int32Ty
, 1),
497 ConstantInt::get(Int32Ty
, 2), ConstantInt::get(Int32Ty
, 2)};
498 Constant
*Mask
= ConstantVector::get(Indices
);
499 ShuffleVectorInst
*Shuffle
= new ShuffleVectorInst(Arg
, Arg
, Mask
);
500 Shuffle
->insertBefore(Brnch
);
502 ArgType
= Arg
->getType();
503 TotalSize
+= EleSize
;
507 EleCount
= TotalSize
/ 64;
508 IType
= dyn_cast
<Type
>(Type::getInt64Ty(ArgType
->getContext()));
512 EleCount
= TotalSize
/ 64;
513 IType
= dyn_cast
<Type
>(Type::getInt64Ty(ArgType
->getContext()));
514 } else if (EleCount
>= 3) {
516 IType
= dyn_cast
<Type
>(Type::getInt32Ty(ArgType
->getContext()));
519 IType
= dyn_cast
<Type
>(Type::getInt16Ty(ArgType
->getContext()));
524 EleCount
= TotalSize
/ 64;
525 IType
= dyn_cast
<Type
>(Type::getInt64Ty(ArgType
->getContext()));
528 IType
= dyn_cast
<Type
>(Type::getInt32Ty(ArgType
->getContext()));
533 IType
= dyn_cast
<Type
>(VectorType::get(IType
, EleCount
));
535 Arg
= new BitCastInst(Arg
, IType
, "PrintArgVect", Brnch
);
536 WhatToStore
.push_back(Arg
);
538 WhatToStore
.push_back(Arg
);
540 for (unsigned I
= 0, E
= WhatToStore
.size(); I
!= E
; ++I
) {
541 Value
*TheBtCast
= WhatToStore
[I
];
543 TD
->getTypeAllocSizeInBits(TheBtCast
->getType()) / 8;
544 SmallVector
<Value
*, 1> BuffOffset
;
545 BuffOffset
.push_back(ConstantInt::get(I32Ty
, ArgSize
));
547 Type
*ArgPointer
= PointerType::get(TheBtCast
->getType(), 1);
549 new BitCastInst(BufferIdx
, ArgPointer
, "PrintBuffPtrCast", Brnch
);
550 StoreInst
*StBuff
= new StoreInst(TheBtCast
, CastedGEP
, Brnch
);
551 LLVM_DEBUG(dbgs() << "inserting store to printf buffer:\n"
554 if (I
+ 1 == E
&& ArgCount
+ 1 == CI
->getNumArgOperands())
556 BufferIdx
= dyn_cast
<GetElementPtrInst
>(GetElementPtrInst::Create(
557 nullptr, BufferIdx
, BuffOffset
, "PrintBuffNextPtr", Brnch
));
558 LLVM_DEBUG(dbgs() << "inserting gep to the printf buffer:\n"
559 << *BufferIdx
<< '\n');
565 // erase the printf calls
566 for (auto P
: Printfs
) {
567 CallInst
*CI
= dyn_cast
<CallInst
>(P
);
568 CI
->eraseFromParent();
575 bool AMDGPUPrintfRuntimeBinding::runOnModule(Module
&M
) {
576 Triple
TT(M
.getTargetTriple());
577 if (TT
.getArch() == Triple::r600
)
585 TD
= &M
.getDataLayout();
586 auto DTWP
= getAnalysisIfAvailable
<DominatorTreeWrapperPass
>();
587 DT
= DTWP
? &DTWP
->getDomTree() : nullptr;
588 TLI
= &getAnalysis
<TargetLibraryInfoWrapperPass
>().getTLI();
590 return lowerPrintfForGpu(M
);