1 //===-- SPIRVPrepareFunctions.cpp - modify function signatures --*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This pass modifies function signatures containing aggregate arguments
10 // and/or return value before IRTranslator. Information about the original
11 // signatures is stored in metadata. It is used during call lowering to
12 // restore correct SPIR-V types of function arguments and return values.
13 // This pass also substitutes some llvm intrinsic calls with calls to newly
14 // generated functions (as the Khronos LLVM/SPIR-V Translator does).
16 // NOTE: this pass is a module-level one due to the necessity to modify
19 //===----------------------------------------------------------------------===//
22 #include "SPIRVSubtarget.h"
23 #include "SPIRVTargetMachine.h"
24 #include "SPIRVUtils.h"
25 #include "llvm/Analysis/ValueTracking.h"
26 #include "llvm/CodeGen/IntrinsicLowering.h"
27 #include "llvm/IR/IRBuilder.h"
28 #include "llvm/IR/IntrinsicInst.h"
29 #include "llvm/IR/Intrinsics.h"
30 #include "llvm/IR/IntrinsicsSPIRV.h"
31 #include "llvm/Transforms/Utils/Cloning.h"
32 #include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
39 void initializeSPIRVPrepareFunctionsPass(PassRegistry
&);
44 class SPIRVPrepareFunctions
: public ModulePass
{
45 const SPIRVTargetMachine
&TM
;
46 bool substituteIntrinsicCalls(Function
*F
);
47 Function
*removeAggregateTypesFromSignature(Function
*F
);
51 SPIRVPrepareFunctions(const SPIRVTargetMachine
&TM
) : ModulePass(ID
), TM(TM
) {
52 initializeSPIRVPrepareFunctionsPass(*PassRegistry::getPassRegistry());
55 bool runOnModule(Module
&M
) override
;
57 StringRef
getPassName() const override
{ return "SPIRV prepare functions"; }
59 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
60 ModulePass::getAnalysisUsage(AU
);
66 char SPIRVPrepareFunctions::ID
= 0;
68 INITIALIZE_PASS(SPIRVPrepareFunctions
, "prepare-functions",
69 "SPIRV prepare functions", false, false)
71 std::string
lowerLLVMIntrinsicName(IntrinsicInst
*II
) {
72 Function
*IntrinsicFunc
= II
->getCalledFunction();
73 assert(IntrinsicFunc
&& "Missing function");
74 std::string FuncName
= IntrinsicFunc
->getName().str();
75 std::replace(FuncName
.begin(), FuncName
.end(), '.', '_');
76 FuncName
= "spirv." + FuncName
;
80 static Function
*getOrCreateFunction(Module
*M
, Type
*RetTy
,
81 ArrayRef
<Type
*> ArgTypes
,
83 FunctionType
*FT
= FunctionType::get(RetTy
, ArgTypes
, false);
84 Function
*F
= M
->getFunction(Name
);
85 if (F
&& F
->getFunctionType() == FT
)
87 Function
*NewF
= Function::Create(FT
, GlobalValue::ExternalLinkage
, Name
, M
);
89 NewF
->setDSOLocal(F
->isDSOLocal());
90 NewF
->setCallingConv(CallingConv::SPIR_FUNC
);
94 static bool lowerIntrinsicToFunction(IntrinsicInst
*Intrinsic
) {
95 // For @llvm.memset.* intrinsic cases with constant value and length arguments
96 // are emulated via "storing" a constant array to the destination. For other
97 // cases we wrap the intrinsic in @spirv.llvm_memset_* function and expand the
98 // intrinsic to a loop via expandMemSetAsLoop().
99 if (auto *MSI
= dyn_cast
<MemSetInst
>(Intrinsic
))
100 if (isa
<Constant
>(MSI
->getValue()) && isa
<ConstantInt
>(MSI
->getLength()))
101 return false; // It is handled later using OpCopyMemorySized.
103 Module
*M
= Intrinsic
->getModule();
104 std::string FuncName
= lowerLLVMIntrinsicName(Intrinsic
);
105 if (Intrinsic
->isVolatile())
106 FuncName
+= ".volatile";
107 // Redirect @llvm.intrinsic.* call to @spirv.llvm_intrinsic_*
108 Function
*F
= M
->getFunction(FuncName
);
110 Intrinsic
->setCalledFunction(F
);
113 // TODO copy arguments attributes: nocapture writeonly.
115 M
->getOrInsertFunction(FuncName
, Intrinsic
->getFunctionType());
116 auto IntrinsicID
= Intrinsic
->getIntrinsicID();
117 Intrinsic
->setCalledFunction(FC
);
119 F
= dyn_cast
<Function
>(FC
.getCallee());
120 assert(F
&& "Callee must be a function");
122 switch (IntrinsicID
) {
123 case Intrinsic::memset
: {
124 auto *MSI
= static_cast<MemSetInst
*>(Intrinsic
);
125 Argument
*Dest
= F
->getArg(0);
126 Argument
*Val
= F
->getArg(1);
127 Argument
*Len
= F
->getArg(2);
128 Argument
*IsVolatile
= F
->getArg(3);
129 Dest
->setName("dest");
132 IsVolatile
->setName("isvolatile");
133 BasicBlock
*EntryBB
= BasicBlock::Create(M
->getContext(), "entry", F
);
134 IRBuilder
<> IRB(EntryBB
);
135 auto *MemSet
= IRB
.CreateMemSet(Dest
, Val
, Len
, MSI
->getDestAlign(),
138 expandMemSetAsLoop(cast
<MemSetInst
>(MemSet
));
139 MemSet
->eraseFromParent();
142 case Intrinsic::bswap
: {
143 BasicBlock
*EntryBB
= BasicBlock::Create(M
->getContext(), "entry", F
);
144 IRBuilder
<> IRB(EntryBB
);
145 auto *BSwap
= IRB
.CreateIntrinsic(Intrinsic::bswap
, Intrinsic
->getType(),
147 IRB
.CreateRet(BSwap
);
148 IntrinsicLowering
IL(M
->getDataLayout());
149 IL
.LowerIntrinsicCall(BSwap
);
158 static std::string
getAnnotation(Value
*AnnoVal
, Value
*OptAnnoVal
) {
159 if (auto *Ref
= dyn_cast_or_null
<GetElementPtrInst
>(AnnoVal
))
160 AnnoVal
= Ref
->getOperand(0);
161 if (auto *Ref
= dyn_cast_or_null
<BitCastInst
>(OptAnnoVal
))
162 OptAnnoVal
= Ref
->getOperand(0);
165 if (auto *C
= dyn_cast_or_null
<Constant
>(AnnoVal
)) {
167 if (getConstantStringInfo(C
, Str
))
170 // handle optional annotation parameter in a way that Khronos Translator do
171 // (collect integers wrapped in a struct)
172 if (auto *C
= dyn_cast_or_null
<Constant
>(OptAnnoVal
);
173 C
&& C
->getNumOperands()) {
174 Value
*MaybeStruct
= C
->getOperand(0);
175 if (auto *Struct
= dyn_cast
<ConstantStruct
>(MaybeStruct
)) {
176 for (unsigned I
= 0, E
= Struct
->getNumOperands(); I
!= E
; ++I
) {
177 if (auto *CInt
= dyn_cast
<ConstantInt
>(Struct
->getOperand(I
)))
178 Anno
+= (I
== 0 ? ": " : ", ") +
179 std::to_string(CInt
->getType()->getIntegerBitWidth() == 1
180 ? CInt
->getZExtValue()
181 : CInt
->getSExtValue());
183 } else if (auto *Struct
= dyn_cast
<ConstantAggregateZero
>(MaybeStruct
)) {
184 // { i32 i32 ... } zeroinitializer
185 for (unsigned I
= 0, E
= Struct
->getType()->getStructNumElements();
187 Anno
+= I
== 0 ? ": 0" : ", 0";
193 static SmallVector
<Metadata
*> parseAnnotation(Value
*I
,
194 const std::string
&Anno
,
197 // Try to parse the annotation string according to the following rules:
198 // annotation := ({kind} | {kind:value,value,...})+
200 // value := number | string
201 static const std::regex
R(
202 "\\{(\\d+)(?:[:,](\\d+|\"[^\"]*\")(?:,(\\d+|\"[^\"]*\"))*)?\\}");
203 SmallVector
<Metadata
*> MDs
;
205 for (std::sregex_iterator
206 It
= std::sregex_iterator(Anno
.begin(), Anno
.end(), R
),
207 ItEnd
= std::sregex_iterator();
209 if (It
->position() != Pos
)
210 return SmallVector
<Metadata
*>{};
211 Pos
= It
->position() + It
->length();
212 std::smatch Match
= *It
;
213 SmallVector
<Metadata
*> MDsItem
;
214 for (std::size_t i
= 1; i
< Match
.size(); ++i
) {
215 std::ssub_match SMatch
= Match
[i
];
216 std::string Item
= SMatch
.str();
217 if (Item
.length() == 0)
219 if (Item
[0] == '"') {
220 Item
= Item
.substr(1, Item
.length() - 2);
221 // Acceptable format of the string snippet is:
222 static const std::regex
RStr("^(\\d+)(?:,(\\d+))*$");
223 if (std::smatch MatchStr
; std::regex_match(Item
, MatchStr
, RStr
)) {
224 for (std::size_t SubIdx
= 1; SubIdx
< MatchStr
.size(); ++SubIdx
)
225 if (std::string SubStr
= MatchStr
[SubIdx
].str(); SubStr
.length())
226 MDsItem
.push_back(ConstantAsMetadata::get(
227 ConstantInt::get(Int32Ty
, std::stoi(SubStr
))));
229 MDsItem
.push_back(MDString::get(Ctx
, Item
));
231 } else if (int32_t Num
;
232 std::from_chars(Item
.data(), Item
.data() + Item
.size(), Num
)
233 .ec
== std::errc
{}) {
235 ConstantAsMetadata::get(ConstantInt::get(Int32Ty
, Num
)));
237 MDsItem
.push_back(MDString::get(Ctx
, Item
));
240 if (MDsItem
.size() == 0)
241 return SmallVector
<Metadata
*>{};
242 MDs
.push_back(MDNode::get(Ctx
, MDsItem
));
244 return Pos
== static_cast<int>(Anno
.length()) ? MDs
245 : SmallVector
<Metadata
*>{};
248 static void lowerPtrAnnotation(IntrinsicInst
*II
) {
249 LLVMContext
&Ctx
= II
->getContext();
250 Type
*Int32Ty
= Type::getInt32Ty(Ctx
);
252 // Retrieve an annotation string from arguments.
253 Value
*PtrArg
= nullptr;
254 if (auto *BI
= dyn_cast
<BitCastInst
>(II
->getArgOperand(0)))
255 PtrArg
= BI
->getOperand(0);
257 PtrArg
= II
->getOperand(0);
259 getAnnotation(II
->getArgOperand(1),
260 4 < II
->arg_size() ? II
->getArgOperand(4) : nullptr);
262 // Parse the annotation.
263 SmallVector
<Metadata
*> MDs
= parseAnnotation(II
, Anno
, Ctx
, Int32Ty
);
265 // If the annotation string is not parsed successfully we don't know the
266 // format used and output it as a general UserSemantic decoration.
267 // Otherwise MDs is a Metadata tuple (a decoration list) in the format
268 // expected by `spirv.Decorations`.
269 if (MDs
.size() == 0) {
270 auto UserSemantic
= ConstantAsMetadata::get(ConstantInt::get(
271 Int32Ty
, static_cast<uint32_t>(SPIRV::Decoration::UserSemantic
)));
272 MDs
.push_back(MDNode::get(Ctx
, {UserSemantic
, MDString::get(Ctx
, Anno
)}));
275 // Build the internal intrinsic function.
276 IRBuilder
<> IRB(II
->getParent());
277 IRB
.SetInsertPoint(II
);
279 Intrinsic::spv_assign_decoration
, {PtrArg
->getType()},
280 {PtrArg
, MetadataAsValue::get(Ctx
, MDNode::get(Ctx
, MDs
))});
281 II
->replaceAllUsesWith(II
->getOperand(0));
284 static void lowerFunnelShifts(IntrinsicInst
*FSHIntrinsic
) {
285 // Get a separate function - otherwise, we'd have to rework the CFG of the
286 // current one. Then simply replace the intrinsic uses with a call to the new
288 // Generate LLVM IR for i* @spirv.llvm_fsh?_i* (i* %a, i* %b, i* %c)
289 Module
*M
= FSHIntrinsic
->getModule();
290 FunctionType
*FSHFuncTy
= FSHIntrinsic
->getFunctionType();
291 Type
*FSHRetTy
= FSHFuncTy
->getReturnType();
292 const std::string FuncName
= lowerLLVMIntrinsicName(FSHIntrinsic
);
294 getOrCreateFunction(M
, FSHRetTy
, FSHFuncTy
->params(), FuncName
);
296 if (!FSHFunc
->empty()) {
297 FSHIntrinsic
->setCalledFunction(FSHFunc
);
300 BasicBlock
*RotateBB
= BasicBlock::Create(M
->getContext(), "rotate", FSHFunc
);
301 IRBuilder
<> IRB(RotateBB
);
302 Type
*Ty
= FSHFunc
->getReturnType();
303 // Build the actual funnel shift rotate logic.
304 // In the comments, "int" is used interchangeably with "vector of int
306 FixedVectorType
*VectorTy
= dyn_cast
<FixedVectorType
>(Ty
);
307 Type
*IntTy
= VectorTy
? VectorTy
->getElementType() : Ty
;
308 unsigned BitWidth
= IntTy
->getIntegerBitWidth();
309 ConstantInt
*BitWidthConstant
= IRB
.getInt({BitWidth
, BitWidth
});
310 Value
*BitWidthForInsts
=
312 ? IRB
.CreateVectorSplat(VectorTy
->getNumElements(), BitWidthConstant
)
314 Value
*RotateModVal
=
315 IRB
.CreateURem(/*Rotate*/ FSHFunc
->getArg(2), BitWidthForInsts
);
316 Value
*FirstShift
= nullptr, *SecShift
= nullptr;
317 if (FSHIntrinsic
->getIntrinsicID() == Intrinsic::fshr
) {
318 // Shift the less significant number right, the "rotate" number of bits
319 // will be 0-filled on the left as a result of this regular shift.
320 FirstShift
= IRB
.CreateLShr(FSHFunc
->getArg(1), RotateModVal
);
322 // Shift the more significant number left, the "rotate" number of bits
323 // will be 0-filled on the right as a result of this regular shift.
324 FirstShift
= IRB
.CreateShl(FSHFunc
->getArg(0), RotateModVal
);
326 // We want the "rotate" number of the more significant int's LSBs (MSBs) to
327 // occupy the leftmost (rightmost) "0 space" left by the previous operation.
328 // Therefore, subtract the "rotate" number from the integer bitsize...
329 Value
*SubRotateVal
= IRB
.CreateSub(BitWidthForInsts
, RotateModVal
);
330 if (FSHIntrinsic
->getIntrinsicID() == Intrinsic::fshr
) {
331 // ...and left-shift the more significant int by this number, zero-filling
333 SecShift
= IRB
.CreateShl(FSHFunc
->getArg(0), SubRotateVal
);
335 // ...and right-shift the less significant int by this number, zero-filling
337 SecShift
= IRB
.CreateLShr(FSHFunc
->getArg(1), SubRotateVal
);
339 // A simple binary addition of the shifted ints yields the final result.
340 IRB
.CreateRet(IRB
.CreateOr(FirstShift
, SecShift
));
342 FSHIntrinsic
->setCalledFunction(FSHFunc
);
345 static void lowerExpectAssume(IntrinsicInst
*II
) {
346 // If we cannot use the SPV_KHR_expect_assume extension, then we need to
347 // ignore the intrinsic and move on. It should be removed later on by LLVM.
348 // Otherwise we should lower the intrinsic to the corresponding SPIR-V
350 // For @llvm.assume we have OpAssumeTrueKHR.
351 // For @llvm.expect we have OpExpectKHR.
353 // We need to lower this into a builtin and then the builtin into a SPIR-V
355 if (II
->getIntrinsicID() == Intrinsic::assume
) {
356 Function
*F
= Intrinsic::getOrInsertDeclaration(
357 II
->getModule(), Intrinsic::SPVIntrinsics::spv_assume
);
358 II
->setCalledFunction(F
);
359 } else if (II
->getIntrinsicID() == Intrinsic::expect
) {
360 Function
*F
= Intrinsic::getOrInsertDeclaration(
361 II
->getModule(), Intrinsic::SPVIntrinsics::spv_expect
,
362 {II
->getOperand(0)->getType()});
363 II
->setCalledFunction(F
);
365 llvm_unreachable("Unknown intrinsic");
371 static bool toSpvOverloadedIntrinsic(IntrinsicInst
*II
, Intrinsic::ID NewID
,
372 ArrayRef
<unsigned> OpNos
) {
373 Function
*F
= nullptr;
375 F
= Intrinsic::getOrInsertDeclaration(II
->getModule(), NewID
);
377 SmallVector
<Type
*, 4> Tys
;
378 for (unsigned OpNo
: OpNos
)
379 Tys
.push_back(II
->getOperand(OpNo
)->getType());
380 F
= Intrinsic::getOrInsertDeclaration(II
->getModule(), NewID
, Tys
);
382 II
->setCalledFunction(F
);
386 // Substitutes calls to LLVM intrinsics with either calls to SPIR-V intrinsics
387 // or calls to proper generated functions. Returns True if F was modified.
388 bool SPIRVPrepareFunctions::substituteIntrinsicCalls(Function
*F
) {
389 bool Changed
= false;
390 for (BasicBlock
&BB
: *F
) {
391 for (Instruction
&I
: BB
) {
392 auto Call
= dyn_cast
<CallInst
>(&I
);
395 Function
*CF
= Call
->getCalledFunction();
396 if (!CF
|| !CF
->isIntrinsic())
398 auto *II
= cast
<IntrinsicInst
>(Call
);
399 switch (II
->getIntrinsicID()) {
400 case Intrinsic::memset
:
401 case Intrinsic::bswap
:
402 Changed
|= lowerIntrinsicToFunction(II
);
404 case Intrinsic::fshl
:
405 case Intrinsic::fshr
:
406 lowerFunnelShifts(II
);
409 case Intrinsic::assume
:
410 case Intrinsic::expect
: {
411 const SPIRVSubtarget
&STI
= TM
.getSubtarget
<SPIRVSubtarget
>(*F
);
412 if (STI
.canUseExtension(SPIRV::Extension::SPV_KHR_expect_assume
))
413 lowerExpectAssume(II
);
416 case Intrinsic::lifetime_start
:
417 Changed
|= toSpvOverloadedIntrinsic(
418 II
, Intrinsic::SPVIntrinsics::spv_lifetime_start
, {1});
420 case Intrinsic::lifetime_end
:
421 Changed
|= toSpvOverloadedIntrinsic(
422 II
, Intrinsic::SPVIntrinsics::spv_lifetime_end
, {1});
424 case Intrinsic::ptr_annotation
:
425 lowerPtrAnnotation(II
);
434 // Returns F if aggregate argument/return types are not present or cloned F
435 // function with the types replaced by i32 types. The change in types is
436 // noted in 'spv.cloned_funcs' metadata for later restoration.
438 SPIRVPrepareFunctions::removeAggregateTypesFromSignature(Function
*F
) {
439 bool IsRetAggr
= F
->getReturnType()->isAggregateType();
440 // Allow intrinsics with aggregate return type to reach GlobalISel
441 if (F
->isIntrinsic() && IsRetAggr
)
444 IRBuilder
<> B(F
->getContext());
447 std::any_of(F
->arg_begin(), F
->arg_end(), [](Argument
&Arg
) {
448 return Arg
.getType()->isAggregateType();
450 bool DoClone
= IsRetAggr
|| HasAggrArg
;
453 SmallVector
<std::pair
<int, Type
*>, 4> ChangedTypes
;
454 Type
*RetType
= IsRetAggr
? B
.getInt32Ty() : F
->getReturnType();
456 ChangedTypes
.push_back(std::pair
<int, Type
*>(-1, F
->getReturnType()));
457 SmallVector
<Type
*, 4> ArgTypes
;
458 for (const auto &Arg
: F
->args()) {
459 if (Arg
.getType()->isAggregateType()) {
460 ArgTypes
.push_back(B
.getInt32Ty());
461 ChangedTypes
.push_back(
462 std::pair
<int, Type
*>(Arg
.getArgNo(), Arg
.getType()));
464 ArgTypes
.push_back(Arg
.getType());
466 FunctionType
*NewFTy
=
467 FunctionType::get(RetType
, ArgTypes
, F
->getFunctionType()->isVarArg());
469 Function::Create(NewFTy
, F
->getLinkage(), F
->getName(), *F
->getParent());
471 ValueToValueMapTy VMap
;
472 auto NewFArgIt
= NewF
->arg_begin();
473 for (auto &Arg
: F
->args()) {
474 StringRef ArgName
= Arg
.getName();
475 NewFArgIt
->setName(ArgName
);
476 VMap
[&Arg
] = &(*NewFArgIt
++);
478 SmallVector
<ReturnInst
*, 8> Returns
;
480 CloneFunctionInto(NewF
, F
, VMap
, CloneFunctionChangeType::LocalChangesOnly
,
484 NamedMDNode
*FuncMD
=
485 F
->getParent()->getOrInsertNamedMetadata("spv.cloned_funcs");
486 SmallVector
<Metadata
*, 2> MDArgs
;
487 MDArgs
.push_back(MDString::get(B
.getContext(), NewF
->getName()));
488 for (auto &ChangedTyP
: ChangedTypes
)
489 MDArgs
.push_back(MDNode::get(
491 {ConstantAsMetadata::get(B
.getInt32(ChangedTyP
.first
)),
492 ValueAsMetadata::get(Constant::getNullValue(ChangedTyP
.second
))}));
493 MDNode
*ThisFuncMD
= MDNode::get(B
.getContext(), MDArgs
);
494 FuncMD
->addOperand(ThisFuncMD
);
496 for (auto *U
: make_early_inc_range(F
->users())) {
497 if (auto *CI
= dyn_cast
<CallInst
>(U
))
498 CI
->mutateFunctionType(NewF
->getFunctionType());
499 U
->replaceUsesOfWith(F
, NewF
);
502 // register the mutation
503 if (RetType
!= F
->getReturnType())
504 TM
.getSubtarget
<SPIRVSubtarget
>(*F
).getSPIRVGlobalRegistry()->addMutated(
505 NewF
, F
->getReturnType());
509 bool SPIRVPrepareFunctions::runOnModule(Module
&M
) {
510 bool Changed
= false;
511 for (Function
&F
: M
) {
512 Changed
|= substituteIntrinsicCalls(&F
);
513 Changed
|= sortBlocks(F
);
516 std::vector
<Function
*> FuncsWorklist
;
518 FuncsWorklist
.push_back(&F
);
520 for (auto *F
: FuncsWorklist
) {
521 Function
*NewF
= removeAggregateTypesFromSignature(F
);
524 F
->eraseFromParent();
532 llvm::createSPIRVPrepareFunctionsPass(const SPIRVTargetMachine
&TM
) {
533 return new SPIRVPrepareFunctions(TM
);