[ORC] Add std::tuple support to SimplePackedSerialization.
[llvm-project.git] / llvm / lib / Target / AMDGPU / AMDGPUReplaceLDSUseWithPointer.cpp
blobdabb4d006d99470ee3eb604990ac3b3b0bc2b562
1 //===-- AMDGPUReplaceLDSUseWithPointer.cpp --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass replaces all the uses of LDS within non-kernel functions by
10 // corresponding pointer counter-parts.
12 // The main motivation behind this pass is - to *avoid* subsequent LDS lowering
13 // pass from directly packing LDS (assume large LDS) into a struct type which
14 // would otherwise cause allocating huge memory for struct instance within every
15 // kernel.
17 // Brief sketch of the algorithm implemented in this pass is as below:
19 // 1. Collect all the LDS defined in the module which qualify for pointer
20 // replacement, say it is, LDSGlobals set.
22 // 2. Collect all the reachable callees for each kernel defined in the module,
23 // say it is, KernelToCallees map.
25 // 3. FOR (each global GV from LDSGlobals set) DO
26 // LDSUsedNonKernels = Collect all non-kernel functions which use GV.
27 // FOR (each kernel K in KernelToCallees map) DO
28 // ReachableCallees = KernelToCallees[K]
29 // ReachableAndLDSUsedCallees =
30 // SetIntersect(LDSUsedNonKernels, ReachableCallees)
31 // IF (ReachableAndLDSUsedCallees is not empty) THEN
32 // Pointer = Create a pointer to point-to GV if not created.
33 // Initialize Pointer to point-to GV within kernel K.
34 // ENDIF
35 // ENDFOR
36 // Replace all uses of GV within non kernel functions by Pointer.
37 // ENFOR
39 // LLVM IR example:
41 // Input IR:
43 // @lds = internal addrspace(3) global [4 x i32] undef, align 16
45 // define internal void @f0() {
46 // entry:
47 // %gep = getelementptr inbounds [4 x i32], [4 x i32] addrspace(3)* @lds,
48 // i32 0, i32 0
49 // ret void
50 // }
52 // define protected amdgpu_kernel void @k0() {
53 // entry:
54 // call void @f0()
55 // ret void
56 // }
58 // Output IR:
60 // @lds = internal addrspace(3) global [4 x i32] undef, align 16
61 // @lds.ptr = internal unnamed_addr addrspace(3) global i16 undef, align 2
63 // define internal void @f0() {
64 // entry:
65 // %0 = load i16, i16 addrspace(3)* @lds.ptr, align 2
66 // %1 = getelementptr i8, i8 addrspace(3)* null, i16 %0
67 // %2 = bitcast i8 addrspace(3)* %1 to [4 x i32] addrspace(3)*
68 // %gep = getelementptr inbounds [4 x i32], [4 x i32] addrspace(3)* %2,
69 // i32 0, i32 0
70 // ret void
71 // }
73 // define protected amdgpu_kernel void @k0() {
74 // entry:
75 // store i16 ptrtoint ([4 x i32] addrspace(3)* @lds to i16),
76 // i16 addrspace(3)* @lds.ptr, align 2
77 // call void @f0()
78 // ret void
79 // }
81 //===----------------------------------------------------------------------===//
83 #include "AMDGPU.h"
84 #include "GCNSubtarget.h"
85 #include "Utils/AMDGPUBaseInfo.h"
86 #include "Utils/AMDGPULDSUtils.h"
87 #include "llvm/ADT/DenseMap.h"
88 #include "llvm/ADT/STLExtras.h"
89 #include "llvm/ADT/SetOperations.h"
90 #include "llvm/CodeGen/TargetPassConfig.h"
91 #include "llvm/IR/Constants.h"
92 #include "llvm/IR/DerivedTypes.h"
93 #include "llvm/IR/IRBuilder.h"
94 #include "llvm/IR/InlineAsm.h"
95 #include "llvm/IR/Instructions.h"
96 #include "llvm/IR/IntrinsicsAMDGPU.h"
97 #include "llvm/IR/ReplaceConstant.h"
98 #include "llvm/InitializePasses.h"
99 #include "llvm/Pass.h"
100 #include "llvm/Support/Debug.h"
101 #include "llvm/Target/TargetMachine.h"
102 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
103 #include "llvm/Transforms/Utils/ModuleUtils.h"
104 #include <algorithm>
105 #include <vector>
107 #define DEBUG_TYPE "amdgpu-replace-lds-use-with-pointer"
109 using namespace llvm;
111 namespace {
113 class ReplaceLDSUseImpl {
114 Module &M;
115 LLVMContext &Ctx;
116 const DataLayout &DL;
117 Constant *LDSMemBaseAddr;
119 DenseMap<GlobalVariable *, GlobalVariable *> LDSToPointer;
120 DenseMap<GlobalVariable *, SmallPtrSet<Function *, 8>> LDSToNonKernels;
121 DenseMap<Function *, SmallPtrSet<Function *, 8>> KernelToCallees;
122 DenseMap<Function *, SmallPtrSet<GlobalVariable *, 8>> KernelToLDSPointers;
123 DenseMap<Function *, BasicBlock *> KernelToInitBB;
124 DenseMap<Function *, DenseMap<GlobalVariable *, Value *>>
125 FunctionToLDSToReplaceInst;
127 // Collect LDS which requires their uses to be replaced by pointer.
128 std::vector<GlobalVariable *> collectLDSRequiringPointerReplace() {
129 // Collect LDS which requires module lowering.
130 std::vector<GlobalVariable *> LDSGlobals = AMDGPU::findVariablesToLower(M);
132 // Remove LDS which don't qualify for replacement.
133 LDSGlobals.erase(std::remove_if(LDSGlobals.begin(), LDSGlobals.end(),
134 [&](GlobalVariable *GV) {
135 return shouldIgnorePointerReplacement(GV);
137 LDSGlobals.end());
139 return LDSGlobals;
142 // Returns true if uses of given LDS global within non-kernel functions should
143 // be keep as it is without pointer replacement.
144 bool shouldIgnorePointerReplacement(GlobalVariable *GV) {
145 // LDS whose size is very small and doesn`t exceed pointer size is not worth
146 // replacing.
147 if (DL.getTypeAllocSize(GV->getValueType()) <= 2)
148 return true;
150 // LDS which is not used from non-kernel function scope or it is used from
151 // global scope does not qualify for replacement.
152 LDSToNonKernels[GV] = AMDGPU::collectNonKernelAccessorsOfLDS(GV);
153 return LDSToNonKernels[GV].empty();
155 // FIXME: When GV is used within all (or within most of the kernels), then
156 // it does not make sense to create a pointer for it.
159 // Insert new global LDS pointer which points to LDS.
160 GlobalVariable *createLDSPointer(GlobalVariable *GV) {
161 // LDS pointer which points to LDS is already created? return it.
162 auto PointerEntry = LDSToPointer.insert(std::make_pair(GV, nullptr));
163 if (!PointerEntry.second)
164 return PointerEntry.first->second;
166 // We need to create new LDS pointer which points to LDS.
168 // Each CU owns at max 64K of LDS memory, so LDS address ranges from 0 to
169 // 2^16 - 1. Hence 16 bit pointer is enough to hold the LDS address.
170 auto *I16Ty = Type::getInt16Ty(Ctx);
171 GlobalVariable *LDSPointer = new GlobalVariable(
172 M, I16Ty, false, GlobalValue::InternalLinkage, UndefValue::get(I16Ty),
173 GV->getName() + Twine(".ptr"), nullptr, GlobalVariable::NotThreadLocal,
174 AMDGPUAS::LOCAL_ADDRESS);
176 LDSPointer->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
177 LDSPointer->setAlignment(AMDGPU::getAlign(DL, LDSPointer));
179 // Mark that an associated LDS pointer is created for LDS.
180 LDSToPointer[GV] = LDSPointer;
182 return LDSPointer;
185 // Split entry basic block in such a way that only lane 0 of each wave does
186 // the LDS pointer initialization, and return newly created basic block.
187 BasicBlock *activateLaneZero(Function *K) {
188 // If the entry basic block of kernel K is already splitted, then return
189 // newly created basic block.
190 auto BasicBlockEntry = KernelToInitBB.insert(std::make_pair(K, nullptr));
191 if (!BasicBlockEntry.second)
192 return BasicBlockEntry.first->second;
194 // Split entry basic block of kernel K.
195 auto *EI = &(*(K->getEntryBlock().getFirstInsertionPt()));
196 IRBuilder<> Builder(EI);
198 Value *Mbcnt =
199 Builder.CreateIntrinsic(Intrinsic::amdgcn_mbcnt_lo, {},
200 {Builder.getInt32(-1), Builder.getInt32(0)});
201 Value *Cond = Builder.CreateICmpEQ(Mbcnt, Builder.getInt32(0));
202 Instruction *WB = cast<Instruction>(
203 Builder.CreateIntrinsic(Intrinsic::amdgcn_wave_barrier, {}, {}));
205 BasicBlock *NBB = SplitBlockAndInsertIfThen(Cond, WB, false)->getParent();
207 // Mark that the entry basic block of kernel K is splitted.
208 KernelToInitBB[K] = NBB;
210 return NBB;
213 // Within given kernel, initialize given LDS pointer to point to given LDS.
214 void initializeLDSPointer(Function *K, GlobalVariable *GV,
215 GlobalVariable *LDSPointer) {
216 // If LDS pointer is already initialized within K, then nothing to do.
217 auto PointerEntry = KernelToLDSPointers.insert(
218 std::make_pair(K, SmallPtrSet<GlobalVariable *, 8>()));
219 if (!PointerEntry.second)
220 if (PointerEntry.first->second.contains(LDSPointer))
221 return;
223 // Insert instructions at EI which initialize LDS pointer to point-to LDS
224 // within kernel K.
226 // That is, convert pointer type of GV to i16, and then store this converted
227 // i16 value within LDSPointer which is of type i16*.
228 auto *EI = &(*(activateLaneZero(K)->getFirstInsertionPt()));
229 IRBuilder<> Builder(EI);
230 Builder.CreateStore(Builder.CreatePtrToInt(GV, Type::getInt16Ty(Ctx)),
231 LDSPointer);
233 // Mark that LDS pointer is initialized within kernel K.
234 KernelToLDSPointers[K].insert(LDSPointer);
237 // We have created an LDS pointer for LDS, and initialized it to point-to LDS
238 // within all relevent kernels. Now replace all the uses of LDS within
239 // non-kernel functions by LDS pointer.
240 void replaceLDSUseByPointer(GlobalVariable *GV, GlobalVariable *LDSPointer) {
241 SmallVector<User *, 8> LDSUsers(GV->users());
242 for (auto *U : LDSUsers) {
243 // When `U` is a constant expression, it is possible that same constant
244 // expression exists within multiple instructions, and within multiple
245 // non-kernel functions. Collect all those non-kernel functions and all
246 // those instructions within which `U` exist.
247 auto FunctionToInsts =
248 AMDGPU::getFunctionToInstsMap(U, false /*=CollectKernelInsts*/);
250 for (auto FI = FunctionToInsts.begin(), FE = FunctionToInsts.end();
251 FI != FE; ++FI) {
252 Function *F = FI->first;
253 auto &Insts = FI->second;
254 for (auto *I : Insts) {
255 // If `U` is a constant expression, then we need to break the
256 // associated instruction into a set of separate instructions by
257 // converting constant expressions into instructions.
258 SmallPtrSet<Instruction *, 8> UserInsts;
260 if (U == I) {
261 // `U` is an instruction, conversion from constant expression to
262 // set of instructions is *not* required.
263 UserInsts.insert(I);
264 } else {
265 // `U` is a constant expression, convert it into corresponding set
266 // of instructions.
267 auto *CE = cast<ConstantExpr>(U);
268 convertConstantExprsToInstructions(I, CE, &UserInsts);
271 // Go through all the user instrutions, if LDS exist within them as an
272 // operand, then replace it by replace instruction.
273 for (auto *II : UserInsts) {
274 auto *ReplaceInst = getReplacementInst(F, GV, LDSPointer);
275 II->replaceUsesOfWith(GV, ReplaceInst);
282 // Create a set of replacement instructions which together replace LDS within
283 // non-kernel function F by accessing LDS indirectly using LDS pointer.
284 Value *getReplacementInst(Function *F, GlobalVariable *GV,
285 GlobalVariable *LDSPointer) {
286 // If the instruction which replaces LDS within F is already created, then
287 // return it.
288 auto LDSEntry = FunctionToLDSToReplaceInst.insert(
289 std::make_pair(F, DenseMap<GlobalVariable *, Value *>()));
290 if (!LDSEntry.second) {
291 auto ReplaceInstEntry =
292 LDSEntry.first->second.insert(std::make_pair(GV, nullptr));
293 if (!ReplaceInstEntry.second)
294 return ReplaceInstEntry.first->second;
297 // Get the instruction insertion point within the beginning of the entry
298 // block of current non-kernel function.
299 auto *EI = &(*(F->getEntryBlock().getFirstInsertionPt()));
300 IRBuilder<> Builder(EI);
302 // Insert required set of instructions which replace LDS within F.
303 auto *V = Builder.CreateBitCast(
304 Builder.CreateGEP(
305 Builder.getInt8Ty(), LDSMemBaseAddr,
306 Builder.CreateLoad(LDSPointer->getValueType(), LDSPointer)),
307 GV->getType());
309 // Mark that the replacement instruction which replace LDS within F is
310 // created.
311 FunctionToLDSToReplaceInst[F][GV] = V;
313 return V;
316 public:
317 ReplaceLDSUseImpl(Module &M)
318 : M(M), Ctx(M.getContext()), DL(M.getDataLayout()) {
319 LDSMemBaseAddr = Constant::getIntegerValue(
320 PointerType::get(Type::getInt8Ty(M.getContext()),
321 AMDGPUAS::LOCAL_ADDRESS),
322 APInt(32, 0));
325 // Entry-point function which interface ReplaceLDSUseImpl with outside of the
326 // class.
327 bool replaceLDSUse();
329 private:
330 // For a given LDS from collected LDS globals set, replace its non-kernel
331 // function scope uses by pointer.
332 bool replaceLDSUse(GlobalVariable *GV);
335 // For given LDS from collected LDS globals set, replace its non-kernel function
336 // scope uses by pointer.
337 bool ReplaceLDSUseImpl::replaceLDSUse(GlobalVariable *GV) {
338 // Holds all those non-kernel functions within which LDS is being accessed.
339 SmallPtrSet<Function *, 8> &LDSAccessors = LDSToNonKernels[GV];
341 // The LDS pointer which points to LDS and replaces all the uses of LDS.
342 GlobalVariable *LDSPointer = nullptr;
344 // Traverse through each kernel K, check and if required, initialize the
345 // LDS pointer to point to LDS within K.
346 for (auto KI = KernelToCallees.begin(), KE = KernelToCallees.end(); KI != KE;
347 ++KI) {
348 Function *K = KI->first;
349 SmallPtrSet<Function *, 8> Callees = KI->second;
351 // Compute reachable and LDS used callees for kernel K.
352 set_intersect(Callees, LDSAccessors);
354 // None of the LDS accessing non-kernel functions are reachable from
355 // kernel K. Hence, no need to initialize LDS pointer within kernel K.
356 if (Callees.empty())
357 continue;
359 // We have found reachable and LDS used callees for kernel K, and we need to
360 // initialize LDS pointer within kernel K, and we need to replace LDS use
361 // within those callees by LDS pointer.
363 // But, first check if LDS pointer is already created, if not create one.
364 LDSPointer = createLDSPointer(GV);
366 // Initialize LDS pointer to point to LDS within kernel K.
367 initializeLDSPointer(K, GV, LDSPointer);
370 // We have not found reachable and LDS used callees for any of the kernels,
371 // and hence we have not created LDS pointer.
372 if (!LDSPointer)
373 return false;
375 // We have created an LDS pointer for LDS, and initialized it to point-to LDS
376 // within all relevent kernels. Now replace all the uses of LDS within
377 // non-kernel functions by LDS pointer.
378 replaceLDSUseByPointer(GV, LDSPointer);
380 return true;
383 // Entry-point function which interface ReplaceLDSUseImpl with outside of the
384 // class.
385 bool ReplaceLDSUseImpl::replaceLDSUse() {
386 // Collect LDS which requires their uses to be replaced by pointer.
387 std::vector<GlobalVariable *> LDSGlobals =
388 collectLDSRequiringPointerReplace();
390 // No LDS to pointer-replace. Nothing to do.
391 if (LDSGlobals.empty())
392 return false;
394 // Collect reachable callee set for each kernel defined in the module.
395 AMDGPU::collectReachableCallees(M, KernelToCallees);
397 if (KernelToCallees.empty()) {
398 // Either module does not have any kernel definitions, or none of the kernel
399 // has a call to non-kernel functions, or we could not resolve any of the
400 // call sites to proper non-kernel functions, because of the situations like
401 // inline asm calls. Nothing to replace.
402 return false;
405 // For every LDS from collected LDS globals set, replace its non-kernel
406 // function scope use by pointer.
407 bool Changed = false;
408 for (auto *GV : LDSGlobals)
409 Changed |= replaceLDSUse(GV);
411 return Changed;
414 class AMDGPUReplaceLDSUseWithPointer : public ModulePass {
415 public:
416 static char ID;
418 AMDGPUReplaceLDSUseWithPointer() : ModulePass(ID) {
419 initializeAMDGPUReplaceLDSUseWithPointerPass(
420 *PassRegistry::getPassRegistry());
423 bool runOnModule(Module &M) override;
425 void getAnalysisUsage(AnalysisUsage &AU) const override {
426 AU.addRequired<TargetPassConfig>();
430 } // namespace
432 char AMDGPUReplaceLDSUseWithPointer::ID = 0;
433 char &llvm::AMDGPUReplaceLDSUseWithPointerID =
434 AMDGPUReplaceLDSUseWithPointer::ID;
436 INITIALIZE_PASS_BEGIN(
437 AMDGPUReplaceLDSUseWithPointer, DEBUG_TYPE,
438 "Replace within non-kernel function use of LDS with pointer",
439 false /*only look at the cfg*/, false /*analysis pass*/)
440 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
441 INITIALIZE_PASS_END(
442 AMDGPUReplaceLDSUseWithPointer, DEBUG_TYPE,
443 "Replace within non-kernel function use of LDS with pointer",
444 false /*only look at the cfg*/, false /*analysis pass*/)
446 bool AMDGPUReplaceLDSUseWithPointer::runOnModule(Module &M) {
447 ReplaceLDSUseImpl LDSUseReplacer{M};
448 return LDSUseReplacer.replaceLDSUse();
451 ModulePass *llvm::createAMDGPUReplaceLDSUseWithPointerPass() {
452 return new AMDGPUReplaceLDSUseWithPointer();
455 PreservedAnalyses
456 AMDGPUReplaceLDSUseWithPointerPass::run(Module &M, ModuleAnalysisManager &AM) {
457 ReplaceLDSUseImpl LDSUseReplacer{M};
458 LDSUseReplacer.replaceLDSUse();
459 return PreservedAnalyses::all();