1 //===- DXILFlattenArrays.cpp - Flattens DXIL Arrays-----------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===---------------------------------------------------------------------===//
9 /// \file This file contains a pass to flatten arrays for the DirectX Backend.
11 //===----------------------------------------------------------------------===//
13 #include "DXILFlattenArrays.h"
15 #include "llvm/ADT/PostOrderIterator.h"
16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/IR/BasicBlock.h"
18 #include "llvm/IR/DerivedTypes.h"
19 #include "llvm/IR/IRBuilder.h"
20 #include "llvm/IR/InstVisitor.h"
21 #include "llvm/IR/ReplaceConstant.h"
22 #include "llvm/Support/Casting.h"
23 #include "llvm/Transforms/Utils/Local.h"
29 #define DEBUG_TYPE "dxil-flatten-arrays"
34 class DXILFlattenArraysLegacy
: public ModulePass
{
37 bool runOnModule(Module
&M
) override
;
38 DXILFlattenArraysLegacy() : ModulePass(ID
) {}
40 static char ID
; // Pass identification.
44 ArrayType
*ParentArrayType
;
46 SmallVector
<Value
*> Indices
;
47 SmallVector
<uint64_t> Dims
;
48 bool AllIndicesAreConstInt
;
51 class DXILFlattenArraysVisitor
52 : public InstVisitor
<DXILFlattenArraysVisitor
, bool> {
54 DXILFlattenArraysVisitor() {}
55 bool visit(Function
&F
);
56 // InstVisitor methods. They return true if the instruction was scalarized,
57 // false if nothing changed.
58 bool visitGetElementPtrInst(GetElementPtrInst
&GEPI
);
59 bool visitAllocaInst(AllocaInst
&AI
);
60 bool visitInstruction(Instruction
&I
) { return false; }
61 bool visitSelectInst(SelectInst
&SI
) { return false; }
62 bool visitICmpInst(ICmpInst
&ICI
) { return false; }
63 bool visitFCmpInst(FCmpInst
&FCI
) { return false; }
64 bool visitUnaryOperator(UnaryOperator
&UO
) { return false; }
65 bool visitBinaryOperator(BinaryOperator
&BO
) { return false; }
66 bool visitCastInst(CastInst
&CI
) { return false; }
67 bool visitBitCastInst(BitCastInst
&BCI
) { return false; }
68 bool visitInsertElementInst(InsertElementInst
&IEI
) { return false; }
69 bool visitExtractElementInst(ExtractElementInst
&EEI
) { return false; }
70 bool visitShuffleVectorInst(ShuffleVectorInst
&SVI
) { return false; }
71 bool visitPHINode(PHINode
&PHI
) { return false; }
72 bool visitLoadInst(LoadInst
&LI
);
73 bool visitStoreInst(StoreInst
&SI
);
74 bool visitCallInst(CallInst
&ICI
) { return false; }
75 bool visitFreezeInst(FreezeInst
&FI
) { return false; }
76 static bool isMultiDimensionalArray(Type
*T
);
77 static std::pair
<unsigned, Type
*> getElementCountAndType(Type
*ArrayTy
);
80 SmallVector
<WeakTrackingVH
> PotentiallyDeadInstrs
;
81 DenseMap
<GetElementPtrInst
*, GEPData
> GEPChainMap
;
83 ConstantInt
*genConstFlattenIndices(ArrayRef
<Value
*> Indices
,
84 ArrayRef
<uint64_t> Dims
,
85 IRBuilder
<> &Builder
);
86 Value
*genInstructionFlattenIndices(ArrayRef
<Value
*> Indices
,
87 ArrayRef
<uint64_t> Dims
,
88 IRBuilder
<> &Builder
);
90 recursivelyCollectGEPs(GetElementPtrInst
&CurrGEP
,
91 ArrayType
*FlattenedArrayType
, Value
*PtrOperand
,
92 unsigned &GEPChainUseCount
,
93 SmallVector
<Value
*> Indices
= SmallVector
<Value
*>(),
94 SmallVector
<uint64_t> Dims
= SmallVector
<uint64_t>(),
95 bool AllIndicesAreConstInt
= true);
96 bool visitGetElementPtrInstInGEPChain(GetElementPtrInst
&GEP
);
97 bool visitGetElementPtrInstInGEPChainBase(GEPData
&GEPInfo
,
98 GetElementPtrInst
&GEP
);
102 bool DXILFlattenArraysVisitor::finish() {
103 RecursivelyDeleteTriviallyDeadInstructionsPermissive(PotentiallyDeadInstrs
);
107 bool DXILFlattenArraysVisitor::isMultiDimensionalArray(Type
*T
) {
108 if (ArrayType
*ArrType
= dyn_cast
<ArrayType
>(T
))
109 return isa
<ArrayType
>(ArrType
->getElementType());
113 std::pair
<unsigned, Type
*>
114 DXILFlattenArraysVisitor::getElementCountAndType(Type
*ArrayTy
) {
115 unsigned TotalElements
= 1;
116 Type
*CurrArrayTy
= ArrayTy
;
117 while (auto *InnerArrayTy
= dyn_cast
<ArrayType
>(CurrArrayTy
)) {
118 TotalElements
*= InnerArrayTy
->getNumElements();
119 CurrArrayTy
= InnerArrayTy
->getElementType();
121 return std::make_pair(TotalElements
, CurrArrayTy
);
124 ConstantInt
*DXILFlattenArraysVisitor::genConstFlattenIndices(
125 ArrayRef
<Value
*> Indices
, ArrayRef
<uint64_t> Dims
, IRBuilder
<> &Builder
) {
126 assert(Indices
.size() == Dims
.size() &&
127 "Indicies and dimmensions should be the same");
128 unsigned FlatIndex
= 0;
129 unsigned Multiplier
= 1;
131 for (int I
= Indices
.size() - 1; I
>= 0; --I
) {
132 unsigned DimSize
= Dims
[I
];
133 ConstantInt
*CIndex
= dyn_cast
<ConstantInt
>(Indices
[I
]);
134 assert(CIndex
&& "This function expects all indicies to be ConstantInt");
135 FlatIndex
+= CIndex
->getZExtValue() * Multiplier
;
136 Multiplier
*= DimSize
;
138 return Builder
.getInt32(FlatIndex
);
141 Value
*DXILFlattenArraysVisitor::genInstructionFlattenIndices(
142 ArrayRef
<Value
*> Indices
, ArrayRef
<uint64_t> Dims
, IRBuilder
<> &Builder
) {
143 if (Indices
.size() == 1)
146 Value
*FlatIndex
= Builder
.getInt32(0);
147 unsigned Multiplier
= 1;
149 for (int I
= Indices
.size() - 1; I
>= 0; --I
) {
150 unsigned DimSize
= Dims
[I
];
151 Value
*VMultiplier
= Builder
.getInt32(Multiplier
);
152 Value
*ScaledIndex
= Builder
.CreateMul(Indices
[I
], VMultiplier
);
153 FlatIndex
= Builder
.CreateAdd(FlatIndex
, ScaledIndex
);
154 Multiplier
*= DimSize
;
159 bool DXILFlattenArraysVisitor::visitLoadInst(LoadInst
&LI
) {
160 unsigned NumOperands
= LI
.getNumOperands();
161 for (unsigned I
= 0; I
< NumOperands
; ++I
) {
162 Value
*CurrOpperand
= LI
.getOperand(I
);
163 ConstantExpr
*CE
= dyn_cast
<ConstantExpr
>(CurrOpperand
);
164 if (CE
&& CE
->getOpcode() == Instruction::GetElementPtr
) {
165 GetElementPtrInst
*OldGEP
=
166 cast
<GetElementPtrInst
>(CE
->getAsInstruction());
167 OldGEP
->insertBefore(LI
.getIterator());
169 IRBuilder
<> Builder(&LI
);
171 Builder
.CreateLoad(LI
.getType(), OldGEP
, LI
.getName());
172 NewLoad
->setAlignment(LI
.getAlign());
173 LI
.replaceAllUsesWith(NewLoad
);
174 LI
.eraseFromParent();
175 visitGetElementPtrInst(*OldGEP
);
182 bool DXILFlattenArraysVisitor::visitStoreInst(StoreInst
&SI
) {
183 unsigned NumOperands
= SI
.getNumOperands();
184 for (unsigned I
= 0; I
< NumOperands
; ++I
) {
185 Value
*CurrOpperand
= SI
.getOperand(I
);
186 ConstantExpr
*CE
= dyn_cast
<ConstantExpr
>(CurrOpperand
);
187 if (CE
&& CE
->getOpcode() == Instruction::GetElementPtr
) {
188 GetElementPtrInst
*OldGEP
=
189 cast
<GetElementPtrInst
>(CE
->getAsInstruction());
190 OldGEP
->insertBefore(SI
.getIterator());
192 IRBuilder
<> Builder(&SI
);
193 StoreInst
*NewStore
= Builder
.CreateStore(SI
.getValueOperand(), OldGEP
);
194 NewStore
->setAlignment(SI
.getAlign());
195 SI
.replaceAllUsesWith(NewStore
);
196 SI
.eraseFromParent();
197 visitGetElementPtrInst(*OldGEP
);
204 bool DXILFlattenArraysVisitor::visitAllocaInst(AllocaInst
&AI
) {
205 if (!isMultiDimensionalArray(AI
.getAllocatedType()))
208 ArrayType
*ArrType
= cast
<ArrayType
>(AI
.getAllocatedType());
209 IRBuilder
<> Builder(&AI
);
210 auto [TotalElements
, BaseType
] = getElementCountAndType(ArrType
);
212 ArrayType
*FattenedArrayType
= ArrayType::get(BaseType
, TotalElements
);
213 AllocaInst
*FlatAlloca
=
214 Builder
.CreateAlloca(FattenedArrayType
, nullptr, AI
.getName() + ".flat");
215 FlatAlloca
->setAlignment(AI
.getAlign());
216 AI
.replaceAllUsesWith(FlatAlloca
);
217 AI
.eraseFromParent();
221 void DXILFlattenArraysVisitor::recursivelyCollectGEPs(
222 GetElementPtrInst
&CurrGEP
, ArrayType
*FlattenedArrayType
,
223 Value
*PtrOperand
, unsigned &GEPChainUseCount
, SmallVector
<Value
*> Indices
,
224 SmallVector
<uint64_t> Dims
, bool AllIndicesAreConstInt
) {
225 Value
*LastIndex
= CurrGEP
.getOperand(CurrGEP
.getNumOperands() - 1);
226 AllIndicesAreConstInt
&= isa
<ConstantInt
>(LastIndex
);
227 Indices
.push_back(LastIndex
);
228 assert(isa
<ArrayType
>(CurrGEP
.getSourceElementType()));
230 cast
<ArrayType
>(CurrGEP
.getSourceElementType())->getNumElements());
231 bool IsMultiDimArr
= isMultiDimensionalArray(CurrGEP
.getSourceElementType());
232 if (!IsMultiDimArr
) {
233 assert(GEPChainUseCount
< FlattenedArrayType
->getNumElements());
236 {std::move(FlattenedArrayType
), PtrOperand
, std::move(Indices
),
237 std::move(Dims
), AllIndicesAreConstInt
}});
240 bool GepUses
= false;
241 for (auto *User
: CurrGEP
.users()) {
242 if (GetElementPtrInst
*NestedGEP
= dyn_cast
<GetElementPtrInst
>(User
)) {
243 recursivelyCollectGEPs(*NestedGEP
, FlattenedArrayType
, PtrOperand
,
244 ++GEPChainUseCount
, Indices
, Dims
,
245 AllIndicesAreConstInt
);
249 // This case is just incase the gep chain doesn't end with a 1d array.
250 if (IsMultiDimArr
&& GEPChainUseCount
> 0 && !GepUses
) {
253 {std::move(FlattenedArrayType
), PtrOperand
, std::move(Indices
),
254 std::move(Dims
), AllIndicesAreConstInt
}});
258 bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChain(
259 GetElementPtrInst
&GEP
) {
260 GEPData GEPInfo
= GEPChainMap
.at(&GEP
);
261 return visitGetElementPtrInstInGEPChainBase(GEPInfo
, GEP
);
263 bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChainBase(
264 GEPData
&GEPInfo
, GetElementPtrInst
&GEP
) {
265 IRBuilder
<> Builder(&GEP
);
267 if (GEPInfo
.AllIndicesAreConstInt
)
268 FlatIndex
= genConstFlattenIndices(GEPInfo
.Indices
, GEPInfo
.Dims
, Builder
);
271 genInstructionFlattenIndices(GEPInfo
.Indices
, GEPInfo
.Dims
, Builder
);
273 ArrayType
*FlattenedArrayType
= GEPInfo
.ParentArrayType
;
275 Builder
.CreateGEP(FlattenedArrayType
, GEPInfo
.ParendOperand
, FlatIndex
,
276 GEP
.getName() + ".flat", GEP
.isInBounds());
278 GEP
.replaceAllUsesWith(FlatGEP
);
279 GEP
.eraseFromParent();
283 bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst
&GEP
) {
284 auto It
= GEPChainMap
.find(&GEP
);
285 if (It
!= GEPChainMap
.end())
286 return visitGetElementPtrInstInGEPChain(GEP
);
287 if (!isMultiDimensionalArray(GEP
.getSourceElementType()))
290 ArrayType
*ArrType
= cast
<ArrayType
>(GEP
.getSourceElementType());
291 IRBuilder
<> Builder(&GEP
);
292 auto [TotalElements
, BaseType
] = getElementCountAndType(ArrType
);
293 ArrayType
*FlattenedArrayType
= ArrayType::get(BaseType
, TotalElements
);
295 Value
*PtrOperand
= GEP
.getPointerOperand();
297 unsigned GEPChainUseCount
= 0;
298 recursivelyCollectGEPs(GEP
, FlattenedArrayType
, PtrOperand
, GEPChainUseCount
);
300 // NOTE: hasNUses(0) is not the same as GEPChainUseCount == 0.
301 // Here recursion is used to get the length of the GEP chain.
302 // Handle zero uses here because there won't be an update via
303 // a child in the chain later.
304 if (GEPChainUseCount
== 0) {
305 SmallVector
<Value
*> Indices({GEP
.getOperand(GEP
.getNumOperands() - 1)});
306 SmallVector
<uint64_t> Dims({ArrType
->getNumElements()});
307 bool AllIndicesAreConstInt
= isa
<ConstantInt
>(Indices
[0]);
308 GEPData GEPInfo
{std::move(FlattenedArrayType
), PtrOperand
,
309 std::move(Indices
), std::move(Dims
), AllIndicesAreConstInt
};
310 return visitGetElementPtrInstInGEPChainBase(GEPInfo
, GEP
);
313 PotentiallyDeadInstrs
.emplace_back(&GEP
);
317 bool DXILFlattenArraysVisitor::visit(Function
&F
) {
318 bool MadeChange
= false;
319 ReversePostOrderTraversal
<Function
*> RPOT(&F
);
320 for (BasicBlock
*BB
: make_early_inc_range(RPOT
)) {
321 for (Instruction
&I
: make_early_inc_range(*BB
))
322 MadeChange
|= InstVisitor::visit(I
);
328 static void collectElements(Constant
*Init
,
329 SmallVectorImpl
<Constant
*> &Elements
) {
330 // Base case: If Init is not an array, add it directly to the vector.
331 auto *ArrayTy
= dyn_cast
<ArrayType
>(Init
->getType());
333 Elements
.push_back(Init
);
336 unsigned ArrSize
= ArrayTy
->getNumElements();
337 if (isa
<ConstantAggregateZero
>(Init
)) {
338 for (unsigned I
= 0; I
< ArrSize
; ++I
)
339 Elements
.push_back(Constant::getNullValue(ArrayTy
->getElementType()));
343 // Recursive case: Process each element in the array.
344 if (auto *ArrayConstant
= dyn_cast
<ConstantArray
>(Init
)) {
345 for (unsigned I
= 0; I
< ArrayConstant
->getNumOperands(); ++I
) {
346 collectElements(ArrayConstant
->getOperand(I
), Elements
);
348 } else if (auto *DataArrayConstant
= dyn_cast
<ConstantDataArray
>(Init
)) {
349 for (unsigned I
= 0; I
< DataArrayConstant
->getNumElements(); ++I
) {
350 collectElements(DataArrayConstant
->getElementAsConstant(I
), Elements
);
354 "Expected a ConstantArray or ConstantDataArray for array initializer!");
358 static Constant
*transformInitializer(Constant
*Init
, Type
*OrigType
,
359 ArrayType
*FlattenedType
,
361 // Handle ConstantAggregateZero (zero-initialized constants)
362 if (isa
<ConstantAggregateZero
>(Init
))
363 return ConstantAggregateZero::get(FlattenedType
);
365 // Handle UndefValue (undefined constants)
366 if (isa
<UndefValue
>(Init
))
367 return UndefValue::get(FlattenedType
);
369 if (!isa
<ArrayType
>(OrigType
))
372 SmallVector
<Constant
*> FlattenedElements
;
373 collectElements(Init
, FlattenedElements
);
374 assert(FlattenedType
->getNumElements() == FlattenedElements
.size() &&
375 "The number of collected elements should match the FlattenedType");
376 return ConstantArray::get(FlattenedType
, FlattenedElements
);
380 flattenGlobalArrays(Module
&M
,
381 DenseMap
<GlobalVariable
*, GlobalVariable
*> &GlobalMap
) {
382 LLVMContext
&Ctx
= M
.getContext();
383 for (GlobalVariable
&G
: M
.globals()) {
384 Type
*OrigType
= G
.getValueType();
385 if (!DXILFlattenArraysVisitor::isMultiDimensionalArray(OrigType
))
388 ArrayType
*ArrType
= cast
<ArrayType
>(OrigType
);
389 auto [TotalElements
, BaseType
] =
390 DXILFlattenArraysVisitor::getElementCountAndType(ArrType
);
391 ArrayType
*FattenedArrayType
= ArrayType::get(BaseType
, TotalElements
);
393 // Create a new global variable with the updated type
394 // Note: Initializer is set via transformInitializer
395 GlobalVariable
*NewGlobal
=
396 new GlobalVariable(M
, FattenedArrayType
, G
.isConstant(), G
.getLinkage(),
397 /*Initializer=*/nullptr, G
.getName() + ".1dim", &G
,
398 G
.getThreadLocalMode(), G
.getAddressSpace(),
399 G
.isExternallyInitialized());
401 // Copy relevant attributes
402 NewGlobal
->setUnnamedAddr(G
.getUnnamedAddr());
403 if (G
.getAlignment() > 0) {
404 NewGlobal
->setAlignment(G
.getAlign());
407 if (G
.hasInitializer()) {
408 Constant
*Init
= G
.getInitializer();
410 transformInitializer(Init
, OrigType
, FattenedArrayType
, Ctx
);
411 NewGlobal
->setInitializer(NewInit
);
413 GlobalMap
[&G
] = NewGlobal
;
417 static bool flattenArrays(Module
&M
) {
418 bool MadeChange
= false;
419 DXILFlattenArraysVisitor Impl
;
420 DenseMap
<GlobalVariable
*, GlobalVariable
*> GlobalMap
;
421 flattenGlobalArrays(M
, GlobalMap
);
422 for (auto &F
: make_early_inc_range(M
.functions())) {
423 if (F
.isDeclaration())
425 MadeChange
|= Impl
.visit(F
);
427 for (auto &[Old
, New
] : GlobalMap
) {
428 Old
->replaceAllUsesWith(New
);
429 Old
->eraseFromParent();
435 PreservedAnalyses
DXILFlattenArrays::run(Module
&M
, ModuleAnalysisManager
&) {
436 bool MadeChanges
= flattenArrays(M
);
438 return PreservedAnalyses::all();
439 PreservedAnalyses PA
;
443 bool DXILFlattenArraysLegacy::runOnModule(Module
&M
) {
444 return flattenArrays(M
);
447 char DXILFlattenArraysLegacy::ID
= 0;
449 INITIALIZE_PASS_BEGIN(DXILFlattenArraysLegacy
, DEBUG_TYPE
,
450 "DXIL Array Flattener", false, false)
451 INITIALIZE_PASS_END(DXILFlattenArraysLegacy
, DEBUG_TYPE
, "DXIL Array Flattener",
454 ModulePass
*llvm::createDXILFlattenArraysLegacyPass() {
455 return new DXILFlattenArraysLegacy();