1 //===-- NVPTXAsmPrinter.cpp - NVPTX LLVM assembly writer ------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file contains a printer that converts from our internal representation
10 // of machine-dependent LLVM code to NVPTX assembly language.
12 //===----------------------------------------------------------------------===//
14 #include "NVPTXAsmPrinter.h"
15 #include "MCTargetDesc/NVPTXBaseInfo.h"
16 #include "MCTargetDesc/NVPTXInstPrinter.h"
17 #include "MCTargetDesc/NVPTXMCAsmInfo.h"
18 #include "MCTargetDesc/NVPTXTargetStreamer.h"
20 #include "NVPTXMCExpr.h"
21 #include "NVPTXMachineFunctionInfo.h"
22 #include "NVPTXRegisterInfo.h"
23 #include "NVPTXSubtarget.h"
24 #include "NVPTXTargetMachine.h"
25 #include "NVPTXUtilities.h"
26 #include "TargetInfo/NVPTXTargetInfo.h"
27 #include "cl_common_defines.h"
28 #include "llvm/ADT/APFloat.h"
29 #include "llvm/ADT/APInt.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/DenseSet.h"
32 #include "llvm/ADT/SmallString.h"
33 #include "llvm/ADT/SmallVector.h"
34 #include "llvm/ADT/StringExtras.h"
35 #include "llvm/ADT/StringRef.h"
36 #include "llvm/ADT/Twine.h"
37 #include "llvm/Analysis/ConstantFolding.h"
38 #include "llvm/CodeGen/Analysis.h"
39 #include "llvm/CodeGen/MachineBasicBlock.h"
40 #include "llvm/CodeGen/MachineFrameInfo.h"
41 #include "llvm/CodeGen/MachineFunction.h"
42 #include "llvm/CodeGen/MachineInstr.h"
43 #include "llvm/CodeGen/MachineLoopInfo.h"
44 #include "llvm/CodeGen/MachineModuleInfo.h"
45 #include "llvm/CodeGen/MachineOperand.h"
46 #include "llvm/CodeGen/MachineRegisterInfo.h"
47 #include "llvm/CodeGen/TargetRegisterInfo.h"
48 #include "llvm/CodeGen/ValueTypes.h"
49 #include "llvm/CodeGenTypes/MachineValueType.h"
50 #include "llvm/IR/Attributes.h"
51 #include "llvm/IR/BasicBlock.h"
52 #include "llvm/IR/Constant.h"
53 #include "llvm/IR/Constants.h"
54 #include "llvm/IR/DataLayout.h"
55 #include "llvm/IR/DebugInfo.h"
56 #include "llvm/IR/DebugInfoMetadata.h"
57 #include "llvm/IR/DebugLoc.h"
58 #include "llvm/IR/DerivedTypes.h"
59 #include "llvm/IR/Function.h"
60 #include "llvm/IR/GlobalAlias.h"
61 #include "llvm/IR/GlobalValue.h"
62 #include "llvm/IR/GlobalVariable.h"
63 #include "llvm/IR/Instruction.h"
64 #include "llvm/IR/LLVMContext.h"
65 #include "llvm/IR/Module.h"
66 #include "llvm/IR/Operator.h"
67 #include "llvm/IR/Type.h"
68 #include "llvm/IR/User.h"
69 #include "llvm/MC/MCExpr.h"
70 #include "llvm/MC/MCInst.h"
71 #include "llvm/MC/MCInstrDesc.h"
72 #include "llvm/MC/MCStreamer.h"
73 #include "llvm/MC/MCSymbol.h"
74 #include "llvm/MC/TargetRegistry.h"
75 #include "llvm/Support/Alignment.h"
76 #include "llvm/Support/Casting.h"
77 #include "llvm/Support/CommandLine.h"
78 #include "llvm/Support/Endian.h"
79 #include "llvm/Support/ErrorHandling.h"
80 #include "llvm/Support/NativeFormatting.h"
81 #include "llvm/Support/raw_ostream.h"
82 #include "llvm/Target/TargetLoweringObjectFile.h"
83 #include "llvm/Target/TargetMachine.h"
84 #include "llvm/Transforms/Utils/UnrollLoop.h"
95 LowerCtorDtor("nvptx-lower-global-ctor-dtor",
96 cl::desc("Lower GPU ctor / dtors to globals on the device."),
97 cl::init(false), cl::Hidden
);
99 #define DEPOTNAME "__local_depot"
101 /// DiscoverDependentGlobals - Return a set of GlobalVariables on which \p V
104 DiscoverDependentGlobals(const Value
*V
,
105 DenseSet
<const GlobalVariable
*> &Globals
) {
106 if (const GlobalVariable
*GV
= dyn_cast
<GlobalVariable
>(V
))
109 if (const User
*U
= dyn_cast
<User
>(V
)) {
110 for (unsigned i
= 0, e
= U
->getNumOperands(); i
!= e
; ++i
) {
111 DiscoverDependentGlobals(U
->getOperand(i
), Globals
);
117 /// VisitGlobalVariableForEmission - Add \p GV to the list of GlobalVariable
118 /// instances to be emitted, but only after any dependents have been added
121 VisitGlobalVariableForEmission(const GlobalVariable
*GV
,
122 SmallVectorImpl
<const GlobalVariable
*> &Order
,
123 DenseSet
<const GlobalVariable
*> &Visited
,
124 DenseSet
<const GlobalVariable
*> &Visiting
) {
125 // Have we already visited this one?
126 if (Visited
.count(GV
))
129 // Do we have a circular dependency?
130 if (!Visiting
.insert(GV
).second
)
131 report_fatal_error("Circular dependency found in global variable set");
133 // Make sure we visit all dependents first
134 DenseSet
<const GlobalVariable
*> Others
;
135 for (unsigned i
= 0, e
= GV
->getNumOperands(); i
!= e
; ++i
)
136 DiscoverDependentGlobals(GV
->getOperand(i
), Others
);
138 for (const GlobalVariable
*GV
: Others
)
139 VisitGlobalVariableForEmission(GV
, Order
, Visited
, Visiting
);
141 // Now we can visit ourself
147 void NVPTXAsmPrinter::emitInstruction(const MachineInstr
*MI
) {
148 NVPTX_MC::verifyInstructionPredicates(MI
->getOpcode(),
149 getSubtargetInfo().getFeatureBits());
152 lowerToMCInst(MI
, Inst
);
153 EmitToStreamer(*OutStreamer
, Inst
);
156 // Handle symbol backtracking for targets that do not support image handles
157 bool NVPTXAsmPrinter::lowerImageHandleOperand(const MachineInstr
*MI
,
158 unsigned OpNo
, MCOperand
&MCOp
) {
159 const MachineOperand
&MO
= MI
->getOperand(OpNo
);
160 const MCInstrDesc
&MCID
= MI
->getDesc();
162 if (MCID
.TSFlags
& NVPTXII::IsTexFlag
) {
163 // This is a texture fetch, so operand 4 is a texref and operand 5 is
165 if (OpNo
== 4 && MO
.isImm()) {
166 lowerImageHandleSymbol(MO
.getImm(), MCOp
);
169 if (OpNo
== 5 && MO
.isImm() && !(MCID
.TSFlags
& NVPTXII::IsTexModeUnifiedFlag
)) {
170 lowerImageHandleSymbol(MO
.getImm(), MCOp
);
175 } else if (MCID
.TSFlags
& NVPTXII::IsSuldMask
) {
177 1 << (((MCID
.TSFlags
& NVPTXII::IsSuldMask
) >> NVPTXII::IsSuldShift
) - 1);
179 // For a surface load of vector size N, the Nth operand will be the surfref
180 if (OpNo
== VecSize
&& MO
.isImm()) {
181 lowerImageHandleSymbol(MO
.getImm(), MCOp
);
186 } else if (MCID
.TSFlags
& NVPTXII::IsSustFlag
) {
187 // This is a surface store, so operand 0 is a surfref
188 if (OpNo
== 0 && MO
.isImm()) {
189 lowerImageHandleSymbol(MO
.getImm(), MCOp
);
194 } else if (MCID
.TSFlags
& NVPTXII::IsSurfTexQueryFlag
) {
195 // This is a query, so operand 1 is a surfref/texref
196 if (OpNo
== 1 && MO
.isImm()) {
197 lowerImageHandleSymbol(MO
.getImm(), MCOp
);
207 void NVPTXAsmPrinter::lowerImageHandleSymbol(unsigned Index
, MCOperand
&MCOp
) {
209 TargetMachine
&TM
= const_cast<TargetMachine
&>(MF
->getTarget());
210 NVPTXTargetMachine
&nvTM
= static_cast<NVPTXTargetMachine
&>(TM
);
211 const NVPTXMachineFunctionInfo
*MFI
= MF
->getInfo
<NVPTXMachineFunctionInfo
>();
212 const char *Sym
= MFI
->getImageHandleSymbol(Index
);
213 StringRef SymName
= nvTM
.getStrPool().save(Sym
);
214 MCOp
= GetSymbolRef(OutContext
.getOrCreateSymbol(SymName
));
217 void NVPTXAsmPrinter::lowerToMCInst(const MachineInstr
*MI
, MCInst
&OutMI
) {
218 OutMI
.setOpcode(MI
->getOpcode());
219 // Special: Do not mangle symbol operand of CALL_PROTOTYPE
220 if (MI
->getOpcode() == NVPTX::CALL_PROTOTYPE
) {
221 const MachineOperand
&MO
= MI
->getOperand(0);
222 OutMI
.addOperand(GetSymbolRef(
223 OutContext
.getOrCreateSymbol(Twine(MO
.getSymbolName()))));
227 const NVPTXSubtarget
&STI
= MI
->getMF()->getSubtarget
<NVPTXSubtarget
>();
228 for (unsigned i
= 0, e
= MI
->getNumOperands(); i
!= e
; ++i
) {
229 const MachineOperand
&MO
= MI
->getOperand(i
);
232 if (!STI
.hasImageHandles()) {
233 if (lowerImageHandleOperand(MI
, i
, MCOp
)) {
234 OutMI
.addOperand(MCOp
);
239 if (lowerOperand(MO
, MCOp
))
240 OutMI
.addOperand(MCOp
);
244 bool NVPTXAsmPrinter::lowerOperand(const MachineOperand
&MO
,
246 switch (MO
.getType()) {
247 default: llvm_unreachable("unknown operand type");
248 case MachineOperand::MO_Register
:
249 MCOp
= MCOperand::createReg(encodeVirtualRegister(MO
.getReg()));
251 case MachineOperand::MO_Immediate
:
252 MCOp
= MCOperand::createImm(MO
.getImm());
254 case MachineOperand::MO_MachineBasicBlock
:
255 MCOp
= MCOperand::createExpr(MCSymbolRefExpr::create(
256 MO
.getMBB()->getSymbol(), OutContext
));
258 case MachineOperand::MO_ExternalSymbol
:
259 MCOp
= GetSymbolRef(GetExternalSymbolSymbol(MO
.getSymbolName()));
261 case MachineOperand::MO_GlobalAddress
:
262 MCOp
= GetSymbolRef(getSymbol(MO
.getGlobal()));
264 case MachineOperand::MO_FPImmediate
: {
265 const ConstantFP
*Cnt
= MO
.getFPImm();
266 const APFloat
&Val
= Cnt
->getValueAPF();
268 switch (Cnt
->getType()->getTypeID()) {
269 default: report_fatal_error("Unsupported FP type"); break;
271 MCOp
= MCOperand::createExpr(
272 NVPTXFloatMCExpr::createConstantFPHalf(Val
, OutContext
));
274 case Type::BFloatTyID
:
275 MCOp
= MCOperand::createExpr(
276 NVPTXFloatMCExpr::createConstantBFPHalf(Val
, OutContext
));
278 case Type::FloatTyID
:
279 MCOp
= MCOperand::createExpr(
280 NVPTXFloatMCExpr::createConstantFPSingle(Val
, OutContext
));
282 case Type::DoubleTyID
:
283 MCOp
= MCOperand::createExpr(
284 NVPTXFloatMCExpr::createConstantFPDouble(Val
, OutContext
));
293 unsigned NVPTXAsmPrinter::encodeVirtualRegister(unsigned Reg
) {
294 if (Register::isVirtualRegister(Reg
)) {
295 const TargetRegisterClass
*RC
= MRI
->getRegClass(Reg
);
297 DenseMap
<unsigned, unsigned> &RegMap
= VRegMapping
[RC
];
298 unsigned RegNum
= RegMap
[Reg
];
300 // Encode the register class in the upper 4 bits
301 // Must be kept in sync with NVPTXInstPrinter::printRegName
303 if (RC
== &NVPTX::Int1RegsRegClass
) {
305 } else if (RC
== &NVPTX::Int16RegsRegClass
) {
307 } else if (RC
== &NVPTX::Int32RegsRegClass
) {
309 } else if (RC
== &NVPTX::Int64RegsRegClass
) {
311 } else if (RC
== &NVPTX::Float32RegsRegClass
) {
313 } else if (RC
== &NVPTX::Float64RegsRegClass
) {
315 } else if (RC
== &NVPTX::Int128RegsRegClass
) {
318 report_fatal_error("Bad register class");
321 // Insert the vreg number
322 Ret
|= (RegNum
& 0x0FFFFFFF);
325 // Some special-use registers are actually physical registers.
326 // Encode this as the register class ID of 0 and the real register ID.
327 return Reg
& 0x0FFFFFFF;
331 MCOperand
NVPTXAsmPrinter::GetSymbolRef(const MCSymbol
*Symbol
) {
333 Expr
= MCSymbolRefExpr::create(Symbol
, MCSymbolRefExpr::VK_None
,
335 return MCOperand::createExpr(Expr
);
338 static bool ShouldPassAsArray(Type
*Ty
) {
339 return Ty
->isAggregateType() || Ty
->isVectorTy() || Ty
->isIntegerTy(128) ||
340 Ty
->isHalfTy() || Ty
->isBFloatTy();
343 void NVPTXAsmPrinter::printReturnValStr(const Function
*F
, raw_ostream
&O
) {
344 const DataLayout
&DL
= getDataLayout();
345 const NVPTXSubtarget
&STI
= TM
.getSubtarget
<NVPTXSubtarget
>(*F
);
346 const auto *TLI
= cast
<NVPTXTargetLowering
>(STI
.getTargetLowering());
348 Type
*Ty
= F
->getReturnType();
350 bool isABI
= (STI
.getSmVersion() >= 20);
352 if (Ty
->getTypeID() == Type::VoidTyID
)
357 if ((Ty
->isFloatingPointTy() || Ty
->isIntegerTy()) &&
358 !ShouldPassAsArray(Ty
)) {
360 if (auto *ITy
= dyn_cast
<IntegerType
>(Ty
)) {
361 size
= ITy
->getBitWidth();
363 assert(Ty
->isFloatingPointTy() && "Floating point type expected here");
364 size
= Ty
->getPrimitiveSizeInBits();
366 size
= promoteScalarArgumentSize(size
);
367 O
<< ".param .b" << size
<< " func_retval0";
368 } else if (isa
<PointerType
>(Ty
)) {
369 O
<< ".param .b" << TLI
->getPointerTy(DL
).getSizeInBits()
371 } else if (ShouldPassAsArray(Ty
)) {
372 unsigned totalsz
= DL
.getTypeAllocSize(Ty
);
373 Align RetAlignment
= TLI
->getFunctionArgumentAlignment(
374 F
, Ty
, AttributeList::ReturnIndex
, DL
);
375 O
<< ".param .align " << RetAlignment
.value() << " .b8 func_retval0["
378 llvm_unreachable("Unknown return type");
380 SmallVector
<EVT
, 16> vtparts
;
381 ComputeValueVTs(*TLI
, DL
, Ty
, vtparts
);
383 for (unsigned i
= 0, e
= vtparts
.size(); i
!= e
; ++i
) {
385 EVT elemtype
= vtparts
[i
];
386 if (vtparts
[i
].isVector()) {
387 elems
= vtparts
[i
].getVectorNumElements();
388 elemtype
= vtparts
[i
].getVectorElementType();
391 for (unsigned j
= 0, je
= elems
; j
!= je
; ++j
) {
392 unsigned sz
= elemtype
.getSizeInBits();
393 if (elemtype
.isInteger())
394 sz
= promoteScalarArgumentSize(sz
);
395 O
<< ".reg .b" << sz
<< " func_retval" << idx
;
407 void NVPTXAsmPrinter::printReturnValStr(const MachineFunction
&MF
,
409 const Function
&F
= MF
.getFunction();
410 printReturnValStr(&F
, O
);
413 // Return true if MBB is the header of a loop marked with
414 // llvm.loop.unroll.disable or llvm.loop.unroll.count=1.
415 bool NVPTXAsmPrinter::isLoopHeaderOfNoUnroll(
416 const MachineBasicBlock
&MBB
) const {
417 MachineLoopInfo
&LI
= getAnalysis
<MachineLoopInfoWrapperPass
>().getLI();
418 // We insert .pragma "nounroll" only to the loop header.
419 if (!LI
.isLoopHeader(&MBB
))
422 // llvm.loop.unroll.disable is marked on the back edges of a loop. Therefore,
423 // we iterate through each back edge of the loop with header MBB, and check
424 // whether its metadata contains llvm.loop.unroll.disable.
425 for (const MachineBasicBlock
*PMBB
: MBB
.predecessors()) {
426 if (LI
.getLoopFor(PMBB
) != LI
.getLoopFor(&MBB
)) {
427 // Edges from other loops to MBB are not back edges.
430 if (const BasicBlock
*PBB
= PMBB
->getBasicBlock()) {
432 PBB
->getTerminator()->getMetadata(LLVMContext::MD_loop
)) {
433 if (GetUnrollMetadata(LoopID
, "llvm.loop.unroll.disable"))
435 if (MDNode
*UnrollCountMD
=
436 GetUnrollMetadata(LoopID
, "llvm.loop.unroll.count")) {
437 if (mdconst::extract
<ConstantInt
>(UnrollCountMD
->getOperand(1))
447 void NVPTXAsmPrinter::emitBasicBlockStart(const MachineBasicBlock
&MBB
) {
448 AsmPrinter::emitBasicBlockStart(MBB
);
449 if (isLoopHeaderOfNoUnroll(MBB
))
450 OutStreamer
->emitRawText(StringRef("\t.pragma \"nounroll\";\n"));
453 void NVPTXAsmPrinter::emitFunctionEntryLabel() {
454 SmallString
<128> Str
;
455 raw_svector_ostream
O(Str
);
457 if (!GlobalsEmitted
) {
458 emitGlobals(*MF
->getFunction().getParent());
459 GlobalsEmitted
= true;
463 MRI
= &MF
->getRegInfo();
464 F
= &MF
->getFunction();
465 emitLinkageDirective(F
, O
);
466 if (isKernelFunction(*F
))
470 printReturnValStr(*MF
, O
);
473 CurrentFnSym
->print(O
, MAI
);
475 emitFunctionParamList(F
, O
);
478 if (isKernelFunction(*F
))
479 emitKernelFunctionDirectives(*F
, O
);
481 if (shouldEmitPTXNoReturn(F
, TM
))
484 OutStreamer
->emitRawText(O
.str());
487 // Emit open brace for function body.
488 OutStreamer
->emitRawText(StringRef("{\n"));
489 setAndEmitFunctionVirtualRegisters(*MF
);
490 encodeDebugInfoRegisterNumbers(*MF
);
491 // Emit initial .loc debug directive for correct relocation symbol data.
492 if (const DISubprogram
*SP
= MF
->getFunction().getSubprogram()) {
493 assert(SP
->getUnit());
494 if (!SP
->getUnit()->isDebugDirectivesOnly())
495 emitInitialRawDwarfLocDirective(*MF
);
499 bool NVPTXAsmPrinter::runOnMachineFunction(MachineFunction
&F
) {
500 bool Result
= AsmPrinter::runOnMachineFunction(F
);
501 // Emit closing brace for the body of function F.
502 // The closing brace must be emitted here because we need to emit additional
503 // debug labels/data after the last basic block.
504 // We need to emit the closing brace here because we don't have function that
505 // finished emission of the function body.
506 OutStreamer
->emitRawText(StringRef("}\n"));
510 void NVPTXAsmPrinter::emitFunctionBodyStart() {
511 SmallString
<128> Str
;
512 raw_svector_ostream
O(Str
);
513 emitDemotedVars(&MF
->getFunction(), O
);
514 OutStreamer
->emitRawText(O
.str());
517 void NVPTXAsmPrinter::emitFunctionBodyEnd() {
521 const MCSymbol
*NVPTXAsmPrinter::getFunctionFrameSymbol() const {
522 SmallString
<128> Str
;
523 raw_svector_ostream(Str
) << DEPOTNAME
<< getFunctionNumber();
524 return OutContext
.getOrCreateSymbol(Str
);
527 void NVPTXAsmPrinter::emitImplicitDef(const MachineInstr
*MI
) const {
528 Register RegNo
= MI
->getOperand(0).getReg();
529 if (RegNo
.isVirtual()) {
530 OutStreamer
->AddComment(Twine("implicit-def: ") +
531 getVirtualRegisterName(RegNo
));
533 const NVPTXSubtarget
&STI
= MI
->getMF()->getSubtarget
<NVPTXSubtarget
>();
534 OutStreamer
->AddComment(Twine("implicit-def: ") +
535 STI
.getRegisterInfo()->getName(RegNo
));
537 OutStreamer
->addBlankLine();
540 void NVPTXAsmPrinter::emitKernelFunctionDirectives(const Function
&F
,
541 raw_ostream
&O
) const {
542 // If the NVVM IR has some of reqntid* specified, then output
543 // the reqntid directive, and set the unspecified ones to 1.
544 // If none of Reqntid* is specified, don't output reqntid directive.
545 std::optional
<unsigned> Reqntidx
= getReqNTIDx(F
);
546 std::optional
<unsigned> Reqntidy
= getReqNTIDy(F
);
547 std::optional
<unsigned> Reqntidz
= getReqNTIDz(F
);
549 if (Reqntidx
|| Reqntidy
|| Reqntidz
)
550 O
<< ".reqntid " << Reqntidx
.value_or(1) << ", " << Reqntidy
.value_or(1)
551 << ", " << Reqntidz
.value_or(1) << "\n";
553 // If the NVVM IR has some of maxntid* specified, then output
554 // the maxntid directive, and set the unspecified ones to 1.
555 // If none of maxntid* is specified, don't output maxntid directive.
556 std::optional
<unsigned> Maxntidx
= getMaxNTIDx(F
);
557 std::optional
<unsigned> Maxntidy
= getMaxNTIDy(F
);
558 std::optional
<unsigned> Maxntidz
= getMaxNTIDz(F
);
560 if (Maxntidx
|| Maxntidy
|| Maxntidz
)
561 O
<< ".maxntid " << Maxntidx
.value_or(1) << ", " << Maxntidy
.value_or(1)
562 << ", " << Maxntidz
.value_or(1) << "\n";
564 if (const auto Mincta
= getMinCTASm(F
))
565 O
<< ".minnctapersm " << *Mincta
<< "\n";
567 if (const auto Maxnreg
= getMaxNReg(F
))
568 O
<< ".maxnreg " << *Maxnreg
<< "\n";
570 // .maxclusterrank directive requires SM_90 or higher, make sure that we
571 // filter it out for lower SM versions, as it causes a hard ptxas crash.
572 const NVPTXTargetMachine
&NTM
= static_cast<const NVPTXTargetMachine
&>(TM
);
573 const auto *STI
= static_cast<const NVPTXSubtarget
*>(NTM
.getSubtargetImpl());
575 if (STI
->getSmVersion() >= 90) {
576 std::optional
<unsigned> ClusterX
= getClusterDimx(F
);
577 std::optional
<unsigned> ClusterY
= getClusterDimy(F
);
578 std::optional
<unsigned> ClusterZ
= getClusterDimz(F
);
580 if (ClusterX
|| ClusterY
|| ClusterZ
) {
581 O
<< ".explicitcluster\n";
582 if (ClusterX
.value_or(1) != 0) {
583 assert(ClusterY
.value_or(1) && ClusterZ
.value_or(1) &&
584 "cluster_dim_x != 0 implies cluster_dim_y and cluster_dim_z "
585 "should be non-zero as well");
587 O
<< ".reqnctapercluster " << ClusterX
.value_or(1) << ", "
588 << ClusterY
.value_or(1) << ", " << ClusterZ
.value_or(1) << "\n";
590 assert(!ClusterY
.value_or(1) && !ClusterZ
.value_or(1) &&
591 "cluster_dim_x == 0 implies cluster_dim_y and cluster_dim_z "
592 "should be 0 as well");
595 if (const auto Maxclusterrank
= getMaxClusterRank(F
))
596 O
<< ".maxclusterrank " << *Maxclusterrank
<< "\n";
600 std::string
NVPTXAsmPrinter::getVirtualRegisterName(unsigned Reg
) const {
601 const TargetRegisterClass
*RC
= MRI
->getRegClass(Reg
);
604 raw_string_ostream
NameStr(Name
);
606 VRegRCMap::const_iterator I
= VRegMapping
.find(RC
);
607 assert(I
!= VRegMapping
.end() && "Bad register class");
608 const DenseMap
<unsigned, unsigned> &RegMap
= I
->second
;
610 VRegMap::const_iterator VI
= RegMap
.find(Reg
);
611 assert(VI
!= RegMap
.end() && "Bad virtual register");
612 unsigned MappedVR
= VI
->second
;
614 NameStr
<< getNVPTXRegClassStr(RC
) << MappedVR
;
619 void NVPTXAsmPrinter::emitVirtualRegister(unsigned int vr
,
621 O
<< getVirtualRegisterName(vr
);
624 void NVPTXAsmPrinter::emitAliasDeclaration(const GlobalAlias
*GA
,
626 const Function
*F
= dyn_cast_or_null
<Function
>(GA
->getAliaseeObject());
627 if (!F
|| isKernelFunction(*F
) || F
->isDeclaration())
629 "NVPTX aliasee must be a non-kernel function definition");
631 if (GA
->hasLinkOnceLinkage() || GA
->hasWeakLinkage() ||
632 GA
->hasAvailableExternallyLinkage() || GA
->hasCommonLinkage())
633 report_fatal_error("NVPTX aliasee must not be '.weak'");
635 emitDeclarationWithName(F
, getSymbol(GA
), O
);
638 void NVPTXAsmPrinter::emitDeclaration(const Function
*F
, raw_ostream
&O
) {
639 emitDeclarationWithName(F
, getSymbol(F
), O
);
642 void NVPTXAsmPrinter::emitDeclarationWithName(const Function
*F
, MCSymbol
*S
,
644 emitLinkageDirective(F
, O
);
645 if (isKernelFunction(*F
))
649 printReturnValStr(F
, O
);
652 emitFunctionParamList(F
, O
);
654 if (shouldEmitPTXNoReturn(F
, TM
))
659 static bool usedInGlobalVarDef(const Constant
*C
) {
663 if (const GlobalVariable
*GV
= dyn_cast
<GlobalVariable
>(C
)) {
664 return GV
->getName() != "llvm.used";
667 for (const User
*U
: C
->users())
668 if (const Constant
*C
= dyn_cast
<Constant
>(U
))
669 if (usedInGlobalVarDef(C
))
675 static bool usedInOneFunc(const User
*U
, Function
const *&oneFunc
) {
676 if (const GlobalVariable
*othergv
= dyn_cast
<GlobalVariable
>(U
)) {
677 if (othergv
->getName() == "llvm.used")
681 if (const Instruction
*instr
= dyn_cast
<Instruction
>(U
)) {
682 if (instr
->getParent() && instr
->getParent()->getParent()) {
683 const Function
*curFunc
= instr
->getParent()->getParent();
684 if (oneFunc
&& (curFunc
!= oneFunc
))
692 for (const User
*UU
: U
->users())
693 if (!usedInOneFunc(UU
, oneFunc
))
699 /* Find out if a global variable can be demoted to local scope.
700 * Currently, this is valid for CUDA shared variables, which have local
701 * scope and global lifetime. So the conditions to check are :
702 * 1. Is the global variable in shared address space?
703 * 2. Does it have local linkage?
704 * 3. Is the global variable referenced only in one function?
706 static bool canDemoteGlobalVar(const GlobalVariable
*gv
, Function
const *&f
) {
707 if (!gv
->hasLocalLinkage())
709 PointerType
*Pty
= gv
->getType();
710 if (Pty
->getAddressSpace() != ADDRESS_SPACE_SHARED
)
713 const Function
*oneFunc
= nullptr;
715 bool flag
= usedInOneFunc(gv
, oneFunc
);
724 static bool useFuncSeen(const Constant
*C
,
725 DenseMap
<const Function
*, bool> &seenMap
) {
726 for (const User
*U
: C
->users()) {
727 if (const Constant
*cu
= dyn_cast
<Constant
>(U
)) {
728 if (useFuncSeen(cu
, seenMap
))
730 } else if (const Instruction
*I
= dyn_cast
<Instruction
>(U
)) {
731 const BasicBlock
*bb
= I
->getParent();
734 const Function
*caller
= bb
->getParent();
737 if (seenMap
.contains(caller
))
744 void NVPTXAsmPrinter::emitDeclarations(const Module
&M
, raw_ostream
&O
) {
745 DenseMap
<const Function
*, bool> seenMap
;
746 for (const Function
&F
: M
) {
747 if (F
.getAttributes().hasFnAttr("nvptx-libcall-callee")) {
748 emitDeclaration(&F
, O
);
752 if (F
.isDeclaration()) {
755 if (F
.getIntrinsicID())
757 emitDeclaration(&F
, O
);
760 for (const User
*U
: F
.users()) {
761 if (const Constant
*C
= dyn_cast
<Constant
>(U
)) {
762 if (usedInGlobalVarDef(C
)) {
763 // The use is in the initialization of a global variable
764 // that is a function pointer, so print a declaration
765 // for the original function
766 emitDeclaration(&F
, O
);
769 // Emit a declaration of this function if the function that
770 // uses this constant expr has already been seen.
771 if (useFuncSeen(C
, seenMap
)) {
772 emitDeclaration(&F
, O
);
777 if (!isa
<Instruction
>(U
))
779 const Instruction
*instr
= cast
<Instruction
>(U
);
780 const BasicBlock
*bb
= instr
->getParent();
783 const Function
*caller
= bb
->getParent();
787 // If a caller has already been seen, then the caller is
788 // appearing in the module before the callee. so print out
789 // a declaration for the callee.
790 if (seenMap
.contains(caller
)) {
791 emitDeclaration(&F
, O
);
797 for (const GlobalAlias
&GA
: M
.aliases())
798 emitAliasDeclaration(&GA
, O
);
801 static bool isEmptyXXStructor(GlobalVariable
*GV
) {
802 if (!GV
) return true;
803 const ConstantArray
*InitList
= dyn_cast
<ConstantArray
>(GV
->getInitializer());
804 if (!InitList
) return true; // Not an array; we don't know how to parse.
805 return InitList
->getNumOperands() == 0;
808 void NVPTXAsmPrinter::emitStartOfAsmFile(Module
&M
) {
809 // Construct a default subtarget off of the TargetMachine defaults. The
810 // rest of NVPTX isn't friendly to change subtargets per function and
811 // so the default TargetMachine will have all of the options.
812 const NVPTXTargetMachine
&NTM
= static_cast<const NVPTXTargetMachine
&>(TM
);
813 const auto* STI
= static_cast<const NVPTXSubtarget
*>(NTM
.getSubtargetImpl());
814 SmallString
<128> Str1
;
815 raw_svector_ostream
OS1(Str1
);
817 // Emit header before any dwarf directives are emitted below.
818 emitHeader(M
, OS1
, *STI
);
819 OutStreamer
->emitRawText(OS1
.str());
822 bool NVPTXAsmPrinter::doInitialization(Module
&M
) {
823 const NVPTXTargetMachine
&NTM
= static_cast<const NVPTXTargetMachine
&>(TM
);
824 const NVPTXSubtarget
&STI
=
825 *static_cast<const NVPTXSubtarget
*>(NTM
.getSubtargetImpl());
826 if (M
.alias_size() && (STI
.getPTXVersion() < 63 || STI
.getSmVersion() < 30))
827 report_fatal_error(".alias requires PTX version >= 6.3 and sm_30");
829 // OpenMP supports NVPTX global constructors and destructors.
830 bool IsOpenMP
= M
.getModuleFlag("openmp") != nullptr;
832 if (!isEmptyXXStructor(M
.getNamedGlobal("llvm.global_ctors")) &&
833 !LowerCtorDtor
&& !IsOpenMP
) {
835 "Module has a nontrivial global ctor, which NVPTX does not support.");
836 return true; // error
838 if (!isEmptyXXStructor(M
.getNamedGlobal("llvm.global_dtors")) &&
839 !LowerCtorDtor
&& !IsOpenMP
) {
841 "Module has a nontrivial global dtor, which NVPTX does not support.");
842 return true; // error
845 // We need to call the parent's one explicitly.
846 bool Result
= AsmPrinter::doInitialization(M
);
848 GlobalsEmitted
= false;
853 void NVPTXAsmPrinter::emitGlobals(const Module
&M
) {
854 SmallString
<128> Str2
;
855 raw_svector_ostream
OS2(Str2
);
857 emitDeclarations(M
, OS2
);
859 // As ptxas does not support forward references of globals, we need to first
860 // sort the list of module-level globals in def-use order. We visit each
861 // global variable in order, and ensure that we emit it *after* its dependent
862 // globals. We use a little extra memory maintaining both a set and a list to
863 // have fast searches while maintaining a strict ordering.
864 SmallVector
<const GlobalVariable
*, 8> Globals
;
865 DenseSet
<const GlobalVariable
*> GVVisited
;
866 DenseSet
<const GlobalVariable
*> GVVisiting
;
868 // Visit each global variable, in order
869 for (const GlobalVariable
&I
: M
.globals())
870 VisitGlobalVariableForEmission(&I
, Globals
, GVVisited
, GVVisiting
);
872 assert(GVVisited
.size() == M
.global_size() && "Missed a global variable");
873 assert(GVVisiting
.size() == 0 && "Did not fully process a global variable");
875 const NVPTXTargetMachine
&NTM
= static_cast<const NVPTXTargetMachine
&>(TM
);
876 const NVPTXSubtarget
&STI
=
877 *static_cast<const NVPTXSubtarget
*>(NTM
.getSubtargetImpl());
879 // Print out module-level global variables in proper order
880 for (const GlobalVariable
*GV
: Globals
)
881 printModuleLevelGV(GV
, OS2
, /*processDemoted=*/false, STI
);
885 OutStreamer
->emitRawText(OS2
.str());
888 void NVPTXAsmPrinter::emitGlobalAlias(const Module
&M
, const GlobalAlias
&GA
) {
889 SmallString
<128> Str
;
890 raw_svector_ostream
OS(Str
);
892 MCSymbol
*Name
= getSymbol(&GA
);
894 OS
<< ".alias " << Name
->getName() << ", " << GA
.getAliaseeObject()->getName()
897 OutStreamer
->emitRawText(OS
.str());
900 void NVPTXAsmPrinter::emitHeader(Module
&M
, raw_ostream
&O
,
901 const NVPTXSubtarget
&STI
) {
903 O
<< "// Generated by LLVM NVPTX Back-End\n";
907 unsigned PTXVersion
= STI
.getPTXVersion();
908 O
<< ".version " << (PTXVersion
/ 10) << "." << (PTXVersion
% 10) << "\n";
911 O
<< STI
.getTargetName();
913 const NVPTXTargetMachine
&NTM
= static_cast<const NVPTXTargetMachine
&>(TM
);
914 if (NTM
.getDrvInterface() == NVPTX::NVCL
)
915 O
<< ", texmode_independent";
917 bool HasFullDebugInfo
= false;
918 for (DICompileUnit
*CU
: M
.debug_compile_units()) {
919 switch(CU
->getEmissionKind()) {
920 case DICompileUnit::NoDebug
:
921 case DICompileUnit::DebugDirectivesOnly
:
923 case DICompileUnit::LineTablesOnly
:
924 case DICompileUnit::FullDebug
:
925 HasFullDebugInfo
= true;
928 if (HasFullDebugInfo
)
931 if (HasFullDebugInfo
)
936 O
<< ".address_size ";
946 bool NVPTXAsmPrinter::doFinalization(Module
&M
) {
947 // If we did not emit any functions, then the global declarations have not
949 if (!GlobalsEmitted
) {
951 GlobalsEmitted
= true;
954 // call doFinalization
955 bool ret
= AsmPrinter::doFinalization(M
);
957 clearAnnotationCache(&M
);
960 static_cast<NVPTXTargetStreamer
*>(OutStreamer
->getTargetStreamer());
961 // Close the last emitted section
962 if (hasDebugInfo()) {
963 TS
->closeLastSection();
964 // Emit empty .debug_macinfo section for better support of the empty files.
965 OutStreamer
->emitRawText("\t.section\t.debug_macinfo\t{\t}");
968 // Output last DWARF .file directives, if any.
969 TS
->outputDwarfFileDirectives();
974 // This function emits appropriate linkage directives for
975 // functions and global variables.
977 // extern function declaration -> .extern
978 // extern function definition -> .visible
979 // external global variable with init -> .visible
980 // external without init -> .extern
981 // appending -> not allowed, assert.
982 // for any linkage other than
983 // internal, private, linker_private,
984 // linker_private_weak, linker_private_weak_def_auto,
987 void NVPTXAsmPrinter::emitLinkageDirective(const GlobalValue
*V
,
989 if (static_cast<NVPTXTargetMachine
&>(TM
).getDrvInterface() == NVPTX::CUDA
) {
990 if (V
->hasExternalLinkage()) {
991 if (isa
<GlobalVariable
>(V
)) {
992 const GlobalVariable
*GVar
= cast
<GlobalVariable
>(V
);
994 if (GVar
->hasInitializer())
999 } else if (V
->isDeclaration())
1003 } else if (V
->hasAppendingLinkage()) {
1005 msg
.append("Error: ");
1006 msg
.append("Symbol ");
1008 msg
.append(std::string(V
->getName()));
1009 msg
.append("has unsupported appending linkage type");
1010 llvm_unreachable(msg
.c_str());
1011 } else if (!V
->hasInternalLinkage() &&
1012 !V
->hasPrivateLinkage()) {
1018 void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable
*GVar
,
1019 raw_ostream
&O
, bool processDemoted
,
1020 const NVPTXSubtarget
&STI
) {
1022 if (GVar
->hasSection()) {
1023 if (GVar
->getSection() == "llvm.metadata")
1027 // Skip LLVM intrinsic global variables
1028 if (GVar
->getName().starts_with("llvm.") ||
1029 GVar
->getName().starts_with("nvvm."))
1032 const DataLayout
&DL
= getDataLayout();
1034 // GlobalVariables are always constant pointers themselves.
1035 Type
*ETy
= GVar
->getValueType();
1037 if (GVar
->hasExternalLinkage()) {
1038 if (GVar
->hasInitializer())
1042 } else if (STI
.getPTXVersion() >= 50 && GVar
->hasCommonLinkage() &&
1043 GVar
->getAddressSpace() == ADDRESS_SPACE_GLOBAL
) {
1045 } else if (GVar
->hasLinkOnceLinkage() || GVar
->hasWeakLinkage() ||
1046 GVar
->hasAvailableExternallyLinkage() ||
1047 GVar
->hasCommonLinkage()) {
1051 if (isTexture(*GVar
)) {
1052 O
<< ".global .texref " << getTextureName(*GVar
) << ";\n";
1056 if (isSurface(*GVar
)) {
1057 O
<< ".global .surfref " << getSurfaceName(*GVar
) << ";\n";
1061 if (GVar
->isDeclaration()) {
1062 // (extern) declarations, no definition or initializer
1063 // Currently the only known declaration is for an automatic __local
1064 // (.shared) promoted to global.
1065 emitPTXGlobalVariable(GVar
, O
, STI
);
1070 if (isSampler(*GVar
)) {
1071 O
<< ".global .samplerref " << getSamplerName(*GVar
);
1073 const Constant
*Initializer
= nullptr;
1074 if (GVar
->hasInitializer())
1075 Initializer
= GVar
->getInitializer();
1076 const ConstantInt
*CI
= nullptr;
1078 CI
= dyn_cast
<ConstantInt
>(Initializer
);
1080 unsigned sample
= CI
->getZExtValue();
1085 addr
= ((sample
& __CLK_ADDRESS_MASK
) >> __CLK_ADDRESS_BASE
);
1087 O
<< "addr_mode_" << i
<< " = ";
1093 O
<< "clamp_to_border";
1096 O
<< "clamp_to_edge";
1107 O
<< "filter_mode = ";
1108 switch ((sample
& __CLK_FILTER_MASK
) >> __CLK_FILTER_BASE
) {
1116 llvm_unreachable("Anisotropic filtering is not supported");
1121 if (!((sample
& __CLK_NORMALIZED_MASK
) >> __CLK_NORMALIZED_BASE
)) {
1122 O
<< ", force_unnormalized_coords = 1";
1131 if (GVar
->hasPrivateLinkage()) {
1132 if (strncmp(GVar
->getName().data(), "unrollpragma", 12) == 0)
1135 // FIXME - need better way (e.g. Metadata) to avoid generating this global
1136 if (strncmp(GVar
->getName().data(), "filename", 8) == 0)
1138 if (GVar
->use_empty())
1142 const Function
*demotedFunc
= nullptr;
1143 if (!processDemoted
&& canDemoteGlobalVar(GVar
, demotedFunc
)) {
1144 O
<< "// " << GVar
->getName() << " has been demoted\n";
1145 localDecls
[demotedFunc
].push_back(GVar
);
1150 emitPTXAddressSpace(GVar
->getAddressSpace(), O
);
1152 if (isManaged(*GVar
)) {
1153 if (STI
.getPTXVersion() < 40 || STI
.getSmVersion() < 30) {
1155 ".attribute(.managed) requires PTX version >= 4.0 and sm_30");
1157 O
<< " .attribute(.managed)";
1160 if (MaybeAlign A
= GVar
->getAlign())
1161 O
<< " .align " << A
->value();
1163 O
<< " .align " << (int)DL
.getPrefTypeAlign(ETy
).value();
1165 if (ETy
->isFloatingPointTy() || ETy
->isPointerTy() ||
1166 (ETy
->isIntegerTy() && ETy
->getScalarSizeInBits() <= 64)) {
1168 // Special case: ABI requires that we use .u8 for predicates
1169 if (ETy
->isIntegerTy(1))
1172 O
<< getPTXFundamentalTypeStr(ETy
, false);
1174 getSymbol(GVar
)->print(O
, MAI
);
1176 // Ptx allows variable initilization only for constant and global state
1178 if (GVar
->hasInitializer()) {
1179 if ((GVar
->getAddressSpace() == ADDRESS_SPACE_GLOBAL
) ||
1180 (GVar
->getAddressSpace() == ADDRESS_SPACE_CONST
)) {
1181 const Constant
*Initializer
= GVar
->getInitializer();
1182 // 'undef' is treated as there is no value specified.
1183 if (!Initializer
->isNullValue() && !isa
<UndefValue
>(Initializer
)) {
1185 printScalarConstant(Initializer
, O
);
1188 // The frontend adds zero-initializer to device and constant variables
1189 // that don't have an initial value, and UndefValue to shared
1190 // variables, so skip warning for this case.
1191 if (!GVar
->getInitializer()->isNullValue() &&
1192 !isa
<UndefValue
>(GVar
->getInitializer())) {
1193 report_fatal_error("initial value of '" + GVar
->getName() +
1194 "' is not allowed in addrspace(" +
1195 Twine(GVar
->getAddressSpace()) + ")");
1200 uint64_t ElementSize
= 0;
1202 // Although PTX has direct support for struct type and array type and
1203 // LLVM IR is very similar to PTX, the LLVM CodeGen does not support for
1204 // targets that support these high level field accesses. Structs, arrays
1205 // and vectors are lowered into arrays of bytes.
1206 switch (ETy
->getTypeID()) {
1207 case Type::IntegerTyID
: // Integers larger than 64 bits
1208 case Type::StructTyID
:
1209 case Type::ArrayTyID
:
1210 case Type::FixedVectorTyID
:
1211 ElementSize
= DL
.getTypeStoreSize(ETy
);
1212 // Ptx allows variable initilization only for constant and
1213 // global state spaces.
1214 if (((GVar
->getAddressSpace() == ADDRESS_SPACE_GLOBAL
) ||
1215 (GVar
->getAddressSpace() == ADDRESS_SPACE_CONST
)) &&
1216 GVar
->hasInitializer()) {
1217 const Constant
*Initializer
= GVar
->getInitializer();
1218 if (!isa
<UndefValue
>(Initializer
) && !Initializer
->isNullValue()) {
1219 AggBuffer
aggBuffer(ElementSize
, *this);
1220 bufferAggregateConstant(Initializer
, &aggBuffer
);
1221 if (aggBuffer
.numSymbols()) {
1222 unsigned int ptrSize
= MAI
->getCodePointerSize();
1223 if (ElementSize
% ptrSize
||
1224 !aggBuffer
.allSymbolsAligned(ptrSize
)) {
1225 // Print in bytes and use the mask() operator for pointers.
1226 if (!STI
.hasMaskOperator())
1228 "initialized packed aggregate with pointers '" +
1230 "' requires at least PTX ISA version 7.1");
1232 getSymbol(GVar
)->print(O
, MAI
);
1233 O
<< "[" << ElementSize
<< "] = {";
1234 aggBuffer
.printBytes(O
);
1237 O
<< " .u" << ptrSize
* 8 << " ";
1238 getSymbol(GVar
)->print(O
, MAI
);
1239 O
<< "[" << ElementSize
/ ptrSize
<< "] = {";
1240 aggBuffer
.printWords(O
);
1245 getSymbol(GVar
)->print(O
, MAI
);
1246 O
<< "[" << ElementSize
<< "] = {";
1247 aggBuffer
.printBytes(O
);
1252 getSymbol(GVar
)->print(O
, MAI
);
1261 getSymbol(GVar
)->print(O
, MAI
);
1270 llvm_unreachable("type not supported yet");
1276 void NVPTXAsmPrinter::AggBuffer::printSymbol(unsigned nSym
, raw_ostream
&os
) {
1277 const Value
*v
= Symbols
[nSym
];
1278 const Value
*v0
= SymbolsBeforeStripping
[nSym
];
1279 if (const GlobalValue
*GVar
= dyn_cast
<GlobalValue
>(v
)) {
1280 MCSymbol
*Name
= AP
.getSymbol(GVar
);
1281 PointerType
*PTy
= dyn_cast
<PointerType
>(v0
->getType());
1282 // Is v0 a generic pointer?
1283 bool isGenericPointer
= PTy
&& PTy
->getAddressSpace() == 0;
1284 if (EmitGeneric
&& isGenericPointer
&& !isa
<Function
>(v
)) {
1286 Name
->print(os
, AP
.MAI
);
1289 Name
->print(os
, AP
.MAI
);
1291 } else if (const ConstantExpr
*CExpr
= dyn_cast
<ConstantExpr
>(v0
)) {
1292 const MCExpr
*Expr
= AP
.lowerConstantForGV(cast
<Constant
>(CExpr
), false);
1293 AP
.printMCExpr(*Expr
, os
);
1295 llvm_unreachable("symbol type unknown");
1298 void NVPTXAsmPrinter::AggBuffer::printBytes(raw_ostream
&os
) {
1299 unsigned int ptrSize
= AP
.MAI
->getCodePointerSize();
1300 // Do not emit trailing zero initializers. They will be zero-initialized by
1301 // ptxas. This saves on both space requirements for the generated PTX and on
1302 // memory use by ptxas. (See:
1303 // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#global-state-space)
1304 unsigned int InitializerCount
= size
;
1305 // TODO: symbols make this harder, but it would still be good to trim trailing
1306 // 0s for aggs with symbols as well.
1307 if (numSymbols() == 0)
1308 while (InitializerCount
>= 1 && !buffer
[InitializerCount
- 1])
1311 symbolPosInBuffer
.push_back(InitializerCount
);
1312 unsigned int nSym
= 0;
1313 unsigned int nextSymbolPos
= symbolPosInBuffer
[nSym
];
1314 for (unsigned int pos
= 0; pos
< InitializerCount
;) {
1317 if (pos
!= nextSymbolPos
) {
1318 os
<< (unsigned int)buffer
[pos
];
1322 // Generate a per-byte mask() operator for the symbol, which looks like:
1323 // .global .u8 addr[] = {0xFF(foo), 0xFF00(foo), 0xFF0000(foo), ...};
1324 // See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#initializers
1325 std::string symText
;
1326 llvm::raw_string_ostream
oss(symText
);
1327 printSymbol(nSym
, oss
);
1328 for (unsigned i
= 0; i
< ptrSize
; ++i
) {
1331 llvm::write_hex(os
, 0xFFULL
<< i
* 8, HexPrintStyle::PrefixUpper
);
1332 os
<< "(" << symText
<< ")";
1335 nextSymbolPos
= symbolPosInBuffer
[++nSym
];
1336 assert(nextSymbolPos
>= pos
);
1340 void NVPTXAsmPrinter::AggBuffer::printWords(raw_ostream
&os
) {
1341 unsigned int ptrSize
= AP
.MAI
->getCodePointerSize();
1342 symbolPosInBuffer
.push_back(size
);
1343 unsigned int nSym
= 0;
1344 unsigned int nextSymbolPos
= symbolPosInBuffer
[nSym
];
1345 assert(nextSymbolPos
% ptrSize
== 0);
1346 for (unsigned int pos
= 0; pos
< size
; pos
+= ptrSize
) {
1349 if (pos
== nextSymbolPos
) {
1350 printSymbol(nSym
, os
);
1351 nextSymbolPos
= symbolPosInBuffer
[++nSym
];
1352 assert(nextSymbolPos
% ptrSize
== 0);
1353 assert(nextSymbolPos
>= pos
+ ptrSize
);
1354 } else if (ptrSize
== 4)
1355 os
<< support::endian::read32le(&buffer
[pos
]);
1357 os
<< support::endian::read64le(&buffer
[pos
]);
1361 void NVPTXAsmPrinter::emitDemotedVars(const Function
*f
, raw_ostream
&O
) {
1362 auto It
= localDecls
.find(f
);
1363 if (It
== localDecls
.end())
1366 std::vector
<const GlobalVariable
*> &gvars
= It
->second
;
1368 const NVPTXTargetMachine
&NTM
= static_cast<const NVPTXTargetMachine
&>(TM
);
1369 const NVPTXSubtarget
&STI
=
1370 *static_cast<const NVPTXSubtarget
*>(NTM
.getSubtargetImpl());
1372 for (const GlobalVariable
*GV
: gvars
) {
1373 O
<< "\t// demoted variable\n\t";
1374 printModuleLevelGV(GV
, O
, /*processDemoted=*/true, STI
);
1378 void NVPTXAsmPrinter::emitPTXAddressSpace(unsigned int AddressSpace
,
1379 raw_ostream
&O
) const {
1380 switch (AddressSpace
) {
1381 case ADDRESS_SPACE_LOCAL
:
1384 case ADDRESS_SPACE_GLOBAL
:
1387 case ADDRESS_SPACE_CONST
:
1390 case ADDRESS_SPACE_SHARED
:
1394 report_fatal_error("Bad address space found while emitting PTX: " +
1395 llvm::Twine(AddressSpace
));
1401 NVPTXAsmPrinter::getPTXFundamentalTypeStr(Type
*Ty
, bool useB4PTR
) const {
1402 switch (Ty
->getTypeID()) {
1403 case Type::IntegerTyID
: {
1404 unsigned NumBits
= cast
<IntegerType
>(Ty
)->getBitWidth();
1407 else if (NumBits
<= 64) {
1408 std::string name
= "u";
1409 return name
+ utostr(NumBits
);
1411 llvm_unreachable("Integer too large");
1416 case Type::BFloatTyID
:
1417 case Type::HalfTyID
:
1418 // fp16 and bf16 are stored as .b16 for compatibility with pre-sm_53
1421 case Type::FloatTyID
:
1423 case Type::DoubleTyID
:
1425 case Type::PointerTyID
: {
1426 unsigned PtrSize
= TM
.getPointerSizeInBits(Ty
->getPointerAddressSpace());
1427 assert((PtrSize
== 64 || PtrSize
== 32) && "Unexpected pointer size");
1442 llvm_unreachable("unexpected type");
1445 void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable
*GVar
,
1447 const NVPTXSubtarget
&STI
) {
1448 const DataLayout
&DL
= getDataLayout();
1450 // GlobalVariables are always constant pointers themselves.
1451 Type
*ETy
= GVar
->getValueType();
1454 emitPTXAddressSpace(GVar
->getType()->getAddressSpace(), O
);
1455 if (isManaged(*GVar
)) {
1456 if (STI
.getPTXVersion() < 40 || STI
.getSmVersion() < 30) {
1458 ".attribute(.managed) requires PTX version >= 4.0 and sm_30");
1460 O
<< " .attribute(.managed)";
1462 if (MaybeAlign A
= GVar
->getAlign())
1463 O
<< " .align " << A
->value();
1465 O
<< " .align " << (int)DL
.getPrefTypeAlign(ETy
).value();
1467 // Special case for i128
1468 if (ETy
->isIntegerTy(128)) {
1470 getSymbol(GVar
)->print(O
, MAI
);
1475 if (ETy
->isFloatingPointTy() || ETy
->isIntOrPtrTy()) {
1477 O
<< getPTXFundamentalTypeStr(ETy
);
1479 getSymbol(GVar
)->print(O
, MAI
);
1483 int64_t ElementSize
= 0;
1485 // Although PTX has direct support for struct type and array type and LLVM IR
1486 // is very similar to PTX, the LLVM CodeGen does not support for targets that
1487 // support these high level field accesses. Structs and arrays are lowered
1488 // into arrays of bytes.
1489 switch (ETy
->getTypeID()) {
1490 case Type::StructTyID
:
1491 case Type::ArrayTyID
:
1492 case Type::FixedVectorTyID
:
1493 ElementSize
= DL
.getTypeStoreSize(ETy
);
1495 getSymbol(GVar
)->print(O
, MAI
);
1503 llvm_unreachable("type not supported yet");
1507 void NVPTXAsmPrinter::emitFunctionParamList(const Function
*F
, raw_ostream
&O
) {
1508 const DataLayout
&DL
= getDataLayout();
1509 const AttributeList
&PAL
= F
->getAttributes();
1510 const NVPTXSubtarget
&STI
= TM
.getSubtarget
<NVPTXSubtarget
>(*F
);
1511 const auto *TLI
= cast
<NVPTXTargetLowering
>(STI
.getTargetLowering());
1513 Function::const_arg_iterator I
, E
;
1514 unsigned paramIndex
= 0;
1516 bool isKernelFunc
= isKernelFunction(*F
);
1517 bool isABI
= (STI
.getSmVersion() >= 20);
1518 bool hasImageHandles
= STI
.hasImageHandles();
1520 if (F
->arg_empty() && !F
->isVarArg()) {
1527 for (I
= F
->arg_begin(), E
= F
->arg_end(); I
!= E
; ++I
, paramIndex
++) {
1528 Type
*Ty
= I
->getType();
1535 // Handle image/sampler parameters
1536 if (isKernelFunction(*F
)) {
1537 if (isSampler(*I
) || isImage(*I
)) {
1539 if (isImageWriteOnly(*I
) || isImageReadWrite(*I
)) {
1540 if (hasImageHandles
)
1541 O
<< "\t.param .u64 .ptr .surfref ";
1543 O
<< "\t.param .surfref ";
1544 O
<< TLI
->getParamName(F
, paramIndex
);
1546 else { // Default image is read_only
1547 if (hasImageHandles
)
1548 O
<< "\t.param .u64 .ptr .texref ";
1550 O
<< "\t.param .texref ";
1551 O
<< TLI
->getParamName(F
, paramIndex
);
1554 if (hasImageHandles
)
1555 O
<< "\t.param .u64 .ptr .samplerref ";
1557 O
<< "\t.param .samplerref ";
1558 O
<< TLI
->getParamName(F
, paramIndex
);
1564 auto getOptimalAlignForParam
= [TLI
, &DL
, &PAL
, F
,
1565 paramIndex
](Type
*Ty
) -> Align
{
1566 if (MaybeAlign StackAlign
=
1567 getAlign(*F
, paramIndex
+ AttributeList::FirstArgIndex
))
1568 return StackAlign
.value();
1570 Align TypeAlign
= TLI
->getFunctionParamOptimizedAlign(F
, Ty
, DL
);
1571 MaybeAlign ParamAlign
= PAL
.getParamAlignment(paramIndex
);
1572 return std::max(TypeAlign
, ParamAlign
.valueOrOne());
1575 if (!PAL
.hasParamAttr(paramIndex
, Attribute::ByVal
)) {
1576 if (ShouldPassAsArray(Ty
)) {
1577 // Just print .param .align <a> .b8 .param[size];
1578 // <a> = optimal alignment for the element type; always multiple of
1579 // PAL.getParamAlignment
1580 // size = typeallocsize of element type
1581 Align OptimalAlign
= getOptimalAlignForParam(Ty
);
1583 O
<< "\t.param .align " << OptimalAlign
.value() << " .b8 ";
1584 O
<< TLI
->getParamName(F
, paramIndex
);
1585 O
<< "[" << DL
.getTypeAllocSize(Ty
) << "]";
1590 auto *PTy
= dyn_cast
<PointerType
>(Ty
);
1591 unsigned PTySizeInBits
= 0;
1594 TLI
->getPointerTy(DL
, PTy
->getAddressSpace()).getSizeInBits();
1595 assert(PTySizeInBits
&& "Invalid pointer size");
1600 O
<< "\t.param .u" << PTySizeInBits
<< " .ptr";
1602 switch (PTy
->getAddressSpace()) {
1605 case ADDRESS_SPACE_GLOBAL
:
1608 case ADDRESS_SPACE_SHARED
:
1611 case ADDRESS_SPACE_CONST
:
1614 case ADDRESS_SPACE_LOCAL
:
1619 O
<< " .align " << I
->getParamAlign().valueOrOne().value();
1620 O
<< " " << TLI
->getParamName(F
, paramIndex
);
1624 // non-pointer scalar to kernel func
1626 // Special case: predicate operands become .u8 types
1627 if (Ty
->isIntegerTy(1))
1630 O
<< getPTXFundamentalTypeStr(Ty
);
1632 O
<< TLI
->getParamName(F
, paramIndex
);
1635 // Non-kernel function, just print .param .b<size> for ABI
1636 // and .reg .b<size> for non-ABI
1638 if (isa
<IntegerType
>(Ty
)) {
1639 sz
= cast
<IntegerType
>(Ty
)->getBitWidth();
1640 sz
= promoteScalarArgumentSize(sz
);
1642 assert(PTySizeInBits
&& "Invalid pointer size");
1645 sz
= Ty
->getPrimitiveSizeInBits();
1647 O
<< "\t.param .b" << sz
<< " ";
1649 O
<< "\t.reg .b" << sz
<< " ";
1650 O
<< TLI
->getParamName(F
, paramIndex
);
1654 // param has byVal attribute.
1655 Type
*ETy
= PAL
.getParamByValType(paramIndex
);
1656 assert(ETy
&& "Param should have byval type");
1658 if (isABI
|| isKernelFunc
) {
1659 // Just print .param .align <a> .b8 .param[size];
1660 // <a> = optimal alignment for the element type; always multiple of
1661 // PAL.getParamAlignment
1662 // size = typeallocsize of element type
1663 Align OptimalAlign
=
1665 ? getOptimalAlignForParam(ETy
)
1666 : TLI
->getFunctionByValParamAlign(
1667 F
, ETy
, PAL
.getParamAlignment(paramIndex
).valueOrOne(), DL
);
1669 unsigned sz
= DL
.getTypeAllocSize(ETy
);
1670 O
<< "\t.param .align " << OptimalAlign
.value() << " .b8 ";
1671 O
<< TLI
->getParamName(F
, paramIndex
);
1672 O
<< "[" << sz
<< "]";
1675 // Split the ETy into constituent parts and
1676 // print .param .b<size> <name> for each part.
1677 // Further, if a part is vector, print the above for
1678 // each vector element.
1679 SmallVector
<EVT
, 16> vtparts
;
1680 ComputeValueVTs(*TLI
, DL
, ETy
, vtparts
);
1681 for (unsigned i
= 0, e
= vtparts
.size(); i
!= e
; ++i
) {
1683 EVT elemtype
= vtparts
[i
];
1684 if (vtparts
[i
].isVector()) {
1685 elems
= vtparts
[i
].getVectorNumElements();
1686 elemtype
= vtparts
[i
].getVectorElementType();
1689 for (unsigned j
= 0, je
= elems
; j
!= je
; ++j
) {
1690 unsigned sz
= elemtype
.getSizeInBits();
1691 if (elemtype
.isInteger())
1692 sz
= promoteScalarArgumentSize(sz
);
1693 O
<< "\t.reg .b" << sz
<< " ";
1694 O
<< TLI
->getParamName(F
, paramIndex
);
1707 if (F
->isVarArg()) {
1710 O
<< "\t.param .align " << STI
.getMaxRequiredAlignment();
1712 O
<< TLI
->getParamName(F
, /* vararg */ -1) << "[]";
1718 void NVPTXAsmPrinter::setAndEmitFunctionVirtualRegisters(
1719 const MachineFunction
&MF
) {
1720 SmallString
<128> Str
;
1721 raw_svector_ostream
O(Str
);
1723 // Map the global virtual register number to a register class specific
1724 // virtual register number starting from 1 with that class.
1725 const TargetRegisterInfo
*TRI
= MF
.getSubtarget().getRegisterInfo();
1726 //unsigned numRegClasses = TRI->getNumRegClasses();
1728 // Emit the Fake Stack Object
1729 const MachineFrameInfo
&MFI
= MF
.getFrameInfo();
1730 int64_t NumBytes
= MFI
.getStackSize();
1732 O
<< "\t.local .align " << MFI
.getMaxAlign().value() << " .b8 \t"
1733 << DEPOTNAME
<< getFunctionNumber() << "[" << NumBytes
<< "];\n";
1734 if (static_cast<const NVPTXTargetMachine
&>(MF
.getTarget()).is64Bit()) {
1735 O
<< "\t.reg .b64 \t%SP;\n";
1736 O
<< "\t.reg .b64 \t%SPL;\n";
1738 O
<< "\t.reg .b32 \t%SP;\n";
1739 O
<< "\t.reg .b32 \t%SPL;\n";
1743 // Go through all virtual registers to establish the mapping between the
1745 // register number and the per class virtual register number.
1746 // We use the per class virtual register number in the ptx output.
1747 unsigned int numVRs
= MRI
->getNumVirtRegs();
1748 for (unsigned i
= 0; i
< numVRs
; i
++) {
1749 Register vr
= Register::index2VirtReg(i
);
1750 const TargetRegisterClass
*RC
= MRI
->getRegClass(vr
);
1751 DenseMap
<unsigned, unsigned> ®map
= VRegMapping
[RC
];
1752 int n
= regmap
.size();
1753 regmap
.insert(std::make_pair(vr
, n
+ 1));
1756 // Emit register declarations
1757 // @TODO: Extract out the real register usage
1758 // O << "\t.reg .pred %p<" << NVPTXNumRegisters << ">;\n";
1759 // O << "\t.reg .s16 %rc<" << NVPTXNumRegisters << ">;\n";
1760 // O << "\t.reg .s16 %rs<" << NVPTXNumRegisters << ">;\n";
1761 // O << "\t.reg .s32 %r<" << NVPTXNumRegisters << ">;\n";
1762 // O << "\t.reg .s64 %rd<" << NVPTXNumRegisters << ">;\n";
1763 // O << "\t.reg .f32 %f<" << NVPTXNumRegisters << ">;\n";
1764 // O << "\t.reg .f64 %fd<" << NVPTXNumRegisters << ">;\n";
1766 // Emit declaration of the virtual registers or 'physical' registers for
1767 // each register class
1768 for (unsigned i
=0; i
< TRI
->getNumRegClasses(); i
++) {
1769 const TargetRegisterClass
*RC
= TRI
->getRegClass(i
);
1770 DenseMap
<unsigned, unsigned> ®map
= VRegMapping
[RC
];
1771 std::string rcname
= getNVPTXRegClassName(RC
);
1772 std::string rcStr
= getNVPTXRegClassStr(RC
);
1773 int n
= regmap
.size();
1775 // Only declare those registers that may be used.
1777 O
<< "\t.reg " << rcname
<< " \t" << rcStr
<< "<" << (n
+1)
1782 OutStreamer
->emitRawText(O
.str());
1785 /// Translate virtual register numbers in DebugInfo locations to their printed
1786 /// encodings, as used by CUDA-GDB.
1787 void NVPTXAsmPrinter::encodeDebugInfoRegisterNumbers(
1788 const MachineFunction
&MF
) {
1789 const NVPTXSubtarget
&STI
= MF
.getSubtarget
<NVPTXSubtarget
>();
1790 const NVPTXRegisterInfo
*registerInfo
= STI
.getRegisterInfo();
1792 // Clear the old mapping, and add the new one. This mapping is used after the
1793 // printing of the current function is complete, but before the next function
1795 registerInfo
->clearDebugRegisterMap();
1797 for (auto &classMap
: VRegMapping
) {
1798 for (auto ®isterMapping
: classMap
.getSecond()) {
1799 auto reg
= registerMapping
.getFirst();
1800 registerInfo
->addToDebugRegisterMap(reg
, getVirtualRegisterName(reg
));
1805 void NVPTXAsmPrinter::printFPConstant(const ConstantFP
*Fp
, raw_ostream
&O
) {
1806 APFloat APF
= APFloat(Fp
->getValueAPF()); // make a copy
1808 unsigned int numHex
;
1811 if (Fp
->getType()->getTypeID() == Type::FloatTyID
) {
1814 APF
.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven
, &ignored
);
1815 } else if (Fp
->getType()->getTypeID() == Type::DoubleTyID
) {
1818 APF
.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven
, &ignored
);
1820 llvm_unreachable("unsupported fp type");
1822 APInt API
= APF
.bitcastToAPInt();
1823 O
<< lead
<< format_hex_no_prefix(API
.getZExtValue(), numHex
, /*Upper=*/true);
1826 void NVPTXAsmPrinter::printScalarConstant(const Constant
*CPV
, raw_ostream
&O
) {
1827 if (const ConstantInt
*CI
= dyn_cast
<ConstantInt
>(CPV
)) {
1828 O
<< CI
->getValue();
1831 if (const ConstantFP
*CFP
= dyn_cast
<ConstantFP
>(CPV
)) {
1832 printFPConstant(CFP
, O
);
1835 if (isa
<ConstantPointerNull
>(CPV
)) {
1839 if (const GlobalValue
*GVar
= dyn_cast
<GlobalValue
>(CPV
)) {
1840 bool IsNonGenericPointer
= false;
1841 if (GVar
->getType()->getAddressSpace() != 0) {
1842 IsNonGenericPointer
= true;
1844 if (EmitGeneric
&& !isa
<Function
>(CPV
) && !IsNonGenericPointer
) {
1846 getSymbol(GVar
)->print(O
, MAI
);
1849 getSymbol(GVar
)->print(O
, MAI
);
1853 if (const ConstantExpr
*Cexpr
= dyn_cast
<ConstantExpr
>(CPV
)) {
1854 const MCExpr
*E
= lowerConstantForGV(cast
<Constant
>(Cexpr
), false);
1858 llvm_unreachable("Not scalar type found in printScalarConstant()");
1861 void NVPTXAsmPrinter::bufferLEByte(const Constant
*CPV
, int Bytes
,
1862 AggBuffer
*AggBuffer
) {
1863 const DataLayout
&DL
= getDataLayout();
1864 int AllocSize
= DL
.getTypeAllocSize(CPV
->getType());
1865 if (isa
<UndefValue
>(CPV
) || CPV
->isNullValue()) {
1866 // Non-zero Bytes indicates that we need to zero-fill everything. Otherwise,
1867 // only the space allocated by CPV.
1868 AggBuffer
->addZeros(Bytes
? Bytes
: AllocSize
);
1872 // Helper for filling AggBuffer with APInts.
1873 auto AddIntToBuffer
= [AggBuffer
, Bytes
](const APInt
&Val
) {
1874 size_t NumBytes
= (Val
.getBitWidth() + 7) / 8;
1875 SmallVector
<unsigned char, 16> Buf(NumBytes
);
1876 // `extractBitsAsZExtValue` does not allow the extraction of bits beyond the
1877 // input's bit width, and i1 arrays may not have a length that is a multuple
1878 // of 8. We handle the last byte separately, so we never request out of
1880 for (unsigned I
= 0; I
< NumBytes
- 1; ++I
) {
1881 Buf
[I
] = Val
.extractBitsAsZExtValue(8, I
* 8);
1883 size_t LastBytePosition
= (NumBytes
- 1) * 8;
1884 size_t LastByteBits
= Val
.getBitWidth() - LastBytePosition
;
1886 Val
.extractBitsAsZExtValue(LastByteBits
, LastBytePosition
);
1887 AggBuffer
->addBytes(Buf
.data(), NumBytes
, Bytes
);
1890 switch (CPV
->getType()->getTypeID()) {
1891 case Type::IntegerTyID
:
1892 if (const auto CI
= dyn_cast
<ConstantInt
>(CPV
)) {
1893 AddIntToBuffer(CI
->getValue());
1896 if (const auto *Cexpr
= dyn_cast
<ConstantExpr
>(CPV
)) {
1897 if (const auto *CI
=
1898 dyn_cast
<ConstantInt
>(ConstantFoldConstant(Cexpr
, DL
))) {
1899 AddIntToBuffer(CI
->getValue());
1902 if (Cexpr
->getOpcode() == Instruction::PtrToInt
) {
1903 Value
*V
= Cexpr
->getOperand(0)->stripPointerCasts();
1904 AggBuffer
->addSymbol(V
, Cexpr
->getOperand(0));
1905 AggBuffer
->addZeros(AllocSize
);
1909 llvm_unreachable("unsupported integer const type");
1912 case Type::HalfTyID
:
1913 case Type::BFloatTyID
:
1914 case Type::FloatTyID
:
1915 case Type::DoubleTyID
:
1916 AddIntToBuffer(cast
<ConstantFP
>(CPV
)->getValueAPF().bitcastToAPInt());
1919 case Type::PointerTyID
: {
1920 if (const GlobalValue
*GVar
= dyn_cast
<GlobalValue
>(CPV
)) {
1921 AggBuffer
->addSymbol(GVar
, GVar
);
1922 } else if (const ConstantExpr
*Cexpr
= dyn_cast
<ConstantExpr
>(CPV
)) {
1923 const Value
*v
= Cexpr
->stripPointerCasts();
1924 AggBuffer
->addSymbol(v
, Cexpr
);
1926 AggBuffer
->addZeros(AllocSize
);
1930 case Type::ArrayTyID
:
1931 case Type::FixedVectorTyID
:
1932 case Type::StructTyID
: {
1933 if (isa
<ConstantAggregate
>(CPV
) || isa
<ConstantDataSequential
>(CPV
)) {
1934 bufferAggregateConstant(CPV
, AggBuffer
);
1935 if (Bytes
> AllocSize
)
1936 AggBuffer
->addZeros(Bytes
- AllocSize
);
1937 } else if (isa
<ConstantAggregateZero
>(CPV
))
1938 AggBuffer
->addZeros(Bytes
);
1940 llvm_unreachable("Unexpected Constant type");
1945 llvm_unreachable("unsupported type");
1949 void NVPTXAsmPrinter::bufferAggregateConstant(const Constant
*CPV
,
1950 AggBuffer
*aggBuffer
) {
1951 const DataLayout
&DL
= getDataLayout();
1954 // Integers of arbitrary width
1955 if (const ConstantInt
*CI
= dyn_cast
<ConstantInt
>(CPV
)) {
1956 APInt Val
= CI
->getValue();
1957 for (unsigned I
= 0, E
= DL
.getTypeAllocSize(CPV
->getType()); I
< E
; ++I
) {
1958 uint8_t Byte
= Val
.getLoBits(8).getZExtValue();
1959 aggBuffer
->addBytes(&Byte
, 1, 1);
1966 if (isa
<ConstantArray
>(CPV
) || isa
<ConstantVector
>(CPV
)) {
1967 if (CPV
->getNumOperands())
1968 for (unsigned i
= 0, e
= CPV
->getNumOperands(); i
!= e
; ++i
)
1969 bufferLEByte(cast
<Constant
>(CPV
->getOperand(i
)), 0, aggBuffer
);
1973 if (const ConstantDataSequential
*CDS
=
1974 dyn_cast
<ConstantDataSequential
>(CPV
)) {
1975 if (CDS
->getNumElements())
1976 for (unsigned i
= 0; i
< CDS
->getNumElements(); ++i
)
1977 bufferLEByte(cast
<Constant
>(CDS
->getElementAsConstant(i
)), 0,
1982 if (isa
<ConstantStruct
>(CPV
)) {
1983 if (CPV
->getNumOperands()) {
1984 StructType
*ST
= cast
<StructType
>(CPV
->getType());
1985 for (unsigned i
= 0, e
= CPV
->getNumOperands(); i
!= e
; ++i
) {
1987 Bytes
= DL
.getStructLayout(ST
)->getElementOffset(0) +
1988 DL
.getTypeAllocSize(ST
) -
1989 DL
.getStructLayout(ST
)->getElementOffset(i
);
1991 Bytes
= DL
.getStructLayout(ST
)->getElementOffset(i
+ 1) -
1992 DL
.getStructLayout(ST
)->getElementOffset(i
);
1993 bufferLEByte(cast
<Constant
>(CPV
->getOperand(i
)), Bytes
, aggBuffer
);
1998 llvm_unreachable("unsupported constant type in printAggregateConstant()");
2001 /// lowerConstantForGV - Return an MCExpr for the given Constant. This is mostly
2002 /// a copy from AsmPrinter::lowerConstant, except customized to only handle
2003 /// expressions that are representable in PTX and create
2004 /// NVPTXGenericMCSymbolRefExpr nodes for addrspacecast instructions.
2006 NVPTXAsmPrinter::lowerConstantForGV(const Constant
*CV
, bool ProcessingGeneric
) {
2007 MCContext
&Ctx
= OutContext
;
2009 if (CV
->isNullValue() || isa
<UndefValue
>(CV
))
2010 return MCConstantExpr::create(0, Ctx
);
2012 if (const ConstantInt
*CI
= dyn_cast
<ConstantInt
>(CV
))
2013 return MCConstantExpr::create(CI
->getZExtValue(), Ctx
);
2015 if (const GlobalValue
*GV
= dyn_cast
<GlobalValue
>(CV
)) {
2016 const MCSymbolRefExpr
*Expr
=
2017 MCSymbolRefExpr::create(getSymbol(GV
), Ctx
);
2018 if (ProcessingGeneric
) {
2019 return NVPTXGenericMCSymbolRefExpr::create(Expr
, Ctx
);
2025 const ConstantExpr
*CE
= dyn_cast
<ConstantExpr
>(CV
);
2027 llvm_unreachable("Unknown constant value to lower!");
2030 switch (CE
->getOpcode()) {
2034 case Instruction::AddrSpaceCast
: {
2035 // Strip the addrspacecast and pass along the operand
2036 PointerType
*DstTy
= cast
<PointerType
>(CE
->getType());
2037 if (DstTy
->getAddressSpace() == 0)
2038 return lowerConstantForGV(cast
<const Constant
>(CE
->getOperand(0)), true);
2043 case Instruction::GetElementPtr
: {
2044 const DataLayout
&DL
= getDataLayout();
2046 // Generate a symbolic expression for the byte address
2047 APInt
OffsetAI(DL
.getPointerTypeSizeInBits(CE
->getType()), 0);
2048 cast
<GEPOperator
>(CE
)->accumulateConstantOffset(DL
, OffsetAI
);
2050 const MCExpr
*Base
= lowerConstantForGV(CE
->getOperand(0),
2055 int64_t Offset
= OffsetAI
.getSExtValue();
2056 return MCBinaryExpr::createAdd(Base
, MCConstantExpr::create(Offset
, Ctx
),
2060 case Instruction::Trunc
:
2061 // We emit the value and depend on the assembler to truncate the generated
2062 // expression properly. This is important for differences between
2063 // blockaddress labels. Since the two labels are in the same function, it
2064 // is reasonable to treat their delta as a 32-bit value.
2066 case Instruction::BitCast
:
2067 return lowerConstantForGV(CE
->getOperand(0), ProcessingGeneric
);
2069 case Instruction::IntToPtr
: {
2070 const DataLayout
&DL
= getDataLayout();
2072 // Handle casts to pointers by changing them into casts to the appropriate
2073 // integer type. This promotes constant folding and simplifies this code.
2074 Constant
*Op
= CE
->getOperand(0);
2075 Op
= ConstantFoldIntegerCast(Op
, DL
.getIntPtrType(CV
->getType()),
2076 /*IsSigned*/ false, DL
);
2078 return lowerConstantForGV(Op
, ProcessingGeneric
);
2083 case Instruction::PtrToInt
: {
2084 const DataLayout
&DL
= getDataLayout();
2086 // Support only foldable casts to/from pointers that can be eliminated by
2087 // changing the pointer to the appropriately sized integer type.
2088 Constant
*Op
= CE
->getOperand(0);
2089 Type
*Ty
= CE
->getType();
2091 const MCExpr
*OpExpr
= lowerConstantForGV(Op
, ProcessingGeneric
);
2093 // We can emit the pointer value into this slot if the slot is an
2094 // integer slot equal to the size of the pointer.
2095 if (DL
.getTypeAllocSize(Ty
) == DL
.getTypeAllocSize(Op
->getType()))
2098 // Otherwise the pointer is smaller than the resultant integer, mask off
2099 // the high bits so we are sure to get a proper truncation if the input is
2101 unsigned InBits
= DL
.getTypeAllocSizeInBits(Op
->getType());
2102 const MCExpr
*MaskExpr
= MCConstantExpr::create(~0ULL >> (64-InBits
), Ctx
);
2103 return MCBinaryExpr::createAnd(OpExpr
, MaskExpr
, Ctx
);
2106 // The MC library also has a right-shift operator, but it isn't consistently
2107 // signed or unsigned between different targets.
2108 case Instruction::Add
: {
2109 const MCExpr
*LHS
= lowerConstantForGV(CE
->getOperand(0), ProcessingGeneric
);
2110 const MCExpr
*RHS
= lowerConstantForGV(CE
->getOperand(1), ProcessingGeneric
);
2111 switch (CE
->getOpcode()) {
2112 default: llvm_unreachable("Unknown binary operator constant cast expr");
2113 case Instruction::Add
: return MCBinaryExpr::createAdd(LHS
, RHS
, Ctx
);
2118 // If the code isn't optimized, there may be outstanding folding
2119 // opportunities. Attempt to fold the expression using DataLayout as a
2120 // last resort before giving up.
2121 Constant
*C
= ConstantFoldConstant(CE
, getDataLayout());
2123 return lowerConstantForGV(C
, ProcessingGeneric
);
2125 // Otherwise report the problem to the user.
2127 raw_string_ostream
OS(S
);
2128 OS
<< "Unsupported expression in static initializer: ";
2129 CE
->printAsOperand(OS
, /*PrintType=*/false,
2130 !MF
? nullptr : MF
->getFunction().getParent());
2131 report_fatal_error(Twine(OS
.str()));
2134 // Copy of MCExpr::print customized for NVPTX
2135 void NVPTXAsmPrinter::printMCExpr(const MCExpr
&Expr
, raw_ostream
&OS
) {
2136 switch (Expr
.getKind()) {
2137 case MCExpr::Target
:
2138 return cast
<MCTargetExpr
>(&Expr
)->printImpl(OS
, MAI
);
2139 case MCExpr::Constant
:
2140 OS
<< cast
<MCConstantExpr
>(Expr
).getValue();
2143 case MCExpr::SymbolRef
: {
2144 const MCSymbolRefExpr
&SRE
= cast
<MCSymbolRefExpr
>(Expr
);
2145 const MCSymbol
&Sym
= SRE
.getSymbol();
2150 case MCExpr::Unary
: {
2151 const MCUnaryExpr
&UE
= cast
<MCUnaryExpr
>(Expr
);
2152 switch (UE
.getOpcode()) {
2153 case MCUnaryExpr::LNot
: OS
<< '!'; break;
2154 case MCUnaryExpr::Minus
: OS
<< '-'; break;
2155 case MCUnaryExpr::Not
: OS
<< '~'; break;
2156 case MCUnaryExpr::Plus
: OS
<< '+'; break;
2158 printMCExpr(*UE
.getSubExpr(), OS
);
2162 case MCExpr::Binary
: {
2163 const MCBinaryExpr
&BE
= cast
<MCBinaryExpr
>(Expr
);
2165 // Only print parens around the LHS if it is non-trivial.
2166 if (isa
<MCConstantExpr
>(BE
.getLHS()) || isa
<MCSymbolRefExpr
>(BE
.getLHS()) ||
2167 isa
<NVPTXGenericMCSymbolRefExpr
>(BE
.getLHS())) {
2168 printMCExpr(*BE
.getLHS(), OS
);
2171 printMCExpr(*BE
.getLHS(), OS
);
2175 switch (BE
.getOpcode()) {
2176 case MCBinaryExpr::Add
:
2177 // Print "X-42" instead of "X+-42".
2178 if (const MCConstantExpr
*RHSC
= dyn_cast
<MCConstantExpr
>(BE
.getRHS())) {
2179 if (RHSC
->getValue() < 0) {
2180 OS
<< RHSC
->getValue();
2187 default: llvm_unreachable("Unhandled binary operator");
2190 // Only print parens around the LHS if it is non-trivial.
2191 if (isa
<MCConstantExpr
>(BE
.getRHS()) || isa
<MCSymbolRefExpr
>(BE
.getRHS())) {
2192 printMCExpr(*BE
.getRHS(), OS
);
2195 printMCExpr(*BE
.getRHS(), OS
);
2202 llvm_unreachable("Invalid expression kind!");
2205 /// PrintAsmOperand - Print out an operand for an inline asm expression.
2207 bool NVPTXAsmPrinter::PrintAsmOperand(const MachineInstr
*MI
, unsigned OpNo
,
2208 const char *ExtraCode
, raw_ostream
&O
) {
2209 if (ExtraCode
&& ExtraCode
[0]) {
2210 if (ExtraCode
[1] != 0)
2211 return true; // Unknown modifier.
2213 switch (ExtraCode
[0]) {
2215 // See if this is a generic print operand
2216 return AsmPrinter::PrintAsmOperand(MI
, OpNo
, ExtraCode
, O
);
2222 printOperand(MI
, OpNo
, O
);
2227 bool NVPTXAsmPrinter::PrintAsmMemoryOperand(const MachineInstr
*MI
,
2229 const char *ExtraCode
,
2231 if (ExtraCode
&& ExtraCode
[0])
2232 return true; // Unknown modifier
2235 printMemOperand(MI
, OpNo
, O
);
2241 void NVPTXAsmPrinter::printOperand(const MachineInstr
*MI
, unsigned OpNum
,
2243 const MachineOperand
&MO
= MI
->getOperand(OpNum
);
2244 switch (MO
.getType()) {
2245 case MachineOperand::MO_Register
:
2246 if (MO
.getReg().isPhysical()) {
2247 if (MO
.getReg() == NVPTX::VRDepot
)
2248 O
<< DEPOTNAME
<< getFunctionNumber();
2250 O
<< NVPTXInstPrinter::getRegisterName(MO
.getReg());
2252 emitVirtualRegister(MO
.getReg(), O
);
2256 case MachineOperand::MO_Immediate
:
2260 case MachineOperand::MO_FPImmediate
:
2261 printFPConstant(MO
.getFPImm(), O
);
2264 case MachineOperand::MO_GlobalAddress
:
2265 PrintSymbolOperand(MO
, O
);
2268 case MachineOperand::MO_MachineBasicBlock
:
2269 MO
.getMBB()->getSymbol()->print(O
, MAI
);
2273 llvm_unreachable("Operand type not supported.");
2277 void NVPTXAsmPrinter::printMemOperand(const MachineInstr
*MI
, unsigned OpNum
,
2278 raw_ostream
&O
, const char *Modifier
) {
2279 printOperand(MI
, OpNum
, O
);
2281 if (Modifier
&& strcmp(Modifier
, "add") == 0) {
2283 printOperand(MI
, OpNum
+ 1, O
);
2285 if (MI
->getOperand(OpNum
+ 1).isImm() &&
2286 MI
->getOperand(OpNum
+ 1).getImm() == 0)
2287 return; // don't print ',0' or '+0'
2289 printOperand(MI
, OpNum
+ 1, O
);
2293 // Force static initialization.
2294 extern "C" LLVM_EXTERNAL_VISIBILITY
void LLVMInitializeNVPTXAsmPrinter() {
2295 RegisterAsmPrinter
<NVPTXAsmPrinter
> X(getTheNVPTXTarget32());
2296 RegisterAsmPrinter
<NVPTXAsmPrinter
> Y(getTheNVPTXTarget64());