1 //===-- AMDGPULowerKernelArguments.cpp ------------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 /// \file This pass replaces accesses to kernel arguments with loads from
10 /// offsets from the kernarg base pointer.
12 //===----------------------------------------------------------------------===//
15 #include "AMDGPUSubtarget.h"
16 #include "AMDGPUTargetMachine.h"
17 #include "llvm/ADT/StringRef.h"
18 #include "llvm/Analysis/Loads.h"
19 #include "llvm/CodeGen/Passes.h"
20 #include "llvm/CodeGen/TargetPassConfig.h"
21 #include "llvm/IR/Attributes.h"
22 #include "llvm/IR/BasicBlock.h"
23 #include "llvm/IR/Constants.h"
24 #include "llvm/IR/DerivedTypes.h"
25 #include "llvm/IR/Function.h"
26 #include "llvm/IR/IRBuilder.h"
27 #include "llvm/IR/InstrTypes.h"
28 #include "llvm/IR/Instruction.h"
29 #include "llvm/IR/Instructions.h"
30 #include "llvm/IR/LLVMContext.h"
31 #include "llvm/IR/MDBuilder.h"
32 #include "llvm/IR/Metadata.h"
33 #include "llvm/IR/Operator.h"
34 #include "llvm/IR/Type.h"
35 #include "llvm/IR/Value.h"
36 #include "llvm/Pass.h"
37 #include "llvm/Support/Casting.h"
39 #define DEBUG_TYPE "amdgpu-lower-kernel-arguments"
45 class AMDGPULowerKernelArguments
: public FunctionPass
{
49 AMDGPULowerKernelArguments() : FunctionPass(ID
) {}
51 bool runOnFunction(Function
&F
) override
;
53 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
54 AU
.addRequired
<TargetPassConfig
>();
59 } // end anonymous namespace
61 bool AMDGPULowerKernelArguments::runOnFunction(Function
&F
) {
62 CallingConv::ID CC
= F
.getCallingConv();
63 if (CC
!= CallingConv::AMDGPU_KERNEL
|| F
.arg_empty())
66 auto &TPC
= getAnalysis
<TargetPassConfig
>();
68 const TargetMachine
&TM
= TPC
.getTM
<TargetMachine
>();
69 const GCNSubtarget
&ST
= TM
.getSubtarget
<GCNSubtarget
>(F
);
70 LLVMContext
&Ctx
= F
.getParent()->getContext();
71 const DataLayout
&DL
= F
.getParent()->getDataLayout();
72 BasicBlock
&EntryBlock
= *F
.begin();
73 IRBuilder
<> Builder(&*EntryBlock
.begin());
75 const Align
KernArgBaseAlign(16); // FIXME: Increase if necessary
76 const uint64_t BaseOffset
= ST
.getExplicitKernelArgOffset(F
);
79 // FIXME: Alignment is broken broken with explicit arg offset.;
80 const uint64_t TotalKernArgSize
= ST
.getKernArgSegmentSize(F
, MaxAlign
);
81 if (TotalKernArgSize
== 0)
84 CallInst
*KernArgSegment
=
85 Builder
.CreateIntrinsic(Intrinsic::amdgcn_kernarg_segment_ptr
, {}, {},
86 nullptr, F
.getName() + ".kernarg.segment");
88 KernArgSegment
->addAttribute(AttributeList::ReturnIndex
, Attribute::NonNull
);
89 KernArgSegment
->addAttribute(AttributeList::ReturnIndex
,
90 Attribute::getWithDereferenceableBytes(Ctx
, TotalKernArgSize
));
92 unsigned AS
= KernArgSegment
->getType()->getPointerAddressSpace();
93 uint64_t ExplicitArgOffset
= 0;
95 for (Argument
&Arg
: F
.args()) {
96 Type
*ArgTy
= Arg
.getType();
97 unsigned ABITypeAlign
= DL
.getABITypeAlignment(ArgTy
);
98 unsigned Size
= DL
.getTypeSizeInBits(ArgTy
);
99 unsigned AllocSize
= DL
.getTypeAllocSize(ArgTy
);
101 uint64_t EltOffset
= alignTo(ExplicitArgOffset
, ABITypeAlign
) + BaseOffset
;
102 ExplicitArgOffset
= alignTo(ExplicitArgOffset
, ABITypeAlign
) + AllocSize
;
107 if (PointerType
*PT
= dyn_cast
<PointerType
>(ArgTy
)) {
108 // FIXME: Hack. We rely on AssertZext to be able to fold DS addressing
109 // modes on SI to know the high bits are 0 so pointer adds don't wrap. We
110 // can't represent this with range metadata because it's only allowed for
112 if ((PT
->getAddressSpace() == AMDGPUAS::LOCAL_ADDRESS
||
113 PT
->getAddressSpace() == AMDGPUAS::REGION_ADDRESS
) &&
114 !ST
.hasUsableDSOffset())
117 // FIXME: We can replace this with equivalent alias.scope/noalias
118 // metadata, but this appears to be a lot of work.
119 if (Arg
.hasNoAliasAttr())
123 VectorType
*VT
= dyn_cast
<VectorType
>(ArgTy
);
124 bool IsV3
= VT
&& VT
->getNumElements() == 3;
125 bool DoShiftOpt
= Size
< 32 && !ArgTy
->isAggregateType();
127 VectorType
*V4Ty
= nullptr;
129 int64_t AlignDownOffset
= alignDown(EltOffset
, 4);
130 int64_t OffsetDiff
= EltOffset
- AlignDownOffset
;
131 Align AdjustedAlign
= commonAlignment(
132 KernArgBaseAlign
, DoShiftOpt
? AlignDownOffset
: EltOffset
);
136 if (DoShiftOpt
) { // FIXME: Handle aggregate types
137 // Since we don't have sub-dword scalar loads, avoid doing an extload by
138 // loading earlier than the argument address, and extracting the relevant
141 // Additionally widen any sub-dword load to i32 even if suitably aligned,
142 // so that CSE between different argument loads works easily.
143 ArgPtr
= Builder
.CreateConstInBoundsGEP1_64(
144 Builder
.getInt8Ty(), KernArgSegment
, AlignDownOffset
,
145 Arg
.getName() + ".kernarg.offset.align.down");
146 AdjustedArgTy
= Builder
.getInt32Ty();
148 ArgPtr
= Builder
.CreateConstInBoundsGEP1_64(
149 Builder
.getInt8Ty(), KernArgSegment
, EltOffset
,
150 Arg
.getName() + ".kernarg.offset");
151 AdjustedArgTy
= ArgTy
;
154 if (IsV3
&& Size
>= 32) {
155 V4Ty
= VectorType::get(VT
->getVectorElementType(), 4);
156 // Use the hack that clang uses to avoid SelectionDAG ruining v3 loads
157 AdjustedArgTy
= V4Ty
;
160 ArgPtr
= Builder
.CreateBitCast(ArgPtr
, AdjustedArgTy
->getPointerTo(AS
),
161 ArgPtr
->getName() + ".cast");
163 Builder
.CreateAlignedLoad(AdjustedArgTy
, ArgPtr
, AdjustedAlign
.value());
164 Load
->setMetadata(LLVMContext::MD_invariant_load
, MDNode::get(Ctx
, {}));
168 if (isa
<PointerType
>(ArgTy
)) {
169 if (Arg
.hasNonNullAttr())
170 Load
->setMetadata(LLVMContext::MD_nonnull
, MDNode::get(Ctx
, {}));
172 uint64_t DerefBytes
= Arg
.getDereferenceableBytes();
173 if (DerefBytes
!= 0) {
175 LLVMContext::MD_dereferenceable
,
178 ConstantInt::get(Builder
.getInt64Ty(), DerefBytes
))));
181 uint64_t DerefOrNullBytes
= Arg
.getDereferenceableOrNullBytes();
182 if (DerefOrNullBytes
!= 0) {
184 LLVMContext::MD_dereferenceable_or_null
,
186 MDB
.createConstant(ConstantInt::get(Builder
.getInt64Ty(),
187 DerefOrNullBytes
))));
190 unsigned ParamAlign
= Arg
.getParamAlignment();
191 if (ParamAlign
!= 0) {
193 LLVMContext::MD_align
,
195 MDB
.createConstant(ConstantInt::get(Builder
.getInt64Ty(),
200 // TODO: Convert noalias arg to !noalias
203 Value
*ExtractBits
= OffsetDiff
== 0 ?
204 Load
: Builder
.CreateLShr(Load
, OffsetDiff
* 8);
206 IntegerType
*ArgIntTy
= Builder
.getIntNTy(Size
);
207 Value
*Trunc
= Builder
.CreateTrunc(ExtractBits
, ArgIntTy
);
208 Value
*NewVal
= Builder
.CreateBitCast(Trunc
, ArgTy
,
209 Arg
.getName() + ".load");
210 Arg
.replaceAllUsesWith(NewVal
);
212 Value
*Shuf
= Builder
.CreateShuffleVector(Load
, UndefValue::get(V4Ty
),
214 Arg
.getName() + ".load");
215 Arg
.replaceAllUsesWith(Shuf
);
217 Load
->setName(Arg
.getName() + ".load");
218 Arg
.replaceAllUsesWith(Load
);
222 KernArgSegment
->addAttribute(
223 AttributeList::ReturnIndex
,
224 Attribute::getWithAlignment(Ctx
, std::max(KernArgBaseAlign
, MaxAlign
)));
229 INITIALIZE_PASS_BEGIN(AMDGPULowerKernelArguments
, DEBUG_TYPE
,
230 "AMDGPU Lower Kernel Arguments", false, false)
231 INITIALIZE_PASS_END(AMDGPULowerKernelArguments
, DEBUG_TYPE
, "AMDGPU Lower Kernel Arguments",
234 char AMDGPULowerKernelArguments::ID
= 0;
236 FunctionPass
*llvm::createAMDGPULowerKernelArgumentsPass() {
237 return new AMDGPULowerKernelArguments();