[ELF] Avoid make in elf::writeARMCmseImportLib
[llvm-project.git] / llvm / lib / Target / NVPTX / NVPTXAsmPrinter.cpp
blob7cac4d787778f2f6536d4f0c4fa1be347d0af5be
1 //===-- NVPTXAsmPrinter.cpp - NVPTX LLVM assembly writer ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains a printer that converts from our internal representation
10 // of machine-dependent LLVM code to NVPTX assembly language.
12 //===----------------------------------------------------------------------===//
14 #include "NVPTXAsmPrinter.h"
15 #include "MCTargetDesc/NVPTXBaseInfo.h"
16 #include "MCTargetDesc/NVPTXInstPrinter.h"
17 #include "MCTargetDesc/NVPTXMCAsmInfo.h"
18 #include "MCTargetDesc/NVPTXTargetStreamer.h"
19 #include "NVPTX.h"
20 #include "NVPTXMCExpr.h"
21 #include "NVPTXMachineFunctionInfo.h"
22 #include "NVPTXRegisterInfo.h"
23 #include "NVPTXSubtarget.h"
24 #include "NVPTXTargetMachine.h"
25 #include "NVPTXUtilities.h"
26 #include "TargetInfo/NVPTXTargetInfo.h"
27 #include "cl_common_defines.h"
28 #include "llvm/ADT/APFloat.h"
29 #include "llvm/ADT/APInt.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/DenseSet.h"
32 #include "llvm/ADT/SmallString.h"
33 #include "llvm/ADT/SmallVector.h"
34 #include "llvm/ADT/StringExtras.h"
35 #include "llvm/ADT/StringRef.h"
36 #include "llvm/ADT/Twine.h"
37 #include "llvm/Analysis/ConstantFolding.h"
38 #include "llvm/CodeGen/Analysis.h"
39 #include "llvm/CodeGen/MachineBasicBlock.h"
40 #include "llvm/CodeGen/MachineFrameInfo.h"
41 #include "llvm/CodeGen/MachineFunction.h"
42 #include "llvm/CodeGen/MachineInstr.h"
43 #include "llvm/CodeGen/MachineLoopInfo.h"
44 #include "llvm/CodeGen/MachineModuleInfo.h"
45 #include "llvm/CodeGen/MachineOperand.h"
46 #include "llvm/CodeGen/MachineRegisterInfo.h"
47 #include "llvm/CodeGen/TargetRegisterInfo.h"
48 #include "llvm/CodeGen/ValueTypes.h"
49 #include "llvm/CodeGenTypes/MachineValueType.h"
50 #include "llvm/IR/Attributes.h"
51 #include "llvm/IR/BasicBlock.h"
52 #include "llvm/IR/Constant.h"
53 #include "llvm/IR/Constants.h"
54 #include "llvm/IR/DataLayout.h"
55 #include "llvm/IR/DebugInfo.h"
56 #include "llvm/IR/DebugInfoMetadata.h"
57 #include "llvm/IR/DebugLoc.h"
58 #include "llvm/IR/DerivedTypes.h"
59 #include "llvm/IR/Function.h"
60 #include "llvm/IR/GlobalAlias.h"
61 #include "llvm/IR/GlobalValue.h"
62 #include "llvm/IR/GlobalVariable.h"
63 #include "llvm/IR/Instruction.h"
64 #include "llvm/IR/LLVMContext.h"
65 #include "llvm/IR/Module.h"
66 #include "llvm/IR/Operator.h"
67 #include "llvm/IR/Type.h"
68 #include "llvm/IR/User.h"
69 #include "llvm/MC/MCExpr.h"
70 #include "llvm/MC/MCInst.h"
71 #include "llvm/MC/MCInstrDesc.h"
72 #include "llvm/MC/MCStreamer.h"
73 #include "llvm/MC/MCSymbol.h"
74 #include "llvm/MC/TargetRegistry.h"
75 #include "llvm/Support/Alignment.h"
76 #include "llvm/Support/Casting.h"
77 #include "llvm/Support/CommandLine.h"
78 #include "llvm/Support/Endian.h"
79 #include "llvm/Support/ErrorHandling.h"
80 #include "llvm/Support/NativeFormatting.h"
81 #include "llvm/Support/raw_ostream.h"
82 #include "llvm/Target/TargetLoweringObjectFile.h"
83 #include "llvm/Target/TargetMachine.h"
84 #include "llvm/Transforms/Utils/UnrollLoop.h"
85 #include <cassert>
86 #include <cstdint>
87 #include <cstring>
88 #include <string>
89 #include <utility>
90 #include <vector>
92 using namespace llvm;
94 static cl::opt<bool>
95 LowerCtorDtor("nvptx-lower-global-ctor-dtor",
96 cl::desc("Lower GPU ctor / dtors to globals on the device."),
97 cl::init(false), cl::Hidden);
99 #define DEPOTNAME "__local_depot"
101 /// DiscoverDependentGlobals - Return a set of GlobalVariables on which \p V
102 /// depends.
103 static void
104 DiscoverDependentGlobals(const Value *V,
105 DenseSet<const GlobalVariable *> &Globals) {
106 if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(V))
107 Globals.insert(GV);
108 else {
109 if (const User *U = dyn_cast<User>(V)) {
110 for (unsigned i = 0, e = U->getNumOperands(); i != e; ++i) {
111 DiscoverDependentGlobals(U->getOperand(i), Globals);
117 /// VisitGlobalVariableForEmission - Add \p GV to the list of GlobalVariable
118 /// instances to be emitted, but only after any dependents have been added
119 /// first.s
120 static void
121 VisitGlobalVariableForEmission(const GlobalVariable *GV,
122 SmallVectorImpl<const GlobalVariable *> &Order,
123 DenseSet<const GlobalVariable *> &Visited,
124 DenseSet<const GlobalVariable *> &Visiting) {
125 // Have we already visited this one?
126 if (Visited.count(GV))
127 return;
129 // Do we have a circular dependency?
130 if (!Visiting.insert(GV).second)
131 report_fatal_error("Circular dependency found in global variable set");
133 // Make sure we visit all dependents first
134 DenseSet<const GlobalVariable *> Others;
135 for (unsigned i = 0, e = GV->getNumOperands(); i != e; ++i)
136 DiscoverDependentGlobals(GV->getOperand(i), Others);
138 for (const GlobalVariable *GV : Others)
139 VisitGlobalVariableForEmission(GV, Order, Visited, Visiting);
141 // Now we can visit ourself
142 Order.push_back(GV);
143 Visited.insert(GV);
144 Visiting.erase(GV);
147 void NVPTXAsmPrinter::emitInstruction(const MachineInstr *MI) {
148 NVPTX_MC::verifyInstructionPredicates(MI->getOpcode(),
149 getSubtargetInfo().getFeatureBits());
151 MCInst Inst;
152 lowerToMCInst(MI, Inst);
153 EmitToStreamer(*OutStreamer, Inst);
156 // Handle symbol backtracking for targets that do not support image handles
157 bool NVPTXAsmPrinter::lowerImageHandleOperand(const MachineInstr *MI,
158 unsigned OpNo, MCOperand &MCOp) {
159 const MachineOperand &MO = MI->getOperand(OpNo);
160 const MCInstrDesc &MCID = MI->getDesc();
162 if (MCID.TSFlags & NVPTXII::IsTexFlag) {
163 // This is a texture fetch, so operand 4 is a texref and operand 5 is
164 // a samplerref
165 if (OpNo == 4 && MO.isImm()) {
166 lowerImageHandleSymbol(MO.getImm(), MCOp);
167 return true;
169 if (OpNo == 5 && MO.isImm() && !(MCID.TSFlags & NVPTXII::IsTexModeUnifiedFlag)) {
170 lowerImageHandleSymbol(MO.getImm(), MCOp);
171 return true;
174 return false;
175 } else if (MCID.TSFlags & NVPTXII::IsSuldMask) {
176 unsigned VecSize =
177 1 << (((MCID.TSFlags & NVPTXII::IsSuldMask) >> NVPTXII::IsSuldShift) - 1);
179 // For a surface load of vector size N, the Nth operand will be the surfref
180 if (OpNo == VecSize && MO.isImm()) {
181 lowerImageHandleSymbol(MO.getImm(), MCOp);
182 return true;
185 return false;
186 } else if (MCID.TSFlags & NVPTXII::IsSustFlag) {
187 // This is a surface store, so operand 0 is a surfref
188 if (OpNo == 0 && MO.isImm()) {
189 lowerImageHandleSymbol(MO.getImm(), MCOp);
190 return true;
193 return false;
194 } else if (MCID.TSFlags & NVPTXII::IsSurfTexQueryFlag) {
195 // This is a query, so operand 1 is a surfref/texref
196 if (OpNo == 1 && MO.isImm()) {
197 lowerImageHandleSymbol(MO.getImm(), MCOp);
198 return true;
201 return false;
204 return false;
207 void NVPTXAsmPrinter::lowerImageHandleSymbol(unsigned Index, MCOperand &MCOp) {
208 // Ewwww
209 TargetMachine &TM = const_cast<TargetMachine &>(MF->getTarget());
210 NVPTXTargetMachine &nvTM = static_cast<NVPTXTargetMachine &>(TM);
211 const NVPTXMachineFunctionInfo *MFI = MF->getInfo<NVPTXMachineFunctionInfo>();
212 const char *Sym = MFI->getImageHandleSymbol(Index);
213 StringRef SymName = nvTM.getStrPool().save(Sym);
214 MCOp = GetSymbolRef(OutContext.getOrCreateSymbol(SymName));
217 void NVPTXAsmPrinter::lowerToMCInst(const MachineInstr *MI, MCInst &OutMI) {
218 OutMI.setOpcode(MI->getOpcode());
219 // Special: Do not mangle symbol operand of CALL_PROTOTYPE
220 if (MI->getOpcode() == NVPTX::CALL_PROTOTYPE) {
221 const MachineOperand &MO = MI->getOperand(0);
222 OutMI.addOperand(GetSymbolRef(
223 OutContext.getOrCreateSymbol(Twine(MO.getSymbolName()))));
224 return;
227 const NVPTXSubtarget &STI = MI->getMF()->getSubtarget<NVPTXSubtarget>();
228 for (unsigned i = 0, e = MI->getNumOperands(); i != e; ++i) {
229 const MachineOperand &MO = MI->getOperand(i);
231 MCOperand MCOp;
232 if (!STI.hasImageHandles()) {
233 if (lowerImageHandleOperand(MI, i, MCOp)) {
234 OutMI.addOperand(MCOp);
235 continue;
239 if (lowerOperand(MO, MCOp))
240 OutMI.addOperand(MCOp);
244 bool NVPTXAsmPrinter::lowerOperand(const MachineOperand &MO,
245 MCOperand &MCOp) {
246 switch (MO.getType()) {
247 default: llvm_unreachable("unknown operand type");
248 case MachineOperand::MO_Register:
249 MCOp = MCOperand::createReg(encodeVirtualRegister(MO.getReg()));
250 break;
251 case MachineOperand::MO_Immediate:
252 MCOp = MCOperand::createImm(MO.getImm());
253 break;
254 case MachineOperand::MO_MachineBasicBlock:
255 MCOp = MCOperand::createExpr(MCSymbolRefExpr::create(
256 MO.getMBB()->getSymbol(), OutContext));
257 break;
258 case MachineOperand::MO_ExternalSymbol:
259 MCOp = GetSymbolRef(GetExternalSymbolSymbol(MO.getSymbolName()));
260 break;
261 case MachineOperand::MO_GlobalAddress:
262 MCOp = GetSymbolRef(getSymbol(MO.getGlobal()));
263 break;
264 case MachineOperand::MO_FPImmediate: {
265 const ConstantFP *Cnt = MO.getFPImm();
266 const APFloat &Val = Cnt->getValueAPF();
268 switch (Cnt->getType()->getTypeID()) {
269 default: report_fatal_error("Unsupported FP type"); break;
270 case Type::HalfTyID:
271 MCOp = MCOperand::createExpr(
272 NVPTXFloatMCExpr::createConstantFPHalf(Val, OutContext));
273 break;
274 case Type::BFloatTyID:
275 MCOp = MCOperand::createExpr(
276 NVPTXFloatMCExpr::createConstantBFPHalf(Val, OutContext));
277 break;
278 case Type::FloatTyID:
279 MCOp = MCOperand::createExpr(
280 NVPTXFloatMCExpr::createConstantFPSingle(Val, OutContext));
281 break;
282 case Type::DoubleTyID:
283 MCOp = MCOperand::createExpr(
284 NVPTXFloatMCExpr::createConstantFPDouble(Val, OutContext));
285 break;
287 break;
290 return true;
293 unsigned NVPTXAsmPrinter::encodeVirtualRegister(unsigned Reg) {
294 if (Register::isVirtualRegister(Reg)) {
295 const TargetRegisterClass *RC = MRI->getRegClass(Reg);
297 DenseMap<unsigned, unsigned> &RegMap = VRegMapping[RC];
298 unsigned RegNum = RegMap[Reg];
300 // Encode the register class in the upper 4 bits
301 // Must be kept in sync with NVPTXInstPrinter::printRegName
302 unsigned Ret = 0;
303 if (RC == &NVPTX::Int1RegsRegClass) {
304 Ret = (1 << 28);
305 } else if (RC == &NVPTX::Int16RegsRegClass) {
306 Ret = (2 << 28);
307 } else if (RC == &NVPTX::Int32RegsRegClass) {
308 Ret = (3 << 28);
309 } else if (RC == &NVPTX::Int64RegsRegClass) {
310 Ret = (4 << 28);
311 } else if (RC == &NVPTX::Float32RegsRegClass) {
312 Ret = (5 << 28);
313 } else if (RC == &NVPTX::Float64RegsRegClass) {
314 Ret = (6 << 28);
315 } else if (RC == &NVPTX::Int128RegsRegClass) {
316 Ret = (7 << 28);
317 } else {
318 report_fatal_error("Bad register class");
321 // Insert the vreg number
322 Ret |= (RegNum & 0x0FFFFFFF);
323 return Ret;
324 } else {
325 // Some special-use registers are actually physical registers.
326 // Encode this as the register class ID of 0 and the real register ID.
327 return Reg & 0x0FFFFFFF;
331 MCOperand NVPTXAsmPrinter::GetSymbolRef(const MCSymbol *Symbol) {
332 const MCExpr *Expr;
333 Expr = MCSymbolRefExpr::create(Symbol, MCSymbolRefExpr::VK_None,
334 OutContext);
335 return MCOperand::createExpr(Expr);
338 static bool ShouldPassAsArray(Type *Ty) {
339 return Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128) ||
340 Ty->isHalfTy() || Ty->isBFloatTy();
343 void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) {
344 const DataLayout &DL = getDataLayout();
345 const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(*F);
346 const auto *TLI = cast<NVPTXTargetLowering>(STI.getTargetLowering());
348 Type *Ty = F->getReturnType();
350 bool isABI = (STI.getSmVersion() >= 20);
352 if (Ty->getTypeID() == Type::VoidTyID)
353 return;
354 O << " (";
356 if (isABI) {
357 if ((Ty->isFloatingPointTy() || Ty->isIntegerTy()) &&
358 !ShouldPassAsArray(Ty)) {
359 unsigned size = 0;
360 if (auto *ITy = dyn_cast<IntegerType>(Ty)) {
361 size = ITy->getBitWidth();
362 } else {
363 assert(Ty->isFloatingPointTy() && "Floating point type expected here");
364 size = Ty->getPrimitiveSizeInBits();
366 size = promoteScalarArgumentSize(size);
367 O << ".param .b" << size << " func_retval0";
368 } else if (isa<PointerType>(Ty)) {
369 O << ".param .b" << TLI->getPointerTy(DL).getSizeInBits()
370 << " func_retval0";
371 } else if (ShouldPassAsArray(Ty)) {
372 unsigned totalsz = DL.getTypeAllocSize(Ty);
373 Align RetAlignment = TLI->getFunctionArgumentAlignment(
374 F, Ty, AttributeList::ReturnIndex, DL);
375 O << ".param .align " << RetAlignment.value() << " .b8 func_retval0["
376 << totalsz << "]";
377 } else
378 llvm_unreachable("Unknown return type");
379 } else {
380 SmallVector<EVT, 16> vtparts;
381 ComputeValueVTs(*TLI, DL, Ty, vtparts);
382 unsigned idx = 0;
383 for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
384 unsigned elems = 1;
385 EVT elemtype = vtparts[i];
386 if (vtparts[i].isVector()) {
387 elems = vtparts[i].getVectorNumElements();
388 elemtype = vtparts[i].getVectorElementType();
391 for (unsigned j = 0, je = elems; j != je; ++j) {
392 unsigned sz = elemtype.getSizeInBits();
393 if (elemtype.isInteger())
394 sz = promoteScalarArgumentSize(sz);
395 O << ".reg .b" << sz << " func_retval" << idx;
396 if (j < je - 1)
397 O << ", ";
398 ++idx;
400 if (i < e - 1)
401 O << ", ";
404 O << ") ";
407 void NVPTXAsmPrinter::printReturnValStr(const MachineFunction &MF,
408 raw_ostream &O) {
409 const Function &F = MF.getFunction();
410 printReturnValStr(&F, O);
413 // Return true if MBB is the header of a loop marked with
414 // llvm.loop.unroll.disable or llvm.loop.unroll.count=1.
415 bool NVPTXAsmPrinter::isLoopHeaderOfNoUnroll(
416 const MachineBasicBlock &MBB) const {
417 MachineLoopInfo &LI = getAnalysis<MachineLoopInfoWrapperPass>().getLI();
418 // We insert .pragma "nounroll" only to the loop header.
419 if (!LI.isLoopHeader(&MBB))
420 return false;
422 // llvm.loop.unroll.disable is marked on the back edges of a loop. Therefore,
423 // we iterate through each back edge of the loop with header MBB, and check
424 // whether its metadata contains llvm.loop.unroll.disable.
425 for (const MachineBasicBlock *PMBB : MBB.predecessors()) {
426 if (LI.getLoopFor(PMBB) != LI.getLoopFor(&MBB)) {
427 // Edges from other loops to MBB are not back edges.
428 continue;
430 if (const BasicBlock *PBB = PMBB->getBasicBlock()) {
431 if (MDNode *LoopID =
432 PBB->getTerminator()->getMetadata(LLVMContext::MD_loop)) {
433 if (GetUnrollMetadata(LoopID, "llvm.loop.unroll.disable"))
434 return true;
435 if (MDNode *UnrollCountMD =
436 GetUnrollMetadata(LoopID, "llvm.loop.unroll.count")) {
437 if (mdconst::extract<ConstantInt>(UnrollCountMD->getOperand(1))
438 ->isOne())
439 return true;
444 return false;
447 void NVPTXAsmPrinter::emitBasicBlockStart(const MachineBasicBlock &MBB) {
448 AsmPrinter::emitBasicBlockStart(MBB);
449 if (isLoopHeaderOfNoUnroll(MBB))
450 OutStreamer->emitRawText(StringRef("\t.pragma \"nounroll\";\n"));
453 void NVPTXAsmPrinter::emitFunctionEntryLabel() {
454 SmallString<128> Str;
455 raw_svector_ostream O(Str);
457 if (!GlobalsEmitted) {
458 emitGlobals(*MF->getFunction().getParent());
459 GlobalsEmitted = true;
462 // Set up
463 MRI = &MF->getRegInfo();
464 F = &MF->getFunction();
465 emitLinkageDirective(F, O);
466 if (isKernelFunction(*F))
467 O << ".entry ";
468 else {
469 O << ".func ";
470 printReturnValStr(*MF, O);
473 CurrentFnSym->print(O, MAI);
475 emitFunctionParamList(F, O);
476 O << "\n";
478 if (isKernelFunction(*F))
479 emitKernelFunctionDirectives(*F, O);
481 if (shouldEmitPTXNoReturn(F, TM))
482 O << ".noreturn";
484 OutStreamer->emitRawText(O.str());
486 VRegMapping.clear();
487 // Emit open brace for function body.
488 OutStreamer->emitRawText(StringRef("{\n"));
489 setAndEmitFunctionVirtualRegisters(*MF);
490 encodeDebugInfoRegisterNumbers(*MF);
491 // Emit initial .loc debug directive for correct relocation symbol data.
492 if (const DISubprogram *SP = MF->getFunction().getSubprogram()) {
493 assert(SP->getUnit());
494 if (!SP->getUnit()->isDebugDirectivesOnly())
495 emitInitialRawDwarfLocDirective(*MF);
499 bool NVPTXAsmPrinter::runOnMachineFunction(MachineFunction &F) {
500 bool Result = AsmPrinter::runOnMachineFunction(F);
501 // Emit closing brace for the body of function F.
502 // The closing brace must be emitted here because we need to emit additional
503 // debug labels/data after the last basic block.
504 // We need to emit the closing brace here because we don't have function that
505 // finished emission of the function body.
506 OutStreamer->emitRawText(StringRef("}\n"));
507 return Result;
510 void NVPTXAsmPrinter::emitFunctionBodyStart() {
511 SmallString<128> Str;
512 raw_svector_ostream O(Str);
513 emitDemotedVars(&MF->getFunction(), O);
514 OutStreamer->emitRawText(O.str());
517 void NVPTXAsmPrinter::emitFunctionBodyEnd() {
518 VRegMapping.clear();
521 const MCSymbol *NVPTXAsmPrinter::getFunctionFrameSymbol() const {
522 SmallString<128> Str;
523 raw_svector_ostream(Str) << DEPOTNAME << getFunctionNumber();
524 return OutContext.getOrCreateSymbol(Str);
527 void NVPTXAsmPrinter::emitImplicitDef(const MachineInstr *MI) const {
528 Register RegNo = MI->getOperand(0).getReg();
529 if (RegNo.isVirtual()) {
530 OutStreamer->AddComment(Twine("implicit-def: ") +
531 getVirtualRegisterName(RegNo));
532 } else {
533 const NVPTXSubtarget &STI = MI->getMF()->getSubtarget<NVPTXSubtarget>();
534 OutStreamer->AddComment(Twine("implicit-def: ") +
535 STI.getRegisterInfo()->getName(RegNo));
537 OutStreamer->addBlankLine();
540 void NVPTXAsmPrinter::emitKernelFunctionDirectives(const Function &F,
541 raw_ostream &O) const {
542 // If the NVVM IR has some of reqntid* specified, then output
543 // the reqntid directive, and set the unspecified ones to 1.
544 // If none of Reqntid* is specified, don't output reqntid directive.
545 std::optional<unsigned> Reqntidx = getReqNTIDx(F);
546 std::optional<unsigned> Reqntidy = getReqNTIDy(F);
547 std::optional<unsigned> Reqntidz = getReqNTIDz(F);
549 if (Reqntidx || Reqntidy || Reqntidz)
550 O << ".reqntid " << Reqntidx.value_or(1) << ", " << Reqntidy.value_or(1)
551 << ", " << Reqntidz.value_or(1) << "\n";
553 // If the NVVM IR has some of maxntid* specified, then output
554 // the maxntid directive, and set the unspecified ones to 1.
555 // If none of maxntid* is specified, don't output maxntid directive.
556 std::optional<unsigned> Maxntidx = getMaxNTIDx(F);
557 std::optional<unsigned> Maxntidy = getMaxNTIDy(F);
558 std::optional<unsigned> Maxntidz = getMaxNTIDz(F);
560 if (Maxntidx || Maxntidy || Maxntidz)
561 O << ".maxntid " << Maxntidx.value_or(1) << ", " << Maxntidy.value_or(1)
562 << ", " << Maxntidz.value_or(1) << "\n";
564 if (const auto Mincta = getMinCTASm(F))
565 O << ".minnctapersm " << *Mincta << "\n";
567 if (const auto Maxnreg = getMaxNReg(F))
568 O << ".maxnreg " << *Maxnreg << "\n";
570 // .maxclusterrank directive requires SM_90 or higher, make sure that we
571 // filter it out for lower SM versions, as it causes a hard ptxas crash.
572 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
573 const auto *STI = static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
575 if (STI->getSmVersion() >= 90) {
576 std::optional<unsigned> ClusterX = getClusterDimx(F);
577 std::optional<unsigned> ClusterY = getClusterDimy(F);
578 std::optional<unsigned> ClusterZ = getClusterDimz(F);
580 if (ClusterX || ClusterY || ClusterZ) {
581 O << ".explicitcluster\n";
582 if (ClusterX.value_or(1) != 0) {
583 assert(ClusterY.value_or(1) && ClusterZ.value_or(1) &&
584 "cluster_dim_x != 0 implies cluster_dim_y and cluster_dim_z "
585 "should be non-zero as well");
587 O << ".reqnctapercluster " << ClusterX.value_or(1) << ", "
588 << ClusterY.value_or(1) << ", " << ClusterZ.value_or(1) << "\n";
589 } else {
590 assert(!ClusterY.value_or(1) && !ClusterZ.value_or(1) &&
591 "cluster_dim_x == 0 implies cluster_dim_y and cluster_dim_z "
592 "should be 0 as well");
595 if (const auto Maxclusterrank = getMaxClusterRank(F))
596 O << ".maxclusterrank " << *Maxclusterrank << "\n";
600 std::string NVPTXAsmPrinter::getVirtualRegisterName(unsigned Reg) const {
601 const TargetRegisterClass *RC = MRI->getRegClass(Reg);
603 std::string Name;
604 raw_string_ostream NameStr(Name);
606 VRegRCMap::const_iterator I = VRegMapping.find(RC);
607 assert(I != VRegMapping.end() && "Bad register class");
608 const DenseMap<unsigned, unsigned> &RegMap = I->second;
610 VRegMap::const_iterator VI = RegMap.find(Reg);
611 assert(VI != RegMap.end() && "Bad virtual register");
612 unsigned MappedVR = VI->second;
614 NameStr << getNVPTXRegClassStr(RC) << MappedVR;
616 return Name;
619 void NVPTXAsmPrinter::emitVirtualRegister(unsigned int vr,
620 raw_ostream &O) {
621 O << getVirtualRegisterName(vr);
624 void NVPTXAsmPrinter::emitAliasDeclaration(const GlobalAlias *GA,
625 raw_ostream &O) {
626 const Function *F = dyn_cast_or_null<Function>(GA->getAliaseeObject());
627 if (!F || isKernelFunction(*F) || F->isDeclaration())
628 report_fatal_error(
629 "NVPTX aliasee must be a non-kernel function definition");
631 if (GA->hasLinkOnceLinkage() || GA->hasWeakLinkage() ||
632 GA->hasAvailableExternallyLinkage() || GA->hasCommonLinkage())
633 report_fatal_error("NVPTX aliasee must not be '.weak'");
635 emitDeclarationWithName(F, getSymbol(GA), O);
638 void NVPTXAsmPrinter::emitDeclaration(const Function *F, raw_ostream &O) {
639 emitDeclarationWithName(F, getSymbol(F), O);
642 void NVPTXAsmPrinter::emitDeclarationWithName(const Function *F, MCSymbol *S,
643 raw_ostream &O) {
644 emitLinkageDirective(F, O);
645 if (isKernelFunction(*F))
646 O << ".entry ";
647 else
648 O << ".func ";
649 printReturnValStr(F, O);
650 S->print(O, MAI);
651 O << "\n";
652 emitFunctionParamList(F, O);
653 O << "\n";
654 if (shouldEmitPTXNoReturn(F, TM))
655 O << ".noreturn";
656 O << ";\n";
659 static bool usedInGlobalVarDef(const Constant *C) {
660 if (!C)
661 return false;
663 if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(C)) {
664 return GV->getName() != "llvm.used";
667 for (const User *U : C->users())
668 if (const Constant *C = dyn_cast<Constant>(U))
669 if (usedInGlobalVarDef(C))
670 return true;
672 return false;
675 static bool usedInOneFunc(const User *U, Function const *&oneFunc) {
676 if (const GlobalVariable *othergv = dyn_cast<GlobalVariable>(U)) {
677 if (othergv->getName() == "llvm.used")
678 return true;
681 if (const Instruction *instr = dyn_cast<Instruction>(U)) {
682 if (instr->getParent() && instr->getParent()->getParent()) {
683 const Function *curFunc = instr->getParent()->getParent();
684 if (oneFunc && (curFunc != oneFunc))
685 return false;
686 oneFunc = curFunc;
687 return true;
688 } else
689 return false;
692 for (const User *UU : U->users())
693 if (!usedInOneFunc(UU, oneFunc))
694 return false;
696 return true;
699 /* Find out if a global variable can be demoted to local scope.
700 * Currently, this is valid for CUDA shared variables, which have local
701 * scope and global lifetime. So the conditions to check are :
702 * 1. Is the global variable in shared address space?
703 * 2. Does it have local linkage?
704 * 3. Is the global variable referenced only in one function?
706 static bool canDemoteGlobalVar(const GlobalVariable *gv, Function const *&f) {
707 if (!gv->hasLocalLinkage())
708 return false;
709 PointerType *Pty = gv->getType();
710 if (Pty->getAddressSpace() != ADDRESS_SPACE_SHARED)
711 return false;
713 const Function *oneFunc = nullptr;
715 bool flag = usedInOneFunc(gv, oneFunc);
716 if (!flag)
717 return false;
718 if (!oneFunc)
719 return false;
720 f = oneFunc;
721 return true;
724 static bool useFuncSeen(const Constant *C,
725 DenseMap<const Function *, bool> &seenMap) {
726 for (const User *U : C->users()) {
727 if (const Constant *cu = dyn_cast<Constant>(U)) {
728 if (useFuncSeen(cu, seenMap))
729 return true;
730 } else if (const Instruction *I = dyn_cast<Instruction>(U)) {
731 const BasicBlock *bb = I->getParent();
732 if (!bb)
733 continue;
734 const Function *caller = bb->getParent();
735 if (!caller)
736 continue;
737 if (seenMap.contains(caller))
738 return true;
741 return false;
744 void NVPTXAsmPrinter::emitDeclarations(const Module &M, raw_ostream &O) {
745 DenseMap<const Function *, bool> seenMap;
746 for (const Function &F : M) {
747 if (F.getAttributes().hasFnAttr("nvptx-libcall-callee")) {
748 emitDeclaration(&F, O);
749 continue;
752 if (F.isDeclaration()) {
753 if (F.use_empty())
754 continue;
755 if (F.getIntrinsicID())
756 continue;
757 emitDeclaration(&F, O);
758 continue;
760 for (const User *U : F.users()) {
761 if (const Constant *C = dyn_cast<Constant>(U)) {
762 if (usedInGlobalVarDef(C)) {
763 // The use is in the initialization of a global variable
764 // that is a function pointer, so print a declaration
765 // for the original function
766 emitDeclaration(&F, O);
767 break;
769 // Emit a declaration of this function if the function that
770 // uses this constant expr has already been seen.
771 if (useFuncSeen(C, seenMap)) {
772 emitDeclaration(&F, O);
773 break;
777 if (!isa<Instruction>(U))
778 continue;
779 const Instruction *instr = cast<Instruction>(U);
780 const BasicBlock *bb = instr->getParent();
781 if (!bb)
782 continue;
783 const Function *caller = bb->getParent();
784 if (!caller)
785 continue;
787 // If a caller has already been seen, then the caller is
788 // appearing in the module before the callee. so print out
789 // a declaration for the callee.
790 if (seenMap.contains(caller)) {
791 emitDeclaration(&F, O);
792 break;
795 seenMap[&F] = true;
797 for (const GlobalAlias &GA : M.aliases())
798 emitAliasDeclaration(&GA, O);
801 static bool isEmptyXXStructor(GlobalVariable *GV) {
802 if (!GV) return true;
803 const ConstantArray *InitList = dyn_cast<ConstantArray>(GV->getInitializer());
804 if (!InitList) return true; // Not an array; we don't know how to parse.
805 return InitList->getNumOperands() == 0;
808 void NVPTXAsmPrinter::emitStartOfAsmFile(Module &M) {
809 // Construct a default subtarget off of the TargetMachine defaults. The
810 // rest of NVPTX isn't friendly to change subtargets per function and
811 // so the default TargetMachine will have all of the options.
812 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
813 const auto* STI = static_cast<const NVPTXSubtarget*>(NTM.getSubtargetImpl());
814 SmallString<128> Str1;
815 raw_svector_ostream OS1(Str1);
817 // Emit header before any dwarf directives are emitted below.
818 emitHeader(M, OS1, *STI);
819 OutStreamer->emitRawText(OS1.str());
822 bool NVPTXAsmPrinter::doInitialization(Module &M) {
823 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
824 const NVPTXSubtarget &STI =
825 *static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
826 if (M.alias_size() && (STI.getPTXVersion() < 63 || STI.getSmVersion() < 30))
827 report_fatal_error(".alias requires PTX version >= 6.3 and sm_30");
829 // OpenMP supports NVPTX global constructors and destructors.
830 bool IsOpenMP = M.getModuleFlag("openmp") != nullptr;
832 if (!isEmptyXXStructor(M.getNamedGlobal("llvm.global_ctors")) &&
833 !LowerCtorDtor && !IsOpenMP) {
834 report_fatal_error(
835 "Module has a nontrivial global ctor, which NVPTX does not support.");
836 return true; // error
838 if (!isEmptyXXStructor(M.getNamedGlobal("llvm.global_dtors")) &&
839 !LowerCtorDtor && !IsOpenMP) {
840 report_fatal_error(
841 "Module has a nontrivial global dtor, which NVPTX does not support.");
842 return true; // error
845 // We need to call the parent's one explicitly.
846 bool Result = AsmPrinter::doInitialization(M);
848 GlobalsEmitted = false;
850 return Result;
853 void NVPTXAsmPrinter::emitGlobals(const Module &M) {
854 SmallString<128> Str2;
855 raw_svector_ostream OS2(Str2);
857 emitDeclarations(M, OS2);
859 // As ptxas does not support forward references of globals, we need to first
860 // sort the list of module-level globals in def-use order. We visit each
861 // global variable in order, and ensure that we emit it *after* its dependent
862 // globals. We use a little extra memory maintaining both a set and a list to
863 // have fast searches while maintaining a strict ordering.
864 SmallVector<const GlobalVariable *, 8> Globals;
865 DenseSet<const GlobalVariable *> GVVisited;
866 DenseSet<const GlobalVariable *> GVVisiting;
868 // Visit each global variable, in order
869 for (const GlobalVariable &I : M.globals())
870 VisitGlobalVariableForEmission(&I, Globals, GVVisited, GVVisiting);
872 assert(GVVisited.size() == M.global_size() && "Missed a global variable");
873 assert(GVVisiting.size() == 0 && "Did not fully process a global variable");
875 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
876 const NVPTXSubtarget &STI =
877 *static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
879 // Print out module-level global variables in proper order
880 for (const GlobalVariable *GV : Globals)
881 printModuleLevelGV(GV, OS2, /*processDemoted=*/false, STI);
883 OS2 << '\n';
885 OutStreamer->emitRawText(OS2.str());
888 void NVPTXAsmPrinter::emitGlobalAlias(const Module &M, const GlobalAlias &GA) {
889 SmallString<128> Str;
890 raw_svector_ostream OS(Str);
892 MCSymbol *Name = getSymbol(&GA);
894 OS << ".alias " << Name->getName() << ", " << GA.getAliaseeObject()->getName()
895 << ";\n";
897 OutStreamer->emitRawText(OS.str());
900 void NVPTXAsmPrinter::emitHeader(Module &M, raw_ostream &O,
901 const NVPTXSubtarget &STI) {
902 O << "//\n";
903 O << "// Generated by LLVM NVPTX Back-End\n";
904 O << "//\n";
905 O << "\n";
907 unsigned PTXVersion = STI.getPTXVersion();
908 O << ".version " << (PTXVersion / 10) << "." << (PTXVersion % 10) << "\n";
910 O << ".target ";
911 O << STI.getTargetName();
913 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
914 if (NTM.getDrvInterface() == NVPTX::NVCL)
915 O << ", texmode_independent";
917 bool HasFullDebugInfo = false;
918 for (DICompileUnit *CU : M.debug_compile_units()) {
919 switch(CU->getEmissionKind()) {
920 case DICompileUnit::NoDebug:
921 case DICompileUnit::DebugDirectivesOnly:
922 break;
923 case DICompileUnit::LineTablesOnly:
924 case DICompileUnit::FullDebug:
925 HasFullDebugInfo = true;
926 break;
928 if (HasFullDebugInfo)
929 break;
931 if (HasFullDebugInfo)
932 O << ", debug";
934 O << "\n";
936 O << ".address_size ";
937 if (NTM.is64Bit())
938 O << "64";
939 else
940 O << "32";
941 O << "\n";
943 O << "\n";
946 bool NVPTXAsmPrinter::doFinalization(Module &M) {
947 // If we did not emit any functions, then the global declarations have not
948 // yet been emitted.
949 if (!GlobalsEmitted) {
950 emitGlobals(M);
951 GlobalsEmitted = true;
954 // call doFinalization
955 bool ret = AsmPrinter::doFinalization(M);
957 clearAnnotationCache(&M);
959 auto *TS =
960 static_cast<NVPTXTargetStreamer *>(OutStreamer->getTargetStreamer());
961 // Close the last emitted section
962 if (hasDebugInfo()) {
963 TS->closeLastSection();
964 // Emit empty .debug_macinfo section for better support of the empty files.
965 OutStreamer->emitRawText("\t.section\t.debug_macinfo\t{\t}");
968 // Output last DWARF .file directives, if any.
969 TS->outputDwarfFileDirectives();
971 return ret;
974 // This function emits appropriate linkage directives for
975 // functions and global variables.
977 // extern function declaration -> .extern
978 // extern function definition -> .visible
979 // external global variable with init -> .visible
980 // external without init -> .extern
981 // appending -> not allowed, assert.
982 // for any linkage other than
983 // internal, private, linker_private,
984 // linker_private_weak, linker_private_weak_def_auto,
985 // we emit -> .weak.
987 void NVPTXAsmPrinter::emitLinkageDirective(const GlobalValue *V,
988 raw_ostream &O) {
989 if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() == NVPTX::CUDA) {
990 if (V->hasExternalLinkage()) {
991 if (isa<GlobalVariable>(V)) {
992 const GlobalVariable *GVar = cast<GlobalVariable>(V);
993 if (GVar) {
994 if (GVar->hasInitializer())
995 O << ".visible ";
996 else
997 O << ".extern ";
999 } else if (V->isDeclaration())
1000 O << ".extern ";
1001 else
1002 O << ".visible ";
1003 } else if (V->hasAppendingLinkage()) {
1004 std::string msg;
1005 msg.append("Error: ");
1006 msg.append("Symbol ");
1007 if (V->hasName())
1008 msg.append(std::string(V->getName()));
1009 msg.append("has unsupported appending linkage type");
1010 llvm_unreachable(msg.c_str());
1011 } else if (!V->hasInternalLinkage() &&
1012 !V->hasPrivateLinkage()) {
1013 O << ".weak ";
1018 void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
1019 raw_ostream &O, bool processDemoted,
1020 const NVPTXSubtarget &STI) {
1021 // Skip meta data
1022 if (GVar->hasSection()) {
1023 if (GVar->getSection() == "llvm.metadata")
1024 return;
1027 // Skip LLVM intrinsic global variables
1028 if (GVar->getName().starts_with("llvm.") ||
1029 GVar->getName().starts_with("nvvm."))
1030 return;
1032 const DataLayout &DL = getDataLayout();
1034 // GlobalVariables are always constant pointers themselves.
1035 Type *ETy = GVar->getValueType();
1037 if (GVar->hasExternalLinkage()) {
1038 if (GVar->hasInitializer())
1039 O << ".visible ";
1040 else
1041 O << ".extern ";
1042 } else if (STI.getPTXVersion() >= 50 && GVar->hasCommonLinkage() &&
1043 GVar->getAddressSpace() == ADDRESS_SPACE_GLOBAL) {
1044 O << ".common ";
1045 } else if (GVar->hasLinkOnceLinkage() || GVar->hasWeakLinkage() ||
1046 GVar->hasAvailableExternallyLinkage() ||
1047 GVar->hasCommonLinkage()) {
1048 O << ".weak ";
1051 if (isTexture(*GVar)) {
1052 O << ".global .texref " << getTextureName(*GVar) << ";\n";
1053 return;
1056 if (isSurface(*GVar)) {
1057 O << ".global .surfref " << getSurfaceName(*GVar) << ";\n";
1058 return;
1061 if (GVar->isDeclaration()) {
1062 // (extern) declarations, no definition or initializer
1063 // Currently the only known declaration is for an automatic __local
1064 // (.shared) promoted to global.
1065 emitPTXGlobalVariable(GVar, O, STI);
1066 O << ";\n";
1067 return;
1070 if (isSampler(*GVar)) {
1071 O << ".global .samplerref " << getSamplerName(*GVar);
1073 const Constant *Initializer = nullptr;
1074 if (GVar->hasInitializer())
1075 Initializer = GVar->getInitializer();
1076 const ConstantInt *CI = nullptr;
1077 if (Initializer)
1078 CI = dyn_cast<ConstantInt>(Initializer);
1079 if (CI) {
1080 unsigned sample = CI->getZExtValue();
1082 O << " = { ";
1084 for (int i = 0,
1085 addr = ((sample & __CLK_ADDRESS_MASK) >> __CLK_ADDRESS_BASE);
1086 i < 3; i++) {
1087 O << "addr_mode_" << i << " = ";
1088 switch (addr) {
1089 case 0:
1090 O << "wrap";
1091 break;
1092 case 1:
1093 O << "clamp_to_border";
1094 break;
1095 case 2:
1096 O << "clamp_to_edge";
1097 break;
1098 case 3:
1099 O << "wrap";
1100 break;
1101 case 4:
1102 O << "mirror";
1103 break;
1105 O << ", ";
1107 O << "filter_mode = ";
1108 switch ((sample & __CLK_FILTER_MASK) >> __CLK_FILTER_BASE) {
1109 case 0:
1110 O << "nearest";
1111 break;
1112 case 1:
1113 O << "linear";
1114 break;
1115 case 2:
1116 llvm_unreachable("Anisotropic filtering is not supported");
1117 default:
1118 O << "nearest";
1119 break;
1121 if (!((sample & __CLK_NORMALIZED_MASK) >> __CLK_NORMALIZED_BASE)) {
1122 O << ", force_unnormalized_coords = 1";
1124 O << " }";
1127 O << ";\n";
1128 return;
1131 if (GVar->hasPrivateLinkage()) {
1132 if (strncmp(GVar->getName().data(), "unrollpragma", 12) == 0)
1133 return;
1135 // FIXME - need better way (e.g. Metadata) to avoid generating this global
1136 if (strncmp(GVar->getName().data(), "filename", 8) == 0)
1137 return;
1138 if (GVar->use_empty())
1139 return;
1142 const Function *demotedFunc = nullptr;
1143 if (!processDemoted && canDemoteGlobalVar(GVar, demotedFunc)) {
1144 O << "// " << GVar->getName() << " has been demoted\n";
1145 localDecls[demotedFunc].push_back(GVar);
1146 return;
1149 O << ".";
1150 emitPTXAddressSpace(GVar->getAddressSpace(), O);
1152 if (isManaged(*GVar)) {
1153 if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30) {
1154 report_fatal_error(
1155 ".attribute(.managed) requires PTX version >= 4.0 and sm_30");
1157 O << " .attribute(.managed)";
1160 if (MaybeAlign A = GVar->getAlign())
1161 O << " .align " << A->value();
1162 else
1163 O << " .align " << (int)DL.getPrefTypeAlign(ETy).value();
1165 if (ETy->isFloatingPointTy() || ETy->isPointerTy() ||
1166 (ETy->isIntegerTy() && ETy->getScalarSizeInBits() <= 64)) {
1167 O << " .";
1168 // Special case: ABI requires that we use .u8 for predicates
1169 if (ETy->isIntegerTy(1))
1170 O << "u8";
1171 else
1172 O << getPTXFundamentalTypeStr(ETy, false);
1173 O << " ";
1174 getSymbol(GVar)->print(O, MAI);
1176 // Ptx allows variable initilization only for constant and global state
1177 // spaces.
1178 if (GVar->hasInitializer()) {
1179 if ((GVar->getAddressSpace() == ADDRESS_SPACE_GLOBAL) ||
1180 (GVar->getAddressSpace() == ADDRESS_SPACE_CONST)) {
1181 const Constant *Initializer = GVar->getInitializer();
1182 // 'undef' is treated as there is no value specified.
1183 if (!Initializer->isNullValue() && !isa<UndefValue>(Initializer)) {
1184 O << " = ";
1185 printScalarConstant(Initializer, O);
1187 } else {
1188 // The frontend adds zero-initializer to device and constant variables
1189 // that don't have an initial value, and UndefValue to shared
1190 // variables, so skip warning for this case.
1191 if (!GVar->getInitializer()->isNullValue() &&
1192 !isa<UndefValue>(GVar->getInitializer())) {
1193 report_fatal_error("initial value of '" + GVar->getName() +
1194 "' is not allowed in addrspace(" +
1195 Twine(GVar->getAddressSpace()) + ")");
1199 } else {
1200 uint64_t ElementSize = 0;
1202 // Although PTX has direct support for struct type and array type and
1203 // LLVM IR is very similar to PTX, the LLVM CodeGen does not support for
1204 // targets that support these high level field accesses. Structs, arrays
1205 // and vectors are lowered into arrays of bytes.
1206 switch (ETy->getTypeID()) {
1207 case Type::IntegerTyID: // Integers larger than 64 bits
1208 case Type::StructTyID:
1209 case Type::ArrayTyID:
1210 case Type::FixedVectorTyID:
1211 ElementSize = DL.getTypeStoreSize(ETy);
1212 // Ptx allows variable initilization only for constant and
1213 // global state spaces.
1214 if (((GVar->getAddressSpace() == ADDRESS_SPACE_GLOBAL) ||
1215 (GVar->getAddressSpace() == ADDRESS_SPACE_CONST)) &&
1216 GVar->hasInitializer()) {
1217 const Constant *Initializer = GVar->getInitializer();
1218 if (!isa<UndefValue>(Initializer) && !Initializer->isNullValue()) {
1219 AggBuffer aggBuffer(ElementSize, *this);
1220 bufferAggregateConstant(Initializer, &aggBuffer);
1221 if (aggBuffer.numSymbols()) {
1222 unsigned int ptrSize = MAI->getCodePointerSize();
1223 if (ElementSize % ptrSize ||
1224 !aggBuffer.allSymbolsAligned(ptrSize)) {
1225 // Print in bytes and use the mask() operator for pointers.
1226 if (!STI.hasMaskOperator())
1227 report_fatal_error(
1228 "initialized packed aggregate with pointers '" +
1229 GVar->getName() +
1230 "' requires at least PTX ISA version 7.1");
1231 O << " .u8 ";
1232 getSymbol(GVar)->print(O, MAI);
1233 O << "[" << ElementSize << "] = {";
1234 aggBuffer.printBytes(O);
1235 O << "}";
1236 } else {
1237 O << " .u" << ptrSize * 8 << " ";
1238 getSymbol(GVar)->print(O, MAI);
1239 O << "[" << ElementSize / ptrSize << "] = {";
1240 aggBuffer.printWords(O);
1241 O << "}";
1243 } else {
1244 O << " .b8 ";
1245 getSymbol(GVar)->print(O, MAI);
1246 O << "[" << ElementSize << "] = {";
1247 aggBuffer.printBytes(O);
1248 O << "}";
1250 } else {
1251 O << " .b8 ";
1252 getSymbol(GVar)->print(O, MAI);
1253 if (ElementSize) {
1254 O << "[";
1255 O << ElementSize;
1256 O << "]";
1259 } else {
1260 O << " .b8 ";
1261 getSymbol(GVar)->print(O, MAI);
1262 if (ElementSize) {
1263 O << "[";
1264 O << ElementSize;
1265 O << "]";
1268 break;
1269 default:
1270 llvm_unreachable("type not supported yet");
1273 O << ";\n";
1276 void NVPTXAsmPrinter::AggBuffer::printSymbol(unsigned nSym, raw_ostream &os) {
1277 const Value *v = Symbols[nSym];
1278 const Value *v0 = SymbolsBeforeStripping[nSym];
1279 if (const GlobalValue *GVar = dyn_cast<GlobalValue>(v)) {
1280 MCSymbol *Name = AP.getSymbol(GVar);
1281 PointerType *PTy = dyn_cast<PointerType>(v0->getType());
1282 // Is v0 a generic pointer?
1283 bool isGenericPointer = PTy && PTy->getAddressSpace() == 0;
1284 if (EmitGeneric && isGenericPointer && !isa<Function>(v)) {
1285 os << "generic(";
1286 Name->print(os, AP.MAI);
1287 os << ")";
1288 } else {
1289 Name->print(os, AP.MAI);
1291 } else if (const ConstantExpr *CExpr = dyn_cast<ConstantExpr>(v0)) {
1292 const MCExpr *Expr = AP.lowerConstantForGV(cast<Constant>(CExpr), false);
1293 AP.printMCExpr(*Expr, os);
1294 } else
1295 llvm_unreachable("symbol type unknown");
1298 void NVPTXAsmPrinter::AggBuffer::printBytes(raw_ostream &os) {
1299 unsigned int ptrSize = AP.MAI->getCodePointerSize();
1300 // Do not emit trailing zero initializers. They will be zero-initialized by
1301 // ptxas. This saves on both space requirements for the generated PTX and on
1302 // memory use by ptxas. (See:
1303 // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#global-state-space)
1304 unsigned int InitializerCount = size;
1305 // TODO: symbols make this harder, but it would still be good to trim trailing
1306 // 0s for aggs with symbols as well.
1307 if (numSymbols() == 0)
1308 while (InitializerCount >= 1 && !buffer[InitializerCount - 1])
1309 InitializerCount--;
1311 symbolPosInBuffer.push_back(InitializerCount);
1312 unsigned int nSym = 0;
1313 unsigned int nextSymbolPos = symbolPosInBuffer[nSym];
1314 for (unsigned int pos = 0; pos < InitializerCount;) {
1315 if (pos)
1316 os << ", ";
1317 if (pos != nextSymbolPos) {
1318 os << (unsigned int)buffer[pos];
1319 ++pos;
1320 continue;
1322 // Generate a per-byte mask() operator for the symbol, which looks like:
1323 // .global .u8 addr[] = {0xFF(foo), 0xFF00(foo), 0xFF0000(foo), ...};
1324 // See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#initializers
1325 std::string symText;
1326 llvm::raw_string_ostream oss(symText);
1327 printSymbol(nSym, oss);
1328 for (unsigned i = 0; i < ptrSize; ++i) {
1329 if (i)
1330 os << ", ";
1331 llvm::write_hex(os, 0xFFULL << i * 8, HexPrintStyle::PrefixUpper);
1332 os << "(" << symText << ")";
1334 pos += ptrSize;
1335 nextSymbolPos = symbolPosInBuffer[++nSym];
1336 assert(nextSymbolPos >= pos);
1340 void NVPTXAsmPrinter::AggBuffer::printWords(raw_ostream &os) {
1341 unsigned int ptrSize = AP.MAI->getCodePointerSize();
1342 symbolPosInBuffer.push_back(size);
1343 unsigned int nSym = 0;
1344 unsigned int nextSymbolPos = symbolPosInBuffer[nSym];
1345 assert(nextSymbolPos % ptrSize == 0);
1346 for (unsigned int pos = 0; pos < size; pos += ptrSize) {
1347 if (pos)
1348 os << ", ";
1349 if (pos == nextSymbolPos) {
1350 printSymbol(nSym, os);
1351 nextSymbolPos = symbolPosInBuffer[++nSym];
1352 assert(nextSymbolPos % ptrSize == 0);
1353 assert(nextSymbolPos >= pos + ptrSize);
1354 } else if (ptrSize == 4)
1355 os << support::endian::read32le(&buffer[pos]);
1356 else
1357 os << support::endian::read64le(&buffer[pos]);
1361 void NVPTXAsmPrinter::emitDemotedVars(const Function *f, raw_ostream &O) {
1362 auto It = localDecls.find(f);
1363 if (It == localDecls.end())
1364 return;
1366 std::vector<const GlobalVariable *> &gvars = It->second;
1368 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
1369 const NVPTXSubtarget &STI =
1370 *static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
1372 for (const GlobalVariable *GV : gvars) {
1373 O << "\t// demoted variable\n\t";
1374 printModuleLevelGV(GV, O, /*processDemoted=*/true, STI);
1378 void NVPTXAsmPrinter::emitPTXAddressSpace(unsigned int AddressSpace,
1379 raw_ostream &O) const {
1380 switch (AddressSpace) {
1381 case ADDRESS_SPACE_LOCAL:
1382 O << "local";
1383 break;
1384 case ADDRESS_SPACE_GLOBAL:
1385 O << "global";
1386 break;
1387 case ADDRESS_SPACE_CONST:
1388 O << "const";
1389 break;
1390 case ADDRESS_SPACE_SHARED:
1391 O << "shared";
1392 break;
1393 default:
1394 report_fatal_error("Bad address space found while emitting PTX: " +
1395 llvm::Twine(AddressSpace));
1396 break;
1400 std::string
1401 NVPTXAsmPrinter::getPTXFundamentalTypeStr(Type *Ty, bool useB4PTR) const {
1402 switch (Ty->getTypeID()) {
1403 case Type::IntegerTyID: {
1404 unsigned NumBits = cast<IntegerType>(Ty)->getBitWidth();
1405 if (NumBits == 1)
1406 return "pred";
1407 else if (NumBits <= 64) {
1408 std::string name = "u";
1409 return name + utostr(NumBits);
1410 } else {
1411 llvm_unreachable("Integer too large");
1412 break;
1414 break;
1416 case Type::BFloatTyID:
1417 case Type::HalfTyID:
1418 // fp16 and bf16 are stored as .b16 for compatibility with pre-sm_53
1419 // PTX assembly.
1420 return "b16";
1421 case Type::FloatTyID:
1422 return "f32";
1423 case Type::DoubleTyID:
1424 return "f64";
1425 case Type::PointerTyID: {
1426 unsigned PtrSize = TM.getPointerSizeInBits(Ty->getPointerAddressSpace());
1427 assert((PtrSize == 64 || PtrSize == 32) && "Unexpected pointer size");
1429 if (PtrSize == 64)
1430 if (useB4PTR)
1431 return "b64";
1432 else
1433 return "u64";
1434 else if (useB4PTR)
1435 return "b32";
1436 else
1437 return "u32";
1439 default:
1440 break;
1442 llvm_unreachable("unexpected type");
1445 void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar,
1446 raw_ostream &O,
1447 const NVPTXSubtarget &STI) {
1448 const DataLayout &DL = getDataLayout();
1450 // GlobalVariables are always constant pointers themselves.
1451 Type *ETy = GVar->getValueType();
1453 O << ".";
1454 emitPTXAddressSpace(GVar->getType()->getAddressSpace(), O);
1455 if (isManaged(*GVar)) {
1456 if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30) {
1457 report_fatal_error(
1458 ".attribute(.managed) requires PTX version >= 4.0 and sm_30");
1460 O << " .attribute(.managed)";
1462 if (MaybeAlign A = GVar->getAlign())
1463 O << " .align " << A->value();
1464 else
1465 O << " .align " << (int)DL.getPrefTypeAlign(ETy).value();
1467 // Special case for i128
1468 if (ETy->isIntegerTy(128)) {
1469 O << " .b8 ";
1470 getSymbol(GVar)->print(O, MAI);
1471 O << "[16]";
1472 return;
1475 if (ETy->isFloatingPointTy() || ETy->isIntOrPtrTy()) {
1476 O << " .";
1477 O << getPTXFundamentalTypeStr(ETy);
1478 O << " ";
1479 getSymbol(GVar)->print(O, MAI);
1480 return;
1483 int64_t ElementSize = 0;
1485 // Although PTX has direct support for struct type and array type and LLVM IR
1486 // is very similar to PTX, the LLVM CodeGen does not support for targets that
1487 // support these high level field accesses. Structs and arrays are lowered
1488 // into arrays of bytes.
1489 switch (ETy->getTypeID()) {
1490 case Type::StructTyID:
1491 case Type::ArrayTyID:
1492 case Type::FixedVectorTyID:
1493 ElementSize = DL.getTypeStoreSize(ETy);
1494 O << " .b8 ";
1495 getSymbol(GVar)->print(O, MAI);
1496 O << "[";
1497 if (ElementSize) {
1498 O << ElementSize;
1500 O << "]";
1501 break;
1502 default:
1503 llvm_unreachable("type not supported yet");
1507 void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
1508 const DataLayout &DL = getDataLayout();
1509 const AttributeList &PAL = F->getAttributes();
1510 const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(*F);
1511 const auto *TLI = cast<NVPTXTargetLowering>(STI.getTargetLowering());
1513 Function::const_arg_iterator I, E;
1514 unsigned paramIndex = 0;
1515 bool first = true;
1516 bool isKernelFunc = isKernelFunction(*F);
1517 bool isABI = (STI.getSmVersion() >= 20);
1518 bool hasImageHandles = STI.hasImageHandles();
1520 if (F->arg_empty() && !F->isVarArg()) {
1521 O << "()";
1522 return;
1525 O << "(\n";
1527 for (I = F->arg_begin(), E = F->arg_end(); I != E; ++I, paramIndex++) {
1528 Type *Ty = I->getType();
1530 if (!first)
1531 O << ",\n";
1533 first = false;
1535 // Handle image/sampler parameters
1536 if (isKernelFunction(*F)) {
1537 if (isSampler(*I) || isImage(*I)) {
1538 if (isImage(*I)) {
1539 if (isImageWriteOnly(*I) || isImageReadWrite(*I)) {
1540 if (hasImageHandles)
1541 O << "\t.param .u64 .ptr .surfref ";
1542 else
1543 O << "\t.param .surfref ";
1544 O << TLI->getParamName(F, paramIndex);
1546 else { // Default image is read_only
1547 if (hasImageHandles)
1548 O << "\t.param .u64 .ptr .texref ";
1549 else
1550 O << "\t.param .texref ";
1551 O << TLI->getParamName(F, paramIndex);
1553 } else {
1554 if (hasImageHandles)
1555 O << "\t.param .u64 .ptr .samplerref ";
1556 else
1557 O << "\t.param .samplerref ";
1558 O << TLI->getParamName(F, paramIndex);
1560 continue;
1564 auto getOptimalAlignForParam = [TLI, &DL, &PAL, F,
1565 paramIndex](Type *Ty) -> Align {
1566 if (MaybeAlign StackAlign =
1567 getAlign(*F, paramIndex + AttributeList::FirstArgIndex))
1568 return StackAlign.value();
1570 Align TypeAlign = TLI->getFunctionParamOptimizedAlign(F, Ty, DL);
1571 MaybeAlign ParamAlign = PAL.getParamAlignment(paramIndex);
1572 return std::max(TypeAlign, ParamAlign.valueOrOne());
1575 if (!PAL.hasParamAttr(paramIndex, Attribute::ByVal)) {
1576 if (ShouldPassAsArray(Ty)) {
1577 // Just print .param .align <a> .b8 .param[size];
1578 // <a> = optimal alignment for the element type; always multiple of
1579 // PAL.getParamAlignment
1580 // size = typeallocsize of element type
1581 Align OptimalAlign = getOptimalAlignForParam(Ty);
1583 O << "\t.param .align " << OptimalAlign.value() << " .b8 ";
1584 O << TLI->getParamName(F, paramIndex);
1585 O << "[" << DL.getTypeAllocSize(Ty) << "]";
1587 continue;
1589 // Just a scalar
1590 auto *PTy = dyn_cast<PointerType>(Ty);
1591 unsigned PTySizeInBits = 0;
1592 if (PTy) {
1593 PTySizeInBits =
1594 TLI->getPointerTy(DL, PTy->getAddressSpace()).getSizeInBits();
1595 assert(PTySizeInBits && "Invalid pointer size");
1598 if (isKernelFunc) {
1599 if (PTy) {
1600 O << "\t.param .u" << PTySizeInBits << " .ptr";
1602 switch (PTy->getAddressSpace()) {
1603 default:
1604 break;
1605 case ADDRESS_SPACE_GLOBAL:
1606 O << " .global";
1607 break;
1608 case ADDRESS_SPACE_SHARED:
1609 O << " .shared";
1610 break;
1611 case ADDRESS_SPACE_CONST:
1612 O << " .const";
1613 break;
1614 case ADDRESS_SPACE_LOCAL:
1615 O << " .local";
1616 break;
1619 O << " .align " << I->getParamAlign().valueOrOne().value();
1620 O << " " << TLI->getParamName(F, paramIndex);
1621 continue;
1624 // non-pointer scalar to kernel func
1625 O << "\t.param .";
1626 // Special case: predicate operands become .u8 types
1627 if (Ty->isIntegerTy(1))
1628 O << "u8";
1629 else
1630 O << getPTXFundamentalTypeStr(Ty);
1631 O << " ";
1632 O << TLI->getParamName(F, paramIndex);
1633 continue;
1635 // Non-kernel function, just print .param .b<size> for ABI
1636 // and .reg .b<size> for non-ABI
1637 unsigned sz = 0;
1638 if (isa<IntegerType>(Ty)) {
1639 sz = cast<IntegerType>(Ty)->getBitWidth();
1640 sz = promoteScalarArgumentSize(sz);
1641 } else if (PTy) {
1642 assert(PTySizeInBits && "Invalid pointer size");
1643 sz = PTySizeInBits;
1644 } else
1645 sz = Ty->getPrimitiveSizeInBits();
1646 if (isABI)
1647 O << "\t.param .b" << sz << " ";
1648 else
1649 O << "\t.reg .b" << sz << " ";
1650 O << TLI->getParamName(F, paramIndex);
1651 continue;
1654 // param has byVal attribute.
1655 Type *ETy = PAL.getParamByValType(paramIndex);
1656 assert(ETy && "Param should have byval type");
1658 if (isABI || isKernelFunc) {
1659 // Just print .param .align <a> .b8 .param[size];
1660 // <a> = optimal alignment for the element type; always multiple of
1661 // PAL.getParamAlignment
1662 // size = typeallocsize of element type
1663 Align OptimalAlign =
1664 isKernelFunc
1665 ? getOptimalAlignForParam(ETy)
1666 : TLI->getFunctionByValParamAlign(
1667 F, ETy, PAL.getParamAlignment(paramIndex).valueOrOne(), DL);
1669 unsigned sz = DL.getTypeAllocSize(ETy);
1670 O << "\t.param .align " << OptimalAlign.value() << " .b8 ";
1671 O << TLI->getParamName(F, paramIndex);
1672 O << "[" << sz << "]";
1673 continue;
1674 } else {
1675 // Split the ETy into constituent parts and
1676 // print .param .b<size> <name> for each part.
1677 // Further, if a part is vector, print the above for
1678 // each vector element.
1679 SmallVector<EVT, 16> vtparts;
1680 ComputeValueVTs(*TLI, DL, ETy, vtparts);
1681 for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
1682 unsigned elems = 1;
1683 EVT elemtype = vtparts[i];
1684 if (vtparts[i].isVector()) {
1685 elems = vtparts[i].getVectorNumElements();
1686 elemtype = vtparts[i].getVectorElementType();
1689 for (unsigned j = 0, je = elems; j != je; ++j) {
1690 unsigned sz = elemtype.getSizeInBits();
1691 if (elemtype.isInteger())
1692 sz = promoteScalarArgumentSize(sz);
1693 O << "\t.reg .b" << sz << " ";
1694 O << TLI->getParamName(F, paramIndex);
1695 if (j < je - 1)
1696 O << ",\n";
1697 ++paramIndex;
1699 if (i < e - 1)
1700 O << ",\n";
1702 --paramIndex;
1703 continue;
1707 if (F->isVarArg()) {
1708 if (!first)
1709 O << ",\n";
1710 O << "\t.param .align " << STI.getMaxRequiredAlignment();
1711 O << " .b8 ";
1712 O << TLI->getParamName(F, /* vararg */ -1) << "[]";
1715 O << "\n)";
1718 void NVPTXAsmPrinter::setAndEmitFunctionVirtualRegisters(
1719 const MachineFunction &MF) {
1720 SmallString<128> Str;
1721 raw_svector_ostream O(Str);
1723 // Map the global virtual register number to a register class specific
1724 // virtual register number starting from 1 with that class.
1725 const TargetRegisterInfo *TRI = MF.getSubtarget().getRegisterInfo();
1726 //unsigned numRegClasses = TRI->getNumRegClasses();
1728 // Emit the Fake Stack Object
1729 const MachineFrameInfo &MFI = MF.getFrameInfo();
1730 int64_t NumBytes = MFI.getStackSize();
1731 if (NumBytes) {
1732 O << "\t.local .align " << MFI.getMaxAlign().value() << " .b8 \t"
1733 << DEPOTNAME << getFunctionNumber() << "[" << NumBytes << "];\n";
1734 if (static_cast<const NVPTXTargetMachine &>(MF.getTarget()).is64Bit()) {
1735 O << "\t.reg .b64 \t%SP;\n";
1736 O << "\t.reg .b64 \t%SPL;\n";
1737 } else {
1738 O << "\t.reg .b32 \t%SP;\n";
1739 O << "\t.reg .b32 \t%SPL;\n";
1743 // Go through all virtual registers to establish the mapping between the
1744 // global virtual
1745 // register number and the per class virtual register number.
1746 // We use the per class virtual register number in the ptx output.
1747 unsigned int numVRs = MRI->getNumVirtRegs();
1748 for (unsigned i = 0; i < numVRs; i++) {
1749 Register vr = Register::index2VirtReg(i);
1750 const TargetRegisterClass *RC = MRI->getRegClass(vr);
1751 DenseMap<unsigned, unsigned> &regmap = VRegMapping[RC];
1752 int n = regmap.size();
1753 regmap.insert(std::make_pair(vr, n + 1));
1756 // Emit register declarations
1757 // @TODO: Extract out the real register usage
1758 // O << "\t.reg .pred %p<" << NVPTXNumRegisters << ">;\n";
1759 // O << "\t.reg .s16 %rc<" << NVPTXNumRegisters << ">;\n";
1760 // O << "\t.reg .s16 %rs<" << NVPTXNumRegisters << ">;\n";
1761 // O << "\t.reg .s32 %r<" << NVPTXNumRegisters << ">;\n";
1762 // O << "\t.reg .s64 %rd<" << NVPTXNumRegisters << ">;\n";
1763 // O << "\t.reg .f32 %f<" << NVPTXNumRegisters << ">;\n";
1764 // O << "\t.reg .f64 %fd<" << NVPTXNumRegisters << ">;\n";
1766 // Emit declaration of the virtual registers or 'physical' registers for
1767 // each register class
1768 for (unsigned i=0; i< TRI->getNumRegClasses(); i++) {
1769 const TargetRegisterClass *RC = TRI->getRegClass(i);
1770 DenseMap<unsigned, unsigned> &regmap = VRegMapping[RC];
1771 std::string rcname = getNVPTXRegClassName(RC);
1772 std::string rcStr = getNVPTXRegClassStr(RC);
1773 int n = regmap.size();
1775 // Only declare those registers that may be used.
1776 if (n) {
1777 O << "\t.reg " << rcname << " \t" << rcStr << "<" << (n+1)
1778 << ">;\n";
1782 OutStreamer->emitRawText(O.str());
1785 /// Translate virtual register numbers in DebugInfo locations to their printed
1786 /// encodings, as used by CUDA-GDB.
1787 void NVPTXAsmPrinter::encodeDebugInfoRegisterNumbers(
1788 const MachineFunction &MF) {
1789 const NVPTXSubtarget &STI = MF.getSubtarget<NVPTXSubtarget>();
1790 const NVPTXRegisterInfo *registerInfo = STI.getRegisterInfo();
1792 // Clear the old mapping, and add the new one. This mapping is used after the
1793 // printing of the current function is complete, but before the next function
1794 // is printed.
1795 registerInfo->clearDebugRegisterMap();
1797 for (auto &classMap : VRegMapping) {
1798 for (auto &registerMapping : classMap.getSecond()) {
1799 auto reg = registerMapping.getFirst();
1800 registerInfo->addToDebugRegisterMap(reg, getVirtualRegisterName(reg));
1805 void NVPTXAsmPrinter::printFPConstant(const ConstantFP *Fp, raw_ostream &O) {
1806 APFloat APF = APFloat(Fp->getValueAPF()); // make a copy
1807 bool ignored;
1808 unsigned int numHex;
1809 const char *lead;
1811 if (Fp->getType()->getTypeID() == Type::FloatTyID) {
1812 numHex = 8;
1813 lead = "0f";
1814 APF.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven, &ignored);
1815 } else if (Fp->getType()->getTypeID() == Type::DoubleTyID) {
1816 numHex = 16;
1817 lead = "0d";
1818 APF.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, &ignored);
1819 } else
1820 llvm_unreachable("unsupported fp type");
1822 APInt API = APF.bitcastToAPInt();
1823 O << lead << format_hex_no_prefix(API.getZExtValue(), numHex, /*Upper=*/true);
1826 void NVPTXAsmPrinter::printScalarConstant(const Constant *CPV, raw_ostream &O) {
1827 if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) {
1828 O << CI->getValue();
1829 return;
1831 if (const ConstantFP *CFP = dyn_cast<ConstantFP>(CPV)) {
1832 printFPConstant(CFP, O);
1833 return;
1835 if (isa<ConstantPointerNull>(CPV)) {
1836 O << "0";
1837 return;
1839 if (const GlobalValue *GVar = dyn_cast<GlobalValue>(CPV)) {
1840 bool IsNonGenericPointer = false;
1841 if (GVar->getType()->getAddressSpace() != 0) {
1842 IsNonGenericPointer = true;
1844 if (EmitGeneric && !isa<Function>(CPV) && !IsNonGenericPointer) {
1845 O << "generic(";
1846 getSymbol(GVar)->print(O, MAI);
1847 O << ")";
1848 } else {
1849 getSymbol(GVar)->print(O, MAI);
1851 return;
1853 if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(CPV)) {
1854 const MCExpr *E = lowerConstantForGV(cast<Constant>(Cexpr), false);
1855 printMCExpr(*E, O);
1856 return;
1858 llvm_unreachable("Not scalar type found in printScalarConstant()");
1861 void NVPTXAsmPrinter::bufferLEByte(const Constant *CPV, int Bytes,
1862 AggBuffer *AggBuffer) {
1863 const DataLayout &DL = getDataLayout();
1864 int AllocSize = DL.getTypeAllocSize(CPV->getType());
1865 if (isa<UndefValue>(CPV) || CPV->isNullValue()) {
1866 // Non-zero Bytes indicates that we need to zero-fill everything. Otherwise,
1867 // only the space allocated by CPV.
1868 AggBuffer->addZeros(Bytes ? Bytes : AllocSize);
1869 return;
1872 // Helper for filling AggBuffer with APInts.
1873 auto AddIntToBuffer = [AggBuffer, Bytes](const APInt &Val) {
1874 size_t NumBytes = (Val.getBitWidth() + 7) / 8;
1875 SmallVector<unsigned char, 16> Buf(NumBytes);
1876 // `extractBitsAsZExtValue` does not allow the extraction of bits beyond the
1877 // input's bit width, and i1 arrays may not have a length that is a multuple
1878 // of 8. We handle the last byte separately, so we never request out of
1879 // bounds bits.
1880 for (unsigned I = 0; I < NumBytes - 1; ++I) {
1881 Buf[I] = Val.extractBitsAsZExtValue(8, I * 8);
1883 size_t LastBytePosition = (NumBytes - 1) * 8;
1884 size_t LastByteBits = Val.getBitWidth() - LastBytePosition;
1885 Buf[NumBytes - 1] =
1886 Val.extractBitsAsZExtValue(LastByteBits, LastBytePosition);
1887 AggBuffer->addBytes(Buf.data(), NumBytes, Bytes);
1890 switch (CPV->getType()->getTypeID()) {
1891 case Type::IntegerTyID:
1892 if (const auto CI = dyn_cast<ConstantInt>(CPV)) {
1893 AddIntToBuffer(CI->getValue());
1894 break;
1896 if (const auto *Cexpr = dyn_cast<ConstantExpr>(CPV)) {
1897 if (const auto *CI =
1898 dyn_cast<ConstantInt>(ConstantFoldConstant(Cexpr, DL))) {
1899 AddIntToBuffer(CI->getValue());
1900 break;
1902 if (Cexpr->getOpcode() == Instruction::PtrToInt) {
1903 Value *V = Cexpr->getOperand(0)->stripPointerCasts();
1904 AggBuffer->addSymbol(V, Cexpr->getOperand(0));
1905 AggBuffer->addZeros(AllocSize);
1906 break;
1909 llvm_unreachable("unsupported integer const type");
1910 break;
1912 case Type::HalfTyID:
1913 case Type::BFloatTyID:
1914 case Type::FloatTyID:
1915 case Type::DoubleTyID:
1916 AddIntToBuffer(cast<ConstantFP>(CPV)->getValueAPF().bitcastToAPInt());
1917 break;
1919 case Type::PointerTyID: {
1920 if (const GlobalValue *GVar = dyn_cast<GlobalValue>(CPV)) {
1921 AggBuffer->addSymbol(GVar, GVar);
1922 } else if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(CPV)) {
1923 const Value *v = Cexpr->stripPointerCasts();
1924 AggBuffer->addSymbol(v, Cexpr);
1926 AggBuffer->addZeros(AllocSize);
1927 break;
1930 case Type::ArrayTyID:
1931 case Type::FixedVectorTyID:
1932 case Type::StructTyID: {
1933 if (isa<ConstantAggregate>(CPV) || isa<ConstantDataSequential>(CPV)) {
1934 bufferAggregateConstant(CPV, AggBuffer);
1935 if (Bytes > AllocSize)
1936 AggBuffer->addZeros(Bytes - AllocSize);
1937 } else if (isa<ConstantAggregateZero>(CPV))
1938 AggBuffer->addZeros(Bytes);
1939 else
1940 llvm_unreachable("Unexpected Constant type");
1941 break;
1944 default:
1945 llvm_unreachable("unsupported type");
1949 void NVPTXAsmPrinter::bufferAggregateConstant(const Constant *CPV,
1950 AggBuffer *aggBuffer) {
1951 const DataLayout &DL = getDataLayout();
1952 int Bytes;
1954 // Integers of arbitrary width
1955 if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) {
1956 APInt Val = CI->getValue();
1957 for (unsigned I = 0, E = DL.getTypeAllocSize(CPV->getType()); I < E; ++I) {
1958 uint8_t Byte = Val.getLoBits(8).getZExtValue();
1959 aggBuffer->addBytes(&Byte, 1, 1);
1960 Val.lshrInPlace(8);
1962 return;
1965 // Old constants
1966 if (isa<ConstantArray>(CPV) || isa<ConstantVector>(CPV)) {
1967 if (CPV->getNumOperands())
1968 for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i)
1969 bufferLEByte(cast<Constant>(CPV->getOperand(i)), 0, aggBuffer);
1970 return;
1973 if (const ConstantDataSequential *CDS =
1974 dyn_cast<ConstantDataSequential>(CPV)) {
1975 if (CDS->getNumElements())
1976 for (unsigned i = 0; i < CDS->getNumElements(); ++i)
1977 bufferLEByte(cast<Constant>(CDS->getElementAsConstant(i)), 0,
1978 aggBuffer);
1979 return;
1982 if (isa<ConstantStruct>(CPV)) {
1983 if (CPV->getNumOperands()) {
1984 StructType *ST = cast<StructType>(CPV->getType());
1985 for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i) {
1986 if (i == (e - 1))
1987 Bytes = DL.getStructLayout(ST)->getElementOffset(0) +
1988 DL.getTypeAllocSize(ST) -
1989 DL.getStructLayout(ST)->getElementOffset(i);
1990 else
1991 Bytes = DL.getStructLayout(ST)->getElementOffset(i + 1) -
1992 DL.getStructLayout(ST)->getElementOffset(i);
1993 bufferLEByte(cast<Constant>(CPV->getOperand(i)), Bytes, aggBuffer);
1996 return;
1998 llvm_unreachable("unsupported constant type in printAggregateConstant()");
2001 /// lowerConstantForGV - Return an MCExpr for the given Constant. This is mostly
2002 /// a copy from AsmPrinter::lowerConstant, except customized to only handle
2003 /// expressions that are representable in PTX and create
2004 /// NVPTXGenericMCSymbolRefExpr nodes for addrspacecast instructions.
2005 const MCExpr *
2006 NVPTXAsmPrinter::lowerConstantForGV(const Constant *CV, bool ProcessingGeneric) {
2007 MCContext &Ctx = OutContext;
2009 if (CV->isNullValue() || isa<UndefValue>(CV))
2010 return MCConstantExpr::create(0, Ctx);
2012 if (const ConstantInt *CI = dyn_cast<ConstantInt>(CV))
2013 return MCConstantExpr::create(CI->getZExtValue(), Ctx);
2015 if (const GlobalValue *GV = dyn_cast<GlobalValue>(CV)) {
2016 const MCSymbolRefExpr *Expr =
2017 MCSymbolRefExpr::create(getSymbol(GV), Ctx);
2018 if (ProcessingGeneric) {
2019 return NVPTXGenericMCSymbolRefExpr::create(Expr, Ctx);
2020 } else {
2021 return Expr;
2025 const ConstantExpr *CE = dyn_cast<ConstantExpr>(CV);
2026 if (!CE) {
2027 llvm_unreachable("Unknown constant value to lower!");
2030 switch (CE->getOpcode()) {
2031 default:
2032 break; // Error
2034 case Instruction::AddrSpaceCast: {
2035 // Strip the addrspacecast and pass along the operand
2036 PointerType *DstTy = cast<PointerType>(CE->getType());
2037 if (DstTy->getAddressSpace() == 0)
2038 return lowerConstantForGV(cast<const Constant>(CE->getOperand(0)), true);
2040 break; // Error
2043 case Instruction::GetElementPtr: {
2044 const DataLayout &DL = getDataLayout();
2046 // Generate a symbolic expression for the byte address
2047 APInt OffsetAI(DL.getPointerTypeSizeInBits(CE->getType()), 0);
2048 cast<GEPOperator>(CE)->accumulateConstantOffset(DL, OffsetAI);
2050 const MCExpr *Base = lowerConstantForGV(CE->getOperand(0),
2051 ProcessingGeneric);
2052 if (!OffsetAI)
2053 return Base;
2055 int64_t Offset = OffsetAI.getSExtValue();
2056 return MCBinaryExpr::createAdd(Base, MCConstantExpr::create(Offset, Ctx),
2057 Ctx);
2060 case Instruction::Trunc:
2061 // We emit the value and depend on the assembler to truncate the generated
2062 // expression properly. This is important for differences between
2063 // blockaddress labels. Since the two labels are in the same function, it
2064 // is reasonable to treat their delta as a 32-bit value.
2065 [[fallthrough]];
2066 case Instruction::BitCast:
2067 return lowerConstantForGV(CE->getOperand(0), ProcessingGeneric);
2069 case Instruction::IntToPtr: {
2070 const DataLayout &DL = getDataLayout();
2072 // Handle casts to pointers by changing them into casts to the appropriate
2073 // integer type. This promotes constant folding and simplifies this code.
2074 Constant *Op = CE->getOperand(0);
2075 Op = ConstantFoldIntegerCast(Op, DL.getIntPtrType(CV->getType()),
2076 /*IsSigned*/ false, DL);
2077 if (Op)
2078 return lowerConstantForGV(Op, ProcessingGeneric);
2080 break; // Error
2083 case Instruction::PtrToInt: {
2084 const DataLayout &DL = getDataLayout();
2086 // Support only foldable casts to/from pointers that can be eliminated by
2087 // changing the pointer to the appropriately sized integer type.
2088 Constant *Op = CE->getOperand(0);
2089 Type *Ty = CE->getType();
2091 const MCExpr *OpExpr = lowerConstantForGV(Op, ProcessingGeneric);
2093 // We can emit the pointer value into this slot if the slot is an
2094 // integer slot equal to the size of the pointer.
2095 if (DL.getTypeAllocSize(Ty) == DL.getTypeAllocSize(Op->getType()))
2096 return OpExpr;
2098 // Otherwise the pointer is smaller than the resultant integer, mask off
2099 // the high bits so we are sure to get a proper truncation if the input is
2100 // a constant expr.
2101 unsigned InBits = DL.getTypeAllocSizeInBits(Op->getType());
2102 const MCExpr *MaskExpr = MCConstantExpr::create(~0ULL >> (64-InBits), Ctx);
2103 return MCBinaryExpr::createAnd(OpExpr, MaskExpr, Ctx);
2106 // The MC library also has a right-shift operator, but it isn't consistently
2107 // signed or unsigned between different targets.
2108 case Instruction::Add: {
2109 const MCExpr *LHS = lowerConstantForGV(CE->getOperand(0), ProcessingGeneric);
2110 const MCExpr *RHS = lowerConstantForGV(CE->getOperand(1), ProcessingGeneric);
2111 switch (CE->getOpcode()) {
2112 default: llvm_unreachable("Unknown binary operator constant cast expr");
2113 case Instruction::Add: return MCBinaryExpr::createAdd(LHS, RHS, Ctx);
2118 // If the code isn't optimized, there may be outstanding folding
2119 // opportunities. Attempt to fold the expression using DataLayout as a
2120 // last resort before giving up.
2121 Constant *C = ConstantFoldConstant(CE, getDataLayout());
2122 if (C != CE)
2123 return lowerConstantForGV(C, ProcessingGeneric);
2125 // Otherwise report the problem to the user.
2126 std::string S;
2127 raw_string_ostream OS(S);
2128 OS << "Unsupported expression in static initializer: ";
2129 CE->printAsOperand(OS, /*PrintType=*/false,
2130 !MF ? nullptr : MF->getFunction().getParent());
2131 report_fatal_error(Twine(OS.str()));
2134 // Copy of MCExpr::print customized for NVPTX
2135 void NVPTXAsmPrinter::printMCExpr(const MCExpr &Expr, raw_ostream &OS) {
2136 switch (Expr.getKind()) {
2137 case MCExpr::Target:
2138 return cast<MCTargetExpr>(&Expr)->printImpl(OS, MAI);
2139 case MCExpr::Constant:
2140 OS << cast<MCConstantExpr>(Expr).getValue();
2141 return;
2143 case MCExpr::SymbolRef: {
2144 const MCSymbolRefExpr &SRE = cast<MCSymbolRefExpr>(Expr);
2145 const MCSymbol &Sym = SRE.getSymbol();
2146 Sym.print(OS, MAI);
2147 return;
2150 case MCExpr::Unary: {
2151 const MCUnaryExpr &UE = cast<MCUnaryExpr>(Expr);
2152 switch (UE.getOpcode()) {
2153 case MCUnaryExpr::LNot: OS << '!'; break;
2154 case MCUnaryExpr::Minus: OS << '-'; break;
2155 case MCUnaryExpr::Not: OS << '~'; break;
2156 case MCUnaryExpr::Plus: OS << '+'; break;
2158 printMCExpr(*UE.getSubExpr(), OS);
2159 return;
2162 case MCExpr::Binary: {
2163 const MCBinaryExpr &BE = cast<MCBinaryExpr>(Expr);
2165 // Only print parens around the LHS if it is non-trivial.
2166 if (isa<MCConstantExpr>(BE.getLHS()) || isa<MCSymbolRefExpr>(BE.getLHS()) ||
2167 isa<NVPTXGenericMCSymbolRefExpr>(BE.getLHS())) {
2168 printMCExpr(*BE.getLHS(), OS);
2169 } else {
2170 OS << '(';
2171 printMCExpr(*BE.getLHS(), OS);
2172 OS<< ')';
2175 switch (BE.getOpcode()) {
2176 case MCBinaryExpr::Add:
2177 // Print "X-42" instead of "X+-42".
2178 if (const MCConstantExpr *RHSC = dyn_cast<MCConstantExpr>(BE.getRHS())) {
2179 if (RHSC->getValue() < 0) {
2180 OS << RHSC->getValue();
2181 return;
2185 OS << '+';
2186 break;
2187 default: llvm_unreachable("Unhandled binary operator");
2190 // Only print parens around the LHS if it is non-trivial.
2191 if (isa<MCConstantExpr>(BE.getRHS()) || isa<MCSymbolRefExpr>(BE.getRHS())) {
2192 printMCExpr(*BE.getRHS(), OS);
2193 } else {
2194 OS << '(';
2195 printMCExpr(*BE.getRHS(), OS);
2196 OS << ')';
2198 return;
2202 llvm_unreachable("Invalid expression kind!");
2205 /// PrintAsmOperand - Print out an operand for an inline asm expression.
2207 bool NVPTXAsmPrinter::PrintAsmOperand(const MachineInstr *MI, unsigned OpNo,
2208 const char *ExtraCode, raw_ostream &O) {
2209 if (ExtraCode && ExtraCode[0]) {
2210 if (ExtraCode[1] != 0)
2211 return true; // Unknown modifier.
2213 switch (ExtraCode[0]) {
2214 default:
2215 // See if this is a generic print operand
2216 return AsmPrinter::PrintAsmOperand(MI, OpNo, ExtraCode, O);
2217 case 'r':
2218 break;
2222 printOperand(MI, OpNo, O);
2224 return false;
2227 bool NVPTXAsmPrinter::PrintAsmMemoryOperand(const MachineInstr *MI,
2228 unsigned OpNo,
2229 const char *ExtraCode,
2230 raw_ostream &O) {
2231 if (ExtraCode && ExtraCode[0])
2232 return true; // Unknown modifier
2234 O << '[';
2235 printMemOperand(MI, OpNo, O);
2236 O << ']';
2238 return false;
2241 void NVPTXAsmPrinter::printOperand(const MachineInstr *MI, unsigned OpNum,
2242 raw_ostream &O) {
2243 const MachineOperand &MO = MI->getOperand(OpNum);
2244 switch (MO.getType()) {
2245 case MachineOperand::MO_Register:
2246 if (MO.getReg().isPhysical()) {
2247 if (MO.getReg() == NVPTX::VRDepot)
2248 O << DEPOTNAME << getFunctionNumber();
2249 else
2250 O << NVPTXInstPrinter::getRegisterName(MO.getReg());
2251 } else {
2252 emitVirtualRegister(MO.getReg(), O);
2254 break;
2256 case MachineOperand::MO_Immediate:
2257 O << MO.getImm();
2258 break;
2260 case MachineOperand::MO_FPImmediate:
2261 printFPConstant(MO.getFPImm(), O);
2262 break;
2264 case MachineOperand::MO_GlobalAddress:
2265 PrintSymbolOperand(MO, O);
2266 break;
2268 case MachineOperand::MO_MachineBasicBlock:
2269 MO.getMBB()->getSymbol()->print(O, MAI);
2270 break;
2272 default:
2273 llvm_unreachable("Operand type not supported.");
2277 void NVPTXAsmPrinter::printMemOperand(const MachineInstr *MI, unsigned OpNum,
2278 raw_ostream &O, const char *Modifier) {
2279 printOperand(MI, OpNum, O);
2281 if (Modifier && strcmp(Modifier, "add") == 0) {
2282 O << ", ";
2283 printOperand(MI, OpNum + 1, O);
2284 } else {
2285 if (MI->getOperand(OpNum + 1).isImm() &&
2286 MI->getOperand(OpNum + 1).getImm() == 0)
2287 return; // don't print ',0' or '+0'
2288 O << "+";
2289 printOperand(MI, OpNum + 1, O);
2293 // Force static initialization.
2294 extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeNVPTXAsmPrinter() {
2295 RegisterAsmPrinter<NVPTXAsmPrinter> X(getTheNVPTXTarget32());
2296 RegisterAsmPrinter<NVPTXAsmPrinter> Y(getTheNVPTXTarget64());