1 //===- NVPTXUtilities.cpp - Utility Functions -----------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file contains miscellaneous utility functions
11 //===----------------------------------------------------------------------===//
13 #include "NVPTXUtilities.h"
15 #include "llvm/IR/Constants.h"
16 #include "llvm/IR/Function.h"
17 #include "llvm/IR/GlobalVariable.h"
18 #include "llvm/IR/InstIterator.h"
19 #include "llvm/IR/Module.h"
20 #include "llvm/IR/Operator.h"
21 #include "llvm/Support/ManagedStatic.h"
22 #include "llvm/Support/Mutex.h"
33 typedef std::map
<std::string
, std::vector
<unsigned> > key_val_pair_t
;
34 typedef std::map
<const GlobalValue
*, key_val_pair_t
> global_val_annot_t
;
35 typedef std::map
<const Module
*, global_val_annot_t
> per_module_annot_t
;
36 } // anonymous namespace
38 static ManagedStatic
<per_module_annot_t
> annotationCache
;
39 static sys::Mutex Lock
;
41 void clearAnnotationCache(const Module
*Mod
) {
42 std::lock_guard
<sys::Mutex
> Guard(Lock
);
43 annotationCache
->erase(Mod
);
46 static void cacheAnnotationFromMD(const MDNode
*md
, key_val_pair_t
&retval
) {
47 std::lock_guard
<sys::Mutex
> Guard(Lock
);
48 assert(md
&& "Invalid mdnode for annotation");
49 assert((md
->getNumOperands() % 2) == 1 && "Invalid number of operands");
50 // start index = 1, to skip the global variable key
51 // increment = 2, to skip the value for each property-value pairs
52 for (unsigned i
= 1, e
= md
->getNumOperands(); i
!= e
; i
+= 2) {
54 const MDString
*prop
= dyn_cast
<MDString
>(md
->getOperand(i
));
55 assert(prop
&& "Annotation property not a string");
58 ConstantInt
*Val
= mdconst::dyn_extract
<ConstantInt
>(md
->getOperand(i
+ 1));
59 assert(Val
&& "Value operand not a constant int");
61 std::string keyname
= prop
->getString().str();
62 if (retval
.find(keyname
) != retval
.end())
63 retval
[keyname
].push_back(Val
->getZExtValue());
65 std::vector
<unsigned> tmp
;
66 tmp
.push_back(Val
->getZExtValue());
67 retval
[keyname
] = tmp
;
72 static void cacheAnnotationFromMD(const Module
*m
, const GlobalValue
*gv
) {
73 std::lock_guard
<sys::Mutex
> Guard(Lock
);
74 NamedMDNode
*NMD
= m
->getNamedMetadata("nvvm.annotations");
78 for (unsigned i
= 0, e
= NMD
->getNumOperands(); i
!= e
; ++i
) {
79 const MDNode
*elem
= NMD
->getOperand(i
);
82 mdconst::dyn_extract_or_null
<GlobalValue
>(elem
->getOperand(0));
83 // entity may be null due to DCE
89 // accumulate annotations for entity in tmp
90 cacheAnnotationFromMD(elem
, tmp
);
93 if (tmp
.empty()) // no annotations for this gv
96 if ((*annotationCache
).find(m
) != (*annotationCache
).end())
97 (*annotationCache
)[m
][gv
] = std::move(tmp
);
99 global_val_annot_t tmp1
;
100 tmp1
[gv
] = std::move(tmp
);
101 (*annotationCache
)[m
] = std::move(tmp1
);
105 bool findOneNVVMAnnotation(const GlobalValue
*gv
, const std::string
&prop
,
107 std::lock_guard
<sys::Mutex
> Guard(Lock
);
108 const Module
*m
= gv
->getParent();
109 if ((*annotationCache
).find(m
) == (*annotationCache
).end())
110 cacheAnnotationFromMD(m
, gv
);
111 else if ((*annotationCache
)[m
].find(gv
) == (*annotationCache
)[m
].end())
112 cacheAnnotationFromMD(m
, gv
);
113 if ((*annotationCache
)[m
][gv
].find(prop
) == (*annotationCache
)[m
][gv
].end())
115 retval
= (*annotationCache
)[m
][gv
][prop
][0];
119 bool findAllNVVMAnnotation(const GlobalValue
*gv
, const std::string
&prop
,
120 std::vector
<unsigned> &retval
) {
121 std::lock_guard
<sys::Mutex
> Guard(Lock
);
122 const Module
*m
= gv
->getParent();
123 if ((*annotationCache
).find(m
) == (*annotationCache
).end())
124 cacheAnnotationFromMD(m
, gv
);
125 else if ((*annotationCache
)[m
].find(gv
) == (*annotationCache
)[m
].end())
126 cacheAnnotationFromMD(m
, gv
);
127 if ((*annotationCache
)[m
][gv
].find(prop
) == (*annotationCache
)[m
][gv
].end())
129 retval
= (*annotationCache
)[m
][gv
][prop
];
133 bool isTexture(const Value
&val
) {
134 if (const GlobalValue
*gv
= dyn_cast
<GlobalValue
>(&val
)) {
136 if (findOneNVVMAnnotation(gv
, "texture", annot
)) {
137 assert((annot
== 1) && "Unexpected annotation on a texture symbol");
144 bool isSurface(const Value
&val
) {
145 if (const GlobalValue
*gv
= dyn_cast
<GlobalValue
>(&val
)) {
147 if (findOneNVVMAnnotation(gv
, "surface", annot
)) {
148 assert((annot
== 1) && "Unexpected annotation on a surface symbol");
155 bool isSampler(const Value
&val
) {
156 const char *AnnotationName
= "sampler";
158 if (const GlobalValue
*gv
= dyn_cast
<GlobalValue
>(&val
)) {
160 if (findOneNVVMAnnotation(gv
, AnnotationName
, annot
)) {
161 assert((annot
== 1) && "Unexpected annotation on a sampler symbol");
165 if (const Argument
*arg
= dyn_cast
<Argument
>(&val
)) {
166 const Function
*func
= arg
->getParent();
167 std::vector
<unsigned> annot
;
168 if (findAllNVVMAnnotation(func
, AnnotationName
, annot
)) {
169 if (is_contained(annot
, arg
->getArgNo()))
176 bool isImageReadOnly(const Value
&val
) {
177 if (const Argument
*arg
= dyn_cast
<Argument
>(&val
)) {
178 const Function
*func
= arg
->getParent();
179 std::vector
<unsigned> annot
;
180 if (findAllNVVMAnnotation(func
, "rdoimage", annot
)) {
181 if (is_contained(annot
, arg
->getArgNo()))
188 bool isImageWriteOnly(const Value
&val
) {
189 if (const Argument
*arg
= dyn_cast
<Argument
>(&val
)) {
190 const Function
*func
= arg
->getParent();
191 std::vector
<unsigned> annot
;
192 if (findAllNVVMAnnotation(func
, "wroimage", annot
)) {
193 if (is_contained(annot
, arg
->getArgNo()))
200 bool isImageReadWrite(const Value
&val
) {
201 if (const Argument
*arg
= dyn_cast
<Argument
>(&val
)) {
202 const Function
*func
= arg
->getParent();
203 std::vector
<unsigned> annot
;
204 if (findAllNVVMAnnotation(func
, "rdwrimage", annot
)) {
205 if (is_contained(annot
, arg
->getArgNo()))
212 bool isImage(const Value
&val
) {
213 return isImageReadOnly(val
) || isImageWriteOnly(val
) || isImageReadWrite(val
);
216 bool isManaged(const Value
&val
) {
217 if(const GlobalValue
*gv
= dyn_cast
<GlobalValue
>(&val
)) {
219 if (findOneNVVMAnnotation(gv
, "managed", annot
)) {
220 assert((annot
== 1) && "Unexpected annotation on a managed symbol");
227 std::string
getTextureName(const Value
&val
) {
228 assert(val
.hasName() && "Found texture variable with no name");
229 return std::string(val
.getName());
232 std::string
getSurfaceName(const Value
&val
) {
233 assert(val
.hasName() && "Found surface variable with no name");
234 return std::string(val
.getName());
237 std::string
getSamplerName(const Value
&val
) {
238 assert(val
.hasName() && "Found sampler variable with no name");
239 return std::string(val
.getName());
242 bool getMaxNTIDx(const Function
&F
, unsigned &x
) {
243 return findOneNVVMAnnotation(&F
, "maxntidx", x
);
246 bool getMaxNTIDy(const Function
&F
, unsigned &y
) {
247 return findOneNVVMAnnotation(&F
, "maxntidy", y
);
250 bool getMaxNTIDz(const Function
&F
, unsigned &z
) {
251 return findOneNVVMAnnotation(&F
, "maxntidz", z
);
254 bool getReqNTIDx(const Function
&F
, unsigned &x
) {
255 return findOneNVVMAnnotation(&F
, "reqntidx", x
);
258 bool getReqNTIDy(const Function
&F
, unsigned &y
) {
259 return findOneNVVMAnnotation(&F
, "reqntidy", y
);
262 bool getReqNTIDz(const Function
&F
, unsigned &z
) {
263 return findOneNVVMAnnotation(&F
, "reqntidz", z
);
266 bool getMinCTASm(const Function
&F
, unsigned &x
) {
267 return findOneNVVMAnnotation(&F
, "minctasm", x
);
270 bool getMaxNReg(const Function
&F
, unsigned &x
) {
271 return findOneNVVMAnnotation(&F
, "maxnreg", x
);
274 bool isKernelFunction(const Function
&F
) {
276 bool retval
= findOneNVVMAnnotation(&F
, "kernel", x
);
278 // There is no NVVM metadata, check the calling convention
279 return F
.getCallingConv() == CallingConv::PTX_Kernel
;
284 bool getAlign(const Function
&F
, unsigned index
, unsigned &align
) {
285 std::vector
<unsigned> Vs
;
286 bool retval
= findAllNVVMAnnotation(&F
, "align", Vs
);
289 for (int i
= 0, e
= Vs
.size(); i
< e
; i
++) {
291 if ((v
>> 16) == index
) {
299 bool getAlign(const CallInst
&I
, unsigned index
, unsigned &align
) {
300 if (MDNode
*alignNode
= I
.getMetadata("callalign")) {
301 for (int i
= 0, n
= alignNode
->getNumOperands(); i
< n
; i
++) {
302 if (const ConstantInt
*CI
=
303 mdconst::dyn_extract
<ConstantInt
>(alignNode
->getOperand(i
))) {
304 unsigned v
= CI
->getZExtValue();
305 if ((v
>> 16) == index
) {
309 if ((v
>> 16) > index
) {