1 //===- NVPTXUtilities.cpp - Utility Functions -----------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file contains miscellaneous utility functions
11 //===----------------------------------------------------------------------===//
13 #include "NVPTXUtilities.h"
15 #include "llvm/IR/Constants.h"
16 #include "llvm/IR/Function.h"
17 #include "llvm/IR/GlobalVariable.h"
18 #include "llvm/IR/InstIterator.h"
19 #include "llvm/IR/Module.h"
20 #include "llvm/IR/Operator.h"
21 #include "llvm/Support/ManagedStatic.h"
22 #include "llvm/Support/MutexGuard.h"
32 typedef std::map
<std::string
, std::vector
<unsigned> > key_val_pair_t
;
33 typedef std::map
<const GlobalValue
*, key_val_pair_t
> global_val_annot_t
;
34 typedef std::map
<const Module
*, global_val_annot_t
> per_module_annot_t
;
35 } // anonymous namespace
37 static ManagedStatic
<per_module_annot_t
> annotationCache
;
38 static sys::Mutex Lock
;
40 void clearAnnotationCache(const Module
*Mod
) {
41 MutexGuard
Guard(Lock
);
42 annotationCache
->erase(Mod
);
45 static void cacheAnnotationFromMD(const MDNode
*md
, key_val_pair_t
&retval
) {
46 MutexGuard
Guard(Lock
);
47 assert(md
&& "Invalid mdnode for annotation");
48 assert((md
->getNumOperands() % 2) == 1 && "Invalid number of operands");
49 // start index = 1, to skip the global variable key
50 // increment = 2, to skip the value for each property-value pairs
51 for (unsigned i
= 1, e
= md
->getNumOperands(); i
!= e
; i
+= 2) {
53 const MDString
*prop
= dyn_cast
<MDString
>(md
->getOperand(i
));
54 assert(prop
&& "Annotation property not a string");
57 ConstantInt
*Val
= mdconst::dyn_extract
<ConstantInt
>(md
->getOperand(i
+ 1));
58 assert(Val
&& "Value operand not a constant int");
60 std::string keyname
= prop
->getString().str();
61 if (retval
.find(keyname
) != retval
.end())
62 retval
[keyname
].push_back(Val
->getZExtValue());
64 std::vector
<unsigned> tmp
;
65 tmp
.push_back(Val
->getZExtValue());
66 retval
[keyname
] = tmp
;
71 static void cacheAnnotationFromMD(const Module
*m
, const GlobalValue
*gv
) {
72 MutexGuard
Guard(Lock
);
73 NamedMDNode
*NMD
= m
->getNamedMetadata("nvvm.annotations");
77 for (unsigned i
= 0, e
= NMD
->getNumOperands(); i
!= e
; ++i
) {
78 const MDNode
*elem
= NMD
->getOperand(i
);
81 mdconst::dyn_extract_or_null
<GlobalValue
>(elem
->getOperand(0));
82 // entity may be null due to DCE
88 // accumulate annotations for entity in tmp
89 cacheAnnotationFromMD(elem
, tmp
);
92 if (tmp
.empty()) // no annotations for this gv
95 if ((*annotationCache
).find(m
) != (*annotationCache
).end())
96 (*annotationCache
)[m
][gv
] = std::move(tmp
);
98 global_val_annot_t tmp1
;
99 tmp1
[gv
] = std::move(tmp
);
100 (*annotationCache
)[m
] = std::move(tmp1
);
104 bool findOneNVVMAnnotation(const GlobalValue
*gv
, const std::string
&prop
,
106 MutexGuard
Guard(Lock
);
107 const Module
*m
= gv
->getParent();
108 if ((*annotationCache
).find(m
) == (*annotationCache
).end())
109 cacheAnnotationFromMD(m
, gv
);
110 else if ((*annotationCache
)[m
].find(gv
) == (*annotationCache
)[m
].end())
111 cacheAnnotationFromMD(m
, gv
);
112 if ((*annotationCache
)[m
][gv
].find(prop
) == (*annotationCache
)[m
][gv
].end())
114 retval
= (*annotationCache
)[m
][gv
][prop
][0];
118 bool findAllNVVMAnnotation(const GlobalValue
*gv
, const std::string
&prop
,
119 std::vector
<unsigned> &retval
) {
120 MutexGuard
Guard(Lock
);
121 const Module
*m
= gv
->getParent();
122 if ((*annotationCache
).find(m
) == (*annotationCache
).end())
123 cacheAnnotationFromMD(m
, gv
);
124 else if ((*annotationCache
)[m
].find(gv
) == (*annotationCache
)[m
].end())
125 cacheAnnotationFromMD(m
, gv
);
126 if ((*annotationCache
)[m
][gv
].find(prop
) == (*annotationCache
)[m
][gv
].end())
128 retval
= (*annotationCache
)[m
][gv
][prop
];
132 bool isTexture(const Value
&val
) {
133 if (const GlobalValue
*gv
= dyn_cast
<GlobalValue
>(&val
)) {
135 if (findOneNVVMAnnotation(gv
, "texture", annot
)) {
136 assert((annot
== 1) && "Unexpected annotation on a texture symbol");
143 bool isSurface(const Value
&val
) {
144 if (const GlobalValue
*gv
= dyn_cast
<GlobalValue
>(&val
)) {
146 if (findOneNVVMAnnotation(gv
, "surface", annot
)) {
147 assert((annot
== 1) && "Unexpected annotation on a surface symbol");
154 bool isSampler(const Value
&val
) {
155 const char *AnnotationName
= "sampler";
157 if (const GlobalValue
*gv
= dyn_cast
<GlobalValue
>(&val
)) {
159 if (findOneNVVMAnnotation(gv
, AnnotationName
, annot
)) {
160 assert((annot
== 1) && "Unexpected annotation on a sampler symbol");
164 if (const Argument
*arg
= dyn_cast
<Argument
>(&val
)) {
165 const Function
*func
= arg
->getParent();
166 std::vector
<unsigned> annot
;
167 if (findAllNVVMAnnotation(func
, AnnotationName
, annot
)) {
168 if (is_contained(annot
, arg
->getArgNo()))
175 bool isImageReadOnly(const Value
&val
) {
176 if (const Argument
*arg
= dyn_cast
<Argument
>(&val
)) {
177 const Function
*func
= arg
->getParent();
178 std::vector
<unsigned> annot
;
179 if (findAllNVVMAnnotation(func
, "rdoimage", annot
)) {
180 if (is_contained(annot
, arg
->getArgNo()))
187 bool isImageWriteOnly(const Value
&val
) {
188 if (const Argument
*arg
= dyn_cast
<Argument
>(&val
)) {
189 const Function
*func
= arg
->getParent();
190 std::vector
<unsigned> annot
;
191 if (findAllNVVMAnnotation(func
, "wroimage", annot
)) {
192 if (is_contained(annot
, arg
->getArgNo()))
199 bool isImageReadWrite(const Value
&val
) {
200 if (const Argument
*arg
= dyn_cast
<Argument
>(&val
)) {
201 const Function
*func
= arg
->getParent();
202 std::vector
<unsigned> annot
;
203 if (findAllNVVMAnnotation(func
, "rdwrimage", annot
)) {
204 if (is_contained(annot
, arg
->getArgNo()))
211 bool isImage(const Value
&val
) {
212 return isImageReadOnly(val
) || isImageWriteOnly(val
) || isImageReadWrite(val
);
215 bool isManaged(const Value
&val
) {
216 if(const GlobalValue
*gv
= dyn_cast
<GlobalValue
>(&val
)) {
218 if (findOneNVVMAnnotation(gv
, "managed", annot
)) {
219 assert((annot
== 1) && "Unexpected annotation on a managed symbol");
226 std::string
getTextureName(const Value
&val
) {
227 assert(val
.hasName() && "Found texture variable with no name");
228 return val
.getName();
231 std::string
getSurfaceName(const Value
&val
) {
232 assert(val
.hasName() && "Found surface variable with no name");
233 return val
.getName();
236 std::string
getSamplerName(const Value
&val
) {
237 assert(val
.hasName() && "Found sampler variable with no name");
238 return val
.getName();
241 bool getMaxNTIDx(const Function
&F
, unsigned &x
) {
242 return findOneNVVMAnnotation(&F
, "maxntidx", x
);
245 bool getMaxNTIDy(const Function
&F
, unsigned &y
) {
246 return findOneNVVMAnnotation(&F
, "maxntidy", y
);
249 bool getMaxNTIDz(const Function
&F
, unsigned &z
) {
250 return findOneNVVMAnnotation(&F
, "maxntidz", z
);
253 bool getReqNTIDx(const Function
&F
, unsigned &x
) {
254 return findOneNVVMAnnotation(&F
, "reqntidx", x
);
257 bool getReqNTIDy(const Function
&F
, unsigned &y
) {
258 return findOneNVVMAnnotation(&F
, "reqntidy", y
);
261 bool getReqNTIDz(const Function
&F
, unsigned &z
) {
262 return findOneNVVMAnnotation(&F
, "reqntidz", z
);
265 bool getMinCTASm(const Function
&F
, unsigned &x
) {
266 return findOneNVVMAnnotation(&F
, "minctasm", x
);
269 bool getMaxNReg(const Function
&F
, unsigned &x
) {
270 return findOneNVVMAnnotation(&F
, "maxnreg", x
);
273 bool isKernelFunction(const Function
&F
) {
275 bool retval
= findOneNVVMAnnotation(&F
, "kernel", x
);
277 // There is no NVVM metadata, check the calling convention
278 return F
.getCallingConv() == CallingConv::PTX_Kernel
;
283 bool getAlign(const Function
&F
, unsigned index
, unsigned &align
) {
284 std::vector
<unsigned> Vs
;
285 bool retval
= findAllNVVMAnnotation(&F
, "align", Vs
);
288 for (int i
= 0, e
= Vs
.size(); i
< e
; i
++) {
290 if ((v
>> 16) == index
) {
298 bool getAlign(const CallInst
&I
, unsigned index
, unsigned &align
) {
299 if (MDNode
*alignNode
= I
.getMetadata("callalign")) {
300 for (int i
= 0, n
= alignNode
->getNumOperands(); i
< n
; i
++) {
301 if (const ConstantInt
*CI
=
302 mdconst::dyn_extract
<ConstantInt
>(alignNode
->getOperand(i
))) {
303 unsigned v
= CI
->getZExtValue();
304 if ((v
>> 16) == index
) {
308 if ((v
>> 16) > index
) {