Revert " [LoongArch][ISel] Check the number of sign bits in `PatGprGpr_32` (#107432)"
[llvm-project.git] / llvm / lib / Target / SPIRV / SPIRVRegularizer.cpp
blob322e051a87db1a695819f7dd366020c5b8606492
1 //===-- SPIRVRegularizer.cpp - regularize IR for SPIR-V ---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass implements regularization of LLVM IR for SPIR-V. The prototype of
10 // the pass was taken from SPIRV-LLVM translator.
12 //===----------------------------------------------------------------------===//
14 #include "SPIRV.h"
15 #include "SPIRVTargetMachine.h"
16 #include "llvm/Demangle/Demangle.h"
17 #include "llvm/IR/InstIterator.h"
18 #include "llvm/IR/InstVisitor.h"
19 #include "llvm/IR/PassManager.h"
20 #include "llvm/Transforms/Utils/Cloning.h"
22 #include <list>
24 #define DEBUG_TYPE "spirv-regularizer"
26 using namespace llvm;
28 namespace llvm {
29 void initializeSPIRVRegularizerPass(PassRegistry &);
32 namespace {
33 struct SPIRVRegularizer : public FunctionPass, InstVisitor<SPIRVRegularizer> {
34 DenseMap<Function *, Function *> Old2NewFuncs;
36 public:
37 static char ID;
38 SPIRVRegularizer() : FunctionPass(ID) {
39 initializeSPIRVRegularizerPass(*PassRegistry::getPassRegistry());
41 bool runOnFunction(Function &F) override;
42 StringRef getPassName() const override { return "SPIR-V Regularizer"; }
44 void getAnalysisUsage(AnalysisUsage &AU) const override {
45 FunctionPass::getAnalysisUsage(AU);
47 void visitCallInst(CallInst &CI);
49 private:
50 void visitCallScalToVec(CallInst *CI, StringRef MangledName,
51 StringRef DemangledName);
52 void runLowerConstExpr(Function &F);
54 } // namespace
56 char SPIRVRegularizer::ID = 0;
58 INITIALIZE_PASS(SPIRVRegularizer, DEBUG_TYPE, "SPIR-V Regularizer", false,
59 false)
61 // Since SPIR-V cannot represent constant expression, constant expressions
62 // in LLVM IR need to be lowered to instructions. For each function,
63 // the constant expressions used by instructions of the function are replaced
64 // by instructions placed in the entry block since it dominates all other BBs.
65 // Each constant expression only needs to be lowered once in each function
66 // and all uses of it by instructions in that function are replaced by
67 // one instruction.
68 // TODO: remove redundant instructions for common subexpression.
69 void SPIRVRegularizer::runLowerConstExpr(Function &F) {
70 LLVMContext &Ctx = F.getContext();
71 std::list<Instruction *> WorkList;
72 for (auto &II : instructions(F))
73 WorkList.push_back(&II);
75 auto FBegin = F.begin();
76 while (!WorkList.empty()) {
77 Instruction *II = WorkList.front();
79 auto LowerOp = [&II, &FBegin, &F](Value *V) -> Value * {
80 if (isa<Function>(V))
81 return V;
82 auto *CE = cast<ConstantExpr>(V);
83 LLVM_DEBUG(dbgs() << "[lowerConstantExpressions] " << *CE);
84 auto ReplInst = CE->getAsInstruction();
85 auto InsPoint = II->getParent() == &*FBegin ? II : &FBegin->back();
86 ReplInst->insertBefore(InsPoint);
87 LLVM_DEBUG(dbgs() << " -> " << *ReplInst << '\n');
88 std::vector<Instruction *> Users;
89 // Do not replace use during iteration of use. Do it in another loop.
90 for (auto U : CE->users()) {
91 LLVM_DEBUG(dbgs() << "[lowerConstantExpressions] Use: " << *U << '\n');
92 auto InstUser = dyn_cast<Instruction>(U);
93 // Only replace users in scope of current function.
94 if (InstUser && InstUser->getParent()->getParent() == &F)
95 Users.push_back(InstUser);
97 for (auto &User : Users) {
98 if (ReplInst->getParent() == User->getParent() &&
99 User->comesBefore(ReplInst))
100 ReplInst->moveBefore(User);
101 User->replaceUsesOfWith(CE, ReplInst);
103 return ReplInst;
106 WorkList.pop_front();
107 auto LowerConstantVec = [&II, &LowerOp, &WorkList,
108 &Ctx](ConstantVector *Vec,
109 unsigned NumOfOp) -> Value * {
110 if (std::all_of(Vec->op_begin(), Vec->op_end(), [](Value *V) {
111 return isa<ConstantExpr>(V) || isa<Function>(V);
112 })) {
113 // Expand a vector of constexprs and construct it back with
114 // series of insertelement instructions.
115 std::list<Value *> OpList;
116 std::transform(Vec->op_begin(), Vec->op_end(),
117 std::back_inserter(OpList),
118 [LowerOp](Value *V) { return LowerOp(V); });
119 Value *Repl = nullptr;
120 unsigned Idx = 0;
121 auto *PhiII = dyn_cast<PHINode>(II);
122 Instruction *InsPoint =
123 PhiII ? &PhiII->getIncomingBlock(NumOfOp)->back() : II;
124 std::list<Instruction *> ReplList;
125 for (auto V : OpList) {
126 if (auto *Inst = dyn_cast<Instruction>(V))
127 ReplList.push_back(Inst);
128 Repl = InsertElementInst::Create(
129 (Repl ? Repl : PoisonValue::get(Vec->getType())), V,
130 ConstantInt::get(Type::getInt32Ty(Ctx), Idx++), "", InsPoint);
132 WorkList.splice(WorkList.begin(), ReplList);
133 return Repl;
135 return nullptr;
137 for (unsigned OI = 0, OE = II->getNumOperands(); OI != OE; ++OI) {
138 auto *Op = II->getOperand(OI);
139 if (auto *Vec = dyn_cast<ConstantVector>(Op)) {
140 Value *ReplInst = LowerConstantVec(Vec, OI);
141 if (ReplInst)
142 II->replaceUsesOfWith(Op, ReplInst);
143 } else if (auto CE = dyn_cast<ConstantExpr>(Op)) {
144 WorkList.push_front(cast<Instruction>(LowerOp(CE)));
145 } else if (auto MDAsVal = dyn_cast<MetadataAsValue>(Op)) {
146 auto ConstMD = dyn_cast<ConstantAsMetadata>(MDAsVal->getMetadata());
147 if (!ConstMD)
148 continue;
149 Constant *C = ConstMD->getValue();
150 Value *ReplInst = nullptr;
151 if (auto *Vec = dyn_cast<ConstantVector>(C))
152 ReplInst = LowerConstantVec(Vec, OI);
153 if (auto *CE = dyn_cast<ConstantExpr>(C))
154 ReplInst = LowerOp(CE);
155 if (!ReplInst)
156 continue;
157 Metadata *RepMD = ValueAsMetadata::get(ReplInst);
158 Value *RepMDVal = MetadataAsValue::get(Ctx, RepMD);
159 II->setOperand(OI, RepMDVal);
160 WorkList.push_front(cast<Instruction>(ReplInst));
166 // It fixes calls to OCL builtins that accept vector arguments and one of them
167 // is actually a scalar splat.
168 void SPIRVRegularizer::visitCallInst(CallInst &CI) {
169 auto F = CI.getCalledFunction();
170 if (!F)
171 return;
173 auto MangledName = F->getName();
174 char *NameStr = itaniumDemangle(F->getName().data());
175 if (!NameStr)
176 return;
177 StringRef DemangledName(NameStr);
179 // TODO: add support for other builtins.
180 if (DemangledName.starts_with("fmin") || DemangledName.starts_with("fmax") ||
181 DemangledName.starts_with("min") || DemangledName.starts_with("max"))
182 visitCallScalToVec(&CI, MangledName, DemangledName);
183 free(NameStr);
186 void SPIRVRegularizer::visitCallScalToVec(CallInst *CI, StringRef MangledName,
187 StringRef DemangledName) {
188 // Check if all arguments have the same type - it's simple case.
189 auto Uniform = true;
190 Type *Arg0Ty = CI->getOperand(0)->getType();
191 auto IsArg0Vector = isa<VectorType>(Arg0Ty);
192 for (unsigned I = 1, E = CI->arg_size(); Uniform && (I != E); ++I)
193 Uniform = isa<VectorType>(CI->getOperand(I)->getType()) == IsArg0Vector;
194 if (Uniform)
195 return;
197 auto *OldF = CI->getCalledFunction();
198 Function *NewF = nullptr;
199 if (!Old2NewFuncs.count(OldF)) {
200 AttributeList Attrs = CI->getCalledFunction()->getAttributes();
201 SmallVector<Type *, 2> ArgTypes = {OldF->getArg(0)->getType(), Arg0Ty};
202 auto *NewFTy =
203 FunctionType::get(OldF->getReturnType(), ArgTypes, OldF->isVarArg());
204 NewF = Function::Create(NewFTy, OldF->getLinkage(), OldF->getName(),
205 *OldF->getParent());
206 ValueToValueMapTy VMap;
207 auto NewFArgIt = NewF->arg_begin();
208 for (auto &Arg : OldF->args()) {
209 auto ArgName = Arg.getName();
210 NewFArgIt->setName(ArgName);
211 VMap[&Arg] = &(*NewFArgIt++);
213 SmallVector<ReturnInst *, 8> Returns;
214 CloneFunctionInto(NewF, OldF, VMap,
215 CloneFunctionChangeType::LocalChangesOnly, Returns);
216 NewF->setAttributes(Attrs);
217 Old2NewFuncs[OldF] = NewF;
218 } else {
219 NewF = Old2NewFuncs[OldF];
221 assert(NewF);
223 // This produces an instruction sequence that implements a splat of
224 // CI->getOperand(1) to a vector Arg0Ty. However, we use InsertElementInst
225 // and ShuffleVectorInst to generate the same code as the SPIR-V translator.
226 // For instance (transcoding/OpMin.ll), this call
227 // call spir_func <2 x i32> @_Z3minDv2_ii(<2 x i32> <i32 1, i32 10>, i32 5)
228 // is translated to
229 // %8 = OpUndef %v2uint
230 // %14 = OpConstantComposite %v2uint %uint_1 %uint_10
231 // ...
232 // %10 = OpCompositeInsert %v2uint %uint_5 %8 0
233 // %11 = OpVectorShuffle %v2uint %10 %8 0 0
234 // %call = OpExtInst %v2uint %1 s_min %14 %11
235 auto ConstInt = ConstantInt::get(IntegerType::get(CI->getContext(), 32), 0);
236 PoisonValue *PVal = PoisonValue::get(Arg0Ty);
237 Instruction *Inst =
238 InsertElementInst::Create(PVal, CI->getOperand(1), ConstInt, "", CI);
239 ElementCount VecElemCount = cast<VectorType>(Arg0Ty)->getElementCount();
240 Constant *ConstVec = ConstantVector::getSplat(VecElemCount, ConstInt);
241 Value *NewVec = new ShuffleVectorInst(Inst, PVal, ConstVec, "", CI);
242 CI->setOperand(1, NewVec);
243 CI->replaceUsesOfWith(OldF, NewF);
244 CI->mutateFunctionType(NewF->getFunctionType());
247 bool SPIRVRegularizer::runOnFunction(Function &F) {
248 runLowerConstExpr(F);
249 visit(F);
250 for (auto &OldNew : Old2NewFuncs) {
251 Function *OldF = OldNew.first;
252 Function *NewF = OldNew.second;
253 NewF->takeName(OldF);
254 OldF->eraseFromParent();
256 return true;
259 FunctionPass *llvm::createSPIRVRegularizerPass() {
260 return new SPIRVRegularizer();