1 //===-- SPIRVRegularizer.cpp - regularize IR for SPIR-V ---------*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This pass implements regularization of LLVM IR for SPIR-V. The prototype of
10 // the pass was taken from SPIRV-LLVM translator.
12 //===----------------------------------------------------------------------===//
15 #include "SPIRVTargetMachine.h"
16 #include "llvm/Demangle/Demangle.h"
17 #include "llvm/IR/InstIterator.h"
18 #include "llvm/IR/InstVisitor.h"
19 #include "llvm/IR/PassManager.h"
20 #include "llvm/Transforms/Utils/Cloning.h"
24 #define DEBUG_TYPE "spirv-regularizer"
29 void initializeSPIRVRegularizerPass(PassRegistry
&);
33 struct SPIRVRegularizer
: public FunctionPass
, InstVisitor
<SPIRVRegularizer
> {
34 DenseMap
<Function
*, Function
*> Old2NewFuncs
;
38 SPIRVRegularizer() : FunctionPass(ID
) {
39 initializeSPIRVRegularizerPass(*PassRegistry::getPassRegistry());
41 bool runOnFunction(Function
&F
) override
;
42 StringRef
getPassName() const override
{ return "SPIR-V Regularizer"; }
44 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
45 FunctionPass::getAnalysisUsage(AU
);
47 void visitCallInst(CallInst
&CI
);
50 void visitCallScalToVec(CallInst
*CI
, StringRef MangledName
,
51 StringRef DemangledName
);
52 void runLowerConstExpr(Function
&F
);
56 char SPIRVRegularizer::ID
= 0;
58 INITIALIZE_PASS(SPIRVRegularizer
, DEBUG_TYPE
, "SPIR-V Regularizer", false,
61 // Since SPIR-V cannot represent constant expression, constant expressions
62 // in LLVM IR need to be lowered to instructions. For each function,
63 // the constant expressions used by instructions of the function are replaced
64 // by instructions placed in the entry block since it dominates all other BBs.
65 // Each constant expression only needs to be lowered once in each function
66 // and all uses of it by instructions in that function are replaced by
68 // TODO: remove redundant instructions for common subexpression.
69 void SPIRVRegularizer::runLowerConstExpr(Function
&F
) {
70 LLVMContext
&Ctx
= F
.getContext();
71 std::list
<Instruction
*> WorkList
;
72 for (auto &II
: instructions(F
))
73 WorkList
.push_back(&II
);
75 auto FBegin
= F
.begin();
76 while (!WorkList
.empty()) {
77 Instruction
*II
= WorkList
.front();
79 auto LowerOp
= [&II
, &FBegin
, &F
](Value
*V
) -> Value
* {
82 auto *CE
= cast
<ConstantExpr
>(V
);
83 LLVM_DEBUG(dbgs() << "[lowerConstantExpressions] " << *CE
);
84 auto ReplInst
= CE
->getAsInstruction();
85 auto InsPoint
= II
->getParent() == &*FBegin
? II
: &FBegin
->back();
86 ReplInst
->insertBefore(InsPoint
);
87 LLVM_DEBUG(dbgs() << " -> " << *ReplInst
<< '\n');
88 std::vector
<Instruction
*> Users
;
89 // Do not replace use during iteration of use. Do it in another loop.
90 for (auto U
: CE
->users()) {
91 LLVM_DEBUG(dbgs() << "[lowerConstantExpressions] Use: " << *U
<< '\n');
92 auto InstUser
= dyn_cast
<Instruction
>(U
);
93 // Only replace users in scope of current function.
94 if (InstUser
&& InstUser
->getParent()->getParent() == &F
)
95 Users
.push_back(InstUser
);
97 for (auto &User
: Users
) {
98 if (ReplInst
->getParent() == User
->getParent() &&
99 User
->comesBefore(ReplInst
))
100 ReplInst
->moveBefore(User
);
101 User
->replaceUsesOfWith(CE
, ReplInst
);
106 WorkList
.pop_front();
107 auto LowerConstantVec
= [&II
, &LowerOp
, &WorkList
,
108 &Ctx
](ConstantVector
*Vec
,
109 unsigned NumOfOp
) -> Value
* {
110 if (std::all_of(Vec
->op_begin(), Vec
->op_end(), [](Value
*V
) {
111 return isa
<ConstantExpr
>(V
) || isa
<Function
>(V
);
113 // Expand a vector of constexprs and construct it back with
114 // series of insertelement instructions.
115 std::list
<Value
*> OpList
;
116 std::transform(Vec
->op_begin(), Vec
->op_end(),
117 std::back_inserter(OpList
),
118 [LowerOp
](Value
*V
) { return LowerOp(V
); });
119 Value
*Repl
= nullptr;
121 auto *PhiII
= dyn_cast
<PHINode
>(II
);
122 Instruction
*InsPoint
=
123 PhiII
? &PhiII
->getIncomingBlock(NumOfOp
)->back() : II
;
124 std::list
<Instruction
*> ReplList
;
125 for (auto V
: OpList
) {
126 if (auto *Inst
= dyn_cast
<Instruction
>(V
))
127 ReplList
.push_back(Inst
);
128 Repl
= InsertElementInst::Create(
129 (Repl
? Repl
: PoisonValue::get(Vec
->getType())), V
,
130 ConstantInt::get(Type::getInt32Ty(Ctx
), Idx
++), "", InsPoint
);
132 WorkList
.splice(WorkList
.begin(), ReplList
);
137 for (unsigned OI
= 0, OE
= II
->getNumOperands(); OI
!= OE
; ++OI
) {
138 auto *Op
= II
->getOperand(OI
);
139 if (auto *Vec
= dyn_cast
<ConstantVector
>(Op
)) {
140 Value
*ReplInst
= LowerConstantVec(Vec
, OI
);
142 II
->replaceUsesOfWith(Op
, ReplInst
);
143 } else if (auto CE
= dyn_cast
<ConstantExpr
>(Op
)) {
144 WorkList
.push_front(cast
<Instruction
>(LowerOp(CE
)));
145 } else if (auto MDAsVal
= dyn_cast
<MetadataAsValue
>(Op
)) {
146 auto ConstMD
= dyn_cast
<ConstantAsMetadata
>(MDAsVal
->getMetadata());
149 Constant
*C
= ConstMD
->getValue();
150 Value
*ReplInst
= nullptr;
151 if (auto *Vec
= dyn_cast
<ConstantVector
>(C
))
152 ReplInst
= LowerConstantVec(Vec
, OI
);
153 if (auto *CE
= dyn_cast
<ConstantExpr
>(C
))
154 ReplInst
= LowerOp(CE
);
157 Metadata
*RepMD
= ValueAsMetadata::get(ReplInst
);
158 Value
*RepMDVal
= MetadataAsValue::get(Ctx
, RepMD
);
159 II
->setOperand(OI
, RepMDVal
);
160 WorkList
.push_front(cast
<Instruction
>(ReplInst
));
166 // It fixes calls to OCL builtins that accept vector arguments and one of them
167 // is actually a scalar splat.
168 void SPIRVRegularizer::visitCallInst(CallInst
&CI
) {
169 auto F
= CI
.getCalledFunction();
173 auto MangledName
= F
->getName();
174 char *NameStr
= itaniumDemangle(F
->getName().data());
177 StringRef
DemangledName(NameStr
);
179 // TODO: add support for other builtins.
180 if (DemangledName
.starts_with("fmin") || DemangledName
.starts_with("fmax") ||
181 DemangledName
.starts_with("min") || DemangledName
.starts_with("max"))
182 visitCallScalToVec(&CI
, MangledName
, DemangledName
);
186 void SPIRVRegularizer::visitCallScalToVec(CallInst
*CI
, StringRef MangledName
,
187 StringRef DemangledName
) {
188 // Check if all arguments have the same type - it's simple case.
190 Type
*Arg0Ty
= CI
->getOperand(0)->getType();
191 auto IsArg0Vector
= isa
<VectorType
>(Arg0Ty
);
192 for (unsigned I
= 1, E
= CI
->arg_size(); Uniform
&& (I
!= E
); ++I
)
193 Uniform
= isa
<VectorType
>(CI
->getOperand(I
)->getType()) == IsArg0Vector
;
197 auto *OldF
= CI
->getCalledFunction();
198 Function
*NewF
= nullptr;
199 if (!Old2NewFuncs
.count(OldF
)) {
200 AttributeList Attrs
= CI
->getCalledFunction()->getAttributes();
201 SmallVector
<Type
*, 2> ArgTypes
= {OldF
->getArg(0)->getType(), Arg0Ty
};
203 FunctionType::get(OldF
->getReturnType(), ArgTypes
, OldF
->isVarArg());
204 NewF
= Function::Create(NewFTy
, OldF
->getLinkage(), OldF
->getName(),
206 ValueToValueMapTy VMap
;
207 auto NewFArgIt
= NewF
->arg_begin();
208 for (auto &Arg
: OldF
->args()) {
209 auto ArgName
= Arg
.getName();
210 NewFArgIt
->setName(ArgName
);
211 VMap
[&Arg
] = &(*NewFArgIt
++);
213 SmallVector
<ReturnInst
*, 8> Returns
;
214 CloneFunctionInto(NewF
, OldF
, VMap
,
215 CloneFunctionChangeType::LocalChangesOnly
, Returns
);
216 NewF
->setAttributes(Attrs
);
217 Old2NewFuncs
[OldF
] = NewF
;
219 NewF
= Old2NewFuncs
[OldF
];
223 // This produces an instruction sequence that implements a splat of
224 // CI->getOperand(1) to a vector Arg0Ty. However, we use InsertElementInst
225 // and ShuffleVectorInst to generate the same code as the SPIR-V translator.
226 // For instance (transcoding/OpMin.ll), this call
227 // call spir_func <2 x i32> @_Z3minDv2_ii(<2 x i32> <i32 1, i32 10>, i32 5)
229 // %8 = OpUndef %v2uint
230 // %14 = OpConstantComposite %v2uint %uint_1 %uint_10
232 // %10 = OpCompositeInsert %v2uint %uint_5 %8 0
233 // %11 = OpVectorShuffle %v2uint %10 %8 0 0
234 // %call = OpExtInst %v2uint %1 s_min %14 %11
235 auto ConstInt
= ConstantInt::get(IntegerType::get(CI
->getContext(), 32), 0);
236 PoisonValue
*PVal
= PoisonValue::get(Arg0Ty
);
238 InsertElementInst::Create(PVal
, CI
->getOperand(1), ConstInt
, "", CI
);
239 ElementCount VecElemCount
= cast
<VectorType
>(Arg0Ty
)->getElementCount();
240 Constant
*ConstVec
= ConstantVector::getSplat(VecElemCount
, ConstInt
);
241 Value
*NewVec
= new ShuffleVectorInst(Inst
, PVal
, ConstVec
, "", CI
);
242 CI
->setOperand(1, NewVec
);
243 CI
->replaceUsesOfWith(OldF
, NewF
);
244 CI
->mutateFunctionType(NewF
->getFunctionType());
247 bool SPIRVRegularizer::runOnFunction(Function
&F
) {
248 runLowerConstExpr(F
);
250 for (auto &OldNew
: Old2NewFuncs
) {
251 Function
*OldF
= OldNew
.first
;
252 Function
*NewF
= OldNew
.second
;
253 NewF
->takeName(OldF
);
254 OldF
->eraseFromParent();
259 FunctionPass
*llvm::createSPIRVRegularizerPass() {
260 return new SPIRVRegularizer();