Revert " [LoongArch][ISel] Check the number of sign bits in `PatGprGpr_32` (#107432)"
[llvm-project.git] / llvm / lib / Target / AMDGPU / R600OpenCLImageTypeLoweringPass.cpp
blob604a4cb1bf881b0dcb73e3e830742d69e42377be
1 //===- R600OpenCLImageTypeLoweringPass.cpp ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file
10 /// This pass resolves calls to OpenCL image attribute, image resource ID and
11 /// sampler resource ID getter functions.
12 ///
13 /// Image attributes (size and format) are expected to be passed to the kernel
14 /// as kernel arguments immediately following the image argument itself,
15 /// therefore this pass adds image size and format arguments to the kernel
16 /// functions in the module. The kernel functions with image arguments are
17 /// re-created using the new signature. The new arguments are added to the
18 /// kernel metadata with kernel_arg_type set to "image_size" or "image_format".
19 /// Note: this pass may invalidate pointers to functions.
20 ///
21 /// Resource IDs of read-only images, write-only images and samplers are
22 /// defined to be their index among the kernel arguments of the same
23 /// type and access qualifier.
25 //===----------------------------------------------------------------------===//
27 #include "R600.h"
28 #include "llvm/ADT/SmallVector.h"
29 #include "llvm/ADT/StringRef.h"
30 #include "llvm/IR/Constants.h"
31 #include "llvm/IR/Function.h"
32 #include "llvm/IR/Instructions.h"
33 #include "llvm/IR/Metadata.h"
34 #include "llvm/IR/Module.h"
35 #include "llvm/Pass.h"
36 #include "llvm/Transforms/Utils/Cloning.h"
38 using namespace llvm;
40 static StringRef GetImageSizeFunc = "llvm.OpenCL.image.get.size";
41 static StringRef GetImageFormatFunc = "llvm.OpenCL.image.get.format";
42 static StringRef GetImageResourceIDFunc = "llvm.OpenCL.image.get.resource.id";
43 static StringRef GetSamplerResourceIDFunc =
44 "llvm.OpenCL.sampler.get.resource.id";
46 static StringRef ImageSizeArgMDType = "__llvm_image_size";
47 static StringRef ImageFormatArgMDType = "__llvm_image_format";
49 static StringRef KernelsMDNodeName = "opencl.kernels";
50 static StringRef KernelArgMDNodeNames[] = {
51 "kernel_arg_addr_space",
52 "kernel_arg_access_qual",
53 "kernel_arg_type",
54 "kernel_arg_base_type",
55 "kernel_arg_type_qual"};
56 static const unsigned NumKernelArgMDNodes = 5;
58 namespace {
60 using MDVector = SmallVector<Metadata *, 8>;
61 struct KernelArgMD {
62 MDVector ArgVector[NumKernelArgMDNodes];
65 } // end anonymous namespace
67 static inline bool
68 IsImageType(StringRef TypeString) {
69 return TypeString == "image2d_t" || TypeString == "image3d_t";
72 static inline bool
73 IsSamplerType(StringRef TypeString) {
74 return TypeString == "sampler_t";
77 static Function *
78 GetFunctionFromMDNode(MDNode *Node) {
79 if (!Node)
80 return nullptr;
82 size_t NumOps = Node->getNumOperands();
83 if (NumOps != NumKernelArgMDNodes + 1)
84 return nullptr;
86 auto F = mdconst::dyn_extract<Function>(Node->getOperand(0));
87 if (!F)
88 return nullptr;
90 // Validation checks.
91 size_t ExpectNumArgNodeOps = F->arg_size() + 1;
92 for (size_t i = 0; i < NumKernelArgMDNodes; ++i) {
93 MDNode *ArgNode = dyn_cast_or_null<MDNode>(Node->getOperand(i + 1));
94 if (ArgNode->getNumOperands() != ExpectNumArgNodeOps)
95 return nullptr;
96 if (!ArgNode->getOperand(0))
97 return nullptr;
99 // FIXME: It should be possible to do image lowering when some metadata
100 // args missing or not in the expected order.
101 MDString *StringNode = dyn_cast<MDString>(ArgNode->getOperand(0));
102 if (!StringNode || StringNode->getString() != KernelArgMDNodeNames[i])
103 return nullptr;
106 return F;
109 static StringRef
110 AccessQualFromMD(MDNode *KernelMDNode, unsigned ArgIdx) {
111 MDNode *ArgAQNode = cast<MDNode>(KernelMDNode->getOperand(2));
112 return cast<MDString>(ArgAQNode->getOperand(ArgIdx + 1))->getString();
115 static StringRef
116 ArgTypeFromMD(MDNode *KernelMDNode, unsigned ArgIdx) {
117 MDNode *ArgTypeNode = cast<MDNode>(KernelMDNode->getOperand(3));
118 return cast<MDString>(ArgTypeNode->getOperand(ArgIdx + 1))->getString();
121 static MDVector
122 GetArgMD(MDNode *KernelMDNode, unsigned OpIdx) {
123 MDVector Res;
124 for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) {
125 MDNode *Node = cast<MDNode>(KernelMDNode->getOperand(i + 1));
126 Res.push_back(Node->getOperand(OpIdx));
128 return Res;
131 static void
132 PushArgMD(KernelArgMD &MD, const MDVector &V) {
133 assert(V.size() == NumKernelArgMDNodes);
134 for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) {
135 MD.ArgVector[i].push_back(V[i]);
139 namespace {
141 class R600OpenCLImageTypeLoweringPass : public ModulePass {
142 static char ID;
144 LLVMContext *Context;
145 Type *Int32Type;
146 Type *ImageSizeType;
147 Type *ImageFormatType;
148 SmallVector<Instruction *, 4> InstsToErase;
150 bool replaceImageUses(Argument &ImageArg, uint32_t ResourceID,
151 Argument &ImageSizeArg,
152 Argument &ImageFormatArg) {
153 bool Modified = false;
155 for (auto &Use : ImageArg.uses()) {
156 auto Inst = dyn_cast<CallInst>(Use.getUser());
157 if (!Inst) {
158 continue;
161 Function *F = Inst->getCalledFunction();
162 if (!F)
163 continue;
165 Value *Replacement = nullptr;
166 StringRef Name = F->getName();
167 if (Name.starts_with(GetImageResourceIDFunc)) {
168 Replacement = ConstantInt::get(Int32Type, ResourceID);
169 } else if (Name.starts_with(GetImageSizeFunc)) {
170 Replacement = &ImageSizeArg;
171 } else if (Name.starts_with(GetImageFormatFunc)) {
172 Replacement = &ImageFormatArg;
173 } else {
174 continue;
177 Inst->replaceAllUsesWith(Replacement);
178 InstsToErase.push_back(Inst);
179 Modified = true;
182 return Modified;
185 bool replaceSamplerUses(Argument &SamplerArg, uint32_t ResourceID) {
186 bool Modified = false;
188 for (const auto &Use : SamplerArg.uses()) {
189 auto Inst = dyn_cast<CallInst>(Use.getUser());
190 if (!Inst) {
191 continue;
194 Function *F = Inst->getCalledFunction();
195 if (!F)
196 continue;
198 Value *Replacement = nullptr;
199 StringRef Name = F->getName();
200 if (Name == GetSamplerResourceIDFunc) {
201 Replacement = ConstantInt::get(Int32Type, ResourceID);
202 } else {
203 continue;
206 Inst->replaceAllUsesWith(Replacement);
207 InstsToErase.push_back(Inst);
208 Modified = true;
211 return Modified;
214 bool replaceImageAndSamplerUses(Function *F, MDNode *KernelMDNode) {
215 uint32_t NumReadOnlyImageArgs = 0;
216 uint32_t NumWriteOnlyImageArgs = 0;
217 uint32_t NumSamplerArgs = 0;
219 bool Modified = false;
220 InstsToErase.clear();
221 for (auto ArgI = F->arg_begin(); ArgI != F->arg_end(); ++ArgI) {
222 Argument &Arg = *ArgI;
223 StringRef Type = ArgTypeFromMD(KernelMDNode, Arg.getArgNo());
225 // Handle image types.
226 if (IsImageType(Type)) {
227 StringRef AccessQual = AccessQualFromMD(KernelMDNode, Arg.getArgNo());
228 uint32_t ResourceID;
229 if (AccessQual == "read_only") {
230 ResourceID = NumReadOnlyImageArgs++;
231 } else if (AccessQual == "write_only") {
232 ResourceID = NumWriteOnlyImageArgs++;
233 } else {
234 llvm_unreachable("Wrong image access qualifier.");
237 Argument &SizeArg = *(++ArgI);
238 Argument &FormatArg = *(++ArgI);
239 Modified |= replaceImageUses(Arg, ResourceID, SizeArg, FormatArg);
241 // Handle sampler type.
242 } else if (IsSamplerType(Type)) {
243 uint32_t ResourceID = NumSamplerArgs++;
244 Modified |= replaceSamplerUses(Arg, ResourceID);
247 for (auto *Inst : InstsToErase)
248 Inst->eraseFromParent();
250 return Modified;
253 std::tuple<Function *, MDNode *>
254 addImplicitArgs(Function *F, MDNode *KernelMDNode) {
255 bool Modified = false;
257 FunctionType *FT = F->getFunctionType();
258 SmallVector<Type *, 8> ArgTypes;
260 // Metadata operands for new MDNode.
261 KernelArgMD NewArgMDs;
262 PushArgMD(NewArgMDs, GetArgMD(KernelMDNode, 0));
264 // Add implicit arguments to the signature.
265 for (unsigned i = 0; i < FT->getNumParams(); ++i) {
266 ArgTypes.push_back(FT->getParamType(i));
267 MDVector ArgMD = GetArgMD(KernelMDNode, i + 1);
268 PushArgMD(NewArgMDs, ArgMD);
270 if (!IsImageType(ArgTypeFromMD(KernelMDNode, i)))
271 continue;
273 // Add size implicit argument.
274 ArgTypes.push_back(ImageSizeType);
275 ArgMD[2] = ArgMD[3] = MDString::get(*Context, ImageSizeArgMDType);
276 PushArgMD(NewArgMDs, ArgMD);
278 // Add format implicit argument.
279 ArgTypes.push_back(ImageFormatType);
280 ArgMD[2] = ArgMD[3] = MDString::get(*Context, ImageFormatArgMDType);
281 PushArgMD(NewArgMDs, ArgMD);
283 Modified = true;
285 if (!Modified) {
286 return std::tuple(nullptr, nullptr);
289 // Create function with new signature and clone the old body into it.
290 auto NewFT = FunctionType::get(FT->getReturnType(), ArgTypes, false);
291 auto NewF = Function::Create(NewFT, F->getLinkage(), F->getName());
292 ValueToValueMapTy VMap;
293 auto NewFArgIt = NewF->arg_begin();
294 for (auto &Arg: F->args()) {
295 auto ArgName = Arg.getName();
296 NewFArgIt->setName(ArgName);
297 VMap[&Arg] = &(*NewFArgIt++);
298 if (IsImageType(ArgTypeFromMD(KernelMDNode, Arg.getArgNo()))) {
299 (NewFArgIt++)->setName(Twine("__size_") + ArgName);
300 (NewFArgIt++)->setName(Twine("__format_") + ArgName);
303 SmallVector<ReturnInst*, 8> Returns;
304 CloneFunctionInto(NewF, F, VMap, CloneFunctionChangeType::LocalChangesOnly,
305 Returns);
307 // Build new MDNode.
308 SmallVector<Metadata *, 6> KernelMDArgs;
309 KernelMDArgs.push_back(ConstantAsMetadata::get(NewF));
310 for (const MDVector &MDV : NewArgMDs.ArgVector)
311 KernelMDArgs.push_back(MDNode::get(*Context, MDV));
312 MDNode *NewMDNode = MDNode::get(*Context, KernelMDArgs);
314 return std::tuple(NewF, NewMDNode);
317 bool transformKernels(Module &M) {
318 NamedMDNode *KernelsMDNode = M.getNamedMetadata(KernelsMDNodeName);
319 if (!KernelsMDNode)
320 return false;
322 bool Modified = false;
323 for (unsigned i = 0; i < KernelsMDNode->getNumOperands(); ++i) {
324 MDNode *KernelMDNode = KernelsMDNode->getOperand(i);
325 Function *F = GetFunctionFromMDNode(KernelMDNode);
326 if (!F)
327 continue;
329 Function *NewF;
330 MDNode *NewMDNode;
331 std::tie(NewF, NewMDNode) = addImplicitArgs(F, KernelMDNode);
332 if (NewF) {
333 // Replace old function and metadata with new ones.
334 F->eraseFromParent();
335 M.getFunctionList().push_back(NewF);
336 M.getOrInsertFunction(NewF->getName(), NewF->getFunctionType(),
337 NewF->getAttributes());
338 KernelsMDNode->setOperand(i, NewMDNode);
340 F = NewF;
341 KernelMDNode = NewMDNode;
342 Modified = true;
345 Modified |= replaceImageAndSamplerUses(F, KernelMDNode);
348 return Modified;
351 public:
352 R600OpenCLImageTypeLoweringPass() : ModulePass(ID) {}
354 bool runOnModule(Module &M) override {
355 Context = &M.getContext();
356 Int32Type = Type::getInt32Ty(M.getContext());
357 ImageSizeType = ArrayType::get(Int32Type, 3);
358 ImageFormatType = ArrayType::get(Int32Type, 2);
360 return transformKernels(M);
363 StringRef getPassName() const override {
364 return "R600 OpenCL Image Type Pass";
368 } // end anonymous namespace
370 char R600OpenCLImageTypeLoweringPass::ID = 0;
372 ModulePass *llvm::createR600OpenCLImageTypeLoweringPass() {
373 return new R600OpenCLImageTypeLoweringPass();