[clang] NFC, add a "continue" bailout in the for-loop of
[llvm-project.git] / llvm / lib / Target / DirectX / DXILFlattenArrays.cpp
bloba3163a896964284d73eb15b9f5b34ed3811fb6a9
1 //===- DXILFlattenArrays.cpp - Flattens DXIL Arrays-----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===---------------------------------------------------------------------===//
8 ///
9 /// \file This file contains a pass to flatten arrays for the DirectX Backend.
10 ///
11 //===----------------------------------------------------------------------===//
13 #include "DXILFlattenArrays.h"
14 #include "DirectX.h"
15 #include "llvm/ADT/PostOrderIterator.h"
16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/IR/BasicBlock.h"
18 #include "llvm/IR/DerivedTypes.h"
19 #include "llvm/IR/IRBuilder.h"
20 #include "llvm/IR/InstVisitor.h"
21 #include "llvm/IR/ReplaceConstant.h"
22 #include "llvm/Support/Casting.h"
23 #include "llvm/Transforms/Utils/Local.h"
24 #include <cassert>
25 #include <cstddef>
26 #include <cstdint>
27 #include <utility>
29 #define DEBUG_TYPE "dxil-flatten-arrays"
31 using namespace llvm;
32 namespace {
34 class DXILFlattenArraysLegacy : public ModulePass {
36 public:
37 bool runOnModule(Module &M) override;
38 DXILFlattenArraysLegacy() : ModulePass(ID) {}
40 static char ID; // Pass identification.
43 struct GEPData {
44 ArrayType *ParentArrayType;
45 Value *ParendOperand;
46 SmallVector<Value *> Indices;
47 SmallVector<uint64_t> Dims;
48 bool AllIndicesAreConstInt;
51 class DXILFlattenArraysVisitor
52 : public InstVisitor<DXILFlattenArraysVisitor, bool> {
53 public:
54 DXILFlattenArraysVisitor() {}
55 bool visit(Function &F);
56 // InstVisitor methods. They return true if the instruction was scalarized,
57 // false if nothing changed.
58 bool visitGetElementPtrInst(GetElementPtrInst &GEPI);
59 bool visitAllocaInst(AllocaInst &AI);
60 bool visitInstruction(Instruction &I) { return false; }
61 bool visitSelectInst(SelectInst &SI) { return false; }
62 bool visitICmpInst(ICmpInst &ICI) { return false; }
63 bool visitFCmpInst(FCmpInst &FCI) { return false; }
64 bool visitUnaryOperator(UnaryOperator &UO) { return false; }
65 bool visitBinaryOperator(BinaryOperator &BO) { return false; }
66 bool visitCastInst(CastInst &CI) { return false; }
67 bool visitBitCastInst(BitCastInst &BCI) { return false; }
68 bool visitInsertElementInst(InsertElementInst &IEI) { return false; }
69 bool visitExtractElementInst(ExtractElementInst &EEI) { return false; }
70 bool visitShuffleVectorInst(ShuffleVectorInst &SVI) { return false; }
71 bool visitPHINode(PHINode &PHI) { return false; }
72 bool visitLoadInst(LoadInst &LI);
73 bool visitStoreInst(StoreInst &SI);
74 bool visitCallInst(CallInst &ICI) { return false; }
75 bool visitFreezeInst(FreezeInst &FI) { return false; }
76 static bool isMultiDimensionalArray(Type *T);
77 static std::pair<unsigned, Type *> getElementCountAndType(Type *ArrayTy);
79 private:
80 SmallVector<WeakTrackingVH> PotentiallyDeadInstrs;
81 DenseMap<GetElementPtrInst *, GEPData> GEPChainMap;
82 bool finish();
83 ConstantInt *genConstFlattenIndices(ArrayRef<Value *> Indices,
84 ArrayRef<uint64_t> Dims,
85 IRBuilder<> &Builder);
86 Value *genInstructionFlattenIndices(ArrayRef<Value *> Indices,
87 ArrayRef<uint64_t> Dims,
88 IRBuilder<> &Builder);
89 void
90 recursivelyCollectGEPs(GetElementPtrInst &CurrGEP,
91 ArrayType *FlattenedArrayType, Value *PtrOperand,
92 unsigned &GEPChainUseCount,
93 SmallVector<Value *> Indices = SmallVector<Value *>(),
94 SmallVector<uint64_t> Dims = SmallVector<uint64_t>(),
95 bool AllIndicesAreConstInt = true);
96 bool visitGetElementPtrInstInGEPChain(GetElementPtrInst &GEP);
97 bool visitGetElementPtrInstInGEPChainBase(GEPData &GEPInfo,
98 GetElementPtrInst &GEP);
100 } // namespace
102 bool DXILFlattenArraysVisitor::finish() {
103 RecursivelyDeleteTriviallyDeadInstructionsPermissive(PotentiallyDeadInstrs);
104 return true;
107 bool DXILFlattenArraysVisitor::isMultiDimensionalArray(Type *T) {
108 if (ArrayType *ArrType = dyn_cast<ArrayType>(T))
109 return isa<ArrayType>(ArrType->getElementType());
110 return false;
113 std::pair<unsigned, Type *>
114 DXILFlattenArraysVisitor::getElementCountAndType(Type *ArrayTy) {
115 unsigned TotalElements = 1;
116 Type *CurrArrayTy = ArrayTy;
117 while (auto *InnerArrayTy = dyn_cast<ArrayType>(CurrArrayTy)) {
118 TotalElements *= InnerArrayTy->getNumElements();
119 CurrArrayTy = InnerArrayTy->getElementType();
121 return std::make_pair(TotalElements, CurrArrayTy);
124 ConstantInt *DXILFlattenArraysVisitor::genConstFlattenIndices(
125 ArrayRef<Value *> Indices, ArrayRef<uint64_t> Dims, IRBuilder<> &Builder) {
126 assert(Indices.size() == Dims.size() &&
127 "Indicies and dimmensions should be the same");
128 unsigned FlatIndex = 0;
129 unsigned Multiplier = 1;
131 for (int I = Indices.size() - 1; I >= 0; --I) {
132 unsigned DimSize = Dims[I];
133 ConstantInt *CIndex = dyn_cast<ConstantInt>(Indices[I]);
134 assert(CIndex && "This function expects all indicies to be ConstantInt");
135 FlatIndex += CIndex->getZExtValue() * Multiplier;
136 Multiplier *= DimSize;
138 return Builder.getInt32(FlatIndex);
141 Value *DXILFlattenArraysVisitor::genInstructionFlattenIndices(
142 ArrayRef<Value *> Indices, ArrayRef<uint64_t> Dims, IRBuilder<> &Builder) {
143 if (Indices.size() == 1)
144 return Indices[0];
146 Value *FlatIndex = Builder.getInt32(0);
147 unsigned Multiplier = 1;
149 for (int I = Indices.size() - 1; I >= 0; --I) {
150 unsigned DimSize = Dims[I];
151 Value *VMultiplier = Builder.getInt32(Multiplier);
152 Value *ScaledIndex = Builder.CreateMul(Indices[I], VMultiplier);
153 FlatIndex = Builder.CreateAdd(FlatIndex, ScaledIndex);
154 Multiplier *= DimSize;
156 return FlatIndex;
159 bool DXILFlattenArraysVisitor::visitLoadInst(LoadInst &LI) {
160 unsigned NumOperands = LI.getNumOperands();
161 for (unsigned I = 0; I < NumOperands; ++I) {
162 Value *CurrOpperand = LI.getOperand(I);
163 ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand);
164 if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
165 GetElementPtrInst *OldGEP =
166 cast<GetElementPtrInst>(CE->getAsInstruction());
167 OldGEP->insertBefore(LI.getIterator());
169 IRBuilder<> Builder(&LI);
170 LoadInst *NewLoad =
171 Builder.CreateLoad(LI.getType(), OldGEP, LI.getName());
172 NewLoad->setAlignment(LI.getAlign());
173 LI.replaceAllUsesWith(NewLoad);
174 LI.eraseFromParent();
175 visitGetElementPtrInst(*OldGEP);
176 return true;
179 return false;
182 bool DXILFlattenArraysVisitor::visitStoreInst(StoreInst &SI) {
183 unsigned NumOperands = SI.getNumOperands();
184 for (unsigned I = 0; I < NumOperands; ++I) {
185 Value *CurrOpperand = SI.getOperand(I);
186 ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand);
187 if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
188 GetElementPtrInst *OldGEP =
189 cast<GetElementPtrInst>(CE->getAsInstruction());
190 OldGEP->insertBefore(SI.getIterator());
192 IRBuilder<> Builder(&SI);
193 StoreInst *NewStore = Builder.CreateStore(SI.getValueOperand(), OldGEP);
194 NewStore->setAlignment(SI.getAlign());
195 SI.replaceAllUsesWith(NewStore);
196 SI.eraseFromParent();
197 visitGetElementPtrInst(*OldGEP);
198 return true;
201 return false;
204 bool DXILFlattenArraysVisitor::visitAllocaInst(AllocaInst &AI) {
205 if (!isMultiDimensionalArray(AI.getAllocatedType()))
206 return false;
208 ArrayType *ArrType = cast<ArrayType>(AI.getAllocatedType());
209 IRBuilder<> Builder(&AI);
210 auto [TotalElements, BaseType] = getElementCountAndType(ArrType);
212 ArrayType *FattenedArrayType = ArrayType::get(BaseType, TotalElements);
213 AllocaInst *FlatAlloca =
214 Builder.CreateAlloca(FattenedArrayType, nullptr, AI.getName() + ".flat");
215 FlatAlloca->setAlignment(AI.getAlign());
216 AI.replaceAllUsesWith(FlatAlloca);
217 AI.eraseFromParent();
218 return true;
221 void DXILFlattenArraysVisitor::recursivelyCollectGEPs(
222 GetElementPtrInst &CurrGEP, ArrayType *FlattenedArrayType,
223 Value *PtrOperand, unsigned &GEPChainUseCount, SmallVector<Value *> Indices,
224 SmallVector<uint64_t> Dims, bool AllIndicesAreConstInt) {
225 Value *LastIndex = CurrGEP.getOperand(CurrGEP.getNumOperands() - 1);
226 AllIndicesAreConstInt &= isa<ConstantInt>(LastIndex);
227 Indices.push_back(LastIndex);
228 assert(isa<ArrayType>(CurrGEP.getSourceElementType()));
229 Dims.push_back(
230 cast<ArrayType>(CurrGEP.getSourceElementType())->getNumElements());
231 bool IsMultiDimArr = isMultiDimensionalArray(CurrGEP.getSourceElementType());
232 if (!IsMultiDimArr) {
233 assert(GEPChainUseCount < FlattenedArrayType->getNumElements());
234 GEPChainMap.insert(
235 {&CurrGEP,
236 {std::move(FlattenedArrayType), PtrOperand, std::move(Indices),
237 std::move(Dims), AllIndicesAreConstInt}});
238 return;
240 bool GepUses = false;
241 for (auto *User : CurrGEP.users()) {
242 if (GetElementPtrInst *NestedGEP = dyn_cast<GetElementPtrInst>(User)) {
243 recursivelyCollectGEPs(*NestedGEP, FlattenedArrayType, PtrOperand,
244 ++GEPChainUseCount, Indices, Dims,
245 AllIndicesAreConstInt);
246 GepUses = true;
249 // This case is just incase the gep chain doesn't end with a 1d array.
250 if (IsMultiDimArr && GEPChainUseCount > 0 && !GepUses) {
251 GEPChainMap.insert(
252 {&CurrGEP,
253 {std::move(FlattenedArrayType), PtrOperand, std::move(Indices),
254 std::move(Dims), AllIndicesAreConstInt}});
258 bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChain(
259 GetElementPtrInst &GEP) {
260 GEPData GEPInfo = GEPChainMap.at(&GEP);
261 return visitGetElementPtrInstInGEPChainBase(GEPInfo, GEP);
263 bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChainBase(
264 GEPData &GEPInfo, GetElementPtrInst &GEP) {
265 IRBuilder<> Builder(&GEP);
266 Value *FlatIndex;
267 if (GEPInfo.AllIndicesAreConstInt)
268 FlatIndex = genConstFlattenIndices(GEPInfo.Indices, GEPInfo.Dims, Builder);
269 else
270 FlatIndex =
271 genInstructionFlattenIndices(GEPInfo.Indices, GEPInfo.Dims, Builder);
273 ArrayType *FlattenedArrayType = GEPInfo.ParentArrayType;
274 Value *FlatGEP =
275 Builder.CreateGEP(FlattenedArrayType, GEPInfo.ParendOperand, FlatIndex,
276 GEP.getName() + ".flat", GEP.isInBounds());
278 GEP.replaceAllUsesWith(FlatGEP);
279 GEP.eraseFromParent();
280 return true;
283 bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) {
284 auto It = GEPChainMap.find(&GEP);
285 if (It != GEPChainMap.end())
286 return visitGetElementPtrInstInGEPChain(GEP);
287 if (!isMultiDimensionalArray(GEP.getSourceElementType()))
288 return false;
290 ArrayType *ArrType = cast<ArrayType>(GEP.getSourceElementType());
291 IRBuilder<> Builder(&GEP);
292 auto [TotalElements, BaseType] = getElementCountAndType(ArrType);
293 ArrayType *FlattenedArrayType = ArrayType::get(BaseType, TotalElements);
295 Value *PtrOperand = GEP.getPointerOperand();
297 unsigned GEPChainUseCount = 0;
298 recursivelyCollectGEPs(GEP, FlattenedArrayType, PtrOperand, GEPChainUseCount);
300 // NOTE: hasNUses(0) is not the same as GEPChainUseCount == 0.
301 // Here recursion is used to get the length of the GEP chain.
302 // Handle zero uses here because there won't be an update via
303 // a child in the chain later.
304 if (GEPChainUseCount == 0) {
305 SmallVector<Value *> Indices({GEP.getOperand(GEP.getNumOperands() - 1)});
306 SmallVector<uint64_t> Dims({ArrType->getNumElements()});
307 bool AllIndicesAreConstInt = isa<ConstantInt>(Indices[0]);
308 GEPData GEPInfo{std::move(FlattenedArrayType), PtrOperand,
309 std::move(Indices), std::move(Dims), AllIndicesAreConstInt};
310 return visitGetElementPtrInstInGEPChainBase(GEPInfo, GEP);
313 PotentiallyDeadInstrs.emplace_back(&GEP);
314 return false;
317 bool DXILFlattenArraysVisitor::visit(Function &F) {
318 bool MadeChange = false;
319 ReversePostOrderTraversal<Function *> RPOT(&F);
320 for (BasicBlock *BB : make_early_inc_range(RPOT)) {
321 for (Instruction &I : make_early_inc_range(*BB))
322 MadeChange |= InstVisitor::visit(I);
324 finish();
325 return MadeChange;
328 static void collectElements(Constant *Init,
329 SmallVectorImpl<Constant *> &Elements) {
330 // Base case: If Init is not an array, add it directly to the vector.
331 auto *ArrayTy = dyn_cast<ArrayType>(Init->getType());
332 if (!ArrayTy) {
333 Elements.push_back(Init);
334 return;
336 unsigned ArrSize = ArrayTy->getNumElements();
337 if (isa<ConstantAggregateZero>(Init)) {
338 for (unsigned I = 0; I < ArrSize; ++I)
339 Elements.push_back(Constant::getNullValue(ArrayTy->getElementType()));
340 return;
343 // Recursive case: Process each element in the array.
344 if (auto *ArrayConstant = dyn_cast<ConstantArray>(Init)) {
345 for (unsigned I = 0; I < ArrayConstant->getNumOperands(); ++I) {
346 collectElements(ArrayConstant->getOperand(I), Elements);
348 } else if (auto *DataArrayConstant = dyn_cast<ConstantDataArray>(Init)) {
349 for (unsigned I = 0; I < DataArrayConstant->getNumElements(); ++I) {
350 collectElements(DataArrayConstant->getElementAsConstant(I), Elements);
352 } else {
353 llvm_unreachable(
354 "Expected a ConstantArray or ConstantDataArray for array initializer!");
358 static Constant *transformInitializer(Constant *Init, Type *OrigType,
359 ArrayType *FlattenedType,
360 LLVMContext &Ctx) {
361 // Handle ConstantAggregateZero (zero-initialized constants)
362 if (isa<ConstantAggregateZero>(Init))
363 return ConstantAggregateZero::get(FlattenedType);
365 // Handle UndefValue (undefined constants)
366 if (isa<UndefValue>(Init))
367 return UndefValue::get(FlattenedType);
369 if (!isa<ArrayType>(OrigType))
370 return Init;
372 SmallVector<Constant *> FlattenedElements;
373 collectElements(Init, FlattenedElements);
374 assert(FlattenedType->getNumElements() == FlattenedElements.size() &&
375 "The number of collected elements should match the FlattenedType");
376 return ConstantArray::get(FlattenedType, FlattenedElements);
379 static void
380 flattenGlobalArrays(Module &M,
381 DenseMap<GlobalVariable *, GlobalVariable *> &GlobalMap) {
382 LLVMContext &Ctx = M.getContext();
383 for (GlobalVariable &G : M.globals()) {
384 Type *OrigType = G.getValueType();
385 if (!DXILFlattenArraysVisitor::isMultiDimensionalArray(OrigType))
386 continue;
388 ArrayType *ArrType = cast<ArrayType>(OrigType);
389 auto [TotalElements, BaseType] =
390 DXILFlattenArraysVisitor::getElementCountAndType(ArrType);
391 ArrayType *FattenedArrayType = ArrayType::get(BaseType, TotalElements);
393 // Create a new global variable with the updated type
394 // Note: Initializer is set via transformInitializer
395 GlobalVariable *NewGlobal =
396 new GlobalVariable(M, FattenedArrayType, G.isConstant(), G.getLinkage(),
397 /*Initializer=*/nullptr, G.getName() + ".1dim", &G,
398 G.getThreadLocalMode(), G.getAddressSpace(),
399 G.isExternallyInitialized());
401 // Copy relevant attributes
402 NewGlobal->setUnnamedAddr(G.getUnnamedAddr());
403 if (G.getAlignment() > 0) {
404 NewGlobal->setAlignment(G.getAlign());
407 if (G.hasInitializer()) {
408 Constant *Init = G.getInitializer();
409 Constant *NewInit =
410 transformInitializer(Init, OrigType, FattenedArrayType, Ctx);
411 NewGlobal->setInitializer(NewInit);
413 GlobalMap[&G] = NewGlobal;
417 static bool flattenArrays(Module &M) {
418 bool MadeChange = false;
419 DXILFlattenArraysVisitor Impl;
420 DenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;
421 flattenGlobalArrays(M, GlobalMap);
422 for (auto &F : make_early_inc_range(M.functions())) {
423 if (F.isDeclaration())
424 continue;
425 MadeChange |= Impl.visit(F);
427 for (auto &[Old, New] : GlobalMap) {
428 Old->replaceAllUsesWith(New);
429 Old->eraseFromParent();
430 MadeChange = true;
432 return MadeChange;
435 PreservedAnalyses DXILFlattenArrays::run(Module &M, ModuleAnalysisManager &) {
436 bool MadeChanges = flattenArrays(M);
437 if (!MadeChanges)
438 return PreservedAnalyses::all();
439 PreservedAnalyses PA;
440 return PA;
443 bool DXILFlattenArraysLegacy::runOnModule(Module &M) {
444 return flattenArrays(M);
447 char DXILFlattenArraysLegacy::ID = 0;
449 INITIALIZE_PASS_BEGIN(DXILFlattenArraysLegacy, DEBUG_TYPE,
450 "DXIL Array Flattener", false, false)
451 INITIALIZE_PASS_END(DXILFlattenArraysLegacy, DEBUG_TYPE, "DXIL Array Flattener",
452 false, false)
454 ModulePass *llvm::createDXILFlattenArraysLegacyPass() {
455 return new DXILFlattenArraysLegacy();