1 //===-- AMDGPUReplaceLDSUseWithPointer.cpp --------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This pass replaces all the uses of LDS within non-kernel functions by
10 // corresponding pointer counter-parts.
12 // The main motivation behind this pass is - to *avoid* subsequent LDS lowering
13 // pass from directly packing LDS (assume large LDS) into a struct type which
14 // would otherwise cause allocating huge memory for struct instance within every
17 // Brief sketch of the algorithm implemented in this pass is as below:
19 // 1. Collect all the LDS defined in the module which qualify for pointer
20 // replacement, say it is, LDSGlobals set.
22 // 2. Collect all the reachable callees for each kernel defined in the module,
23 // say it is, KernelToCallees map.
25 // 3. FOR (each global GV from LDSGlobals set) DO
26 // LDSUsedNonKernels = Collect all non-kernel functions which use GV.
27 // FOR (each kernel K in KernelToCallees map) DO
28 // ReachableCallees = KernelToCallees[K]
29 // ReachableAndLDSUsedCallees =
30 // SetIntersect(LDSUsedNonKernels, ReachableCallees)
31 // IF (ReachableAndLDSUsedCallees is not empty) THEN
32 // Pointer = Create a pointer to point-to GV if not created.
33 // Initialize Pointer to point-to GV within kernel K.
36 // Replace all uses of GV within non kernel functions by Pointer.
43 // @lds = internal addrspace(3) global [4 x i32] undef, align 16
45 // define internal void @f0() {
47 // %gep = getelementptr inbounds [4 x i32], [4 x i32] addrspace(3)* @lds,
52 // define protected amdgpu_kernel void @k0() {
60 // @lds = internal addrspace(3) global [4 x i32] undef, align 16
61 // @lds.ptr = internal unnamed_addr addrspace(3) global i16 undef, align 2
63 // define internal void @f0() {
65 // %0 = load i16, i16 addrspace(3)* @lds.ptr, align 2
66 // %1 = getelementptr i8, i8 addrspace(3)* null, i16 %0
67 // %2 = bitcast i8 addrspace(3)* %1 to [4 x i32] addrspace(3)*
68 // %gep = getelementptr inbounds [4 x i32], [4 x i32] addrspace(3)* %2,
73 // define protected amdgpu_kernel void @k0() {
75 // store i16 ptrtoint ([4 x i32] addrspace(3)* @lds to i16),
76 // i16 addrspace(3)* @lds.ptr, align 2
81 //===----------------------------------------------------------------------===//
84 #include "GCNSubtarget.h"
85 #include "Utils/AMDGPUBaseInfo.h"
86 #include "Utils/AMDGPULDSUtils.h"
87 #include "llvm/ADT/DenseMap.h"
88 #include "llvm/ADT/STLExtras.h"
89 #include "llvm/ADT/SetOperations.h"
90 #include "llvm/CodeGen/TargetPassConfig.h"
91 #include "llvm/IR/Constants.h"
92 #include "llvm/IR/DerivedTypes.h"
93 #include "llvm/IR/IRBuilder.h"
94 #include "llvm/IR/InlineAsm.h"
95 #include "llvm/IR/Instructions.h"
96 #include "llvm/IR/IntrinsicsAMDGPU.h"
97 #include "llvm/IR/ReplaceConstant.h"
98 #include "llvm/InitializePasses.h"
99 #include "llvm/Pass.h"
100 #include "llvm/Support/Debug.h"
101 #include "llvm/Target/TargetMachine.h"
102 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
103 #include "llvm/Transforms/Utils/ModuleUtils.h"
107 #define DEBUG_TYPE "amdgpu-replace-lds-use-with-pointer"
109 using namespace llvm
;
113 class ReplaceLDSUseImpl
{
116 const DataLayout
&DL
;
117 Constant
*LDSMemBaseAddr
;
119 DenseMap
<GlobalVariable
*, GlobalVariable
*> LDSToPointer
;
120 DenseMap
<GlobalVariable
*, SmallPtrSet
<Function
*, 8>> LDSToNonKernels
;
121 DenseMap
<Function
*, SmallPtrSet
<Function
*, 8>> KernelToCallees
;
122 DenseMap
<Function
*, SmallPtrSet
<GlobalVariable
*, 8>> KernelToLDSPointers
;
123 DenseMap
<Function
*, BasicBlock
*> KernelToInitBB
;
124 DenseMap
<Function
*, DenseMap
<GlobalVariable
*, Value
*>>
125 FunctionToLDSToReplaceInst
;
127 // Collect LDS which requires their uses to be replaced by pointer.
128 std::vector
<GlobalVariable
*> collectLDSRequiringPointerReplace() {
129 // Collect LDS which requires module lowering.
130 std::vector
<GlobalVariable
*> LDSGlobals
= AMDGPU::findVariablesToLower(M
);
132 // Remove LDS which don't qualify for replacement.
133 LDSGlobals
.erase(std::remove_if(LDSGlobals
.begin(), LDSGlobals
.end(),
134 [&](GlobalVariable
*GV
) {
135 return shouldIgnorePointerReplacement(GV
);
142 // Returns true if uses of given LDS global within non-kernel functions should
143 // be keep as it is without pointer replacement.
144 bool shouldIgnorePointerReplacement(GlobalVariable
*GV
) {
145 // LDS whose size is very small and doesn`t exceed pointer size is not worth
147 if (DL
.getTypeAllocSize(GV
->getValueType()) <= 2)
150 // LDS which is not used from non-kernel function scope or it is used from
151 // global scope does not qualify for replacement.
152 LDSToNonKernels
[GV
] = AMDGPU::collectNonKernelAccessorsOfLDS(GV
);
153 return LDSToNonKernels
[GV
].empty();
155 // FIXME: When GV is used within all (or within most of the kernels), then
156 // it does not make sense to create a pointer for it.
159 // Insert new global LDS pointer which points to LDS.
160 GlobalVariable
*createLDSPointer(GlobalVariable
*GV
) {
161 // LDS pointer which points to LDS is already created? return it.
162 auto PointerEntry
= LDSToPointer
.insert(std::make_pair(GV
, nullptr));
163 if (!PointerEntry
.second
)
164 return PointerEntry
.first
->second
;
166 // We need to create new LDS pointer which points to LDS.
168 // Each CU owns at max 64K of LDS memory, so LDS address ranges from 0 to
169 // 2^16 - 1. Hence 16 bit pointer is enough to hold the LDS address.
170 auto *I16Ty
= Type::getInt16Ty(Ctx
);
171 GlobalVariable
*LDSPointer
= new GlobalVariable(
172 M
, I16Ty
, false, GlobalValue::InternalLinkage
, UndefValue::get(I16Ty
),
173 GV
->getName() + Twine(".ptr"), nullptr, GlobalVariable::NotThreadLocal
,
174 AMDGPUAS::LOCAL_ADDRESS
);
176 LDSPointer
->setUnnamedAddr(GlobalValue::UnnamedAddr::Global
);
177 LDSPointer
->setAlignment(AMDGPU::getAlign(DL
, LDSPointer
));
179 // Mark that an associated LDS pointer is created for LDS.
180 LDSToPointer
[GV
] = LDSPointer
;
185 // Split entry basic block in such a way that only lane 0 of each wave does
186 // the LDS pointer initialization, and return newly created basic block.
187 BasicBlock
*activateLaneZero(Function
*K
) {
188 // If the entry basic block of kernel K is already splitted, then return
189 // newly created basic block.
190 auto BasicBlockEntry
= KernelToInitBB
.insert(std::make_pair(K
, nullptr));
191 if (!BasicBlockEntry
.second
)
192 return BasicBlockEntry
.first
->second
;
194 // Split entry basic block of kernel K.
195 auto *EI
= &(*(K
->getEntryBlock().getFirstInsertionPt()));
196 IRBuilder
<> Builder(EI
);
199 Builder
.CreateIntrinsic(Intrinsic::amdgcn_mbcnt_lo
, {},
200 {Builder
.getInt32(-1), Builder
.getInt32(0)});
201 Value
*Cond
= Builder
.CreateICmpEQ(Mbcnt
, Builder
.getInt32(0));
202 Instruction
*WB
= cast
<Instruction
>(
203 Builder
.CreateIntrinsic(Intrinsic::amdgcn_wave_barrier
, {}, {}));
205 BasicBlock
*NBB
= SplitBlockAndInsertIfThen(Cond
, WB
, false)->getParent();
207 // Mark that the entry basic block of kernel K is splitted.
208 KernelToInitBB
[K
] = NBB
;
213 // Within given kernel, initialize given LDS pointer to point to given LDS.
214 void initializeLDSPointer(Function
*K
, GlobalVariable
*GV
,
215 GlobalVariable
*LDSPointer
) {
216 // If LDS pointer is already initialized within K, then nothing to do.
217 auto PointerEntry
= KernelToLDSPointers
.insert(
218 std::make_pair(K
, SmallPtrSet
<GlobalVariable
*, 8>()));
219 if (!PointerEntry
.second
)
220 if (PointerEntry
.first
->second
.contains(LDSPointer
))
223 // Insert instructions at EI which initialize LDS pointer to point-to LDS
226 // That is, convert pointer type of GV to i16, and then store this converted
227 // i16 value within LDSPointer which is of type i16*.
228 auto *EI
= &(*(activateLaneZero(K
)->getFirstInsertionPt()));
229 IRBuilder
<> Builder(EI
);
230 Builder
.CreateStore(Builder
.CreatePtrToInt(GV
, Type::getInt16Ty(Ctx
)),
233 // Mark that LDS pointer is initialized within kernel K.
234 KernelToLDSPointers
[K
].insert(LDSPointer
);
237 // We have created an LDS pointer for LDS, and initialized it to point-to LDS
238 // within all relevent kernels. Now replace all the uses of LDS within
239 // non-kernel functions by LDS pointer.
240 void replaceLDSUseByPointer(GlobalVariable
*GV
, GlobalVariable
*LDSPointer
) {
241 SmallVector
<User
*, 8> LDSUsers(GV
->users());
242 for (auto *U
: LDSUsers
) {
243 // When `U` is a constant expression, it is possible that same constant
244 // expression exists within multiple instructions, and within multiple
245 // non-kernel functions. Collect all those non-kernel functions and all
246 // those instructions within which `U` exist.
247 auto FunctionToInsts
=
248 AMDGPU::getFunctionToInstsMap(U
, false /*=CollectKernelInsts*/);
250 for (auto FI
= FunctionToInsts
.begin(), FE
= FunctionToInsts
.end();
252 Function
*F
= FI
->first
;
253 auto &Insts
= FI
->second
;
254 for (auto *I
: Insts
) {
255 // If `U` is a constant expression, then we need to break the
256 // associated instruction into a set of separate instructions by
257 // converting constant expressions into instructions.
258 SmallPtrSet
<Instruction
*, 8> UserInsts
;
261 // `U` is an instruction, conversion from constant expression to
262 // set of instructions is *not* required.
265 // `U` is a constant expression, convert it into corresponding set
267 auto *CE
= cast
<ConstantExpr
>(U
);
268 convertConstantExprsToInstructions(I
, CE
, &UserInsts
);
271 // Go through all the user instrutions, if LDS exist within them as an
272 // operand, then replace it by replace instruction.
273 for (auto *II
: UserInsts
) {
274 auto *ReplaceInst
= getReplacementInst(F
, GV
, LDSPointer
);
275 II
->replaceUsesOfWith(GV
, ReplaceInst
);
282 // Create a set of replacement instructions which together replace LDS within
283 // non-kernel function F by accessing LDS indirectly using LDS pointer.
284 Value
*getReplacementInst(Function
*F
, GlobalVariable
*GV
,
285 GlobalVariable
*LDSPointer
) {
286 // If the instruction which replaces LDS within F is already created, then
288 auto LDSEntry
= FunctionToLDSToReplaceInst
.insert(
289 std::make_pair(F
, DenseMap
<GlobalVariable
*, Value
*>()));
290 if (!LDSEntry
.second
) {
291 auto ReplaceInstEntry
=
292 LDSEntry
.first
->second
.insert(std::make_pair(GV
, nullptr));
293 if (!ReplaceInstEntry
.second
)
294 return ReplaceInstEntry
.first
->second
;
297 // Get the instruction insertion point within the beginning of the entry
298 // block of current non-kernel function.
299 auto *EI
= &(*(F
->getEntryBlock().getFirstInsertionPt()));
300 IRBuilder
<> Builder(EI
);
302 // Insert required set of instructions which replace LDS within F.
303 auto *V
= Builder
.CreateBitCast(
305 Builder
.getInt8Ty(), LDSMemBaseAddr
,
306 Builder
.CreateLoad(LDSPointer
->getValueType(), LDSPointer
)),
309 // Mark that the replacement instruction which replace LDS within F is
311 FunctionToLDSToReplaceInst
[F
][GV
] = V
;
317 ReplaceLDSUseImpl(Module
&M
)
318 : M(M
), Ctx(M
.getContext()), DL(M
.getDataLayout()) {
319 LDSMemBaseAddr
= Constant::getIntegerValue(
320 PointerType::get(Type::getInt8Ty(M
.getContext()),
321 AMDGPUAS::LOCAL_ADDRESS
),
325 // Entry-point function which interface ReplaceLDSUseImpl with outside of the
327 bool replaceLDSUse();
330 // For a given LDS from collected LDS globals set, replace its non-kernel
331 // function scope uses by pointer.
332 bool replaceLDSUse(GlobalVariable
*GV
);
335 // For given LDS from collected LDS globals set, replace its non-kernel function
336 // scope uses by pointer.
337 bool ReplaceLDSUseImpl::replaceLDSUse(GlobalVariable
*GV
) {
338 // Holds all those non-kernel functions within which LDS is being accessed.
339 SmallPtrSet
<Function
*, 8> &LDSAccessors
= LDSToNonKernels
[GV
];
341 // The LDS pointer which points to LDS and replaces all the uses of LDS.
342 GlobalVariable
*LDSPointer
= nullptr;
344 // Traverse through each kernel K, check and if required, initialize the
345 // LDS pointer to point to LDS within K.
346 for (auto KI
= KernelToCallees
.begin(), KE
= KernelToCallees
.end(); KI
!= KE
;
348 Function
*K
= KI
->first
;
349 SmallPtrSet
<Function
*, 8> Callees
= KI
->second
;
351 // Compute reachable and LDS used callees for kernel K.
352 set_intersect(Callees
, LDSAccessors
);
354 // None of the LDS accessing non-kernel functions are reachable from
355 // kernel K. Hence, no need to initialize LDS pointer within kernel K.
359 // We have found reachable and LDS used callees for kernel K, and we need to
360 // initialize LDS pointer within kernel K, and we need to replace LDS use
361 // within those callees by LDS pointer.
363 // But, first check if LDS pointer is already created, if not create one.
364 LDSPointer
= createLDSPointer(GV
);
366 // Initialize LDS pointer to point to LDS within kernel K.
367 initializeLDSPointer(K
, GV
, LDSPointer
);
370 // We have not found reachable and LDS used callees for any of the kernels,
371 // and hence we have not created LDS pointer.
375 // We have created an LDS pointer for LDS, and initialized it to point-to LDS
376 // within all relevent kernels. Now replace all the uses of LDS within
377 // non-kernel functions by LDS pointer.
378 replaceLDSUseByPointer(GV
, LDSPointer
);
383 // Entry-point function which interface ReplaceLDSUseImpl with outside of the
385 bool ReplaceLDSUseImpl::replaceLDSUse() {
386 // Collect LDS which requires their uses to be replaced by pointer.
387 std::vector
<GlobalVariable
*> LDSGlobals
=
388 collectLDSRequiringPointerReplace();
390 // No LDS to pointer-replace. Nothing to do.
391 if (LDSGlobals
.empty())
394 // Collect reachable callee set for each kernel defined in the module.
395 AMDGPU::collectReachableCallees(M
, KernelToCallees
);
397 if (KernelToCallees
.empty()) {
398 // Either module does not have any kernel definitions, or none of the kernel
399 // has a call to non-kernel functions, or we could not resolve any of the
400 // call sites to proper non-kernel functions, because of the situations like
401 // inline asm calls. Nothing to replace.
405 // For every LDS from collected LDS globals set, replace its non-kernel
406 // function scope use by pointer.
407 bool Changed
= false;
408 for (auto *GV
: LDSGlobals
)
409 Changed
|= replaceLDSUse(GV
);
414 class AMDGPUReplaceLDSUseWithPointer
: public ModulePass
{
418 AMDGPUReplaceLDSUseWithPointer() : ModulePass(ID
) {
419 initializeAMDGPUReplaceLDSUseWithPointerPass(
420 *PassRegistry::getPassRegistry());
423 bool runOnModule(Module
&M
) override
;
425 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
426 AU
.addRequired
<TargetPassConfig
>();
432 char AMDGPUReplaceLDSUseWithPointer::ID
= 0;
433 char &llvm::AMDGPUReplaceLDSUseWithPointerID
=
434 AMDGPUReplaceLDSUseWithPointer::ID
;
436 INITIALIZE_PASS_BEGIN(
437 AMDGPUReplaceLDSUseWithPointer
, DEBUG_TYPE
,
438 "Replace within non-kernel function use of LDS with pointer",
439 false /*only look at the cfg*/, false /*analysis pass*/)
440 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig
)
442 AMDGPUReplaceLDSUseWithPointer
, DEBUG_TYPE
,
443 "Replace within non-kernel function use of LDS with pointer",
444 false /*only look at the cfg*/, false /*analysis pass*/)
446 bool AMDGPUReplaceLDSUseWithPointer::runOnModule(Module
&M
) {
447 ReplaceLDSUseImpl LDSUseReplacer
{M
};
448 return LDSUseReplacer
.replaceLDSUse();
451 ModulePass
*llvm::createAMDGPUReplaceLDSUseWithPointerPass() {
452 return new AMDGPUReplaceLDSUseWithPointer();
456 AMDGPUReplaceLDSUseWithPointerPass::run(Module
&M
, ModuleAnalysisManager
&AM
) {
457 ReplaceLDSUseImpl LDSUseReplacer
{M
};
458 LDSUseReplacer
.replaceLDSUse();
459 return PreservedAnalyses::all();