[SampleProfileLoader] Fix integer overflow in generateMDProfMetadata (#90217)
[llvm-project.git] / llvm / lib / Target / SPIRV / SPIRVPrepareFunctions.cpp
bloba8a0577f60564c9132725c3b7ec5e54828124079
1 //===-- SPIRVPrepareFunctions.cpp - modify function signatures --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass modifies function signatures containing aggregate arguments
10 // and/or return value before IRTranslator. Information about the original
11 // signatures is stored in metadata. It is used during call lowering to
12 // restore correct SPIR-V types of function arguments and return values.
13 // This pass also substitutes some llvm intrinsic calls with calls to newly
14 // generated functions (as the Khronos LLVM/SPIR-V Translator does).
16 // NOTE: this pass is a module-level one due to the necessity to modify
17 // GVs/functions.
19 //===----------------------------------------------------------------------===//
21 #include "SPIRV.h"
22 #include "SPIRVSubtarget.h"
23 #include "SPIRVTargetMachine.h"
24 #include "SPIRVUtils.h"
25 #include "llvm/CodeGen/IntrinsicLowering.h"
26 #include "llvm/IR/IRBuilder.h"
27 #include "llvm/IR/IntrinsicInst.h"
28 #include "llvm/IR/Intrinsics.h"
29 #include "llvm/IR/IntrinsicsSPIRV.h"
30 #include "llvm/Transforms/Utils/Cloning.h"
31 #include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
33 using namespace llvm;
35 namespace llvm {
36 void initializeSPIRVPrepareFunctionsPass(PassRegistry &);
39 namespace {
41 class SPIRVPrepareFunctions : public ModulePass {
42 const SPIRVTargetMachine &TM;
43 bool substituteIntrinsicCalls(Function *F);
44 Function *removeAggregateTypesFromSignature(Function *F);
46 public:
47 static char ID;
48 SPIRVPrepareFunctions(const SPIRVTargetMachine &TM) : ModulePass(ID), TM(TM) {
49 initializeSPIRVPrepareFunctionsPass(*PassRegistry::getPassRegistry());
52 bool runOnModule(Module &M) override;
54 StringRef getPassName() const override { return "SPIRV prepare functions"; }
56 void getAnalysisUsage(AnalysisUsage &AU) const override {
57 ModulePass::getAnalysisUsage(AU);
61 } // namespace
63 char SPIRVPrepareFunctions::ID = 0;
65 INITIALIZE_PASS(SPIRVPrepareFunctions, "prepare-functions",
66 "SPIRV prepare functions", false, false)
68 std::string lowerLLVMIntrinsicName(IntrinsicInst *II) {
69 Function *IntrinsicFunc = II->getCalledFunction();
70 assert(IntrinsicFunc && "Missing function");
71 std::string FuncName = IntrinsicFunc->getName().str();
72 std::replace(FuncName.begin(), FuncName.end(), '.', '_');
73 FuncName = "spirv." + FuncName;
74 return FuncName;
77 static Function *getOrCreateFunction(Module *M, Type *RetTy,
78 ArrayRef<Type *> ArgTypes,
79 StringRef Name) {
80 FunctionType *FT = FunctionType::get(RetTy, ArgTypes, false);
81 Function *F = M->getFunction(Name);
82 if (F && F->getFunctionType() == FT)
83 return F;
84 Function *NewF = Function::Create(FT, GlobalValue::ExternalLinkage, Name, M);
85 if (F)
86 NewF->setDSOLocal(F->isDSOLocal());
87 NewF->setCallingConv(CallingConv::SPIR_FUNC);
88 return NewF;
91 static bool lowerIntrinsicToFunction(IntrinsicInst *Intrinsic) {
92 // For @llvm.memset.* intrinsic cases with constant value and length arguments
93 // are emulated via "storing" a constant array to the destination. For other
94 // cases we wrap the intrinsic in @spirv.llvm_memset_* function and expand the
95 // intrinsic to a loop via expandMemSetAsLoop().
96 if (auto *MSI = dyn_cast<MemSetInst>(Intrinsic))
97 if (isa<Constant>(MSI->getValue()) && isa<ConstantInt>(MSI->getLength()))
98 return false; // It is handled later using OpCopyMemorySized.
100 Module *M = Intrinsic->getModule();
101 std::string FuncName = lowerLLVMIntrinsicName(Intrinsic);
102 if (Intrinsic->isVolatile())
103 FuncName += ".volatile";
104 // Redirect @llvm.intrinsic.* call to @spirv.llvm_intrinsic_*
105 Function *F = M->getFunction(FuncName);
106 if (F) {
107 Intrinsic->setCalledFunction(F);
108 return true;
110 // TODO copy arguments attributes: nocapture writeonly.
111 FunctionCallee FC =
112 M->getOrInsertFunction(FuncName, Intrinsic->getFunctionType());
113 auto IntrinsicID = Intrinsic->getIntrinsicID();
114 Intrinsic->setCalledFunction(FC);
116 F = dyn_cast<Function>(FC.getCallee());
117 assert(F && "Callee must be a function");
119 switch (IntrinsicID) {
120 case Intrinsic::memset: {
121 auto *MSI = static_cast<MemSetInst *>(Intrinsic);
122 Argument *Dest = F->getArg(0);
123 Argument *Val = F->getArg(1);
124 Argument *Len = F->getArg(2);
125 Argument *IsVolatile = F->getArg(3);
126 Dest->setName("dest");
127 Val->setName("val");
128 Len->setName("len");
129 IsVolatile->setName("isvolatile");
130 BasicBlock *EntryBB = BasicBlock::Create(M->getContext(), "entry", F);
131 IRBuilder<> IRB(EntryBB);
132 auto *MemSet = IRB.CreateMemSet(Dest, Val, Len, MSI->getDestAlign(),
133 MSI->isVolatile());
134 IRB.CreateRetVoid();
135 expandMemSetAsLoop(cast<MemSetInst>(MemSet));
136 MemSet->eraseFromParent();
137 break;
139 case Intrinsic::bswap: {
140 BasicBlock *EntryBB = BasicBlock::Create(M->getContext(), "entry", F);
141 IRBuilder<> IRB(EntryBB);
142 auto *BSwap = IRB.CreateIntrinsic(Intrinsic::bswap, Intrinsic->getType(),
143 F->getArg(0));
144 IRB.CreateRet(BSwap);
145 IntrinsicLowering IL(M->getDataLayout());
146 IL.LowerIntrinsicCall(BSwap);
147 break;
149 default:
150 break;
152 return true;
155 static void lowerFunnelShifts(IntrinsicInst *FSHIntrinsic) {
156 // Get a separate function - otherwise, we'd have to rework the CFG of the
157 // current one. Then simply replace the intrinsic uses with a call to the new
158 // function.
159 // Generate LLVM IR for i* @spirv.llvm_fsh?_i* (i* %a, i* %b, i* %c)
160 Module *M = FSHIntrinsic->getModule();
161 FunctionType *FSHFuncTy = FSHIntrinsic->getFunctionType();
162 Type *FSHRetTy = FSHFuncTy->getReturnType();
163 const std::string FuncName = lowerLLVMIntrinsicName(FSHIntrinsic);
164 Function *FSHFunc =
165 getOrCreateFunction(M, FSHRetTy, FSHFuncTy->params(), FuncName);
167 if (!FSHFunc->empty()) {
168 FSHIntrinsic->setCalledFunction(FSHFunc);
169 return;
171 BasicBlock *RotateBB = BasicBlock::Create(M->getContext(), "rotate", FSHFunc);
172 IRBuilder<> IRB(RotateBB);
173 Type *Ty = FSHFunc->getReturnType();
174 // Build the actual funnel shift rotate logic.
175 // In the comments, "int" is used interchangeably with "vector of int
176 // elements".
177 FixedVectorType *VectorTy = dyn_cast<FixedVectorType>(Ty);
178 Type *IntTy = VectorTy ? VectorTy->getElementType() : Ty;
179 unsigned BitWidth = IntTy->getIntegerBitWidth();
180 ConstantInt *BitWidthConstant = IRB.getInt({BitWidth, BitWidth});
181 Value *BitWidthForInsts =
182 VectorTy
183 ? IRB.CreateVectorSplat(VectorTy->getNumElements(), BitWidthConstant)
184 : BitWidthConstant;
185 Value *RotateModVal =
186 IRB.CreateURem(/*Rotate*/ FSHFunc->getArg(2), BitWidthForInsts);
187 Value *FirstShift = nullptr, *SecShift = nullptr;
188 if (FSHIntrinsic->getIntrinsicID() == Intrinsic::fshr) {
189 // Shift the less significant number right, the "rotate" number of bits
190 // will be 0-filled on the left as a result of this regular shift.
191 FirstShift = IRB.CreateLShr(FSHFunc->getArg(1), RotateModVal);
192 } else {
193 // Shift the more significant number left, the "rotate" number of bits
194 // will be 0-filled on the right as a result of this regular shift.
195 FirstShift = IRB.CreateShl(FSHFunc->getArg(0), RotateModVal);
197 // We want the "rotate" number of the more significant int's LSBs (MSBs) to
198 // occupy the leftmost (rightmost) "0 space" left by the previous operation.
199 // Therefore, subtract the "rotate" number from the integer bitsize...
200 Value *SubRotateVal = IRB.CreateSub(BitWidthForInsts, RotateModVal);
201 if (FSHIntrinsic->getIntrinsicID() == Intrinsic::fshr) {
202 // ...and left-shift the more significant int by this number, zero-filling
203 // the LSBs.
204 SecShift = IRB.CreateShl(FSHFunc->getArg(0), SubRotateVal);
205 } else {
206 // ...and right-shift the less significant int by this number, zero-filling
207 // the MSBs.
208 SecShift = IRB.CreateLShr(FSHFunc->getArg(1), SubRotateVal);
210 // A simple binary addition of the shifted ints yields the final result.
211 IRB.CreateRet(IRB.CreateOr(FirstShift, SecShift));
213 FSHIntrinsic->setCalledFunction(FSHFunc);
216 static void buildUMulWithOverflowFunc(Function *UMulFunc) {
217 // The function body is already created.
218 if (!UMulFunc->empty())
219 return;
221 BasicBlock *EntryBB = BasicBlock::Create(UMulFunc->getParent()->getContext(),
222 "entry", UMulFunc);
223 IRBuilder<> IRB(EntryBB);
224 // Build the actual unsigned multiplication logic with the overflow
225 // indication. Do unsigned multiplication Mul = A * B. Then check
226 // if unsigned division Div = Mul / A is not equal to B. If so,
227 // then overflow has happened.
228 Value *Mul = IRB.CreateNUWMul(UMulFunc->getArg(0), UMulFunc->getArg(1));
229 Value *Div = IRB.CreateUDiv(Mul, UMulFunc->getArg(0));
230 Value *Overflow = IRB.CreateICmpNE(UMulFunc->getArg(0), Div);
232 // umul.with.overflow intrinsic return a structure, where the first element
233 // is the multiplication result, and the second is an overflow bit.
234 Type *StructTy = UMulFunc->getReturnType();
235 Value *Agg = IRB.CreateInsertValue(PoisonValue::get(StructTy), Mul, {0});
236 Value *Res = IRB.CreateInsertValue(Agg, Overflow, {1});
237 IRB.CreateRet(Res);
240 static void lowerExpectAssume(IntrinsicInst *II) {
241 // If we cannot use the SPV_KHR_expect_assume extension, then we need to
242 // ignore the intrinsic and move on. It should be removed later on by LLVM.
243 // Otherwise we should lower the intrinsic to the corresponding SPIR-V
244 // instruction.
245 // For @llvm.assume we have OpAssumeTrueKHR.
246 // For @llvm.expect we have OpExpectKHR.
248 // We need to lower this into a builtin and then the builtin into a SPIR-V
249 // instruction.
250 if (II->getIntrinsicID() == Intrinsic::assume) {
251 Function *F = Intrinsic::getDeclaration(
252 II->getModule(), Intrinsic::SPVIntrinsics::spv_assume);
253 II->setCalledFunction(F);
254 } else if (II->getIntrinsicID() == Intrinsic::expect) {
255 Function *F = Intrinsic::getDeclaration(
256 II->getModule(), Intrinsic::SPVIntrinsics::spv_expect,
257 {II->getOperand(0)->getType()});
258 II->setCalledFunction(F);
259 } else {
260 llvm_unreachable("Unknown intrinsic");
263 return;
266 static bool toSpvOverloadedIntrinsic(IntrinsicInst *II, Intrinsic::ID NewID,
267 ArrayRef<unsigned> OpNos) {
268 Function *F = nullptr;
269 if (OpNos.empty()) {
270 F = Intrinsic::getDeclaration(II->getModule(), NewID);
271 } else {
272 SmallVector<Type *, 4> Tys;
273 for (unsigned OpNo : OpNos)
274 Tys.push_back(II->getOperand(OpNo)->getType());
275 F = Intrinsic::getDeclaration(II->getModule(), NewID, Tys);
277 II->setCalledFunction(F);
278 return true;
281 static void lowerUMulWithOverflow(IntrinsicInst *UMulIntrinsic) {
282 // Get a separate function - otherwise, we'd have to rework the CFG of the
283 // current one. Then simply replace the intrinsic uses with a call to the new
284 // function.
285 Module *M = UMulIntrinsic->getModule();
286 FunctionType *UMulFuncTy = UMulIntrinsic->getFunctionType();
287 Type *FSHLRetTy = UMulFuncTy->getReturnType();
288 const std::string FuncName = lowerLLVMIntrinsicName(UMulIntrinsic);
289 Function *UMulFunc =
290 getOrCreateFunction(M, FSHLRetTy, UMulFuncTy->params(), FuncName);
291 buildUMulWithOverflowFunc(UMulFunc);
292 UMulIntrinsic->setCalledFunction(UMulFunc);
295 // Substitutes calls to LLVM intrinsics with either calls to SPIR-V intrinsics
296 // or calls to proper generated functions. Returns True if F was modified.
297 bool SPIRVPrepareFunctions::substituteIntrinsicCalls(Function *F) {
298 bool Changed = false;
299 for (BasicBlock &BB : *F) {
300 for (Instruction &I : BB) {
301 auto Call = dyn_cast<CallInst>(&I);
302 if (!Call)
303 continue;
304 Function *CF = Call->getCalledFunction();
305 if (!CF || !CF->isIntrinsic())
306 continue;
307 auto *II = cast<IntrinsicInst>(Call);
308 switch (II->getIntrinsicID()) {
309 case Intrinsic::memset:
310 case Intrinsic::bswap:
311 Changed |= lowerIntrinsicToFunction(II);
312 break;
313 case Intrinsic::fshl:
314 case Intrinsic::fshr:
315 lowerFunnelShifts(II);
316 Changed = true;
317 break;
318 case Intrinsic::umul_with_overflow:
319 lowerUMulWithOverflow(II);
320 Changed = true;
321 break;
322 case Intrinsic::assume:
323 case Intrinsic::expect: {
324 const SPIRVSubtarget &STI = TM.getSubtarget<SPIRVSubtarget>(*F);
325 if (STI.canUseExtension(SPIRV::Extension::SPV_KHR_expect_assume))
326 lowerExpectAssume(II);
327 Changed = true;
328 } break;
329 case Intrinsic::lifetime_start:
330 Changed |= toSpvOverloadedIntrinsic(
331 II, Intrinsic::SPVIntrinsics::spv_lifetime_start, {1});
332 break;
333 case Intrinsic::lifetime_end:
334 Changed |= toSpvOverloadedIntrinsic(
335 II, Intrinsic::SPVIntrinsics::spv_lifetime_end, {1});
336 break;
340 return Changed;
343 // Returns F if aggregate argument/return types are not present or cloned F
344 // function with the types replaced by i32 types. The change in types is
345 // noted in 'spv.cloned_funcs' metadata for later restoration.
346 Function *
347 SPIRVPrepareFunctions::removeAggregateTypesFromSignature(Function *F) {
348 IRBuilder<> B(F->getContext());
350 bool IsRetAggr = F->getReturnType()->isAggregateType();
351 bool HasAggrArg =
352 std::any_of(F->arg_begin(), F->arg_end(), [](Argument &Arg) {
353 return Arg.getType()->isAggregateType();
355 bool DoClone = IsRetAggr || HasAggrArg;
356 if (!DoClone)
357 return F;
358 SmallVector<std::pair<int, Type *>, 4> ChangedTypes;
359 Type *RetType = IsRetAggr ? B.getInt32Ty() : F->getReturnType();
360 if (IsRetAggr)
361 ChangedTypes.push_back(std::pair<int, Type *>(-1, F->getReturnType()));
362 SmallVector<Type *, 4> ArgTypes;
363 for (const auto &Arg : F->args()) {
364 if (Arg.getType()->isAggregateType()) {
365 ArgTypes.push_back(B.getInt32Ty());
366 ChangedTypes.push_back(
367 std::pair<int, Type *>(Arg.getArgNo(), Arg.getType()));
368 } else
369 ArgTypes.push_back(Arg.getType());
371 FunctionType *NewFTy =
372 FunctionType::get(RetType, ArgTypes, F->getFunctionType()->isVarArg());
373 Function *NewF =
374 Function::Create(NewFTy, F->getLinkage(), F->getName(), *F->getParent());
376 ValueToValueMapTy VMap;
377 auto NewFArgIt = NewF->arg_begin();
378 for (auto &Arg : F->args()) {
379 StringRef ArgName = Arg.getName();
380 NewFArgIt->setName(ArgName);
381 VMap[&Arg] = &(*NewFArgIt++);
383 SmallVector<ReturnInst *, 8> Returns;
385 CloneFunctionInto(NewF, F, VMap, CloneFunctionChangeType::LocalChangesOnly,
386 Returns);
387 NewF->takeName(F);
389 NamedMDNode *FuncMD =
390 F->getParent()->getOrInsertNamedMetadata("spv.cloned_funcs");
391 SmallVector<Metadata *, 2> MDArgs;
392 MDArgs.push_back(MDString::get(B.getContext(), NewF->getName()));
393 for (auto &ChangedTyP : ChangedTypes)
394 MDArgs.push_back(MDNode::get(
395 B.getContext(),
396 {ConstantAsMetadata::get(B.getInt32(ChangedTyP.first)),
397 ValueAsMetadata::get(Constant::getNullValue(ChangedTyP.second))}));
398 MDNode *ThisFuncMD = MDNode::get(B.getContext(), MDArgs);
399 FuncMD->addOperand(ThisFuncMD);
401 for (auto *U : make_early_inc_range(F->users())) {
402 if (auto *CI = dyn_cast<CallInst>(U))
403 CI->mutateFunctionType(NewF->getFunctionType());
404 U->replaceUsesOfWith(F, NewF);
406 return NewF;
409 bool SPIRVPrepareFunctions::runOnModule(Module &M) {
410 bool Changed = false;
411 for (Function &F : M)
412 Changed |= substituteIntrinsicCalls(&F);
414 std::vector<Function *> FuncsWorklist;
415 for (auto &F : M)
416 FuncsWorklist.push_back(&F);
418 for (auto *F : FuncsWorklist) {
419 Function *NewF = removeAggregateTypesFromSignature(F);
421 if (NewF != F) {
422 F->eraseFromParent();
423 Changed = true;
426 return Changed;
429 ModulePass *
430 llvm::createSPIRVPrepareFunctionsPass(const SPIRVTargetMachine &TM) {
431 return new SPIRVPrepareFunctions(TM);