1 //===-- SPIRVPrepareFunctions.cpp - modify function signatures --*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This pass modifies function signatures containing aggregate arguments
10 // and/or return value before IRTranslator. Information about the original
11 // signatures is stored in metadata. It is used during call lowering to
12 // restore correct SPIR-V types of function arguments and return values.
13 // This pass also substitutes some llvm intrinsic calls with calls to newly
14 // generated functions (as the Khronos LLVM/SPIR-V Translator does).
16 // NOTE: this pass is a module-level one due to the necessity to modify
19 //===----------------------------------------------------------------------===//
22 #include "SPIRVSubtarget.h"
23 #include "SPIRVTargetMachine.h"
24 #include "SPIRVUtils.h"
25 #include "llvm/CodeGen/IntrinsicLowering.h"
26 #include "llvm/IR/IRBuilder.h"
27 #include "llvm/IR/IntrinsicInst.h"
28 #include "llvm/IR/Intrinsics.h"
29 #include "llvm/IR/IntrinsicsSPIRV.h"
30 #include "llvm/Transforms/Utils/Cloning.h"
31 #include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
36 void initializeSPIRVPrepareFunctionsPass(PassRegistry
&);
41 class SPIRVPrepareFunctions
: public ModulePass
{
42 const SPIRVTargetMachine
&TM
;
43 bool substituteIntrinsicCalls(Function
*F
);
44 Function
*removeAggregateTypesFromSignature(Function
*F
);
48 SPIRVPrepareFunctions(const SPIRVTargetMachine
&TM
) : ModulePass(ID
), TM(TM
) {
49 initializeSPIRVPrepareFunctionsPass(*PassRegistry::getPassRegistry());
52 bool runOnModule(Module
&M
) override
;
54 StringRef
getPassName() const override
{ return "SPIRV prepare functions"; }
56 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
57 ModulePass::getAnalysisUsage(AU
);
63 char SPIRVPrepareFunctions::ID
= 0;
65 INITIALIZE_PASS(SPIRVPrepareFunctions
, "prepare-functions",
66 "SPIRV prepare functions", false, false)
68 std::string
lowerLLVMIntrinsicName(IntrinsicInst
*II
) {
69 Function
*IntrinsicFunc
= II
->getCalledFunction();
70 assert(IntrinsicFunc
&& "Missing function");
71 std::string FuncName
= IntrinsicFunc
->getName().str();
72 std::replace(FuncName
.begin(), FuncName
.end(), '.', '_');
73 FuncName
= "spirv." + FuncName
;
77 static Function
*getOrCreateFunction(Module
*M
, Type
*RetTy
,
78 ArrayRef
<Type
*> ArgTypes
,
80 FunctionType
*FT
= FunctionType::get(RetTy
, ArgTypes
, false);
81 Function
*F
= M
->getFunction(Name
);
82 if (F
&& F
->getFunctionType() == FT
)
84 Function
*NewF
= Function::Create(FT
, GlobalValue::ExternalLinkage
, Name
, M
);
86 NewF
->setDSOLocal(F
->isDSOLocal());
87 NewF
->setCallingConv(CallingConv::SPIR_FUNC
);
91 static bool lowerIntrinsicToFunction(IntrinsicInst
*Intrinsic
) {
92 // For @llvm.memset.* intrinsic cases with constant value and length arguments
93 // are emulated via "storing" a constant array to the destination. For other
94 // cases we wrap the intrinsic in @spirv.llvm_memset_* function and expand the
95 // intrinsic to a loop via expandMemSetAsLoop().
96 if (auto *MSI
= dyn_cast
<MemSetInst
>(Intrinsic
))
97 if (isa
<Constant
>(MSI
->getValue()) && isa
<ConstantInt
>(MSI
->getLength()))
98 return false; // It is handled later using OpCopyMemorySized.
100 Module
*M
= Intrinsic
->getModule();
101 std::string FuncName
= lowerLLVMIntrinsicName(Intrinsic
);
102 if (Intrinsic
->isVolatile())
103 FuncName
+= ".volatile";
104 // Redirect @llvm.intrinsic.* call to @spirv.llvm_intrinsic_*
105 Function
*F
= M
->getFunction(FuncName
);
107 Intrinsic
->setCalledFunction(F
);
110 // TODO copy arguments attributes: nocapture writeonly.
112 M
->getOrInsertFunction(FuncName
, Intrinsic
->getFunctionType());
113 auto IntrinsicID
= Intrinsic
->getIntrinsicID();
114 Intrinsic
->setCalledFunction(FC
);
116 F
= dyn_cast
<Function
>(FC
.getCallee());
117 assert(F
&& "Callee must be a function");
119 switch (IntrinsicID
) {
120 case Intrinsic::memset
: {
121 auto *MSI
= static_cast<MemSetInst
*>(Intrinsic
);
122 Argument
*Dest
= F
->getArg(0);
123 Argument
*Val
= F
->getArg(1);
124 Argument
*Len
= F
->getArg(2);
125 Argument
*IsVolatile
= F
->getArg(3);
126 Dest
->setName("dest");
129 IsVolatile
->setName("isvolatile");
130 BasicBlock
*EntryBB
= BasicBlock::Create(M
->getContext(), "entry", F
);
131 IRBuilder
<> IRB(EntryBB
);
132 auto *MemSet
= IRB
.CreateMemSet(Dest
, Val
, Len
, MSI
->getDestAlign(),
135 expandMemSetAsLoop(cast
<MemSetInst
>(MemSet
));
136 MemSet
->eraseFromParent();
139 case Intrinsic::bswap
: {
140 BasicBlock
*EntryBB
= BasicBlock::Create(M
->getContext(), "entry", F
);
141 IRBuilder
<> IRB(EntryBB
);
142 auto *BSwap
= IRB
.CreateIntrinsic(Intrinsic::bswap
, Intrinsic
->getType(),
144 IRB
.CreateRet(BSwap
);
145 IntrinsicLowering
IL(M
->getDataLayout());
146 IL
.LowerIntrinsicCall(BSwap
);
155 static void lowerFunnelShifts(IntrinsicInst
*FSHIntrinsic
) {
156 // Get a separate function - otherwise, we'd have to rework the CFG of the
157 // current one. Then simply replace the intrinsic uses with a call to the new
159 // Generate LLVM IR for i* @spirv.llvm_fsh?_i* (i* %a, i* %b, i* %c)
160 Module
*M
= FSHIntrinsic
->getModule();
161 FunctionType
*FSHFuncTy
= FSHIntrinsic
->getFunctionType();
162 Type
*FSHRetTy
= FSHFuncTy
->getReturnType();
163 const std::string FuncName
= lowerLLVMIntrinsicName(FSHIntrinsic
);
165 getOrCreateFunction(M
, FSHRetTy
, FSHFuncTy
->params(), FuncName
);
167 if (!FSHFunc
->empty()) {
168 FSHIntrinsic
->setCalledFunction(FSHFunc
);
171 BasicBlock
*RotateBB
= BasicBlock::Create(M
->getContext(), "rotate", FSHFunc
);
172 IRBuilder
<> IRB(RotateBB
);
173 Type
*Ty
= FSHFunc
->getReturnType();
174 // Build the actual funnel shift rotate logic.
175 // In the comments, "int" is used interchangeably with "vector of int
177 FixedVectorType
*VectorTy
= dyn_cast
<FixedVectorType
>(Ty
);
178 Type
*IntTy
= VectorTy
? VectorTy
->getElementType() : Ty
;
179 unsigned BitWidth
= IntTy
->getIntegerBitWidth();
180 ConstantInt
*BitWidthConstant
= IRB
.getInt({BitWidth
, BitWidth
});
181 Value
*BitWidthForInsts
=
183 ? IRB
.CreateVectorSplat(VectorTy
->getNumElements(), BitWidthConstant
)
185 Value
*RotateModVal
=
186 IRB
.CreateURem(/*Rotate*/ FSHFunc
->getArg(2), BitWidthForInsts
);
187 Value
*FirstShift
= nullptr, *SecShift
= nullptr;
188 if (FSHIntrinsic
->getIntrinsicID() == Intrinsic::fshr
) {
189 // Shift the less significant number right, the "rotate" number of bits
190 // will be 0-filled on the left as a result of this regular shift.
191 FirstShift
= IRB
.CreateLShr(FSHFunc
->getArg(1), RotateModVal
);
193 // Shift the more significant number left, the "rotate" number of bits
194 // will be 0-filled on the right as a result of this regular shift.
195 FirstShift
= IRB
.CreateShl(FSHFunc
->getArg(0), RotateModVal
);
197 // We want the "rotate" number of the more significant int's LSBs (MSBs) to
198 // occupy the leftmost (rightmost) "0 space" left by the previous operation.
199 // Therefore, subtract the "rotate" number from the integer bitsize...
200 Value
*SubRotateVal
= IRB
.CreateSub(BitWidthForInsts
, RotateModVal
);
201 if (FSHIntrinsic
->getIntrinsicID() == Intrinsic::fshr
) {
202 // ...and left-shift the more significant int by this number, zero-filling
204 SecShift
= IRB
.CreateShl(FSHFunc
->getArg(0), SubRotateVal
);
206 // ...and right-shift the less significant int by this number, zero-filling
208 SecShift
= IRB
.CreateLShr(FSHFunc
->getArg(1), SubRotateVal
);
210 // A simple binary addition of the shifted ints yields the final result.
211 IRB
.CreateRet(IRB
.CreateOr(FirstShift
, SecShift
));
213 FSHIntrinsic
->setCalledFunction(FSHFunc
);
216 static void buildUMulWithOverflowFunc(Function
*UMulFunc
) {
217 // The function body is already created.
218 if (!UMulFunc
->empty())
221 BasicBlock
*EntryBB
= BasicBlock::Create(UMulFunc
->getParent()->getContext(),
223 IRBuilder
<> IRB(EntryBB
);
224 // Build the actual unsigned multiplication logic with the overflow
225 // indication. Do unsigned multiplication Mul = A * B. Then check
226 // if unsigned division Div = Mul / A is not equal to B. If so,
227 // then overflow has happened.
228 Value
*Mul
= IRB
.CreateNUWMul(UMulFunc
->getArg(0), UMulFunc
->getArg(1));
229 Value
*Div
= IRB
.CreateUDiv(Mul
, UMulFunc
->getArg(0));
230 Value
*Overflow
= IRB
.CreateICmpNE(UMulFunc
->getArg(0), Div
);
232 // umul.with.overflow intrinsic return a structure, where the first element
233 // is the multiplication result, and the second is an overflow bit.
234 Type
*StructTy
= UMulFunc
->getReturnType();
235 Value
*Agg
= IRB
.CreateInsertValue(PoisonValue::get(StructTy
), Mul
, {0});
236 Value
*Res
= IRB
.CreateInsertValue(Agg
, Overflow
, {1});
240 static void lowerExpectAssume(IntrinsicInst
*II
) {
241 // If we cannot use the SPV_KHR_expect_assume extension, then we need to
242 // ignore the intrinsic and move on. It should be removed later on by LLVM.
243 // Otherwise we should lower the intrinsic to the corresponding SPIR-V
245 // For @llvm.assume we have OpAssumeTrueKHR.
246 // For @llvm.expect we have OpExpectKHR.
248 // We need to lower this into a builtin and then the builtin into a SPIR-V
250 if (II
->getIntrinsicID() == Intrinsic::assume
) {
251 Function
*F
= Intrinsic::getDeclaration(
252 II
->getModule(), Intrinsic::SPVIntrinsics::spv_assume
);
253 II
->setCalledFunction(F
);
254 } else if (II
->getIntrinsicID() == Intrinsic::expect
) {
255 Function
*F
= Intrinsic::getDeclaration(
256 II
->getModule(), Intrinsic::SPVIntrinsics::spv_expect
,
257 {II
->getOperand(0)->getType()});
258 II
->setCalledFunction(F
);
260 llvm_unreachable("Unknown intrinsic");
266 static bool toSpvOverloadedIntrinsic(IntrinsicInst
*II
, Intrinsic::ID NewID
,
267 ArrayRef
<unsigned> OpNos
) {
268 Function
*F
= nullptr;
270 F
= Intrinsic::getDeclaration(II
->getModule(), NewID
);
272 SmallVector
<Type
*, 4> Tys
;
273 for (unsigned OpNo
: OpNos
)
274 Tys
.push_back(II
->getOperand(OpNo
)->getType());
275 F
= Intrinsic::getDeclaration(II
->getModule(), NewID
, Tys
);
277 II
->setCalledFunction(F
);
281 static void lowerUMulWithOverflow(IntrinsicInst
*UMulIntrinsic
) {
282 // Get a separate function - otherwise, we'd have to rework the CFG of the
283 // current one. Then simply replace the intrinsic uses with a call to the new
285 Module
*M
= UMulIntrinsic
->getModule();
286 FunctionType
*UMulFuncTy
= UMulIntrinsic
->getFunctionType();
287 Type
*FSHLRetTy
= UMulFuncTy
->getReturnType();
288 const std::string FuncName
= lowerLLVMIntrinsicName(UMulIntrinsic
);
290 getOrCreateFunction(M
, FSHLRetTy
, UMulFuncTy
->params(), FuncName
);
291 buildUMulWithOverflowFunc(UMulFunc
);
292 UMulIntrinsic
->setCalledFunction(UMulFunc
);
295 // Substitutes calls to LLVM intrinsics with either calls to SPIR-V intrinsics
296 // or calls to proper generated functions. Returns True if F was modified.
297 bool SPIRVPrepareFunctions::substituteIntrinsicCalls(Function
*F
) {
298 bool Changed
= false;
299 for (BasicBlock
&BB
: *F
) {
300 for (Instruction
&I
: BB
) {
301 auto Call
= dyn_cast
<CallInst
>(&I
);
304 Function
*CF
= Call
->getCalledFunction();
305 if (!CF
|| !CF
->isIntrinsic())
307 auto *II
= cast
<IntrinsicInst
>(Call
);
308 switch (II
->getIntrinsicID()) {
309 case Intrinsic::memset
:
310 case Intrinsic::bswap
:
311 Changed
|= lowerIntrinsicToFunction(II
);
313 case Intrinsic::fshl
:
314 case Intrinsic::fshr
:
315 lowerFunnelShifts(II
);
318 case Intrinsic::umul_with_overflow
:
319 lowerUMulWithOverflow(II
);
322 case Intrinsic::assume
:
323 case Intrinsic::expect
: {
324 const SPIRVSubtarget
&STI
= TM
.getSubtarget
<SPIRVSubtarget
>(*F
);
325 if (STI
.canUseExtension(SPIRV::Extension::SPV_KHR_expect_assume
))
326 lowerExpectAssume(II
);
329 case Intrinsic::lifetime_start
:
330 Changed
|= toSpvOverloadedIntrinsic(
331 II
, Intrinsic::SPVIntrinsics::spv_lifetime_start
, {1});
333 case Intrinsic::lifetime_end
:
334 Changed
|= toSpvOverloadedIntrinsic(
335 II
, Intrinsic::SPVIntrinsics::spv_lifetime_end
, {1});
343 // Returns F if aggregate argument/return types are not present or cloned F
344 // function with the types replaced by i32 types. The change in types is
345 // noted in 'spv.cloned_funcs' metadata for later restoration.
347 SPIRVPrepareFunctions::removeAggregateTypesFromSignature(Function
*F
) {
348 IRBuilder
<> B(F
->getContext());
350 bool IsRetAggr
= F
->getReturnType()->isAggregateType();
352 std::any_of(F
->arg_begin(), F
->arg_end(), [](Argument
&Arg
) {
353 return Arg
.getType()->isAggregateType();
355 bool DoClone
= IsRetAggr
|| HasAggrArg
;
358 SmallVector
<std::pair
<int, Type
*>, 4> ChangedTypes
;
359 Type
*RetType
= IsRetAggr
? B
.getInt32Ty() : F
->getReturnType();
361 ChangedTypes
.push_back(std::pair
<int, Type
*>(-1, F
->getReturnType()));
362 SmallVector
<Type
*, 4> ArgTypes
;
363 for (const auto &Arg
: F
->args()) {
364 if (Arg
.getType()->isAggregateType()) {
365 ArgTypes
.push_back(B
.getInt32Ty());
366 ChangedTypes
.push_back(
367 std::pair
<int, Type
*>(Arg
.getArgNo(), Arg
.getType()));
369 ArgTypes
.push_back(Arg
.getType());
371 FunctionType
*NewFTy
=
372 FunctionType::get(RetType
, ArgTypes
, F
->getFunctionType()->isVarArg());
374 Function::Create(NewFTy
, F
->getLinkage(), F
->getName(), *F
->getParent());
376 ValueToValueMapTy VMap
;
377 auto NewFArgIt
= NewF
->arg_begin();
378 for (auto &Arg
: F
->args()) {
379 StringRef ArgName
= Arg
.getName();
380 NewFArgIt
->setName(ArgName
);
381 VMap
[&Arg
] = &(*NewFArgIt
++);
383 SmallVector
<ReturnInst
*, 8> Returns
;
385 CloneFunctionInto(NewF
, F
, VMap
, CloneFunctionChangeType::LocalChangesOnly
,
389 NamedMDNode
*FuncMD
=
390 F
->getParent()->getOrInsertNamedMetadata("spv.cloned_funcs");
391 SmallVector
<Metadata
*, 2> MDArgs
;
392 MDArgs
.push_back(MDString::get(B
.getContext(), NewF
->getName()));
393 for (auto &ChangedTyP
: ChangedTypes
)
394 MDArgs
.push_back(MDNode::get(
396 {ConstantAsMetadata::get(B
.getInt32(ChangedTyP
.first
)),
397 ValueAsMetadata::get(Constant::getNullValue(ChangedTyP
.second
))}));
398 MDNode
*ThisFuncMD
= MDNode::get(B
.getContext(), MDArgs
);
399 FuncMD
->addOperand(ThisFuncMD
);
401 for (auto *U
: make_early_inc_range(F
->users())) {
402 if (auto *CI
= dyn_cast
<CallInst
>(U
))
403 CI
->mutateFunctionType(NewF
->getFunctionType());
404 U
->replaceUsesOfWith(F
, NewF
);
409 bool SPIRVPrepareFunctions::runOnModule(Module
&M
) {
410 bool Changed
= false;
411 for (Function
&F
: M
)
412 Changed
|= substituteIntrinsicCalls(&F
);
414 std::vector
<Function
*> FuncsWorklist
;
416 FuncsWorklist
.push_back(&F
);
418 for (auto *F
: FuncsWorklist
) {
419 Function
*NewF
= removeAggregateTypesFromSignature(F
);
422 F
->eraseFromParent();
430 llvm::createSPIRVPrepareFunctionsPass(const SPIRVTargetMachine
&TM
) {
431 return new SPIRVPrepareFunctions(TM
);