1 //===- R600OpenCLImageTypeLoweringPass.cpp ------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
10 /// This pass resolves calls to OpenCL image attribute, image resource ID and
11 /// sampler resource ID getter functions.
13 /// Image attributes (size and format) are expected to be passed to the kernel
14 /// as kernel arguments immediately following the image argument itself,
15 /// therefore this pass adds image size and format arguments to the kernel
16 /// functions in the module. The kernel functions with image arguments are
17 /// re-created using the new signature. The new arguments are added to the
18 /// kernel metadata with kernel_arg_type set to "image_size" or "image_format".
19 /// Note: this pass may invalidate pointers to functions.
21 /// Resource IDs of read-only images, write-only images and samplers are
22 /// defined to be their index among the kernel arguments of the same
23 /// type and access qualifier.
25 //===----------------------------------------------------------------------===//
28 #include "llvm/ADT/SmallVector.h"
29 #include "llvm/ADT/StringRef.h"
30 #include "llvm/ADT/Twine.h"
31 #include "llvm/IR/Argument.h"
32 #include "llvm/IR/DerivedTypes.h"
33 #include "llvm/IR/Constants.h"
34 #include "llvm/IR/Function.h"
35 #include "llvm/IR/Instruction.h"
36 #include "llvm/IR/Instructions.h"
37 #include "llvm/IR/Metadata.h"
38 #include "llvm/IR/Module.h"
39 #include "llvm/IR/Type.h"
40 #include "llvm/IR/Use.h"
41 #include "llvm/IR/User.h"
42 #include "llvm/Pass.h"
43 #include "llvm/Support/Casting.h"
44 #include "llvm/Support/ErrorHandling.h"
45 #include "llvm/Transforms/Utils/Cloning.h"
46 #include "llvm/Transforms/Utils/ValueMapper.h"
54 static StringRef GetImageSizeFunc
= "llvm.OpenCL.image.get.size";
55 static StringRef GetImageFormatFunc
= "llvm.OpenCL.image.get.format";
56 static StringRef GetImageResourceIDFunc
= "llvm.OpenCL.image.get.resource.id";
57 static StringRef GetSamplerResourceIDFunc
=
58 "llvm.OpenCL.sampler.get.resource.id";
60 static StringRef ImageSizeArgMDType
= "__llvm_image_size";
61 static StringRef ImageFormatArgMDType
= "__llvm_image_format";
63 static StringRef KernelsMDNodeName
= "opencl.kernels";
64 static StringRef KernelArgMDNodeNames
[] = {
65 "kernel_arg_addr_space",
66 "kernel_arg_access_qual",
68 "kernel_arg_base_type",
69 "kernel_arg_type_qual"};
70 static const unsigned NumKernelArgMDNodes
= 5;
74 using MDVector
= SmallVector
<Metadata
*, 8>;
76 MDVector ArgVector
[NumKernelArgMDNodes
];
79 } // end anonymous namespace
82 IsImageType(StringRef TypeString
) {
83 return TypeString
== "image2d_t" || TypeString
== "image3d_t";
87 IsSamplerType(StringRef TypeString
) {
88 return TypeString
== "sampler_t";
92 GetFunctionFromMDNode(MDNode
*Node
) {
96 size_t NumOps
= Node
->getNumOperands();
97 if (NumOps
!= NumKernelArgMDNodes
+ 1)
100 auto F
= mdconst::dyn_extract
<Function
>(Node
->getOperand(0));
105 size_t ExpectNumArgNodeOps
= F
->arg_size() + 1;
106 for (size_t i
= 0; i
< NumKernelArgMDNodes
; ++i
) {
107 MDNode
*ArgNode
= dyn_cast_or_null
<MDNode
>(Node
->getOperand(i
+ 1));
108 if (ArgNode
->getNumOperands() != ExpectNumArgNodeOps
)
110 if (!ArgNode
->getOperand(0))
113 // FIXME: It should be possible to do image lowering when some metadata
114 // args missing or not in the expected order.
115 MDString
*StringNode
= dyn_cast
<MDString
>(ArgNode
->getOperand(0));
116 if (!StringNode
|| StringNode
->getString() != KernelArgMDNodeNames
[i
])
124 AccessQualFromMD(MDNode
*KernelMDNode
, unsigned ArgIdx
) {
125 MDNode
*ArgAQNode
= cast
<MDNode
>(KernelMDNode
->getOperand(2));
126 return cast
<MDString
>(ArgAQNode
->getOperand(ArgIdx
+ 1))->getString();
130 ArgTypeFromMD(MDNode
*KernelMDNode
, unsigned ArgIdx
) {
131 MDNode
*ArgTypeNode
= cast
<MDNode
>(KernelMDNode
->getOperand(3));
132 return cast
<MDString
>(ArgTypeNode
->getOperand(ArgIdx
+ 1))->getString();
136 GetArgMD(MDNode
*KernelMDNode
, unsigned OpIdx
) {
138 for (unsigned i
= 0; i
< NumKernelArgMDNodes
; ++i
) {
139 MDNode
*Node
= cast
<MDNode
>(KernelMDNode
->getOperand(i
+ 1));
140 Res
.push_back(Node
->getOperand(OpIdx
));
146 PushArgMD(KernelArgMD
&MD
, const MDVector
&V
) {
147 assert(V
.size() == NumKernelArgMDNodes
);
148 for (unsigned i
= 0; i
< NumKernelArgMDNodes
; ++i
) {
149 MD
.ArgVector
[i
].push_back(V
[i
]);
155 class R600OpenCLImageTypeLoweringPass
: public ModulePass
{
158 LLVMContext
*Context
;
161 Type
*ImageFormatType
;
162 SmallVector
<Instruction
*, 4> InstsToErase
;
164 bool replaceImageUses(Argument
&ImageArg
, uint32_t ResourceID
,
165 Argument
&ImageSizeArg
,
166 Argument
&ImageFormatArg
) {
167 bool Modified
= false;
169 for (auto &Use
: ImageArg
.uses()) {
170 auto Inst
= dyn_cast
<CallInst
>(Use
.getUser());
175 Function
*F
= Inst
->getCalledFunction();
179 Value
*Replacement
= nullptr;
180 StringRef Name
= F
->getName();
181 if (Name
.startswith(GetImageResourceIDFunc
)) {
182 Replacement
= ConstantInt::get(Int32Type
, ResourceID
);
183 } else if (Name
.startswith(GetImageSizeFunc
)) {
184 Replacement
= &ImageSizeArg
;
185 } else if (Name
.startswith(GetImageFormatFunc
)) {
186 Replacement
= &ImageFormatArg
;
191 Inst
->replaceAllUsesWith(Replacement
);
192 InstsToErase
.push_back(Inst
);
199 bool replaceSamplerUses(Argument
&SamplerArg
, uint32_t ResourceID
) {
200 bool Modified
= false;
202 for (const auto &Use
: SamplerArg
.uses()) {
203 auto Inst
= dyn_cast
<CallInst
>(Use
.getUser());
208 Function
*F
= Inst
->getCalledFunction();
212 Value
*Replacement
= nullptr;
213 StringRef Name
= F
->getName();
214 if (Name
== GetSamplerResourceIDFunc
) {
215 Replacement
= ConstantInt::get(Int32Type
, ResourceID
);
220 Inst
->replaceAllUsesWith(Replacement
);
221 InstsToErase
.push_back(Inst
);
228 bool replaceImageAndSamplerUses(Function
*F
, MDNode
*KernelMDNode
) {
229 uint32_t NumReadOnlyImageArgs
= 0;
230 uint32_t NumWriteOnlyImageArgs
= 0;
231 uint32_t NumSamplerArgs
= 0;
233 bool Modified
= false;
234 InstsToErase
.clear();
235 for (auto ArgI
= F
->arg_begin(); ArgI
!= F
->arg_end(); ++ArgI
) {
236 Argument
&Arg
= *ArgI
;
237 StringRef Type
= ArgTypeFromMD(KernelMDNode
, Arg
.getArgNo());
239 // Handle image types.
240 if (IsImageType(Type
)) {
241 StringRef AccessQual
= AccessQualFromMD(KernelMDNode
, Arg
.getArgNo());
243 if (AccessQual
== "read_only") {
244 ResourceID
= NumReadOnlyImageArgs
++;
245 } else if (AccessQual
== "write_only") {
246 ResourceID
= NumWriteOnlyImageArgs
++;
248 llvm_unreachable("Wrong image access qualifier.");
251 Argument
&SizeArg
= *(++ArgI
);
252 Argument
&FormatArg
= *(++ArgI
);
253 Modified
|= replaceImageUses(Arg
, ResourceID
, SizeArg
, FormatArg
);
255 // Handle sampler type.
256 } else if (IsSamplerType(Type
)) {
257 uint32_t ResourceID
= NumSamplerArgs
++;
258 Modified
|= replaceSamplerUses(Arg
, ResourceID
);
261 for (unsigned i
= 0; i
< InstsToErase
.size(); ++i
) {
262 InstsToErase
[i
]->eraseFromParent();
268 std::tuple
<Function
*, MDNode
*>
269 addImplicitArgs(Function
*F
, MDNode
*KernelMDNode
) {
270 bool Modified
= false;
272 FunctionType
*FT
= F
->getFunctionType();
273 SmallVector
<Type
*, 8> ArgTypes
;
275 // Metadata operands for new MDNode.
276 KernelArgMD NewArgMDs
;
277 PushArgMD(NewArgMDs
, GetArgMD(KernelMDNode
, 0));
279 // Add implicit arguments to the signature.
280 for (unsigned i
= 0; i
< FT
->getNumParams(); ++i
) {
281 ArgTypes
.push_back(FT
->getParamType(i
));
282 MDVector ArgMD
= GetArgMD(KernelMDNode
, i
+ 1);
283 PushArgMD(NewArgMDs
, ArgMD
);
285 if (!IsImageType(ArgTypeFromMD(KernelMDNode
, i
)))
288 // Add size implicit argument.
289 ArgTypes
.push_back(ImageSizeType
);
290 ArgMD
[2] = ArgMD
[3] = MDString::get(*Context
, ImageSizeArgMDType
);
291 PushArgMD(NewArgMDs
, ArgMD
);
293 // Add format implicit argument.
294 ArgTypes
.push_back(ImageFormatType
);
295 ArgMD
[2] = ArgMD
[3] = MDString::get(*Context
, ImageFormatArgMDType
);
296 PushArgMD(NewArgMDs
, ArgMD
);
301 return std::make_tuple(nullptr, nullptr);
304 // Create function with new signature and clone the old body into it.
305 auto NewFT
= FunctionType::get(FT
->getReturnType(), ArgTypes
, false);
306 auto NewF
= Function::Create(NewFT
, F
->getLinkage(), F
->getName());
307 ValueToValueMapTy VMap
;
308 auto NewFArgIt
= NewF
->arg_begin();
309 for (auto &Arg
: F
->args()) {
310 auto ArgName
= Arg
.getName();
311 NewFArgIt
->setName(ArgName
);
312 VMap
[&Arg
] = &(*NewFArgIt
++);
313 if (IsImageType(ArgTypeFromMD(KernelMDNode
, Arg
.getArgNo()))) {
314 (NewFArgIt
++)->setName(Twine("__size_") + ArgName
);
315 (NewFArgIt
++)->setName(Twine("__format_") + ArgName
);
318 SmallVector
<ReturnInst
*, 8> Returns
;
319 CloneFunctionInto(NewF
, F
, VMap
, /*ModuleLevelChanges=*/false, Returns
);
322 SmallVector
<Metadata
*, 6> KernelMDArgs
;
323 KernelMDArgs
.push_back(ConstantAsMetadata::get(NewF
));
324 for (unsigned i
= 0; i
< NumKernelArgMDNodes
; ++i
)
325 KernelMDArgs
.push_back(MDNode::get(*Context
, NewArgMDs
.ArgVector
[i
]));
326 MDNode
*NewMDNode
= MDNode::get(*Context
, KernelMDArgs
);
328 return std::make_tuple(NewF
, NewMDNode
);
331 bool transformKernels(Module
&M
) {
332 NamedMDNode
*KernelsMDNode
= M
.getNamedMetadata(KernelsMDNodeName
);
336 bool Modified
= false;
337 for (unsigned i
= 0; i
< KernelsMDNode
->getNumOperands(); ++i
) {
338 MDNode
*KernelMDNode
= KernelsMDNode
->getOperand(i
);
339 Function
*F
= GetFunctionFromMDNode(KernelMDNode
);
345 std::tie(NewF
, NewMDNode
) = addImplicitArgs(F
, KernelMDNode
);
347 // Replace old function and metadata with new ones.
348 F
->eraseFromParent();
349 M
.getFunctionList().push_back(NewF
);
350 M
.getOrInsertFunction(NewF
->getName(), NewF
->getFunctionType(),
351 NewF
->getAttributes());
352 KernelsMDNode
->setOperand(i
, NewMDNode
);
355 KernelMDNode
= NewMDNode
;
359 Modified
|= replaceImageAndSamplerUses(F
, KernelMDNode
);
366 R600OpenCLImageTypeLoweringPass() : ModulePass(ID
) {}
368 bool runOnModule(Module
&M
) override
{
369 Context
= &M
.getContext();
370 Int32Type
= Type::getInt32Ty(M
.getContext());
371 ImageSizeType
= ArrayType::get(Int32Type
, 3);
372 ImageFormatType
= ArrayType::get(Int32Type
, 2);
374 return transformKernels(M
);
377 StringRef
getPassName() const override
{
378 return "R600 OpenCL Image Type Pass";
382 } // end anonymous namespace
384 char R600OpenCLImageTypeLoweringPass::ID
= 0;
386 ModulePass
*llvm::createR600OpenCLImageTypeLoweringPass() {
387 return new R600OpenCLImageTypeLoweringPass();