1 //===- R600OpenCLImageTypeLoweringPass.cpp ------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
10 /// This pass resolves calls to OpenCL image attribute, image resource ID and
11 /// sampler resource ID getter functions.
13 /// Image attributes (size and format) are expected to be passed to the kernel
14 /// as kernel arguments immediately following the image argument itself,
15 /// therefore this pass adds image size and format arguments to the kernel
16 /// functions in the module. The kernel functions with image arguments are
17 /// re-created using the new signature. The new arguments are added to the
18 /// kernel metadata with kernel_arg_type set to "image_size" or "image_format".
19 /// Note: this pass may invalidate pointers to functions.
21 /// Resource IDs of read-only images, write-only images and samplers are
22 /// defined to be their index among the kernel arguments of the same
23 /// type and access qualifier.
25 //===----------------------------------------------------------------------===//
28 #include "llvm/ADT/SmallVector.h"
29 #include "llvm/ADT/StringRef.h"
30 #include "llvm/IR/Constants.h"
31 #include "llvm/IR/Function.h"
32 #include "llvm/IR/Instructions.h"
33 #include "llvm/IR/Metadata.h"
34 #include "llvm/IR/Module.h"
35 #include "llvm/Pass.h"
36 #include "llvm/Transforms/Utils/Cloning.h"
40 static StringRef GetImageSizeFunc
= "llvm.OpenCL.image.get.size";
41 static StringRef GetImageFormatFunc
= "llvm.OpenCL.image.get.format";
42 static StringRef GetImageResourceIDFunc
= "llvm.OpenCL.image.get.resource.id";
43 static StringRef GetSamplerResourceIDFunc
=
44 "llvm.OpenCL.sampler.get.resource.id";
46 static StringRef ImageSizeArgMDType
= "__llvm_image_size";
47 static StringRef ImageFormatArgMDType
= "__llvm_image_format";
49 static StringRef KernelsMDNodeName
= "opencl.kernels";
50 static StringRef KernelArgMDNodeNames
[] = {
51 "kernel_arg_addr_space",
52 "kernel_arg_access_qual",
54 "kernel_arg_base_type",
55 "kernel_arg_type_qual"};
56 static const unsigned NumKernelArgMDNodes
= 5;
60 using MDVector
= SmallVector
<Metadata
*, 8>;
62 MDVector ArgVector
[NumKernelArgMDNodes
];
65 } // end anonymous namespace
68 IsImageType(StringRef TypeString
) {
69 return TypeString
== "image2d_t" || TypeString
== "image3d_t";
73 IsSamplerType(StringRef TypeString
) {
74 return TypeString
== "sampler_t";
78 GetFunctionFromMDNode(MDNode
*Node
) {
82 size_t NumOps
= Node
->getNumOperands();
83 if (NumOps
!= NumKernelArgMDNodes
+ 1)
86 auto F
= mdconst::dyn_extract
<Function
>(Node
->getOperand(0));
91 size_t ExpectNumArgNodeOps
= F
->arg_size() + 1;
92 for (size_t i
= 0; i
< NumKernelArgMDNodes
; ++i
) {
93 MDNode
*ArgNode
= dyn_cast_or_null
<MDNode
>(Node
->getOperand(i
+ 1));
94 if (ArgNode
->getNumOperands() != ExpectNumArgNodeOps
)
96 if (!ArgNode
->getOperand(0))
99 // FIXME: It should be possible to do image lowering when some metadata
100 // args missing or not in the expected order.
101 MDString
*StringNode
= dyn_cast
<MDString
>(ArgNode
->getOperand(0));
102 if (!StringNode
|| StringNode
->getString() != KernelArgMDNodeNames
[i
])
110 AccessQualFromMD(MDNode
*KernelMDNode
, unsigned ArgIdx
) {
111 MDNode
*ArgAQNode
= cast
<MDNode
>(KernelMDNode
->getOperand(2));
112 return cast
<MDString
>(ArgAQNode
->getOperand(ArgIdx
+ 1))->getString();
116 ArgTypeFromMD(MDNode
*KernelMDNode
, unsigned ArgIdx
) {
117 MDNode
*ArgTypeNode
= cast
<MDNode
>(KernelMDNode
->getOperand(3));
118 return cast
<MDString
>(ArgTypeNode
->getOperand(ArgIdx
+ 1))->getString();
122 GetArgMD(MDNode
*KernelMDNode
, unsigned OpIdx
) {
124 for (unsigned i
= 0; i
< NumKernelArgMDNodes
; ++i
) {
125 MDNode
*Node
= cast
<MDNode
>(KernelMDNode
->getOperand(i
+ 1));
126 Res
.push_back(Node
->getOperand(OpIdx
));
132 PushArgMD(KernelArgMD
&MD
, const MDVector
&V
) {
133 assert(V
.size() == NumKernelArgMDNodes
);
134 for (unsigned i
= 0; i
< NumKernelArgMDNodes
; ++i
) {
135 MD
.ArgVector
[i
].push_back(V
[i
]);
141 class R600OpenCLImageTypeLoweringPass
: public ModulePass
{
144 LLVMContext
*Context
;
147 Type
*ImageFormatType
;
148 SmallVector
<Instruction
*, 4> InstsToErase
;
150 bool replaceImageUses(Argument
&ImageArg
, uint32_t ResourceID
,
151 Argument
&ImageSizeArg
,
152 Argument
&ImageFormatArg
) {
153 bool Modified
= false;
155 for (auto &Use
: ImageArg
.uses()) {
156 auto Inst
= dyn_cast
<CallInst
>(Use
.getUser());
161 Function
*F
= Inst
->getCalledFunction();
165 Value
*Replacement
= nullptr;
166 StringRef Name
= F
->getName();
167 if (Name
.starts_with(GetImageResourceIDFunc
)) {
168 Replacement
= ConstantInt::get(Int32Type
, ResourceID
);
169 } else if (Name
.starts_with(GetImageSizeFunc
)) {
170 Replacement
= &ImageSizeArg
;
171 } else if (Name
.starts_with(GetImageFormatFunc
)) {
172 Replacement
= &ImageFormatArg
;
177 Inst
->replaceAllUsesWith(Replacement
);
178 InstsToErase
.push_back(Inst
);
185 bool replaceSamplerUses(Argument
&SamplerArg
, uint32_t ResourceID
) {
186 bool Modified
= false;
188 for (const auto &Use
: SamplerArg
.uses()) {
189 auto Inst
= dyn_cast
<CallInst
>(Use
.getUser());
194 Function
*F
= Inst
->getCalledFunction();
198 Value
*Replacement
= nullptr;
199 StringRef Name
= F
->getName();
200 if (Name
== GetSamplerResourceIDFunc
) {
201 Replacement
= ConstantInt::get(Int32Type
, ResourceID
);
206 Inst
->replaceAllUsesWith(Replacement
);
207 InstsToErase
.push_back(Inst
);
214 bool replaceImageAndSamplerUses(Function
*F
, MDNode
*KernelMDNode
) {
215 uint32_t NumReadOnlyImageArgs
= 0;
216 uint32_t NumWriteOnlyImageArgs
= 0;
217 uint32_t NumSamplerArgs
= 0;
219 bool Modified
= false;
220 InstsToErase
.clear();
221 for (auto ArgI
= F
->arg_begin(); ArgI
!= F
->arg_end(); ++ArgI
) {
222 Argument
&Arg
= *ArgI
;
223 StringRef Type
= ArgTypeFromMD(KernelMDNode
, Arg
.getArgNo());
225 // Handle image types.
226 if (IsImageType(Type
)) {
227 StringRef AccessQual
= AccessQualFromMD(KernelMDNode
, Arg
.getArgNo());
229 if (AccessQual
== "read_only") {
230 ResourceID
= NumReadOnlyImageArgs
++;
231 } else if (AccessQual
== "write_only") {
232 ResourceID
= NumWriteOnlyImageArgs
++;
234 llvm_unreachable("Wrong image access qualifier.");
237 Argument
&SizeArg
= *(++ArgI
);
238 Argument
&FormatArg
= *(++ArgI
);
239 Modified
|= replaceImageUses(Arg
, ResourceID
, SizeArg
, FormatArg
);
241 // Handle sampler type.
242 } else if (IsSamplerType(Type
)) {
243 uint32_t ResourceID
= NumSamplerArgs
++;
244 Modified
|= replaceSamplerUses(Arg
, ResourceID
);
247 for (auto *Inst
: InstsToErase
)
248 Inst
->eraseFromParent();
253 std::tuple
<Function
*, MDNode
*>
254 addImplicitArgs(Function
*F
, MDNode
*KernelMDNode
) {
255 bool Modified
= false;
257 FunctionType
*FT
= F
->getFunctionType();
258 SmallVector
<Type
*, 8> ArgTypes
;
260 // Metadata operands for new MDNode.
261 KernelArgMD NewArgMDs
;
262 PushArgMD(NewArgMDs
, GetArgMD(KernelMDNode
, 0));
264 // Add implicit arguments to the signature.
265 for (unsigned i
= 0; i
< FT
->getNumParams(); ++i
) {
266 ArgTypes
.push_back(FT
->getParamType(i
));
267 MDVector ArgMD
= GetArgMD(KernelMDNode
, i
+ 1);
268 PushArgMD(NewArgMDs
, ArgMD
);
270 if (!IsImageType(ArgTypeFromMD(KernelMDNode
, i
)))
273 // Add size implicit argument.
274 ArgTypes
.push_back(ImageSizeType
);
275 ArgMD
[2] = ArgMD
[3] = MDString::get(*Context
, ImageSizeArgMDType
);
276 PushArgMD(NewArgMDs
, ArgMD
);
278 // Add format implicit argument.
279 ArgTypes
.push_back(ImageFormatType
);
280 ArgMD
[2] = ArgMD
[3] = MDString::get(*Context
, ImageFormatArgMDType
);
281 PushArgMD(NewArgMDs
, ArgMD
);
286 return std::tuple(nullptr, nullptr);
289 // Create function with new signature and clone the old body into it.
290 auto NewFT
= FunctionType::get(FT
->getReturnType(), ArgTypes
, false);
291 auto NewF
= Function::Create(NewFT
, F
->getLinkage(), F
->getName());
292 ValueToValueMapTy VMap
;
293 auto NewFArgIt
= NewF
->arg_begin();
294 for (auto &Arg
: F
->args()) {
295 auto ArgName
= Arg
.getName();
296 NewFArgIt
->setName(ArgName
);
297 VMap
[&Arg
] = &(*NewFArgIt
++);
298 if (IsImageType(ArgTypeFromMD(KernelMDNode
, Arg
.getArgNo()))) {
299 (NewFArgIt
++)->setName(Twine("__size_") + ArgName
);
300 (NewFArgIt
++)->setName(Twine("__format_") + ArgName
);
303 SmallVector
<ReturnInst
*, 8> Returns
;
304 CloneFunctionInto(NewF
, F
, VMap
, CloneFunctionChangeType::LocalChangesOnly
,
308 SmallVector
<Metadata
*, 6> KernelMDArgs
;
309 KernelMDArgs
.push_back(ConstantAsMetadata::get(NewF
));
310 for (const MDVector
&MDV
: NewArgMDs
.ArgVector
)
311 KernelMDArgs
.push_back(MDNode::get(*Context
, MDV
));
312 MDNode
*NewMDNode
= MDNode::get(*Context
, KernelMDArgs
);
314 return std::tuple(NewF
, NewMDNode
);
317 bool transformKernels(Module
&M
) {
318 NamedMDNode
*KernelsMDNode
= M
.getNamedMetadata(KernelsMDNodeName
);
322 bool Modified
= false;
323 for (unsigned i
= 0; i
< KernelsMDNode
->getNumOperands(); ++i
) {
324 MDNode
*KernelMDNode
= KernelsMDNode
->getOperand(i
);
325 Function
*F
= GetFunctionFromMDNode(KernelMDNode
);
331 std::tie(NewF
, NewMDNode
) = addImplicitArgs(F
, KernelMDNode
);
333 // Replace old function and metadata with new ones.
334 F
->eraseFromParent();
335 M
.getFunctionList().push_back(NewF
);
336 M
.getOrInsertFunction(NewF
->getName(), NewF
->getFunctionType(),
337 NewF
->getAttributes());
338 KernelsMDNode
->setOperand(i
, NewMDNode
);
341 KernelMDNode
= NewMDNode
;
345 Modified
|= replaceImageAndSamplerUses(F
, KernelMDNode
);
352 R600OpenCLImageTypeLoweringPass() : ModulePass(ID
) {}
354 bool runOnModule(Module
&M
) override
{
355 Context
= &M
.getContext();
356 Int32Type
= Type::getInt32Ty(M
.getContext());
357 ImageSizeType
= ArrayType::get(Int32Type
, 3);
358 ImageFormatType
= ArrayType::get(Int32Type
, 2);
360 return transformKernels(M
);
363 StringRef
getPassName() const override
{
364 return "R600 OpenCL Image Type Pass";
368 } // end anonymous namespace
370 char R600OpenCLImageTypeLoweringPass::ID
= 0;
372 ModulePass
*llvm::createR600OpenCLImageTypeLoweringPass() {
373 return new R600OpenCLImageTypeLoweringPass();