[CostModel][X86] Attempt to match v4f32 shuffles that map to MOVSS/INSERTPS instruction
[llvm-project.git] / llvm / lib / Transforms / Utils / AMDGPUEmitPrintf.cpp
bloba25632acbfcc3a914539689f10cdbac70e843464
1 //===- AMDGPUEmitPrintf.cpp -----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Utility function to lower a printf call into a series of device
10 // library calls on the AMDGPU target.
12 // WARNING: This file knows about certain library functions. It recognizes them
13 // by name, and hardwires knowledge of their semantics.
15 //===----------------------------------------------------------------------===//
17 #include "llvm/Transforms/Utils/AMDGPUEmitPrintf.h"
18 #include "llvm/ADT/SparseBitVector.h"
19 #include "llvm/ADT/StringExtras.h"
20 #include "llvm/Analysis/ValueTracking.h"
21 #include "llvm/IR/Module.h"
22 #include "llvm/Support/DataExtractor.h"
23 #include "llvm/Support/MD5.h"
24 #include "llvm/Support/MathExtras.h"
26 using namespace llvm;
28 #define DEBUG_TYPE "amdgpu-emit-printf"
30 static Value *fitArgInto64Bits(IRBuilder<> &Builder, Value *Arg) {
31 auto Int64Ty = Builder.getInt64Ty();
32 auto Ty = Arg->getType();
34 if (auto IntTy = dyn_cast<IntegerType>(Ty)) {
35 switch (IntTy->getBitWidth()) {
36 case 32:
37 return Builder.CreateZExt(Arg, Int64Ty);
38 case 64:
39 return Arg;
43 if (Ty->getTypeID() == Type::DoubleTyID) {
44 return Builder.CreateBitCast(Arg, Int64Ty);
47 if (isa<PointerType>(Ty)) {
48 return Builder.CreatePtrToInt(Arg, Int64Ty);
51 llvm_unreachable("unexpected type");
54 static Value *callPrintfBegin(IRBuilder<> &Builder, Value *Version) {
55 auto Int64Ty = Builder.getInt64Ty();
56 auto M = Builder.GetInsertBlock()->getModule();
57 auto Fn = M->getOrInsertFunction("__ockl_printf_begin", Int64Ty, Int64Ty);
58 return Builder.CreateCall(Fn, Version);
61 static Value *callAppendArgs(IRBuilder<> &Builder, Value *Desc, int NumArgs,
62 Value *Arg0, Value *Arg1, Value *Arg2, Value *Arg3,
63 Value *Arg4, Value *Arg5, Value *Arg6,
64 bool IsLast) {
65 auto Int64Ty = Builder.getInt64Ty();
66 auto Int32Ty = Builder.getInt32Ty();
67 auto M = Builder.GetInsertBlock()->getModule();
68 auto Fn = M->getOrInsertFunction("__ockl_printf_append_args", Int64Ty,
69 Int64Ty, Int32Ty, Int64Ty, Int64Ty, Int64Ty,
70 Int64Ty, Int64Ty, Int64Ty, Int64Ty, Int32Ty);
71 auto IsLastValue = Builder.getInt32(IsLast);
72 auto NumArgsValue = Builder.getInt32(NumArgs);
73 return Builder.CreateCall(Fn, {Desc, NumArgsValue, Arg0, Arg1, Arg2, Arg3,
74 Arg4, Arg5, Arg6, IsLastValue});
77 static Value *appendArg(IRBuilder<> &Builder, Value *Desc, Value *Arg,
78 bool IsLast) {
79 auto Arg0 = fitArgInto64Bits(Builder, Arg);
80 auto Zero = Builder.getInt64(0);
81 return callAppendArgs(Builder, Desc, 1, Arg0, Zero, Zero, Zero, Zero, Zero,
82 Zero, IsLast);
85 // The device library does not provide strlen, so we build our own loop
86 // here. While we are at it, we also include the terminating null in the length.
87 static Value *getStrlenWithNull(IRBuilder<> &Builder, Value *Str) {
88 auto *Prev = Builder.GetInsertBlock();
89 Module *M = Prev->getModule();
91 auto CharZero = Builder.getInt8(0);
92 auto One = Builder.getInt64(1);
93 auto Zero = Builder.getInt64(0);
94 auto Int64Ty = Builder.getInt64Ty();
96 // The length is either zero for a null pointer, or the computed value for an
97 // actual string. We need a join block for a phi that represents the final
98 // value.
100 // Strictly speaking, the zero does not matter since
101 // __ockl_printf_append_string_n ignores the length if the pointer is null.
102 BasicBlock *Join = nullptr;
103 if (Prev->getTerminator()) {
104 Join = Prev->splitBasicBlock(Builder.GetInsertPoint(),
105 "strlen.join");
106 Prev->getTerminator()->eraseFromParent();
107 } else {
108 Join = BasicBlock::Create(M->getContext(), "strlen.join",
109 Prev->getParent());
111 BasicBlock *While =
112 BasicBlock::Create(M->getContext(), "strlen.while",
113 Prev->getParent(), Join);
114 BasicBlock *WhileDone = BasicBlock::Create(
115 M->getContext(), "strlen.while.done",
116 Prev->getParent(), Join);
118 // Emit an early return for when the pointer is null.
119 Builder.SetInsertPoint(Prev);
120 auto CmpNull =
121 Builder.CreateICmpEQ(Str, Constant::getNullValue(Str->getType()));
122 BranchInst::Create(Join, While, CmpNull, Prev);
124 // Entry to the while loop.
125 Builder.SetInsertPoint(While);
127 auto PtrPhi = Builder.CreatePHI(Str->getType(), 2);
128 PtrPhi->addIncoming(Str, Prev);
129 auto PtrNext = Builder.CreateGEP(Builder.getInt8Ty(), PtrPhi, One);
130 PtrPhi->addIncoming(PtrNext, While);
132 // Condition for the while loop.
133 auto Data = Builder.CreateLoad(Builder.getInt8Ty(), PtrPhi);
134 auto Cmp = Builder.CreateICmpEQ(Data, CharZero);
135 Builder.CreateCondBr(Cmp, WhileDone, While);
137 // Add one to the computed length.
138 Builder.SetInsertPoint(WhileDone, WhileDone->begin());
139 auto Begin = Builder.CreatePtrToInt(Str, Int64Ty);
140 auto End = Builder.CreatePtrToInt(PtrPhi, Int64Ty);
141 auto Len = Builder.CreateSub(End, Begin);
142 Len = Builder.CreateAdd(Len, One);
144 // Final join.
145 BranchInst::Create(Join, WhileDone);
146 Builder.SetInsertPoint(Join, Join->begin());
147 auto LenPhi = Builder.CreatePHI(Len->getType(), 2);
148 LenPhi->addIncoming(Len, WhileDone);
149 LenPhi->addIncoming(Zero, Prev);
151 return LenPhi;
154 static Value *callAppendStringN(IRBuilder<> &Builder, Value *Desc, Value *Str,
155 Value *Length, bool isLast) {
156 auto Int64Ty = Builder.getInt64Ty();
157 auto IsLastInt32 = Builder.getInt32(isLast);
158 auto M = Builder.GetInsertBlock()->getModule();
159 auto Fn = M->getOrInsertFunction("__ockl_printf_append_string_n", Int64Ty,
160 Desc->getType(), Str->getType(),
161 Length->getType(), IsLastInt32->getType());
162 return Builder.CreateCall(Fn, {Desc, Str, Length, IsLastInt32});
165 static Value *appendString(IRBuilder<> &Builder, Value *Desc, Value *Arg,
166 bool IsLast) {
167 auto Length = getStrlenWithNull(Builder, Arg);
168 return callAppendStringN(Builder, Desc, Arg, Length, IsLast);
171 static Value *processArg(IRBuilder<> &Builder, Value *Desc, Value *Arg,
172 bool SpecIsCString, bool IsLast) {
173 if (SpecIsCString && isa<PointerType>(Arg->getType())) {
174 return appendString(Builder, Desc, Arg, IsLast);
176 // If the format specifies a string but the argument is not, the frontend will
177 // have printed a warning. We just rely on undefined behaviour and send the
178 // argument anyway.
179 return appendArg(Builder, Desc, Arg, IsLast);
182 // Scan the format string to locate all specifiers, and mark the ones that
183 // specify a string, i.e, the "%s" specifier with optional '*' characters.
184 static void locateCStrings(SparseBitVector<8> &BV, StringRef Str) {
185 static const char ConvSpecifiers[] = "diouxXfFeEgGaAcspn";
186 size_t SpecPos = 0;
187 // Skip the first argument, the format string.
188 unsigned ArgIdx = 1;
190 while ((SpecPos = Str.find_first_of('%', SpecPos)) != StringRef::npos) {
191 if (Str[SpecPos + 1] == '%') {
192 SpecPos += 2;
193 continue;
195 auto SpecEnd = Str.find_first_of(ConvSpecifiers, SpecPos);
196 if (SpecEnd == StringRef::npos)
197 return;
198 auto Spec = Str.slice(SpecPos, SpecEnd + 1);
199 ArgIdx += Spec.count('*');
200 if (Str[SpecEnd] == 's') {
201 BV.set(ArgIdx);
203 SpecPos = SpecEnd + 1;
204 ++ArgIdx;
208 // helper struct to package the string related data
209 struct StringData {
210 StringRef Str;
211 Value *RealSize = nullptr;
212 Value *AlignedSize = nullptr;
213 bool IsConst = true;
215 StringData(StringRef ST, Value *RS, Value *AS, bool IC)
216 : Str(ST), RealSize(RS), AlignedSize(AS), IsConst(IC) {}
219 // Calculates frame size required for current printf expansion and allocates
220 // space on printf buffer. Printf frame includes following contents
221 // [ ControlDWord , format string/Hash , Arguments (each aligned to 8 byte) ]
222 static Value *callBufferedPrintfStart(
223 IRBuilder<> &Builder, ArrayRef<Value *> Args, Value *Fmt,
224 bool isConstFmtStr, SparseBitVector<8> &SpecIsCString,
225 SmallVectorImpl<StringData> &StringContents, Value *&ArgSize) {
226 Module *M = Builder.GetInsertBlock()->getModule();
227 Value *NonConstStrLen = nullptr;
228 Value *LenWithNull = nullptr;
229 Value *LenWithNullAligned = nullptr;
230 Value *TempAdd = nullptr;
232 // First 4 bytes to be reserved for control dword
233 size_t BufSize = 4;
234 if (isConstFmtStr)
235 // First 8 bytes of MD5 hash
236 BufSize += 8;
237 else {
238 LenWithNull = getStrlenWithNull(Builder, Fmt);
240 // Align the computed length to next 8 byte boundary
241 TempAdd = Builder.CreateAdd(LenWithNull,
242 ConstantInt::get(LenWithNull->getType(), 7U));
243 NonConstStrLen = Builder.CreateAnd(
244 TempAdd, ConstantInt::get(LenWithNull->getType(), ~7U));
246 StringContents.push_back(
247 StringData(StringRef(), LenWithNull, NonConstStrLen, false));
250 for (size_t i = 1; i < Args.size(); i++) {
251 if (SpecIsCString.test(i)) {
252 StringRef ArgStr;
253 if (getConstantStringInfo(Args[i], ArgStr)) {
254 auto alignedLen = alignTo(ArgStr.size() + 1, 8);
255 StringContents.push_back(StringData(
256 ArgStr,
257 /*RealSize*/ nullptr, /*AlignedSize*/ nullptr, /*IsConst*/ true));
258 BufSize += alignedLen;
259 } else {
260 LenWithNull = getStrlenWithNull(Builder, Args[i]);
262 // Align the computed length to next 8 byte boundary
263 TempAdd = Builder.CreateAdd(
264 LenWithNull, ConstantInt::get(LenWithNull->getType(), 7U));
265 LenWithNullAligned = Builder.CreateAnd(
266 TempAdd, ConstantInt::get(LenWithNull->getType(), ~7U));
268 if (NonConstStrLen) {
269 auto Val = Builder.CreateAdd(LenWithNullAligned, NonConstStrLen,
270 "cumulativeAdd");
271 NonConstStrLen = Val;
272 } else
273 NonConstStrLen = LenWithNullAligned;
275 StringContents.push_back(
276 StringData(StringRef(), LenWithNull, LenWithNullAligned, false));
278 } else {
279 int AllocSize = M->getDataLayout().getTypeAllocSize(Args[i]->getType());
280 // We end up expanding non string arguments to 8 bytes
281 // (args smaller than 8 bytes)
282 BufSize += std::max(AllocSize, 8);
286 // calculate final size value to be passed to printf_alloc
287 Value *SizeToReserve = ConstantInt::get(Builder.getInt64Ty(), BufSize, false);
288 SmallVector<Value *, 1> Alloc_args;
289 if (NonConstStrLen)
290 SizeToReserve = Builder.CreateAdd(NonConstStrLen, SizeToReserve);
292 ArgSize = Builder.CreateTrunc(SizeToReserve, Builder.getInt32Ty());
293 Alloc_args.push_back(ArgSize);
295 // call the printf_alloc function
296 AttributeList Attr = AttributeList::get(
297 Builder.getContext(), AttributeList::FunctionIndex, Attribute::NoUnwind);
299 Type *Tys_alloc[1] = {Builder.getInt32Ty()};
300 Type *PtrTy =
301 Builder.getPtrTy(M->getDataLayout().getDefaultGlobalsAddressSpace());
302 FunctionType *FTy_alloc = FunctionType::get(PtrTy, Tys_alloc, false);
303 auto PrintfAllocFn =
304 M->getOrInsertFunction(StringRef("__printf_alloc"), FTy_alloc, Attr);
306 return Builder.CreateCall(PrintfAllocFn, Alloc_args, "printf_alloc_fn");
309 // Prepare constant string argument to push onto the buffer
310 static void processConstantStringArg(StringData *SD, IRBuilder<> &Builder,
311 SmallVectorImpl<Value *> &WhatToStore) {
312 std::string Str(SD->Str.str() + '\0');
314 DataExtractor Extractor(Str, /*IsLittleEndian=*/true, 8);
315 DataExtractor::Cursor Offset(0);
316 while (Offset && Offset.tell() < Str.size()) {
317 const uint64_t ReadSize = 4;
318 uint64_t ReadNow = std::min(ReadSize, Str.size() - Offset.tell());
319 uint64_t ReadBytes = 0;
320 switch (ReadNow) {
321 default:
322 llvm_unreachable("min(4, X) > 4?");
323 case 1:
324 ReadBytes = Extractor.getU8(Offset);
325 break;
326 case 2:
327 ReadBytes = Extractor.getU16(Offset);
328 break;
329 case 3:
330 ReadBytes = Extractor.getU24(Offset);
331 break;
332 case 4:
333 ReadBytes = Extractor.getU32(Offset);
334 break;
336 cantFail(Offset.takeError(), "failed to read bytes from constant array");
338 APInt IntVal(8 * ReadSize, ReadBytes);
340 // TODO: Should not bother aligning up.
341 if (ReadNow < ReadSize)
342 IntVal = IntVal.zext(8 * ReadSize);
344 Type *IntTy = Type::getIntNTy(Builder.getContext(), IntVal.getBitWidth());
345 WhatToStore.push_back(ConstantInt::get(IntTy, IntVal));
347 // Additional padding for 8 byte alignment
348 int Rem = (Str.size() % 8);
349 if (Rem > 0 && Rem <= 4)
350 WhatToStore.push_back(ConstantInt::get(Builder.getInt32Ty(), 0));
353 static Value *processNonStringArg(Value *Arg, IRBuilder<> &Builder) {
354 const DataLayout &DL = Builder.GetInsertBlock()->getDataLayout();
355 auto Ty = Arg->getType();
357 if (auto IntTy = dyn_cast<IntegerType>(Ty)) {
358 if (IntTy->getBitWidth() < 64) {
359 return Builder.CreateZExt(Arg, Builder.getInt64Ty());
363 if (Ty->isFloatingPointTy()) {
364 if (DL.getTypeAllocSize(Ty) < 8) {
365 return Builder.CreateFPExt(Arg, Builder.getDoubleTy());
369 return Arg;
372 static void
373 callBufferedPrintfArgPush(IRBuilder<> &Builder, ArrayRef<Value *> Args,
374 Value *PtrToStore, SparseBitVector<8> &SpecIsCString,
375 SmallVectorImpl<StringData> &StringContents,
376 bool IsConstFmtStr) {
377 Module *M = Builder.GetInsertBlock()->getModule();
378 const DataLayout &DL = M->getDataLayout();
379 auto StrIt = StringContents.begin();
380 size_t i = IsConstFmtStr ? 1 : 0;
381 for (; i < Args.size(); i++) {
382 SmallVector<Value *, 32> WhatToStore;
383 if ((i == 0) || SpecIsCString.test(i)) {
384 if (StrIt->IsConst) {
385 processConstantStringArg(StrIt, Builder, WhatToStore);
386 StrIt++;
387 } else {
388 // This copies the contents of the string, however the next offset
389 // is at aligned length, the extra space that might be created due
390 // to alignment padding is not populated with any specific value
391 // here. This would be safe as long as runtime is sync with
392 // the offsets.
393 Builder.CreateMemCpy(PtrToStore, /*DstAlign*/ Align(1), Args[i],
394 /*SrcAlign*/ Args[i]->getPointerAlignment(DL),
395 StrIt->RealSize);
397 PtrToStore =
398 Builder.CreateInBoundsGEP(Builder.getInt8Ty(), PtrToStore,
399 {StrIt->AlignedSize}, "PrintBuffNextPtr");
400 LLVM_DEBUG(dbgs() << "inserting gep to the printf buffer:"
401 << *PtrToStore << '\n');
403 // done with current argument, move to next
404 StrIt++;
405 continue;
407 } else {
408 WhatToStore.push_back(processNonStringArg(Args[i], Builder));
411 for (Value *toStore : WhatToStore) {
412 StoreInst *StBuff = Builder.CreateStore(toStore, PtrToStore);
413 LLVM_DEBUG(dbgs() << "inserting store to printf buffer:" << *StBuff
414 << '\n');
415 (void)StBuff;
416 PtrToStore = Builder.CreateConstInBoundsGEP1_32(
417 Builder.getInt8Ty(), PtrToStore,
418 M->getDataLayout().getTypeAllocSize(toStore->getType()),
419 "PrintBuffNextPtr");
420 LLVM_DEBUG(dbgs() << "inserting gep to the printf buffer:" << *PtrToStore
421 << '\n');
426 Value *llvm::emitAMDGPUPrintfCall(IRBuilder<> &Builder, ArrayRef<Value *> Args,
427 bool IsBuffered) {
428 auto NumOps = Args.size();
429 assert(NumOps >= 1);
431 auto Fmt = Args[0];
432 SparseBitVector<8> SpecIsCString;
433 StringRef FmtStr;
435 if (getConstantStringInfo(Fmt, FmtStr))
436 locateCStrings(SpecIsCString, FmtStr);
438 if (IsBuffered) {
439 SmallVector<StringData, 8> StringContents;
440 Module *M = Builder.GetInsertBlock()->getModule();
441 LLVMContext &Ctx = Builder.getContext();
442 auto Int8Ty = Builder.getInt8Ty();
443 auto Int32Ty = Builder.getInt32Ty();
444 bool IsConstFmtStr = !FmtStr.empty();
446 Value *ArgSize = nullptr;
447 Value *Ptr =
448 callBufferedPrintfStart(Builder, Args, Fmt, IsConstFmtStr,
449 SpecIsCString, StringContents, ArgSize);
451 // The buffered version still follows OpenCL printf standards for
452 // printf return value, i.e 0 on success, -1 on failure.
453 ConstantPointerNull *zeroIntPtr =
454 ConstantPointerNull::get(cast<PointerType>(Ptr->getType()));
456 auto *Cmp = cast<ICmpInst>(Builder.CreateICmpNE(Ptr, zeroIntPtr, ""));
458 BasicBlock *End = BasicBlock::Create(Ctx, "end.block",
459 Builder.GetInsertBlock()->getParent());
460 BasicBlock *ArgPush = BasicBlock::Create(
461 Ctx, "argpush.block", Builder.GetInsertBlock()->getParent());
463 BranchInst::Create(ArgPush, End, Cmp, Builder.GetInsertBlock());
464 Builder.SetInsertPoint(ArgPush);
466 // Create controlDWord and store as the first entry, format as follows
467 // Bit 0 (LSB) -> stream (1 if stderr, 0 if stdout, printf always outputs to
468 // stdout) Bit 1 -> constant format string (1 if constant) Bits 2-31 -> size
469 // of printf data frame
470 auto ConstantTwo = Builder.getInt32(2);
471 auto ControlDWord = Builder.CreateShl(ArgSize, ConstantTwo);
472 if (IsConstFmtStr)
473 ControlDWord = Builder.CreateOr(ControlDWord, ConstantTwo);
475 Builder.CreateStore(ControlDWord, Ptr);
477 Ptr = Builder.CreateConstInBoundsGEP1_32(Int8Ty, Ptr, 4);
479 // Create MD5 hash for costant format string, push low 64 bits of the
480 // same onto buffer and metadata.
481 NamedMDNode *metaD = M->getOrInsertNamedMetadata("llvm.printf.fmts");
482 if (IsConstFmtStr) {
483 MD5 Hasher;
484 MD5::MD5Result Hash;
485 Hasher.update(FmtStr);
486 Hasher.final(Hash);
488 // Try sticking to llvm.printf.fmts format, although we are not going to
489 // use the ID and argument size fields while printing,
490 std::string MetadataStr =
491 "0:0:" + llvm::utohexstr(Hash.low(), /*LowerCase=*/true) + "," +
492 FmtStr.str();
493 MDString *fmtStrArray = MDString::get(Ctx, MetadataStr);
494 MDNode *myMD = MDNode::get(Ctx, fmtStrArray);
495 metaD->addOperand(myMD);
497 Builder.CreateStore(Builder.getInt64(Hash.low()), Ptr);
498 Ptr = Builder.CreateConstInBoundsGEP1_32(Int8Ty, Ptr, 8);
499 } else {
500 // Include a dummy metadata instance in case of only non constant
501 // format string usage, This might be an absurd usecase but needs to
502 // be done for completeness
503 if (metaD->getNumOperands() == 0) {
504 MDString *fmtStrArray =
505 MDString::get(Ctx, "0:0:ffffffff,\"Non const format string\"");
506 MDNode *myMD = MDNode::get(Ctx, fmtStrArray);
507 metaD->addOperand(myMD);
511 // Push The printf arguments onto buffer
512 callBufferedPrintfArgPush(Builder, Args, Ptr, SpecIsCString, StringContents,
513 IsConstFmtStr);
515 // End block, returns -1 on failure
516 BranchInst::Create(End, ArgPush);
517 Builder.SetInsertPoint(End);
518 return Builder.CreateSExt(Builder.CreateNot(Cmp), Int32Ty, "printf_result");
521 auto Desc = callPrintfBegin(Builder, Builder.getIntN(64, 0));
522 Desc = appendString(Builder, Desc, Fmt, NumOps == 1);
524 // FIXME: This invokes hostcall once for each argument. We can pack up to
525 // seven scalar printf arguments in a single hostcall. See the signature of
526 // callAppendArgs().
527 for (unsigned int i = 1; i != NumOps; ++i) {
528 bool IsLast = i == NumOps - 1;
529 bool IsCString = SpecIsCString.test(i);
530 Desc = processArg(Builder, Desc, Args[i], IsCString, IsLast);
533 return Builder.CreateTrunc(Desc, Builder.getInt32Ty());