[C++20] [Modules] Fix may-be incorrect ADL for module local entities (#123931)
[llvm-project.git] / llvm / lib / Target / NVPTX / MCTargetDesc / NVPTXInstPrinter.cpp
blobd34f45fcac0087509faed8bbcda82746941c892b
1 //===-- NVPTXInstPrinter.cpp - PTX assembly instruction printing ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Print MCInst instructions to .ptx format.
11 //===----------------------------------------------------------------------===//
13 #include "MCTargetDesc/NVPTXInstPrinter.h"
14 #include "NVPTX.h"
15 #include "NVPTXUtilities.h"
16 #include "llvm/ADT/StringRef.h"
17 #include "llvm/IR/NVVMIntrinsicUtils.h"
18 #include "llvm/MC/MCExpr.h"
19 #include "llvm/MC/MCInst.h"
20 #include "llvm/MC/MCInstrInfo.h"
21 #include "llvm/MC/MCSubtargetInfo.h"
22 #include "llvm/MC/MCSymbol.h"
23 #include "llvm/Support/ErrorHandling.h"
24 #include "llvm/Support/FormatVariadic.h"
25 #include <cctype>
26 using namespace llvm;
28 #define DEBUG_TYPE "asm-printer"
30 #include "NVPTXGenAsmWriter.inc"
32 NVPTXInstPrinter::NVPTXInstPrinter(const MCAsmInfo &MAI, const MCInstrInfo &MII,
33 const MCRegisterInfo &MRI)
34 : MCInstPrinter(MAI, MII, MRI) {}
36 void NVPTXInstPrinter::printRegName(raw_ostream &OS, MCRegister Reg) {
37 // Decode the virtual register
38 // Must be kept in sync with NVPTXAsmPrinter::encodeVirtualRegister
39 unsigned RCId = (Reg.id() >> 28);
40 switch (RCId) {
41 default: report_fatal_error("Bad virtual register encoding");
42 case 0:
43 // This is actually a physical register, so defer to the autogenerated
44 // register printer
45 OS << getRegisterName(Reg);
46 return;
47 case 1:
48 OS << "%p";
49 break;
50 case 2:
51 OS << "%rs";
52 break;
53 case 3:
54 OS << "%r";
55 break;
56 case 4:
57 OS << "%rd";
58 break;
59 case 5:
60 OS << "%f";
61 break;
62 case 6:
63 OS << "%fd";
64 break;
65 case 7:
66 OS << "%rq";
67 break;
70 unsigned VReg = Reg.id() & 0x0FFFFFFF;
71 OS << VReg;
74 void NVPTXInstPrinter::printInst(const MCInst *MI, uint64_t Address,
75 StringRef Annot, const MCSubtargetInfo &STI,
76 raw_ostream &OS) {
77 printInstruction(MI, Address, OS);
79 // Next always print the annotation.
80 printAnnotation(OS, Annot);
83 void NVPTXInstPrinter::printOperand(const MCInst *MI, unsigned OpNo,
84 raw_ostream &O) {
85 const MCOperand &Op = MI->getOperand(OpNo);
86 if (Op.isReg()) {
87 unsigned Reg = Op.getReg();
88 printRegName(O, Reg);
89 } else if (Op.isImm()) {
90 markup(O, Markup::Immediate) << formatImm(Op.getImm());
91 } else {
92 assert(Op.isExpr() && "Unknown operand kind in printOperand");
93 Op.getExpr()->print(O, &MAI);
97 void NVPTXInstPrinter::printCvtMode(const MCInst *MI, int OpNum, raw_ostream &O,
98 const char *M) {
99 const MCOperand &MO = MI->getOperand(OpNum);
100 int64_t Imm = MO.getImm();
101 llvm::StringRef Modifier(M);
103 if (Modifier == "ftz") {
104 // FTZ flag
105 if (Imm & NVPTX::PTXCvtMode::FTZ_FLAG)
106 O << ".ftz";
107 return;
108 } else if (Modifier == "sat") {
109 // SAT flag
110 if (Imm & NVPTX::PTXCvtMode::SAT_FLAG)
111 O << ".sat";
112 return;
113 } else if (Modifier == "relu") {
114 // RELU flag
115 if (Imm & NVPTX::PTXCvtMode::RELU_FLAG)
116 O << ".relu";
117 return;
118 } else if (Modifier == "base") {
119 // Default operand
120 switch (Imm & NVPTX::PTXCvtMode::BASE_MASK) {
121 default:
122 return;
123 case NVPTX::PTXCvtMode::NONE:
124 return;
125 case NVPTX::PTXCvtMode::RNI:
126 O << ".rni";
127 return;
128 case NVPTX::PTXCvtMode::RZI:
129 O << ".rzi";
130 return;
131 case NVPTX::PTXCvtMode::RMI:
132 O << ".rmi";
133 return;
134 case NVPTX::PTXCvtMode::RPI:
135 O << ".rpi";
136 return;
137 case NVPTX::PTXCvtMode::RN:
138 O << ".rn";
139 return;
140 case NVPTX::PTXCvtMode::RZ:
141 O << ".rz";
142 return;
143 case NVPTX::PTXCvtMode::RM:
144 O << ".rm";
145 return;
146 case NVPTX::PTXCvtMode::RP:
147 O << ".rp";
148 return;
149 case NVPTX::PTXCvtMode::RNA:
150 O << ".rna";
151 return;
154 llvm_unreachable("Invalid conversion modifier");
157 void NVPTXInstPrinter::printCmpMode(const MCInst *MI, int OpNum, raw_ostream &O,
158 const char *M) {
159 const MCOperand &MO = MI->getOperand(OpNum);
160 int64_t Imm = MO.getImm();
161 llvm::StringRef Modifier(M);
163 if (Modifier == "ftz") {
164 // FTZ flag
165 if (Imm & NVPTX::PTXCmpMode::FTZ_FLAG)
166 O << ".ftz";
167 return;
168 } else if (Modifier == "base") {
169 switch (Imm & NVPTX::PTXCmpMode::BASE_MASK) {
170 default:
171 return;
172 case NVPTX::PTXCmpMode::EQ:
173 O << ".eq";
174 return;
175 case NVPTX::PTXCmpMode::NE:
176 O << ".ne";
177 return;
178 case NVPTX::PTXCmpMode::LT:
179 O << ".lt";
180 return;
181 case NVPTX::PTXCmpMode::LE:
182 O << ".le";
183 return;
184 case NVPTX::PTXCmpMode::GT:
185 O << ".gt";
186 return;
187 case NVPTX::PTXCmpMode::GE:
188 O << ".ge";
189 return;
190 case NVPTX::PTXCmpMode::LO:
191 O << ".lo";
192 return;
193 case NVPTX::PTXCmpMode::LS:
194 O << ".ls";
195 return;
196 case NVPTX::PTXCmpMode::HI:
197 O << ".hi";
198 return;
199 case NVPTX::PTXCmpMode::HS:
200 O << ".hs";
201 return;
202 case NVPTX::PTXCmpMode::EQU:
203 O << ".equ";
204 return;
205 case NVPTX::PTXCmpMode::NEU:
206 O << ".neu";
207 return;
208 case NVPTX::PTXCmpMode::LTU:
209 O << ".ltu";
210 return;
211 case NVPTX::PTXCmpMode::LEU:
212 O << ".leu";
213 return;
214 case NVPTX::PTXCmpMode::GTU:
215 O << ".gtu";
216 return;
217 case NVPTX::PTXCmpMode::GEU:
218 O << ".geu";
219 return;
220 case NVPTX::PTXCmpMode::NUM:
221 O << ".num";
222 return;
223 case NVPTX::PTXCmpMode::NotANumber:
224 O << ".nan";
225 return;
228 llvm_unreachable("Empty Modifier");
231 void NVPTXInstPrinter::printLdStCode(const MCInst *MI, int OpNum,
232 raw_ostream &O, const char *M) {
233 llvm::StringRef Modifier(M);
234 const MCOperand &MO = MI->getOperand(OpNum);
235 int Imm = (int)MO.getImm();
236 if (Modifier == "sem") {
237 auto Ordering = NVPTX::Ordering(Imm);
238 switch (Ordering) {
239 case NVPTX::Ordering::NotAtomic:
240 return;
241 case NVPTX::Ordering::Relaxed:
242 O << ".relaxed";
243 return;
244 case NVPTX::Ordering::Acquire:
245 O << ".acquire";
246 return;
247 case NVPTX::Ordering::Release:
248 O << ".release";
249 return;
250 case NVPTX::Ordering::Volatile:
251 O << ".volatile";
252 return;
253 case NVPTX::Ordering::RelaxedMMIO:
254 O << ".mmio.relaxed";
255 return;
256 default:
257 report_fatal_error(formatv(
258 "NVPTX LdStCode Printer does not support \"{}\" sem modifier. "
259 "Loads/Stores cannot be AcquireRelease or SequentiallyConsistent.",
260 OrderingToString(Ordering)));
262 } else if (Modifier == "scope") {
263 auto S = NVPTX::Scope(Imm);
264 switch (S) {
265 case NVPTX::Scope::Thread:
266 return;
267 case NVPTX::Scope::System:
268 O << ".sys";
269 return;
270 case NVPTX::Scope::Block:
271 O << ".cta";
272 return;
273 case NVPTX::Scope::Cluster:
274 O << ".cluster";
275 return;
276 case NVPTX::Scope::Device:
277 O << ".gpu";
278 return;
280 report_fatal_error(
281 formatv("NVPTX LdStCode Printer does not support \"{}\" sco modifier.",
282 ScopeToString(S)));
283 } else if (Modifier == "addsp") {
284 auto A = NVPTX::AddressSpace(Imm);
285 switch (A) {
286 case NVPTX::AddressSpace::Generic:
287 return;
288 case NVPTX::AddressSpace::Global:
289 case NVPTX::AddressSpace::Const:
290 case NVPTX::AddressSpace::Shared:
291 case NVPTX::AddressSpace::Param:
292 case NVPTX::AddressSpace::Local:
293 O << "." << A;
294 return;
296 report_fatal_error(formatv(
297 "NVPTX LdStCode Printer does not support \"{}\" addsp modifier.",
298 AddressSpaceToString(A)));
299 } else if (Modifier == "sign") {
300 switch (Imm) {
301 case NVPTX::PTXLdStInstCode::Signed:
302 O << "s";
303 return;
304 case NVPTX::PTXLdStInstCode::Unsigned:
305 O << "u";
306 return;
307 case NVPTX::PTXLdStInstCode::Untyped:
308 O << "b";
309 return;
310 case NVPTX::PTXLdStInstCode::Float:
311 O << "f";
312 return;
313 default:
314 llvm_unreachable("Unknown register type");
316 } else if (Modifier == "vec") {
317 switch (Imm) {
318 case NVPTX::PTXLdStInstCode::V2:
319 O << ".v2";
320 return;
321 case NVPTX::PTXLdStInstCode::V4:
322 O << ".v4";
323 return;
325 // TODO: evaluate whether cases not covered by this switch are bugs
326 return;
328 llvm_unreachable(formatv("Unknown Modifier: {}", Modifier).str().c_str());
331 void NVPTXInstPrinter::printMmaCode(const MCInst *MI, int OpNum, raw_ostream &O,
332 const char *M) {
333 const MCOperand &MO = MI->getOperand(OpNum);
334 int Imm = (int)MO.getImm();
335 llvm::StringRef Modifier(M);
336 if (Modifier.empty() || Modifier == "version") {
337 O << Imm; // Just print out PTX version
338 return;
339 } else if (Modifier == "aligned") {
340 // PTX63 requires '.aligned' in the name of the instruction.
341 if (Imm >= 63)
342 O << ".aligned";
343 return;
345 llvm_unreachable("Unknown Modifier");
348 void NVPTXInstPrinter::printMemOperand(const MCInst *MI, int OpNum,
349 raw_ostream &O, const char *M) {
350 printOperand(MI, OpNum, O);
351 llvm::StringRef Modifier(M);
353 if (Modifier == "add") {
354 O << ", ";
355 printOperand(MI, OpNum + 1, O);
356 } else {
357 if (MI->getOperand(OpNum + 1).isImm() &&
358 MI->getOperand(OpNum + 1).getImm() == 0)
359 return; // don't print ',0' or '+0'
360 O << "+";
361 printOperand(MI, OpNum + 1, O);
365 void NVPTXInstPrinter::printOffseti32imm(const MCInst *MI, int OpNum,
366 raw_ostream &O, const char *Modifier) {
367 auto &Op = MI->getOperand(OpNum);
368 assert(Op.isImm() && "Invalid operand");
369 if (Op.getImm() != 0) {
370 O << "+";
371 printOperand(MI, OpNum, O);
375 void NVPTXInstPrinter::printHexu32imm(const MCInst *MI, int OpNum,
376 raw_ostream &O, const char *Modifier) {
377 int64_t Imm = MI->getOperand(OpNum).getImm();
378 O << formatHex(Imm) << "U";
381 void NVPTXInstPrinter::printProtoIdent(const MCInst *MI, int OpNum,
382 raw_ostream &O, const char *Modifier) {
383 const MCOperand &Op = MI->getOperand(OpNum);
384 assert(Op.isExpr() && "Call prototype is not an MCExpr?");
385 const MCExpr *Expr = Op.getExpr();
386 const MCSymbol &Sym = cast<MCSymbolRefExpr>(Expr)->getSymbol();
387 O << Sym.getName();
390 void NVPTXInstPrinter::printPrmtMode(const MCInst *MI, int OpNum,
391 raw_ostream &O, const char *Modifier) {
392 const MCOperand &MO = MI->getOperand(OpNum);
393 int64_t Imm = MO.getImm();
395 switch (Imm) {
396 default:
397 return;
398 case NVPTX::PTXPrmtMode::NONE:
399 return;
400 case NVPTX::PTXPrmtMode::F4E:
401 O << ".f4e";
402 return;
403 case NVPTX::PTXPrmtMode::B4E:
404 O << ".b4e";
405 return;
406 case NVPTX::PTXPrmtMode::RC8:
407 O << ".rc8";
408 return;
409 case NVPTX::PTXPrmtMode::ECL:
410 O << ".ecl";
411 return;
412 case NVPTX::PTXPrmtMode::ECR:
413 O << ".ecr";
414 return;
415 case NVPTX::PTXPrmtMode::RC16:
416 O << ".rc16";
417 return;
421 void NVPTXInstPrinter::printTmaReductionMode(const MCInst *MI, int OpNum,
422 raw_ostream &O,
423 const char *Modifier) {
424 const MCOperand &MO = MI->getOperand(OpNum);
425 using RedTy = llvm::nvvm::TMAReductionOp;
427 switch (static_cast<RedTy>(MO.getImm())) {
428 case RedTy::ADD:
429 O << ".add";
430 return;
431 case RedTy::MIN:
432 O << ".min";
433 return;
434 case RedTy::MAX:
435 O << ".max";
436 return;
437 case RedTy::INC:
438 O << ".inc";
439 return;
440 case RedTy::DEC:
441 O << ".dec";
442 return;
443 case RedTy::AND:
444 O << ".and";
445 return;
446 case RedTy::OR:
447 O << ".or";
448 return;
449 case RedTy::XOR:
450 O << ".xor";
451 return;
453 llvm_unreachable(
454 "Invalid Reduction Op in printCpAsyncBulkTensorReductionMode");