1 //===-- AMDGPUCodeGenPrepare.cpp ------------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
10 /// This pass does misc. AMDGPU optimizations on IR *just* before instruction
13 //===----------------------------------------------------------------------===//
16 #include "AMDGPUTargetMachine.h"
17 #include "llvm/Analysis/AssumptionCache.h"
18 #include "llvm/Analysis/UniformityAnalysis.h"
19 #include "llvm/Analysis/ValueTracking.h"
20 #include "llvm/CodeGen/TargetPassConfig.h"
21 #include "llvm/IR/IRBuilder.h"
22 #include "llvm/IR/InstVisitor.h"
23 #include "llvm/InitializePasses.h"
24 #include "llvm/Support/CommandLine.h"
25 #include "llvm/Support/KnownBits.h"
26 #include "llvm/Transforms/Utils/Local.h"
28 #define DEBUG_TYPE "amdgpu-late-codegenprepare"
32 // Scalar load widening needs running after load-store-vectorizer as that pass
33 // doesn't handle overlapping cases. In addition, this pass enhances the
34 // widening to handle cases where scalar sub-dword loads are naturally aligned
35 // only but not dword aligned.
37 WidenLoads("amdgpu-late-codegenprepare-widen-constant-loads",
38 cl::desc("Widen sub-dword constant address space loads in "
39 "AMDGPULateCodeGenPrepare"),
40 cl::ReallyHidden
, cl::init(true));
44 class AMDGPULateCodeGenPrepare
45 : public FunctionPass
,
46 public InstVisitor
<AMDGPULateCodeGenPrepare
, bool> {
47 Module
*Mod
= nullptr;
48 const DataLayout
*DL
= nullptr;
50 AssumptionCache
*AC
= nullptr;
51 UniformityInfo
*UA
= nullptr;
53 SmallVector
<WeakTrackingVH
, 8> DeadInsts
;
58 AMDGPULateCodeGenPrepare() : FunctionPass(ID
) {}
60 StringRef
getPassName() const override
{
61 return "AMDGPU IR late optimizations";
64 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
65 AU
.addRequired
<TargetPassConfig
>();
66 AU
.addRequired
<AssumptionCacheTracker
>();
67 AU
.addRequired
<UniformityInfoWrapperPass
>();
71 bool doInitialization(Module
&M
) override
;
72 bool runOnFunction(Function
&F
) override
;
74 bool visitInstruction(Instruction
&) { return false; }
76 // Check if the specified value is at least DWORD aligned.
77 bool isDWORDAligned(const Value
*V
) const {
78 KnownBits Known
= computeKnownBits(V
, *DL
, 0, AC
);
79 return Known
.countMinTrailingZeros() >= 2;
82 bool canWidenScalarExtLoad(LoadInst
&LI
) const;
83 bool visitLoadInst(LoadInst
&LI
);
86 using ValueToValueMap
= DenseMap
<const Value
*, Value
*>;
88 class LiveRegOptimizer
{
90 Module
*Mod
= nullptr;
91 const DataLayout
*DL
= nullptr;
92 const GCNSubtarget
*ST
;
93 /// The scalar type to convert to
94 Type
*ConvertToScalar
;
95 /// The set of visited Instructions
96 SmallPtrSet
<Instruction
*, 4> Visited
;
97 /// Map of Value -> Converted Value
98 ValueToValueMap ValMap
;
99 /// Map of containing conversions from Optimal Type -> Original Type per BB.
100 DenseMap
<BasicBlock
*, ValueToValueMap
> BBUseValMap
;
103 /// Calculate the and \p return the type to convert to given a problematic \p
104 /// OriginalType. In some instances, we may widen the type (e.g. v2i8 -> i32).
105 Type
*calculateConvertType(Type
*OriginalType
);
106 /// Convert the virtual register defined by \p V to the compatible vector of
108 Value
*convertToOptType(Instruction
*V
, BasicBlock::iterator
&InstPt
);
109 /// Convert the virtual register defined by \p V back to the original type \p
110 /// ConvertType, stripping away the MSBs in cases where there was an imperfect
111 /// fit (e.g. v2i32 -> v7i8)
112 Value
*convertFromOptType(Type
*ConvertType
, Instruction
*V
,
113 BasicBlock::iterator
&InstPt
,
114 BasicBlock
*InsertBlock
);
115 /// Check for problematic PHI nodes or cross-bb values based on the value
116 /// defined by \p I, and coerce to legal types if necessary. For problematic
117 /// PHI node, we coerce all incoming values in a single invocation.
118 bool optimizeLiveType(Instruction
*I
,
119 SmallVectorImpl
<WeakTrackingVH
> &DeadInsts
);
121 // Whether or not the type should be replaced to avoid inefficient
123 bool shouldReplace(Type
*ITy
) {
124 FixedVectorType
*VTy
= dyn_cast
<FixedVectorType
>(ITy
);
128 auto TLI
= ST
->getTargetLowering();
130 Type
*EltTy
= VTy
->getElementType();
131 // If the element size is not less than the convert to scalar size, then we
132 // can't do any bit packing
133 if (!EltTy
->isIntegerTy() ||
134 EltTy
->getScalarSizeInBits() > ConvertToScalar
->getScalarSizeInBits())
137 // Only coerce illegal types
138 TargetLoweringBase::LegalizeKind LK
=
139 TLI
->getTypeConversion(EltTy
->getContext(), EVT::getEVT(EltTy
, false));
140 return LK
.first
!= TargetLoweringBase::TypeLegal
;
143 LiveRegOptimizer(Module
*Mod
, const GCNSubtarget
*ST
) : Mod(Mod
), ST(ST
) {
144 DL
= &Mod
->getDataLayout();
145 ConvertToScalar
= Type::getInt32Ty(Mod
->getContext());
149 } // end anonymous namespace
151 bool AMDGPULateCodeGenPrepare::doInitialization(Module
&M
) {
153 DL
= &Mod
->getDataLayout();
157 bool AMDGPULateCodeGenPrepare::runOnFunction(Function
&F
) {
161 const TargetPassConfig
&TPC
= getAnalysis
<TargetPassConfig
>();
162 const TargetMachine
&TM
= TPC
.getTM
<TargetMachine
>();
163 const GCNSubtarget
&ST
= TM
.getSubtarget
<GCNSubtarget
>(F
);
165 AC
= &getAnalysis
<AssumptionCacheTracker
>().getAssumptionCache(F
);
166 UA
= &getAnalysis
<UniformityInfoWrapperPass
>().getUniformityInfo();
168 // "Optimize" the virtual regs that cross basic block boundaries. When
169 // building the SelectionDAG, vectors of illegal types that cross basic blocks
170 // will be scalarized and widened, with each scalar living in its
171 // own register. To work around this, this optimization converts the
172 // vectors to equivalent vectors of legal type (which are converted back
173 // before uses in subsequent blocks), to pack the bits into fewer physical
174 // registers (used in CopyToReg/CopyFromReg pairs).
175 LiveRegOptimizer
LRO(Mod
, &ST
);
177 bool Changed
= false;
179 bool HasScalarSubwordLoads
= ST
.hasScalarSubwordLoads();
181 for (auto &BB
: reverse(F
))
182 for (Instruction
&I
: make_early_inc_range(reverse(BB
))) {
183 Changed
|= !HasScalarSubwordLoads
&& visit(I
);
184 Changed
|= LRO
.optimizeLiveType(&I
, DeadInsts
);
187 RecursivelyDeleteTriviallyDeadInstructionsPermissive(DeadInsts
);
191 Type
*LiveRegOptimizer::calculateConvertType(Type
*OriginalType
) {
192 assert(OriginalType
->getScalarSizeInBits() <=
193 ConvertToScalar
->getScalarSizeInBits());
195 FixedVectorType
*VTy
= cast
<FixedVectorType
>(OriginalType
);
197 TypeSize OriginalSize
= DL
->getTypeSizeInBits(VTy
);
198 TypeSize ConvertScalarSize
= DL
->getTypeSizeInBits(ConvertToScalar
);
199 unsigned ConvertEltCount
=
200 (OriginalSize
+ ConvertScalarSize
- 1) / ConvertScalarSize
;
202 if (OriginalSize
<= ConvertScalarSize
)
203 return IntegerType::get(Mod
->getContext(), ConvertScalarSize
);
205 return VectorType::get(Type::getIntNTy(Mod
->getContext(), ConvertScalarSize
),
206 ConvertEltCount
, false);
209 Value
*LiveRegOptimizer::convertToOptType(Instruction
*V
,
210 BasicBlock::iterator
&InsertPt
) {
211 FixedVectorType
*VTy
= cast
<FixedVectorType
>(V
->getType());
212 Type
*NewTy
= calculateConvertType(V
->getType());
214 TypeSize OriginalSize
= DL
->getTypeSizeInBits(VTy
);
215 TypeSize NewSize
= DL
->getTypeSizeInBits(NewTy
);
217 IRBuilder
<> Builder(V
->getParent(), InsertPt
);
218 // If there is a bitsize match, we can fit the old vector into a new vector of
220 if (OriginalSize
== NewSize
)
221 return Builder
.CreateBitCast(V
, NewTy
, V
->getName() + ".bc");
223 // If there is a bitsize mismatch, we must use a wider vector.
224 assert(NewSize
> OriginalSize
);
225 uint64_t ExpandedVecElementCount
= NewSize
/ VTy
->getScalarSizeInBits();
227 SmallVector
<int, 8> ShuffleMask
;
228 uint64_t OriginalElementCount
= VTy
->getElementCount().getFixedValue();
229 for (unsigned I
= 0; I
< OriginalElementCount
; I
++)
230 ShuffleMask
.push_back(I
);
232 for (uint64_t I
= OriginalElementCount
; I
< ExpandedVecElementCount
; I
++)
233 ShuffleMask
.push_back(OriginalElementCount
);
235 Value
*ExpandedVec
= Builder
.CreateShuffleVector(V
, ShuffleMask
);
236 return Builder
.CreateBitCast(ExpandedVec
, NewTy
, V
->getName() + ".bc");
239 Value
*LiveRegOptimizer::convertFromOptType(Type
*ConvertType
, Instruction
*V
,
240 BasicBlock::iterator
&InsertPt
,
241 BasicBlock
*InsertBB
) {
242 FixedVectorType
*NewVTy
= cast
<FixedVectorType
>(ConvertType
);
244 TypeSize OriginalSize
= DL
->getTypeSizeInBits(V
->getType());
245 TypeSize NewSize
= DL
->getTypeSizeInBits(NewVTy
);
247 IRBuilder
<> Builder(InsertBB
, InsertPt
);
248 // If there is a bitsize match, we simply convert back to the original type.
249 if (OriginalSize
== NewSize
)
250 return Builder
.CreateBitCast(V
, NewVTy
, V
->getName() + ".bc");
252 // If there is a bitsize mismatch, then we must have used a wider value to
254 assert(OriginalSize
> NewSize
);
255 // For wide scalars, we can just truncate the value.
256 if (!V
->getType()->isVectorTy()) {
257 Instruction
*Trunc
= cast
<Instruction
>(
258 Builder
.CreateTrunc(V
, IntegerType::get(Mod
->getContext(), NewSize
)));
259 return cast
<Instruction
>(Builder
.CreateBitCast(Trunc
, NewVTy
));
262 // For wider vectors, we must strip the MSBs to convert back to the original
264 VectorType
*ExpandedVT
= VectorType::get(
265 Type::getIntNTy(Mod
->getContext(), NewVTy
->getScalarSizeInBits()),
266 (OriginalSize
/ NewVTy
->getScalarSizeInBits()), false);
267 Instruction
*Converted
=
268 cast
<Instruction
>(Builder
.CreateBitCast(V
, ExpandedVT
));
270 unsigned NarrowElementCount
= NewVTy
->getElementCount().getFixedValue();
271 SmallVector
<int, 8> ShuffleMask(NarrowElementCount
);
272 std::iota(ShuffleMask
.begin(), ShuffleMask
.end(), 0);
274 return Builder
.CreateShuffleVector(Converted
, ShuffleMask
);
277 bool LiveRegOptimizer::optimizeLiveType(
278 Instruction
*I
, SmallVectorImpl
<WeakTrackingVH
> &DeadInsts
) {
279 SmallVector
<Instruction
*, 4> Worklist
;
280 SmallPtrSet
<PHINode
*, 4> PhiNodes
;
281 SmallPtrSet
<Instruction
*, 4> Defs
;
282 SmallPtrSet
<Instruction
*, 4> Uses
;
284 Worklist
.push_back(cast
<Instruction
>(I
));
285 while (!Worklist
.empty()) {
286 Instruction
*II
= Worklist
.pop_back_val();
288 if (!Visited
.insert(II
).second
)
291 if (!shouldReplace(II
->getType()))
294 if (PHINode
*Phi
= dyn_cast
<PHINode
>(II
)) {
295 PhiNodes
.insert(Phi
);
296 // Collect all the incoming values of problematic PHI nodes.
297 for (Value
*V
: Phi
->incoming_values()) {
298 // Repeat the collection process for newly found PHI nodes.
299 if (PHINode
*OpPhi
= dyn_cast
<PHINode
>(V
)) {
300 if (!PhiNodes
.count(OpPhi
) && !Visited
.count(OpPhi
))
301 Worklist
.push_back(OpPhi
);
305 Instruction
*IncInst
= dyn_cast
<Instruction
>(V
);
306 // Other incoming value types (e.g. vector literals) are unhandled
307 if (!IncInst
&& !isa
<ConstantAggregateZero
>(V
))
310 // Collect all other incoming values for coercion.
312 Defs
.insert(IncInst
);
316 // Collect all relevant uses.
317 for (User
*V
: II
->users()) {
318 // Repeat the collection process for problematic PHI nodes.
319 if (PHINode
*OpPhi
= dyn_cast
<PHINode
>(V
)) {
320 if (!PhiNodes
.count(OpPhi
) && !Visited
.count(OpPhi
))
321 Worklist
.push_back(OpPhi
);
325 Instruction
*UseInst
= cast
<Instruction
>(V
);
326 // Collect all uses of PHINodes and any use the crosses BB boundaries.
327 if (UseInst
->getParent() != II
->getParent() || isa
<PHINode
>(II
)) {
328 Uses
.insert(UseInst
);
329 if (!Defs
.count(II
) && !isa
<PHINode
>(II
)) {
336 // Coerce and track the defs.
337 for (Instruction
*D
: Defs
) {
338 if (!ValMap
.contains(D
)) {
339 BasicBlock::iterator InsertPt
= std::next(D
->getIterator());
340 Value
*ConvertVal
= convertToOptType(D
, InsertPt
);
342 ValMap
[D
] = ConvertVal
;
346 // Construct new-typed PHI nodes.
347 for (PHINode
*Phi
: PhiNodes
) {
348 ValMap
[Phi
] = PHINode::Create(calculateConvertType(Phi
->getType()),
349 Phi
->getNumIncomingValues(),
350 Phi
->getName() + ".tc", Phi
->getIterator());
353 // Connect all the PHI nodes with their new incoming values.
354 for (PHINode
*Phi
: PhiNodes
) {
355 PHINode
*NewPhi
= cast
<PHINode
>(ValMap
[Phi
]);
356 bool MissingIncVal
= false;
357 for (int I
= 0, E
= Phi
->getNumIncomingValues(); I
< E
; I
++) {
358 Value
*IncVal
= Phi
->getIncomingValue(I
);
359 if (isa
<ConstantAggregateZero
>(IncVal
)) {
360 Type
*NewType
= calculateConvertType(Phi
->getType());
361 NewPhi
->addIncoming(ConstantInt::get(NewType
, 0, false),
362 Phi
->getIncomingBlock(I
));
363 } else if (ValMap
.contains(IncVal
) && ValMap
[IncVal
])
364 NewPhi
->addIncoming(ValMap
[IncVal
], Phi
->getIncomingBlock(I
));
366 MissingIncVal
= true;
369 Value
*DeadVal
= ValMap
[Phi
];
370 // The coercion chain of the PHI is broken. Delete the Phi
371 // from the ValMap and any connected / user Phis.
372 SmallVector
<Value
*, 4> PHIWorklist
;
373 SmallPtrSet
<Value
*, 4> VisitedPhis
;
374 PHIWorklist
.push_back(DeadVal
);
375 while (!PHIWorklist
.empty()) {
376 Value
*NextDeadValue
= PHIWorklist
.pop_back_val();
377 VisitedPhis
.insert(NextDeadValue
);
379 std::find_if(PhiNodes
.begin(), PhiNodes
.end(),
380 [this, &NextDeadValue
](PHINode
*CandPhi
) {
381 return ValMap
[CandPhi
] == NextDeadValue
;
383 // This PHI may have already been removed from maps when
384 // unwinding a previous Phi
385 if (OriginalPhi
!= PhiNodes
.end())
386 ValMap
.erase(*OriginalPhi
);
388 DeadInsts
.emplace_back(cast
<Instruction
>(NextDeadValue
));
390 for (User
*U
: NextDeadValue
->users()) {
391 if (!VisitedPhis
.contains(cast
<PHINode
>(U
)))
392 PHIWorklist
.push_back(U
);
396 DeadInsts
.emplace_back(cast
<Instruction
>(Phi
));
399 // Coerce back to the original type and replace the uses.
400 for (Instruction
*U
: Uses
) {
401 // Replace all converted operands for a use.
402 for (auto [OpIdx
, Op
] : enumerate(U
->operands())) {
403 if (ValMap
.contains(Op
) && ValMap
[Op
]) {
404 Value
*NewVal
= nullptr;
405 if (BBUseValMap
.contains(U
->getParent()) &&
406 BBUseValMap
[U
->getParent()].contains(ValMap
[Op
]))
407 NewVal
= BBUseValMap
[U
->getParent()][ValMap
[Op
]];
409 BasicBlock::iterator InsertPt
= U
->getParent()->getFirstNonPHIIt();
410 // We may pick up ops that were previously converted for users in
411 // other blocks. If there is an originally typed definition of the Op
412 // already in this block, simply reuse it.
413 if (isa
<Instruction
>(Op
) && !isa
<PHINode
>(Op
) &&
414 U
->getParent() == cast
<Instruction
>(Op
)->getParent()) {
418 convertFromOptType(Op
->getType(), cast
<Instruction
>(ValMap
[Op
]),
419 InsertPt
, U
->getParent());
420 BBUseValMap
[U
->getParent()][ValMap
[Op
]] = NewVal
;
424 U
->setOperand(OpIdx
, NewVal
);
432 bool AMDGPULateCodeGenPrepare::canWidenScalarExtLoad(LoadInst
&LI
) const {
433 unsigned AS
= LI
.getPointerAddressSpace();
434 // Skip non-constant address space.
435 if (AS
!= AMDGPUAS::CONSTANT_ADDRESS
&&
436 AS
!= AMDGPUAS::CONSTANT_ADDRESS_32BIT
)
438 // Skip non-simple loads.
441 Type
*Ty
= LI
.getType();
442 // Skip aggregate types.
443 if (Ty
->isAggregateType())
445 unsigned TySize
= DL
->getTypeStoreSize(Ty
);
446 // Only handle sub-DWORD loads.
449 // That load must be at least naturally aligned.
450 if (LI
.getAlign() < DL
->getABITypeAlign(Ty
))
452 // It should be uniform, i.e. a scalar load.
453 return UA
->isUniform(&LI
);
456 bool AMDGPULateCodeGenPrepare::visitLoadInst(LoadInst
&LI
) {
460 // Skip if that load is already aligned on DWORD at least as it's handled in
462 if (LI
.getAlign() >= 4)
465 if (!canWidenScalarExtLoad(LI
))
470 GetPointerBaseWithConstantOffset(LI
.getPointerOperand(), Offset
, *DL
);
471 // If that base is not DWORD aligned, it's not safe to perform the following
473 if (!isDWORDAligned(Base
))
476 int64_t Adjust
= Offset
& 0x3;
478 // With a zero adjust, the original alignment could be promoted with a
480 LI
.setAlignment(Align(4));
484 IRBuilder
<> IRB(&LI
);
485 IRB
.SetCurrentDebugLocation(LI
.getDebugLoc());
487 unsigned LdBits
= DL
->getTypeStoreSizeInBits(LI
.getType());
488 auto IntNTy
= Type::getIntNTy(LI
.getContext(), LdBits
);
490 auto *NewPtr
= IRB
.CreateConstGEP1_64(
492 IRB
.CreateAddrSpaceCast(Base
, LI
.getPointerOperand()->getType()),
495 LoadInst
*NewLd
= IRB
.CreateAlignedLoad(IRB
.getInt32Ty(), NewPtr
, Align(4));
496 NewLd
->copyMetadata(LI
);
497 NewLd
->setMetadata(LLVMContext::MD_range
, nullptr);
499 unsigned ShAmt
= Adjust
* 8;
500 auto *NewVal
= IRB
.CreateBitCast(
501 IRB
.CreateTrunc(IRB
.CreateLShr(NewLd
, ShAmt
), IntNTy
), LI
.getType());
502 LI
.replaceAllUsesWith(NewVal
);
503 DeadInsts
.emplace_back(&LI
);
508 INITIALIZE_PASS_BEGIN(AMDGPULateCodeGenPrepare
, DEBUG_TYPE
,
509 "AMDGPU IR late optimizations", false, false)
510 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig
)
511 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker
)
512 INITIALIZE_PASS_DEPENDENCY(UniformityInfoWrapperPass
)
513 INITIALIZE_PASS_END(AMDGPULateCodeGenPrepare
, DEBUG_TYPE
,
514 "AMDGPU IR late optimizations", false, false)
516 char AMDGPULateCodeGenPrepare::ID
= 0;
518 FunctionPass
*llvm::createAMDGPULateCodeGenPreparePass() {
519 return new AMDGPULateCodeGenPrepare();