1 //===-- NVPTXAsmPrinter.cpp - NVPTX LLVM assembly writer ------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file contains a printer that converts from our internal representation
10 // of machine-dependent LLVM code to NVPTX assembly language.
12 //===----------------------------------------------------------------------===//
14 #include "NVPTXAsmPrinter.h"
15 #include "MCTargetDesc/NVPTXBaseInfo.h"
16 #include "MCTargetDesc/NVPTXInstPrinter.h"
17 #include "MCTargetDesc/NVPTXMCAsmInfo.h"
18 #include "MCTargetDesc/NVPTXTargetStreamer.h"
20 #include "NVPTXMCExpr.h"
21 #include "NVPTXMachineFunctionInfo.h"
22 #include "NVPTXRegisterInfo.h"
23 #include "NVPTXSubtarget.h"
24 #include "NVPTXTargetMachine.h"
25 #include "NVPTXUtilities.h"
26 #include "TargetInfo/NVPTXTargetInfo.h"
27 #include "cl_common_defines.h"
28 #include "llvm/ADT/APFloat.h"
29 #include "llvm/ADT/APInt.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/DenseSet.h"
32 #include "llvm/ADT/SmallString.h"
33 #include "llvm/ADT/SmallVector.h"
34 #include "llvm/ADT/StringExtras.h"
35 #include "llvm/ADT/StringRef.h"
36 #include "llvm/ADT/Twine.h"
37 #include "llvm/Analysis/ConstantFolding.h"
38 #include "llvm/CodeGen/Analysis.h"
39 #include "llvm/CodeGen/MachineBasicBlock.h"
40 #include "llvm/CodeGen/MachineFrameInfo.h"
41 #include "llvm/CodeGen/MachineFunction.h"
42 #include "llvm/CodeGen/MachineInstr.h"
43 #include "llvm/CodeGen/MachineLoopInfo.h"
44 #include "llvm/CodeGen/MachineModuleInfo.h"
45 #include "llvm/CodeGen/MachineOperand.h"
46 #include "llvm/CodeGen/MachineRegisterInfo.h"
47 #include "llvm/CodeGen/TargetRegisterInfo.h"
48 #include "llvm/CodeGen/ValueTypes.h"
49 #include "llvm/CodeGenTypes/MachineValueType.h"
50 #include "llvm/IR/Attributes.h"
51 #include "llvm/IR/BasicBlock.h"
52 #include "llvm/IR/Constant.h"
53 #include "llvm/IR/Constants.h"
54 #include "llvm/IR/DataLayout.h"
55 #include "llvm/IR/DebugInfo.h"
56 #include "llvm/IR/DebugInfoMetadata.h"
57 #include "llvm/IR/DebugLoc.h"
58 #include "llvm/IR/DerivedTypes.h"
59 #include "llvm/IR/Function.h"
60 #include "llvm/IR/GlobalAlias.h"
61 #include "llvm/IR/GlobalValue.h"
62 #include "llvm/IR/GlobalVariable.h"
63 #include "llvm/IR/Instruction.h"
64 #include "llvm/IR/LLVMContext.h"
65 #include "llvm/IR/Module.h"
66 #include "llvm/IR/Operator.h"
67 #include "llvm/IR/Type.h"
68 #include "llvm/IR/User.h"
69 #include "llvm/MC/MCExpr.h"
70 #include "llvm/MC/MCInst.h"
71 #include "llvm/MC/MCInstrDesc.h"
72 #include "llvm/MC/MCStreamer.h"
73 #include "llvm/MC/MCSymbol.h"
74 #include "llvm/MC/TargetRegistry.h"
75 #include "llvm/Support/Alignment.h"
76 #include "llvm/Support/Casting.h"
77 #include "llvm/Support/CommandLine.h"
78 #include "llvm/Support/Endian.h"
79 #include "llvm/Support/ErrorHandling.h"
80 #include "llvm/Support/NativeFormatting.h"
81 #include "llvm/Support/Path.h"
82 #include "llvm/Support/raw_ostream.h"
83 #include "llvm/Target/TargetLoweringObjectFile.h"
84 #include "llvm/Target/TargetMachine.h"
85 #include "llvm/TargetParser/Triple.h"
86 #include "llvm/Transforms/Utils/UnrollLoop.h"
98 LowerCtorDtor("nvptx-lower-global-ctor-dtor",
99 cl::desc("Lower GPU ctor / dtors to globals on the device."),
100 cl::init(false), cl::Hidden
);
102 #define DEPOTNAME "__local_depot"
104 /// DiscoverDependentGlobals - Return a set of GlobalVariables on which \p V
107 DiscoverDependentGlobals(const Value
*V
,
108 DenseSet
<const GlobalVariable
*> &Globals
) {
109 if (const GlobalVariable
*GV
= dyn_cast
<GlobalVariable
>(V
))
112 if (const User
*U
= dyn_cast
<User
>(V
)) {
113 for (unsigned i
= 0, e
= U
->getNumOperands(); i
!= e
; ++i
) {
114 DiscoverDependentGlobals(U
->getOperand(i
), Globals
);
120 /// VisitGlobalVariableForEmission - Add \p GV to the list of GlobalVariable
121 /// instances to be emitted, but only after any dependents have been added
124 VisitGlobalVariableForEmission(const GlobalVariable
*GV
,
125 SmallVectorImpl
<const GlobalVariable
*> &Order
,
126 DenseSet
<const GlobalVariable
*> &Visited
,
127 DenseSet
<const GlobalVariable
*> &Visiting
) {
128 // Have we already visited this one?
129 if (Visited
.count(GV
))
132 // Do we have a circular dependency?
133 if (!Visiting
.insert(GV
).second
)
134 report_fatal_error("Circular dependency found in global variable set");
136 // Make sure we visit all dependents first
137 DenseSet
<const GlobalVariable
*> Others
;
138 for (unsigned i
= 0, e
= GV
->getNumOperands(); i
!= e
; ++i
)
139 DiscoverDependentGlobals(GV
->getOperand(i
), Others
);
141 for (const GlobalVariable
*GV
: Others
)
142 VisitGlobalVariableForEmission(GV
, Order
, Visited
, Visiting
);
144 // Now we can visit ourself
150 void NVPTXAsmPrinter::emitInstruction(const MachineInstr
*MI
) {
151 NVPTX_MC::verifyInstructionPredicates(MI
->getOpcode(),
152 getSubtargetInfo().getFeatureBits());
155 lowerToMCInst(MI
, Inst
);
156 EmitToStreamer(*OutStreamer
, Inst
);
159 // Handle symbol backtracking for targets that do not support image handles
160 bool NVPTXAsmPrinter::lowerImageHandleOperand(const MachineInstr
*MI
,
161 unsigned OpNo
, MCOperand
&MCOp
) {
162 const MachineOperand
&MO
= MI
->getOperand(OpNo
);
163 const MCInstrDesc
&MCID
= MI
->getDesc();
165 if (MCID
.TSFlags
& NVPTXII::IsTexFlag
) {
166 // This is a texture fetch, so operand 4 is a texref and operand 5 is
168 if (OpNo
== 4 && MO
.isImm()) {
169 lowerImageHandleSymbol(MO
.getImm(), MCOp
);
172 if (OpNo
== 5 && MO
.isImm() && !(MCID
.TSFlags
& NVPTXII::IsTexModeUnifiedFlag
)) {
173 lowerImageHandleSymbol(MO
.getImm(), MCOp
);
178 } else if (MCID
.TSFlags
& NVPTXII::IsSuldMask
) {
180 1 << (((MCID
.TSFlags
& NVPTXII::IsSuldMask
) >> NVPTXII::IsSuldShift
) - 1);
182 // For a surface load of vector size N, the Nth operand will be the surfref
183 if (OpNo
== VecSize
&& MO
.isImm()) {
184 lowerImageHandleSymbol(MO
.getImm(), MCOp
);
189 } else if (MCID
.TSFlags
& NVPTXII::IsSustFlag
) {
190 // This is a surface store, so operand 0 is a surfref
191 if (OpNo
== 0 && MO
.isImm()) {
192 lowerImageHandleSymbol(MO
.getImm(), MCOp
);
197 } else if (MCID
.TSFlags
& NVPTXII::IsSurfTexQueryFlag
) {
198 // This is a query, so operand 1 is a surfref/texref
199 if (OpNo
== 1 && MO
.isImm()) {
200 lowerImageHandleSymbol(MO
.getImm(), MCOp
);
210 void NVPTXAsmPrinter::lowerImageHandleSymbol(unsigned Index
, MCOperand
&MCOp
) {
212 LLVMTargetMachine
&TM
= const_cast<LLVMTargetMachine
&>(MF
->getTarget());
213 NVPTXTargetMachine
&nvTM
= static_cast<NVPTXTargetMachine
&>(TM
);
214 const NVPTXMachineFunctionInfo
*MFI
= MF
->getInfo
<NVPTXMachineFunctionInfo
>();
215 const char *Sym
= MFI
->getImageHandleSymbol(Index
);
216 StringRef SymName
= nvTM
.getStrPool().save(Sym
);
217 MCOp
= GetSymbolRef(OutContext
.getOrCreateSymbol(SymName
));
220 void NVPTXAsmPrinter::lowerToMCInst(const MachineInstr
*MI
, MCInst
&OutMI
) {
221 OutMI
.setOpcode(MI
->getOpcode());
222 // Special: Do not mangle symbol operand of CALL_PROTOTYPE
223 if (MI
->getOpcode() == NVPTX::CALL_PROTOTYPE
) {
224 const MachineOperand
&MO
= MI
->getOperand(0);
225 OutMI
.addOperand(GetSymbolRef(
226 OutContext
.getOrCreateSymbol(Twine(MO
.getSymbolName()))));
230 const NVPTXSubtarget
&STI
= MI
->getMF()->getSubtarget
<NVPTXSubtarget
>();
231 for (unsigned i
= 0, e
= MI
->getNumOperands(); i
!= e
; ++i
) {
232 const MachineOperand
&MO
= MI
->getOperand(i
);
235 if (!STI
.hasImageHandles()) {
236 if (lowerImageHandleOperand(MI
, i
, MCOp
)) {
237 OutMI
.addOperand(MCOp
);
242 if (lowerOperand(MO
, MCOp
))
243 OutMI
.addOperand(MCOp
);
247 bool NVPTXAsmPrinter::lowerOperand(const MachineOperand
&MO
,
249 switch (MO
.getType()) {
250 default: llvm_unreachable("unknown operand type");
251 case MachineOperand::MO_Register
:
252 MCOp
= MCOperand::createReg(encodeVirtualRegister(MO
.getReg()));
254 case MachineOperand::MO_Immediate
:
255 MCOp
= MCOperand::createImm(MO
.getImm());
257 case MachineOperand::MO_MachineBasicBlock
:
258 MCOp
= MCOperand::createExpr(MCSymbolRefExpr::create(
259 MO
.getMBB()->getSymbol(), OutContext
));
261 case MachineOperand::MO_ExternalSymbol
:
262 MCOp
= GetSymbolRef(GetExternalSymbolSymbol(MO
.getSymbolName()));
264 case MachineOperand::MO_GlobalAddress
:
265 MCOp
= GetSymbolRef(getSymbol(MO
.getGlobal()));
267 case MachineOperand::MO_FPImmediate
: {
268 const ConstantFP
*Cnt
= MO
.getFPImm();
269 const APFloat
&Val
= Cnt
->getValueAPF();
271 switch (Cnt
->getType()->getTypeID()) {
272 default: report_fatal_error("Unsupported FP type"); break;
274 MCOp
= MCOperand::createExpr(
275 NVPTXFloatMCExpr::createConstantFPHalf(Val
, OutContext
));
277 case Type::BFloatTyID
:
278 MCOp
= MCOperand::createExpr(
279 NVPTXFloatMCExpr::createConstantBFPHalf(Val
, OutContext
));
281 case Type::FloatTyID
:
282 MCOp
= MCOperand::createExpr(
283 NVPTXFloatMCExpr::createConstantFPSingle(Val
, OutContext
));
285 case Type::DoubleTyID
:
286 MCOp
= MCOperand::createExpr(
287 NVPTXFloatMCExpr::createConstantFPDouble(Val
, OutContext
));
296 unsigned NVPTXAsmPrinter::encodeVirtualRegister(unsigned Reg
) {
297 if (Register::isVirtualRegister(Reg
)) {
298 const TargetRegisterClass
*RC
= MRI
->getRegClass(Reg
);
300 DenseMap
<unsigned, unsigned> &RegMap
= VRegMapping
[RC
];
301 unsigned RegNum
= RegMap
[Reg
];
303 // Encode the register class in the upper 4 bits
304 // Must be kept in sync with NVPTXInstPrinter::printRegName
306 if (RC
== &NVPTX::Int1RegsRegClass
) {
308 } else if (RC
== &NVPTX::Int16RegsRegClass
) {
310 } else if (RC
== &NVPTX::Int32RegsRegClass
) {
312 } else if (RC
== &NVPTX::Int64RegsRegClass
) {
314 } else if (RC
== &NVPTX::Float32RegsRegClass
) {
316 } else if (RC
== &NVPTX::Float64RegsRegClass
) {
318 } else if (RC
== &NVPTX::Int128RegsRegClass
) {
321 report_fatal_error("Bad register class");
324 // Insert the vreg number
325 Ret
|= (RegNum
& 0x0FFFFFFF);
328 // Some special-use registers are actually physical registers.
329 // Encode this as the register class ID of 0 and the real register ID.
330 return Reg
& 0x0FFFFFFF;
334 MCOperand
NVPTXAsmPrinter::GetSymbolRef(const MCSymbol
*Symbol
) {
336 Expr
= MCSymbolRefExpr::create(Symbol
, MCSymbolRefExpr::VK_None
,
338 return MCOperand::createExpr(Expr
);
341 static bool ShouldPassAsArray(Type
*Ty
) {
342 return Ty
->isAggregateType() || Ty
->isVectorTy() || Ty
->isIntegerTy(128) ||
343 Ty
->isHalfTy() || Ty
->isBFloatTy();
346 void NVPTXAsmPrinter::printReturnValStr(const Function
*F
, raw_ostream
&O
) {
347 const DataLayout
&DL
= getDataLayout();
348 const NVPTXSubtarget
&STI
= TM
.getSubtarget
<NVPTXSubtarget
>(*F
);
349 const auto *TLI
= cast
<NVPTXTargetLowering
>(STI
.getTargetLowering());
351 Type
*Ty
= F
->getReturnType();
353 bool isABI
= (STI
.getSmVersion() >= 20);
355 if (Ty
->getTypeID() == Type::VoidTyID
)
360 if ((Ty
->isFloatingPointTy() || Ty
->isIntegerTy()) &&
361 !ShouldPassAsArray(Ty
)) {
363 if (auto *ITy
= dyn_cast
<IntegerType
>(Ty
)) {
364 size
= ITy
->getBitWidth();
366 assert(Ty
->isFloatingPointTy() && "Floating point type expected here");
367 size
= Ty
->getPrimitiveSizeInBits();
369 size
= promoteScalarArgumentSize(size
);
370 O
<< ".param .b" << size
<< " func_retval0";
371 } else if (isa
<PointerType
>(Ty
)) {
372 O
<< ".param .b" << TLI
->getPointerTy(DL
).getSizeInBits()
374 } else if (ShouldPassAsArray(Ty
)) {
375 unsigned totalsz
= DL
.getTypeAllocSize(Ty
);
376 Align RetAlignment
= TLI
->getFunctionArgumentAlignment(
377 F
, Ty
, AttributeList::ReturnIndex
, DL
);
378 O
<< ".param .align " << RetAlignment
.value() << " .b8 func_retval0["
381 llvm_unreachable("Unknown return type");
383 SmallVector
<EVT
, 16> vtparts
;
384 ComputeValueVTs(*TLI
, DL
, Ty
, vtparts
);
386 for (unsigned i
= 0, e
= vtparts
.size(); i
!= e
; ++i
) {
388 EVT elemtype
= vtparts
[i
];
389 if (vtparts
[i
].isVector()) {
390 elems
= vtparts
[i
].getVectorNumElements();
391 elemtype
= vtparts
[i
].getVectorElementType();
394 for (unsigned j
= 0, je
= elems
; j
!= je
; ++j
) {
395 unsigned sz
= elemtype
.getSizeInBits();
396 if (elemtype
.isInteger())
397 sz
= promoteScalarArgumentSize(sz
);
398 O
<< ".reg .b" << sz
<< " func_retval" << idx
;
410 void NVPTXAsmPrinter::printReturnValStr(const MachineFunction
&MF
,
412 const Function
&F
= MF
.getFunction();
413 printReturnValStr(&F
, O
);
416 // Return true if MBB is the header of a loop marked with
417 // llvm.loop.unroll.disable or llvm.loop.unroll.count=1.
418 bool NVPTXAsmPrinter::isLoopHeaderOfNoUnroll(
419 const MachineBasicBlock
&MBB
) const {
420 MachineLoopInfo
&LI
= getAnalysis
<MachineLoopInfoWrapperPass
>().getLI();
421 // We insert .pragma "nounroll" only to the loop header.
422 if (!LI
.isLoopHeader(&MBB
))
425 // llvm.loop.unroll.disable is marked on the back edges of a loop. Therefore,
426 // we iterate through each back edge of the loop with header MBB, and check
427 // whether its metadata contains llvm.loop.unroll.disable.
428 for (const MachineBasicBlock
*PMBB
: MBB
.predecessors()) {
429 if (LI
.getLoopFor(PMBB
) != LI
.getLoopFor(&MBB
)) {
430 // Edges from other loops to MBB are not back edges.
433 if (const BasicBlock
*PBB
= PMBB
->getBasicBlock()) {
435 PBB
->getTerminator()->getMetadata(LLVMContext::MD_loop
)) {
436 if (GetUnrollMetadata(LoopID
, "llvm.loop.unroll.disable"))
438 if (MDNode
*UnrollCountMD
=
439 GetUnrollMetadata(LoopID
, "llvm.loop.unroll.count")) {
440 if (mdconst::extract
<ConstantInt
>(UnrollCountMD
->getOperand(1))
450 void NVPTXAsmPrinter::emitBasicBlockStart(const MachineBasicBlock
&MBB
) {
451 AsmPrinter::emitBasicBlockStart(MBB
);
452 if (isLoopHeaderOfNoUnroll(MBB
))
453 OutStreamer
->emitRawText(StringRef("\t.pragma \"nounroll\";\n"));
456 void NVPTXAsmPrinter::emitFunctionEntryLabel() {
457 SmallString
<128> Str
;
458 raw_svector_ostream
O(Str
);
460 if (!GlobalsEmitted
) {
461 emitGlobals(*MF
->getFunction().getParent());
462 GlobalsEmitted
= true;
466 MRI
= &MF
->getRegInfo();
467 F
= &MF
->getFunction();
468 emitLinkageDirective(F
, O
);
469 if (isKernelFunction(*F
))
473 printReturnValStr(*MF
, O
);
476 CurrentFnSym
->print(O
, MAI
);
478 emitFunctionParamList(F
, O
);
481 if (isKernelFunction(*F
))
482 emitKernelFunctionDirectives(*F
, O
);
484 if (shouldEmitPTXNoReturn(F
, TM
))
487 OutStreamer
->emitRawText(O
.str());
490 // Emit open brace for function body.
491 OutStreamer
->emitRawText(StringRef("{\n"));
492 setAndEmitFunctionVirtualRegisters(*MF
);
493 // Emit initial .loc debug directive for correct relocation symbol data.
494 if (const DISubprogram
*SP
= MF
->getFunction().getSubprogram()) {
495 assert(SP
->getUnit());
496 if (!SP
->getUnit()->isDebugDirectivesOnly() && MMI
&& MMI
->hasDebugInfo())
497 emitInitialRawDwarfLocDirective(*MF
);
501 bool NVPTXAsmPrinter::runOnMachineFunction(MachineFunction
&F
) {
502 bool Result
= AsmPrinter::runOnMachineFunction(F
);
503 // Emit closing brace for the body of function F.
504 // The closing brace must be emitted here because we need to emit additional
505 // debug labels/data after the last basic block.
506 // We need to emit the closing brace here because we don't have function that
507 // finished emission of the function body.
508 OutStreamer
->emitRawText(StringRef("}\n"));
512 void NVPTXAsmPrinter::emitFunctionBodyStart() {
513 SmallString
<128> Str
;
514 raw_svector_ostream
O(Str
);
515 emitDemotedVars(&MF
->getFunction(), O
);
516 OutStreamer
->emitRawText(O
.str());
519 void NVPTXAsmPrinter::emitFunctionBodyEnd() {
523 const MCSymbol
*NVPTXAsmPrinter::getFunctionFrameSymbol() const {
524 SmallString
<128> Str
;
525 raw_svector_ostream(Str
) << DEPOTNAME
<< getFunctionNumber();
526 return OutContext
.getOrCreateSymbol(Str
);
529 void NVPTXAsmPrinter::emitImplicitDef(const MachineInstr
*MI
) const {
530 Register RegNo
= MI
->getOperand(0).getReg();
531 if (RegNo
.isVirtual()) {
532 OutStreamer
->AddComment(Twine("implicit-def: ") +
533 getVirtualRegisterName(RegNo
));
535 const NVPTXSubtarget
&STI
= MI
->getMF()->getSubtarget
<NVPTXSubtarget
>();
536 OutStreamer
->AddComment(Twine("implicit-def: ") +
537 STI
.getRegisterInfo()->getName(RegNo
));
539 OutStreamer
->addBlankLine();
542 void NVPTXAsmPrinter::emitKernelFunctionDirectives(const Function
&F
,
543 raw_ostream
&O
) const {
544 // If the NVVM IR has some of reqntid* specified, then output
545 // the reqntid directive, and set the unspecified ones to 1.
546 // If none of Reqntid* is specified, don't output reqntid directive.
547 std::optional
<unsigned> Reqntidx
= getReqNTIDx(F
);
548 std::optional
<unsigned> Reqntidy
= getReqNTIDy(F
);
549 std::optional
<unsigned> Reqntidz
= getReqNTIDz(F
);
551 if (Reqntidx
|| Reqntidy
|| Reqntidz
)
552 O
<< ".reqntid " << Reqntidx
.value_or(1) << ", " << Reqntidy
.value_or(1)
553 << ", " << Reqntidz
.value_or(1) << "\n";
555 // If the NVVM IR has some of maxntid* specified, then output
556 // the maxntid directive, and set the unspecified ones to 1.
557 // If none of maxntid* is specified, don't output maxntid directive.
558 std::optional
<unsigned> Maxntidx
= getMaxNTIDx(F
);
559 std::optional
<unsigned> Maxntidy
= getMaxNTIDy(F
);
560 std::optional
<unsigned> Maxntidz
= getMaxNTIDz(F
);
562 if (Maxntidx
|| Maxntidy
|| Maxntidz
)
563 O
<< ".maxntid " << Maxntidx
.value_or(1) << ", " << Maxntidy
.value_or(1)
564 << ", " << Maxntidz
.value_or(1) << "\n";
567 if (getMinCTASm(F
, Mincta
))
568 O
<< ".minnctapersm " << Mincta
<< "\n";
570 unsigned Maxnreg
= 0;
571 if (getMaxNReg(F
, Maxnreg
))
572 O
<< ".maxnreg " << Maxnreg
<< "\n";
574 // .maxclusterrank directive requires SM_90 or higher, make sure that we
575 // filter it out for lower SM versions, as it causes a hard ptxas crash.
576 const NVPTXTargetMachine
&NTM
= static_cast<const NVPTXTargetMachine
&>(TM
);
577 const auto *STI
= static_cast<const NVPTXSubtarget
*>(NTM
.getSubtargetImpl());
578 unsigned Maxclusterrank
= 0;
579 if (getMaxClusterRank(F
, Maxclusterrank
) && STI
->getSmVersion() >= 90)
580 O
<< ".maxclusterrank " << Maxclusterrank
<< "\n";
583 std::string
NVPTXAsmPrinter::getVirtualRegisterName(unsigned Reg
) const {
584 const TargetRegisterClass
*RC
= MRI
->getRegClass(Reg
);
587 raw_string_ostream
NameStr(Name
);
589 VRegRCMap::const_iterator I
= VRegMapping
.find(RC
);
590 assert(I
!= VRegMapping
.end() && "Bad register class");
591 const DenseMap
<unsigned, unsigned> &RegMap
= I
->second
;
593 VRegMap::const_iterator VI
= RegMap
.find(Reg
);
594 assert(VI
!= RegMap
.end() && "Bad virtual register");
595 unsigned MappedVR
= VI
->second
;
597 NameStr
<< getNVPTXRegClassStr(RC
) << MappedVR
;
603 void NVPTXAsmPrinter::emitVirtualRegister(unsigned int vr
,
605 O
<< getVirtualRegisterName(vr
);
608 void NVPTXAsmPrinter::emitAliasDeclaration(const GlobalAlias
*GA
,
610 const Function
*F
= dyn_cast_or_null
<Function
>(GA
->getAliaseeObject());
611 if (!F
|| isKernelFunction(*F
) || F
->isDeclaration())
613 "NVPTX aliasee must be a non-kernel function definition");
615 if (GA
->hasLinkOnceLinkage() || GA
->hasWeakLinkage() ||
616 GA
->hasAvailableExternallyLinkage() || GA
->hasCommonLinkage())
617 report_fatal_error("NVPTX aliasee must not be '.weak'");
619 emitDeclarationWithName(F
, getSymbol(GA
), O
);
622 void NVPTXAsmPrinter::emitDeclaration(const Function
*F
, raw_ostream
&O
) {
623 emitDeclarationWithName(F
, getSymbol(F
), O
);
626 void NVPTXAsmPrinter::emitDeclarationWithName(const Function
*F
, MCSymbol
*S
,
628 emitLinkageDirective(F
, O
);
629 if (isKernelFunction(*F
))
633 printReturnValStr(F
, O
);
636 emitFunctionParamList(F
, O
);
638 if (shouldEmitPTXNoReturn(F
, TM
))
643 static bool usedInGlobalVarDef(const Constant
*C
) {
647 if (const GlobalVariable
*GV
= dyn_cast
<GlobalVariable
>(C
)) {
648 return GV
->getName() != "llvm.used";
651 for (const User
*U
: C
->users())
652 if (const Constant
*C
= dyn_cast
<Constant
>(U
))
653 if (usedInGlobalVarDef(C
))
659 static bool usedInOneFunc(const User
*U
, Function
const *&oneFunc
) {
660 if (const GlobalVariable
*othergv
= dyn_cast
<GlobalVariable
>(U
)) {
661 if (othergv
->getName() == "llvm.used")
665 if (const Instruction
*instr
= dyn_cast
<Instruction
>(U
)) {
666 if (instr
->getParent() && instr
->getParent()->getParent()) {
667 const Function
*curFunc
= instr
->getParent()->getParent();
668 if (oneFunc
&& (curFunc
!= oneFunc
))
676 for (const User
*UU
: U
->users())
677 if (!usedInOneFunc(UU
, oneFunc
))
683 /* Find out if a global variable can be demoted to local scope.
684 * Currently, this is valid for CUDA shared variables, which have local
685 * scope and global lifetime. So the conditions to check are :
686 * 1. Is the global variable in shared address space?
687 * 2. Does it have local linkage?
688 * 3. Is the global variable referenced only in one function?
690 static bool canDemoteGlobalVar(const GlobalVariable
*gv
, Function
const *&f
) {
691 if (!gv
->hasLocalLinkage())
693 PointerType
*Pty
= gv
->getType();
694 if (Pty
->getAddressSpace() != ADDRESS_SPACE_SHARED
)
697 const Function
*oneFunc
= nullptr;
699 bool flag
= usedInOneFunc(gv
, oneFunc
);
708 static bool useFuncSeen(const Constant
*C
,
709 DenseMap
<const Function
*, bool> &seenMap
) {
710 for (const User
*U
: C
->users()) {
711 if (const Constant
*cu
= dyn_cast
<Constant
>(U
)) {
712 if (useFuncSeen(cu
, seenMap
))
714 } else if (const Instruction
*I
= dyn_cast
<Instruction
>(U
)) {
715 const BasicBlock
*bb
= I
->getParent();
718 const Function
*caller
= bb
->getParent();
721 if (seenMap
.contains(caller
))
728 void NVPTXAsmPrinter::emitDeclarations(const Module
&M
, raw_ostream
&O
) {
729 DenseMap
<const Function
*, bool> seenMap
;
730 for (const Function
&F
: M
) {
731 if (F
.getAttributes().hasFnAttr("nvptx-libcall-callee")) {
732 emitDeclaration(&F
, O
);
736 if (F
.isDeclaration()) {
739 if (F
.getIntrinsicID())
741 emitDeclaration(&F
, O
);
744 for (const User
*U
: F
.users()) {
745 if (const Constant
*C
= dyn_cast
<Constant
>(U
)) {
746 if (usedInGlobalVarDef(C
)) {
747 // The use is in the initialization of a global variable
748 // that is a function pointer, so print a declaration
749 // for the original function
750 emitDeclaration(&F
, O
);
753 // Emit a declaration of this function if the function that
754 // uses this constant expr has already been seen.
755 if (useFuncSeen(C
, seenMap
)) {
756 emitDeclaration(&F
, O
);
761 if (!isa
<Instruction
>(U
))
763 const Instruction
*instr
= cast
<Instruction
>(U
);
764 const BasicBlock
*bb
= instr
->getParent();
767 const Function
*caller
= bb
->getParent();
771 // If a caller has already been seen, then the caller is
772 // appearing in the module before the callee. so print out
773 // a declaration for the callee.
774 if (seenMap
.contains(caller
)) {
775 emitDeclaration(&F
, O
);
781 for (const GlobalAlias
&GA
: M
.aliases())
782 emitAliasDeclaration(&GA
, O
);
785 static bool isEmptyXXStructor(GlobalVariable
*GV
) {
786 if (!GV
) return true;
787 const ConstantArray
*InitList
= dyn_cast
<ConstantArray
>(GV
->getInitializer());
788 if (!InitList
) return true; // Not an array; we don't know how to parse.
789 return InitList
->getNumOperands() == 0;
792 void NVPTXAsmPrinter::emitStartOfAsmFile(Module
&M
) {
793 // Construct a default subtarget off of the TargetMachine defaults. The
794 // rest of NVPTX isn't friendly to change subtargets per function and
795 // so the default TargetMachine will have all of the options.
796 const NVPTXTargetMachine
&NTM
= static_cast<const NVPTXTargetMachine
&>(TM
);
797 const auto* STI
= static_cast<const NVPTXSubtarget
*>(NTM
.getSubtargetImpl());
798 SmallString
<128> Str1
;
799 raw_svector_ostream
OS1(Str1
);
801 // Emit header before any dwarf directives are emitted below.
802 emitHeader(M
, OS1
, *STI
);
803 OutStreamer
->emitRawText(OS1
.str());
806 bool NVPTXAsmPrinter::doInitialization(Module
&M
) {
807 const NVPTXTargetMachine
&NTM
= static_cast<const NVPTXTargetMachine
&>(TM
);
808 const NVPTXSubtarget
&STI
=
809 *static_cast<const NVPTXSubtarget
*>(NTM
.getSubtargetImpl());
810 if (M
.alias_size() && (STI
.getPTXVersion() < 63 || STI
.getSmVersion() < 30))
811 report_fatal_error(".alias requires PTX version >= 6.3 and sm_30");
813 // OpenMP supports NVPTX global constructors and destructors.
814 bool IsOpenMP
= M
.getModuleFlag("openmp") != nullptr;
816 if (!isEmptyXXStructor(M
.getNamedGlobal("llvm.global_ctors")) &&
817 !LowerCtorDtor
&& !IsOpenMP
) {
819 "Module has a nontrivial global ctor, which NVPTX does not support.");
820 return true; // error
822 if (!isEmptyXXStructor(M
.getNamedGlobal("llvm.global_dtors")) &&
823 !LowerCtorDtor
&& !IsOpenMP
) {
825 "Module has a nontrivial global dtor, which NVPTX does not support.");
826 return true; // error
829 // We need to call the parent's one explicitly.
830 bool Result
= AsmPrinter::doInitialization(M
);
832 GlobalsEmitted
= false;
837 void NVPTXAsmPrinter::emitGlobals(const Module
&M
) {
838 SmallString
<128> Str2
;
839 raw_svector_ostream
OS2(Str2
);
841 emitDeclarations(M
, OS2
);
843 // As ptxas does not support forward references of globals, we need to first
844 // sort the list of module-level globals in def-use order. We visit each
845 // global variable in order, and ensure that we emit it *after* its dependent
846 // globals. We use a little extra memory maintaining both a set and a list to
847 // have fast searches while maintaining a strict ordering.
848 SmallVector
<const GlobalVariable
*, 8> Globals
;
849 DenseSet
<const GlobalVariable
*> GVVisited
;
850 DenseSet
<const GlobalVariable
*> GVVisiting
;
852 // Visit each global variable, in order
853 for (const GlobalVariable
&I
: M
.globals())
854 VisitGlobalVariableForEmission(&I
, Globals
, GVVisited
, GVVisiting
);
856 assert(GVVisited
.size() == M
.global_size() && "Missed a global variable");
857 assert(GVVisiting
.size() == 0 && "Did not fully process a global variable");
859 const NVPTXTargetMachine
&NTM
= static_cast<const NVPTXTargetMachine
&>(TM
);
860 const NVPTXSubtarget
&STI
=
861 *static_cast<const NVPTXSubtarget
*>(NTM
.getSubtargetImpl());
863 // Print out module-level global variables in proper order
864 for (const GlobalVariable
*GV
: Globals
)
865 printModuleLevelGV(GV
, OS2
, /*processDemoted=*/false, STI
);
869 OutStreamer
->emitRawText(OS2
.str());
872 void NVPTXAsmPrinter::emitGlobalAlias(const Module
&M
, const GlobalAlias
&GA
) {
873 SmallString
<128> Str
;
874 raw_svector_ostream
OS(Str
);
876 MCSymbol
*Name
= getSymbol(&GA
);
878 OS
<< ".alias " << Name
->getName() << ", " << GA
.getAliaseeObject()->getName()
881 OutStreamer
->emitRawText(OS
.str());
884 void NVPTXAsmPrinter::emitHeader(Module
&M
, raw_ostream
&O
,
885 const NVPTXSubtarget
&STI
) {
887 O
<< "// Generated by LLVM NVPTX Back-End\n";
891 unsigned PTXVersion
= STI
.getPTXVersion();
892 O
<< ".version " << (PTXVersion
/ 10) << "." << (PTXVersion
% 10) << "\n";
895 O
<< STI
.getTargetName();
897 const NVPTXTargetMachine
&NTM
= static_cast<const NVPTXTargetMachine
&>(TM
);
898 if (NTM
.getDrvInterface() == NVPTX::NVCL
)
899 O
<< ", texmode_independent";
901 bool HasFullDebugInfo
= false;
902 for (DICompileUnit
*CU
: M
.debug_compile_units()) {
903 switch(CU
->getEmissionKind()) {
904 case DICompileUnit::NoDebug
:
905 case DICompileUnit::DebugDirectivesOnly
:
907 case DICompileUnit::LineTablesOnly
:
908 case DICompileUnit::FullDebug
:
909 HasFullDebugInfo
= true;
912 if (HasFullDebugInfo
)
915 if (MMI
&& MMI
->hasDebugInfo() && HasFullDebugInfo
)
920 O
<< ".address_size ";
930 bool NVPTXAsmPrinter::doFinalization(Module
&M
) {
931 bool HasDebugInfo
= MMI
&& MMI
->hasDebugInfo();
933 // If we did not emit any functions, then the global declarations have not
935 if (!GlobalsEmitted
) {
937 GlobalsEmitted
= true;
940 // call doFinalization
941 bool ret
= AsmPrinter::doFinalization(M
);
943 clearAnnotationCache(&M
);
946 static_cast<NVPTXTargetStreamer
*>(OutStreamer
->getTargetStreamer());
947 // Close the last emitted section
949 TS
->closeLastSection();
950 // Emit empty .debug_loc section for better support of the empty files.
951 OutStreamer
->emitRawText("\t.section\t.debug_loc\t{\t}");
954 // Output last DWARF .file directives, if any.
955 TS
->outputDwarfFileDirectives();
960 // This function emits appropriate linkage directives for
961 // functions and global variables.
963 // extern function declaration -> .extern
964 // extern function definition -> .visible
965 // external global variable with init -> .visible
966 // external without init -> .extern
967 // appending -> not allowed, assert.
968 // for any linkage other than
969 // internal, private, linker_private,
970 // linker_private_weak, linker_private_weak_def_auto,
973 void NVPTXAsmPrinter::emitLinkageDirective(const GlobalValue
*V
,
975 if (static_cast<NVPTXTargetMachine
&>(TM
).getDrvInterface() == NVPTX::CUDA
) {
976 if (V
->hasExternalLinkage()) {
977 if (isa
<GlobalVariable
>(V
)) {
978 const GlobalVariable
*GVar
= cast
<GlobalVariable
>(V
);
980 if (GVar
->hasInitializer())
985 } else if (V
->isDeclaration())
989 } else if (V
->hasAppendingLinkage()) {
991 msg
.append("Error: ");
992 msg
.append("Symbol ");
994 msg
.append(std::string(V
->getName()));
995 msg
.append("has unsupported appending linkage type");
996 llvm_unreachable(msg
.c_str());
997 } else if (!V
->hasInternalLinkage() &&
998 !V
->hasPrivateLinkage()) {
1004 void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable
*GVar
,
1005 raw_ostream
&O
, bool processDemoted
,
1006 const NVPTXSubtarget
&STI
) {
1008 if (GVar
->hasSection()) {
1009 if (GVar
->getSection() == "llvm.metadata")
1013 // Skip LLVM intrinsic global variables
1014 if (GVar
->getName().starts_with("llvm.") ||
1015 GVar
->getName().starts_with("nvvm."))
1018 const DataLayout
&DL
= getDataLayout();
1020 // GlobalVariables are always constant pointers themselves.
1021 Type
*ETy
= GVar
->getValueType();
1023 if (GVar
->hasExternalLinkage()) {
1024 if (GVar
->hasInitializer())
1028 } else if (STI
.getPTXVersion() >= 50 && GVar
->hasCommonLinkage() &&
1029 GVar
->getAddressSpace() == ADDRESS_SPACE_GLOBAL
) {
1031 } else if (GVar
->hasLinkOnceLinkage() || GVar
->hasWeakLinkage() ||
1032 GVar
->hasAvailableExternallyLinkage() ||
1033 GVar
->hasCommonLinkage()) {
1037 if (isTexture(*GVar
)) {
1038 O
<< ".global .texref " << getTextureName(*GVar
) << ";\n";
1042 if (isSurface(*GVar
)) {
1043 O
<< ".global .surfref " << getSurfaceName(*GVar
) << ";\n";
1047 if (GVar
->isDeclaration()) {
1048 // (extern) declarations, no definition or initializer
1049 // Currently the only known declaration is for an automatic __local
1050 // (.shared) promoted to global.
1051 emitPTXGlobalVariable(GVar
, O
, STI
);
1056 if (isSampler(*GVar
)) {
1057 O
<< ".global .samplerref " << getSamplerName(*GVar
);
1059 const Constant
*Initializer
= nullptr;
1060 if (GVar
->hasInitializer())
1061 Initializer
= GVar
->getInitializer();
1062 const ConstantInt
*CI
= nullptr;
1064 CI
= dyn_cast
<ConstantInt
>(Initializer
);
1066 unsigned sample
= CI
->getZExtValue();
1071 addr
= ((sample
& __CLK_ADDRESS_MASK
) >> __CLK_ADDRESS_BASE
);
1073 O
<< "addr_mode_" << i
<< " = ";
1079 O
<< "clamp_to_border";
1082 O
<< "clamp_to_edge";
1093 O
<< "filter_mode = ";
1094 switch ((sample
& __CLK_FILTER_MASK
) >> __CLK_FILTER_BASE
) {
1102 llvm_unreachable("Anisotropic filtering is not supported");
1107 if (!((sample
& __CLK_NORMALIZED_MASK
) >> __CLK_NORMALIZED_BASE
)) {
1108 O
<< ", force_unnormalized_coords = 1";
1117 if (GVar
->hasPrivateLinkage()) {
1118 if (strncmp(GVar
->getName().data(), "unrollpragma", 12) == 0)
1121 // FIXME - need better way (e.g. Metadata) to avoid generating this global
1122 if (strncmp(GVar
->getName().data(), "filename", 8) == 0)
1124 if (GVar
->use_empty())
1128 const Function
*demotedFunc
= nullptr;
1129 if (!processDemoted
&& canDemoteGlobalVar(GVar
, demotedFunc
)) {
1130 O
<< "// " << GVar
->getName() << " has been demoted\n";
1131 if (localDecls
.find(demotedFunc
) != localDecls
.end())
1132 localDecls
[demotedFunc
].push_back(GVar
);
1134 std::vector
<const GlobalVariable
*> temp
;
1135 temp
.push_back(GVar
);
1136 localDecls
[demotedFunc
] = temp
;
1142 emitPTXAddressSpace(GVar
->getAddressSpace(), O
);
1144 if (isManaged(*GVar
)) {
1145 if (STI
.getPTXVersion() < 40 || STI
.getSmVersion() < 30) {
1147 ".attribute(.managed) requires PTX version >= 4.0 and sm_30");
1149 O
<< " .attribute(.managed)";
1152 if (MaybeAlign A
= GVar
->getAlign())
1153 O
<< " .align " << A
->value();
1155 O
<< " .align " << (int)DL
.getPrefTypeAlign(ETy
).value();
1157 if (ETy
->isFloatingPointTy() || ETy
->isPointerTy() ||
1158 (ETy
->isIntegerTy() && ETy
->getScalarSizeInBits() <= 64)) {
1160 // Special case: ABI requires that we use .u8 for predicates
1161 if (ETy
->isIntegerTy(1))
1164 O
<< getPTXFundamentalTypeStr(ETy
, false);
1166 getSymbol(GVar
)->print(O
, MAI
);
1168 // Ptx allows variable initilization only for constant and global state
1170 if (GVar
->hasInitializer()) {
1171 if ((GVar
->getAddressSpace() == ADDRESS_SPACE_GLOBAL
) ||
1172 (GVar
->getAddressSpace() == ADDRESS_SPACE_CONST
)) {
1173 const Constant
*Initializer
= GVar
->getInitializer();
1174 // 'undef' is treated as there is no value specified.
1175 if (!Initializer
->isNullValue() && !isa
<UndefValue
>(Initializer
)) {
1177 printScalarConstant(Initializer
, O
);
1180 // The frontend adds zero-initializer to device and constant variables
1181 // that don't have an initial value, and UndefValue to shared
1182 // variables, so skip warning for this case.
1183 if (!GVar
->getInitializer()->isNullValue() &&
1184 !isa
<UndefValue
>(GVar
->getInitializer())) {
1185 report_fatal_error("initial value of '" + GVar
->getName() +
1186 "' is not allowed in addrspace(" +
1187 Twine(GVar
->getAddressSpace()) + ")");
1192 uint64_t ElementSize
= 0;
1194 // Although PTX has direct support for struct type and array type and
1195 // LLVM IR is very similar to PTX, the LLVM CodeGen does not support for
1196 // targets that support these high level field accesses. Structs, arrays
1197 // and vectors are lowered into arrays of bytes.
1198 switch (ETy
->getTypeID()) {
1199 case Type::IntegerTyID
: // Integers larger than 64 bits
1200 case Type::StructTyID
:
1201 case Type::ArrayTyID
:
1202 case Type::FixedVectorTyID
:
1203 ElementSize
= DL
.getTypeStoreSize(ETy
);
1204 // Ptx allows variable initilization only for constant and
1205 // global state spaces.
1206 if (((GVar
->getAddressSpace() == ADDRESS_SPACE_GLOBAL
) ||
1207 (GVar
->getAddressSpace() == ADDRESS_SPACE_CONST
)) &&
1208 GVar
->hasInitializer()) {
1209 const Constant
*Initializer
= GVar
->getInitializer();
1210 if (!isa
<UndefValue
>(Initializer
) && !Initializer
->isNullValue()) {
1211 AggBuffer
aggBuffer(ElementSize
, *this);
1212 bufferAggregateConstant(Initializer
, &aggBuffer
);
1213 if (aggBuffer
.numSymbols()) {
1214 unsigned int ptrSize
= MAI
->getCodePointerSize();
1215 if (ElementSize
% ptrSize
||
1216 !aggBuffer
.allSymbolsAligned(ptrSize
)) {
1217 // Print in bytes and use the mask() operator for pointers.
1218 if (!STI
.hasMaskOperator())
1220 "initialized packed aggregate with pointers '" +
1222 "' requires at least PTX ISA version 7.1");
1224 getSymbol(GVar
)->print(O
, MAI
);
1225 O
<< "[" << ElementSize
<< "] = {";
1226 aggBuffer
.printBytes(O
);
1229 O
<< " .u" << ptrSize
* 8 << " ";
1230 getSymbol(GVar
)->print(O
, MAI
);
1231 O
<< "[" << ElementSize
/ ptrSize
<< "] = {";
1232 aggBuffer
.printWords(O
);
1237 getSymbol(GVar
)->print(O
, MAI
);
1238 O
<< "[" << ElementSize
<< "] = {";
1239 aggBuffer
.printBytes(O
);
1244 getSymbol(GVar
)->print(O
, MAI
);
1253 getSymbol(GVar
)->print(O
, MAI
);
1262 llvm_unreachable("type not supported yet");
1268 void NVPTXAsmPrinter::AggBuffer::printSymbol(unsigned nSym
, raw_ostream
&os
) {
1269 const Value
*v
= Symbols
[nSym
];
1270 const Value
*v0
= SymbolsBeforeStripping
[nSym
];
1271 if (const GlobalValue
*GVar
= dyn_cast
<GlobalValue
>(v
)) {
1272 MCSymbol
*Name
= AP
.getSymbol(GVar
);
1273 PointerType
*PTy
= dyn_cast
<PointerType
>(v0
->getType());
1274 // Is v0 a generic pointer?
1275 bool isGenericPointer
= PTy
&& PTy
->getAddressSpace() == 0;
1276 if (EmitGeneric
&& isGenericPointer
&& !isa
<Function
>(v
)) {
1278 Name
->print(os
, AP
.MAI
);
1281 Name
->print(os
, AP
.MAI
);
1283 } else if (const ConstantExpr
*CExpr
= dyn_cast
<ConstantExpr
>(v0
)) {
1284 const MCExpr
*Expr
= AP
.lowerConstantForGV(cast
<Constant
>(CExpr
), false);
1285 AP
.printMCExpr(*Expr
, os
);
1287 llvm_unreachable("symbol type unknown");
1290 void NVPTXAsmPrinter::AggBuffer::printBytes(raw_ostream
&os
) {
1291 unsigned int ptrSize
= AP
.MAI
->getCodePointerSize();
1292 // Do not emit trailing zero initializers. They will be zero-initialized by
1293 // ptxas. This saves on both space requirements for the generated PTX and on
1294 // memory use by ptxas. (See:
1295 // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#global-state-space)
1296 unsigned int InitializerCount
= size
;
1297 // TODO: symbols make this harder, but it would still be good to trim trailing
1298 // 0s for aggs with symbols as well.
1299 if (numSymbols() == 0)
1300 while (InitializerCount
>= 1 && !buffer
[InitializerCount
- 1])
1303 symbolPosInBuffer
.push_back(InitializerCount
);
1304 unsigned int nSym
= 0;
1305 unsigned int nextSymbolPos
= symbolPosInBuffer
[nSym
];
1306 for (unsigned int pos
= 0; pos
< InitializerCount
;) {
1309 if (pos
!= nextSymbolPos
) {
1310 os
<< (unsigned int)buffer
[pos
];
1314 // Generate a per-byte mask() operator for the symbol, which looks like:
1315 // .global .u8 addr[] = {0xFF(foo), 0xFF00(foo), 0xFF0000(foo), ...};
1316 // See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#initializers
1317 std::string symText
;
1318 llvm::raw_string_ostream
oss(symText
);
1319 printSymbol(nSym
, oss
);
1320 for (unsigned i
= 0; i
< ptrSize
; ++i
) {
1323 llvm::write_hex(os
, 0xFFULL
<< i
* 8, HexPrintStyle::PrefixUpper
);
1324 os
<< "(" << symText
<< ")";
1327 nextSymbolPos
= symbolPosInBuffer
[++nSym
];
1328 assert(nextSymbolPos
>= pos
);
1332 void NVPTXAsmPrinter::AggBuffer::printWords(raw_ostream
&os
) {
1333 unsigned int ptrSize
= AP
.MAI
->getCodePointerSize();
1334 symbolPosInBuffer
.push_back(size
);
1335 unsigned int nSym
= 0;
1336 unsigned int nextSymbolPos
= symbolPosInBuffer
[nSym
];
1337 assert(nextSymbolPos
% ptrSize
== 0);
1338 for (unsigned int pos
= 0; pos
< size
; pos
+= ptrSize
) {
1341 if (pos
== nextSymbolPos
) {
1342 printSymbol(nSym
, os
);
1343 nextSymbolPos
= symbolPosInBuffer
[++nSym
];
1344 assert(nextSymbolPos
% ptrSize
== 0);
1345 assert(nextSymbolPos
>= pos
+ ptrSize
);
1346 } else if (ptrSize
== 4)
1347 os
<< support::endian::read32le(&buffer
[pos
]);
1349 os
<< support::endian::read64le(&buffer
[pos
]);
1353 void NVPTXAsmPrinter::emitDemotedVars(const Function
*f
, raw_ostream
&O
) {
1354 if (localDecls
.find(f
) == localDecls
.end())
1357 std::vector
<const GlobalVariable
*> &gvars
= localDecls
[f
];
1359 const NVPTXTargetMachine
&NTM
= static_cast<const NVPTXTargetMachine
&>(TM
);
1360 const NVPTXSubtarget
&STI
=
1361 *static_cast<const NVPTXSubtarget
*>(NTM
.getSubtargetImpl());
1363 for (const GlobalVariable
*GV
: gvars
) {
1364 O
<< "\t// demoted variable\n\t";
1365 printModuleLevelGV(GV
, O
, /*processDemoted=*/true, STI
);
1369 void NVPTXAsmPrinter::emitPTXAddressSpace(unsigned int AddressSpace
,
1370 raw_ostream
&O
) const {
1371 switch (AddressSpace
) {
1372 case ADDRESS_SPACE_LOCAL
:
1375 case ADDRESS_SPACE_GLOBAL
:
1378 case ADDRESS_SPACE_CONST
:
1381 case ADDRESS_SPACE_SHARED
:
1385 report_fatal_error("Bad address space found while emitting PTX: " +
1386 llvm::Twine(AddressSpace
));
1392 NVPTXAsmPrinter::getPTXFundamentalTypeStr(Type
*Ty
, bool useB4PTR
) const {
1393 switch (Ty
->getTypeID()) {
1394 case Type::IntegerTyID
: {
1395 unsigned NumBits
= cast
<IntegerType
>(Ty
)->getBitWidth();
1398 else if (NumBits
<= 64) {
1399 std::string name
= "u";
1400 return name
+ utostr(NumBits
);
1402 llvm_unreachable("Integer too large");
1407 case Type::BFloatTyID
:
1408 case Type::HalfTyID
:
1409 // fp16 and bf16 are stored as .b16 for compatibility with pre-sm_53
1412 case Type::FloatTyID
:
1414 case Type::DoubleTyID
:
1416 case Type::PointerTyID
: {
1417 unsigned PtrSize
= TM
.getPointerSizeInBits(Ty
->getPointerAddressSpace());
1418 assert((PtrSize
== 64 || PtrSize
== 32) && "Unexpected pointer size");
1433 llvm_unreachable("unexpected type");
1436 void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable
*GVar
,
1438 const NVPTXSubtarget
&STI
) {
1439 const DataLayout
&DL
= getDataLayout();
1441 // GlobalVariables are always constant pointers themselves.
1442 Type
*ETy
= GVar
->getValueType();
1445 emitPTXAddressSpace(GVar
->getType()->getAddressSpace(), O
);
1446 if (isManaged(*GVar
)) {
1447 if (STI
.getPTXVersion() < 40 || STI
.getSmVersion() < 30) {
1449 ".attribute(.managed) requires PTX version >= 4.0 and sm_30");
1451 O
<< " .attribute(.managed)";
1453 if (MaybeAlign A
= GVar
->getAlign())
1454 O
<< " .align " << A
->value();
1456 O
<< " .align " << (int)DL
.getPrefTypeAlign(ETy
).value();
1458 // Special case for i128
1459 if (ETy
->isIntegerTy(128)) {
1461 getSymbol(GVar
)->print(O
, MAI
);
1466 if (ETy
->isFloatingPointTy() || ETy
->isIntOrPtrTy()) {
1468 O
<< getPTXFundamentalTypeStr(ETy
);
1470 getSymbol(GVar
)->print(O
, MAI
);
1474 int64_t ElementSize
= 0;
1476 // Although PTX has direct support for struct type and array type and LLVM IR
1477 // is very similar to PTX, the LLVM CodeGen does not support for targets that
1478 // support these high level field accesses. Structs and arrays are lowered
1479 // into arrays of bytes.
1480 switch (ETy
->getTypeID()) {
1481 case Type::StructTyID
:
1482 case Type::ArrayTyID
:
1483 case Type::FixedVectorTyID
:
1484 ElementSize
= DL
.getTypeStoreSize(ETy
);
1486 getSymbol(GVar
)->print(O
, MAI
);
1494 llvm_unreachable("type not supported yet");
1498 void NVPTXAsmPrinter::emitFunctionParamList(const Function
*F
, raw_ostream
&O
) {
1499 const DataLayout
&DL
= getDataLayout();
1500 const AttributeList
&PAL
= F
->getAttributes();
1501 const NVPTXSubtarget
&STI
= TM
.getSubtarget
<NVPTXSubtarget
>(*F
);
1502 const auto *TLI
= cast
<NVPTXTargetLowering
>(STI
.getTargetLowering());
1504 Function::const_arg_iterator I
, E
;
1505 unsigned paramIndex
= 0;
1507 bool isKernelFunc
= isKernelFunction(*F
);
1508 bool isABI
= (STI
.getSmVersion() >= 20);
1509 bool hasImageHandles
= STI
.hasImageHandles();
1511 if (F
->arg_empty() && !F
->isVarArg()) {
1518 for (I
= F
->arg_begin(), E
= F
->arg_end(); I
!= E
; ++I
, paramIndex
++) {
1519 Type
*Ty
= I
->getType();
1526 // Handle image/sampler parameters
1527 if (isKernelFunction(*F
)) {
1528 if (isSampler(*I
) || isImage(*I
)) {
1530 if (isImageWriteOnly(*I
) || isImageReadWrite(*I
)) {
1531 if (hasImageHandles
)
1532 O
<< "\t.param .u64 .ptr .surfref ";
1534 O
<< "\t.param .surfref ";
1535 O
<< TLI
->getParamName(F
, paramIndex
);
1537 else { // Default image is read_only
1538 if (hasImageHandles
)
1539 O
<< "\t.param .u64 .ptr .texref ";
1541 O
<< "\t.param .texref ";
1542 O
<< TLI
->getParamName(F
, paramIndex
);
1545 if (hasImageHandles
)
1546 O
<< "\t.param .u64 .ptr .samplerref ";
1548 O
<< "\t.param .samplerref ";
1549 O
<< TLI
->getParamName(F
, paramIndex
);
1555 auto getOptimalAlignForParam
= [TLI
, &DL
, &PAL
, F
,
1556 paramIndex
](Type
*Ty
) -> Align
{
1557 if (MaybeAlign StackAlign
=
1558 getAlign(*F
, paramIndex
+ AttributeList::FirstArgIndex
))
1559 return StackAlign
.value();
1561 Align TypeAlign
= TLI
->getFunctionParamOptimizedAlign(F
, Ty
, DL
);
1562 MaybeAlign ParamAlign
= PAL
.getParamAlignment(paramIndex
);
1563 return std::max(TypeAlign
, ParamAlign
.valueOrOne());
1566 if (!PAL
.hasParamAttr(paramIndex
, Attribute::ByVal
)) {
1567 if (ShouldPassAsArray(Ty
)) {
1568 // Just print .param .align <a> .b8 .param[size];
1569 // <a> = optimal alignment for the element type; always multiple of
1570 // PAL.getParamAlignment
1571 // size = typeallocsize of element type
1572 Align OptimalAlign
= getOptimalAlignForParam(Ty
);
1574 O
<< "\t.param .align " << OptimalAlign
.value() << " .b8 ";
1575 O
<< TLI
->getParamName(F
, paramIndex
);
1576 O
<< "[" << DL
.getTypeAllocSize(Ty
) << "]";
1581 auto *PTy
= dyn_cast
<PointerType
>(Ty
);
1582 unsigned PTySizeInBits
= 0;
1585 TLI
->getPointerTy(DL
, PTy
->getAddressSpace()).getSizeInBits();
1586 assert(PTySizeInBits
&& "Invalid pointer size");
1591 // Special handling for pointer arguments to kernel
1592 O
<< "\t.param .u" << PTySizeInBits
<< " ";
1594 if (static_cast<NVPTXTargetMachine
&>(TM
).getDrvInterface() !=
1596 int addrSpace
= PTy
->getAddressSpace();
1597 switch (addrSpace
) {
1601 case ADDRESS_SPACE_CONST
:
1602 O
<< ".ptr .const ";
1604 case ADDRESS_SPACE_SHARED
:
1605 O
<< ".ptr .shared ";
1607 case ADDRESS_SPACE_GLOBAL
:
1608 O
<< ".ptr .global ";
1611 Align ParamAlign
= I
->getParamAlign().valueOrOne();
1612 O
<< ".align " << ParamAlign
.value() << " ";
1614 O
<< TLI
->getParamName(F
, paramIndex
);
1618 // non-pointer scalar to kernel func
1620 // Special case: predicate operands become .u8 types
1621 if (Ty
->isIntegerTy(1))
1624 O
<< getPTXFundamentalTypeStr(Ty
);
1626 O
<< TLI
->getParamName(F
, paramIndex
);
1629 // Non-kernel function, just print .param .b<size> for ABI
1630 // and .reg .b<size> for non-ABI
1632 if (isa
<IntegerType
>(Ty
)) {
1633 sz
= cast
<IntegerType
>(Ty
)->getBitWidth();
1634 sz
= promoteScalarArgumentSize(sz
);
1636 assert(PTySizeInBits
&& "Invalid pointer size");
1639 sz
= Ty
->getPrimitiveSizeInBits();
1641 O
<< "\t.param .b" << sz
<< " ";
1643 O
<< "\t.reg .b" << sz
<< " ";
1644 O
<< TLI
->getParamName(F
, paramIndex
);
1648 // param has byVal attribute.
1649 Type
*ETy
= PAL
.getParamByValType(paramIndex
);
1650 assert(ETy
&& "Param should have byval type");
1652 if (isABI
|| isKernelFunc
) {
1653 // Just print .param .align <a> .b8 .param[size];
1654 // <a> = optimal alignment for the element type; always multiple of
1655 // PAL.getParamAlignment
1656 // size = typeallocsize of element type
1657 Align OptimalAlign
=
1659 ? getOptimalAlignForParam(ETy
)
1660 : TLI
->getFunctionByValParamAlign(
1661 F
, ETy
, PAL
.getParamAlignment(paramIndex
).valueOrOne(), DL
);
1663 unsigned sz
= DL
.getTypeAllocSize(ETy
);
1664 O
<< "\t.param .align " << OptimalAlign
.value() << " .b8 ";
1665 O
<< TLI
->getParamName(F
, paramIndex
);
1666 O
<< "[" << sz
<< "]";
1669 // Split the ETy into constituent parts and
1670 // print .param .b<size> <name> for each part.
1671 // Further, if a part is vector, print the above for
1672 // each vector element.
1673 SmallVector
<EVT
, 16> vtparts
;
1674 ComputeValueVTs(*TLI
, DL
, ETy
, vtparts
);
1675 for (unsigned i
= 0, e
= vtparts
.size(); i
!= e
; ++i
) {
1677 EVT elemtype
= vtparts
[i
];
1678 if (vtparts
[i
].isVector()) {
1679 elems
= vtparts
[i
].getVectorNumElements();
1680 elemtype
= vtparts
[i
].getVectorElementType();
1683 for (unsigned j
= 0, je
= elems
; j
!= je
; ++j
) {
1684 unsigned sz
= elemtype
.getSizeInBits();
1685 if (elemtype
.isInteger())
1686 sz
= promoteScalarArgumentSize(sz
);
1687 O
<< "\t.reg .b" << sz
<< " ";
1688 O
<< TLI
->getParamName(F
, paramIndex
);
1701 if (F
->isVarArg()) {
1704 O
<< "\t.param .align " << STI
.getMaxRequiredAlignment();
1706 O
<< TLI
->getParamName(F
, /* vararg */ -1) << "[]";
1712 void NVPTXAsmPrinter::setAndEmitFunctionVirtualRegisters(
1713 const MachineFunction
&MF
) {
1714 SmallString
<128> Str
;
1715 raw_svector_ostream
O(Str
);
1717 // Map the global virtual register number to a register class specific
1718 // virtual register number starting from 1 with that class.
1719 const TargetRegisterInfo
*TRI
= MF
.getSubtarget().getRegisterInfo();
1720 //unsigned numRegClasses = TRI->getNumRegClasses();
1722 // Emit the Fake Stack Object
1723 const MachineFrameInfo
&MFI
= MF
.getFrameInfo();
1724 int64_t NumBytes
= MFI
.getStackSize();
1726 O
<< "\t.local .align " << MFI
.getMaxAlign().value() << " .b8 \t"
1727 << DEPOTNAME
<< getFunctionNumber() << "[" << NumBytes
<< "];\n";
1728 if (static_cast<const NVPTXTargetMachine
&>(MF
.getTarget()).is64Bit()) {
1729 O
<< "\t.reg .b64 \t%SP;\n";
1730 O
<< "\t.reg .b64 \t%SPL;\n";
1732 O
<< "\t.reg .b32 \t%SP;\n";
1733 O
<< "\t.reg .b32 \t%SPL;\n";
1737 // Go through all virtual registers to establish the mapping between the
1739 // register number and the per class virtual register number.
1740 // We use the per class virtual register number in the ptx output.
1741 unsigned int numVRs
= MRI
->getNumVirtRegs();
1742 for (unsigned i
= 0; i
< numVRs
; i
++) {
1743 Register vr
= Register::index2VirtReg(i
);
1744 const TargetRegisterClass
*RC
= MRI
->getRegClass(vr
);
1745 DenseMap
<unsigned, unsigned> ®map
= VRegMapping
[RC
];
1746 int n
= regmap
.size();
1747 regmap
.insert(std::make_pair(vr
, n
+ 1));
1750 // Emit register declarations
1751 // @TODO: Extract out the real register usage
1752 // O << "\t.reg .pred %p<" << NVPTXNumRegisters << ">;\n";
1753 // O << "\t.reg .s16 %rc<" << NVPTXNumRegisters << ">;\n";
1754 // O << "\t.reg .s16 %rs<" << NVPTXNumRegisters << ">;\n";
1755 // O << "\t.reg .s32 %r<" << NVPTXNumRegisters << ">;\n";
1756 // O << "\t.reg .s64 %rd<" << NVPTXNumRegisters << ">;\n";
1757 // O << "\t.reg .f32 %f<" << NVPTXNumRegisters << ">;\n";
1758 // O << "\t.reg .f64 %fd<" << NVPTXNumRegisters << ">;\n";
1760 // Emit declaration of the virtual registers or 'physical' registers for
1761 // each register class
1762 for (unsigned i
=0; i
< TRI
->getNumRegClasses(); i
++) {
1763 const TargetRegisterClass
*RC
= TRI
->getRegClass(i
);
1764 DenseMap
<unsigned, unsigned> ®map
= VRegMapping
[RC
];
1765 std::string rcname
= getNVPTXRegClassName(RC
);
1766 std::string rcStr
= getNVPTXRegClassStr(RC
);
1767 int n
= regmap
.size();
1769 // Only declare those registers that may be used.
1771 O
<< "\t.reg " << rcname
<< " \t" << rcStr
<< "<" << (n
+1)
1776 OutStreamer
->emitRawText(O
.str());
1779 void NVPTXAsmPrinter::printFPConstant(const ConstantFP
*Fp
, raw_ostream
&O
) {
1780 APFloat APF
= APFloat(Fp
->getValueAPF()); // make a copy
1782 unsigned int numHex
;
1785 if (Fp
->getType()->getTypeID() == Type::FloatTyID
) {
1788 APF
.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven
, &ignored
);
1789 } else if (Fp
->getType()->getTypeID() == Type::DoubleTyID
) {
1792 APF
.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven
, &ignored
);
1794 llvm_unreachable("unsupported fp type");
1796 APInt API
= APF
.bitcastToAPInt();
1797 O
<< lead
<< format_hex_no_prefix(API
.getZExtValue(), numHex
, /*Upper=*/true);
1800 void NVPTXAsmPrinter::printScalarConstant(const Constant
*CPV
, raw_ostream
&O
) {
1801 if (const ConstantInt
*CI
= dyn_cast
<ConstantInt
>(CPV
)) {
1802 O
<< CI
->getValue();
1805 if (const ConstantFP
*CFP
= dyn_cast
<ConstantFP
>(CPV
)) {
1806 printFPConstant(CFP
, O
);
1809 if (isa
<ConstantPointerNull
>(CPV
)) {
1813 if (const GlobalValue
*GVar
= dyn_cast
<GlobalValue
>(CPV
)) {
1814 bool IsNonGenericPointer
= false;
1815 if (GVar
->getType()->getAddressSpace() != 0) {
1816 IsNonGenericPointer
= true;
1818 if (EmitGeneric
&& !isa
<Function
>(CPV
) && !IsNonGenericPointer
) {
1820 getSymbol(GVar
)->print(O
, MAI
);
1823 getSymbol(GVar
)->print(O
, MAI
);
1827 if (const ConstantExpr
*Cexpr
= dyn_cast
<ConstantExpr
>(CPV
)) {
1828 const MCExpr
*E
= lowerConstantForGV(cast
<Constant
>(Cexpr
), false);
1832 llvm_unreachable("Not scalar type found in printScalarConstant()");
1835 void NVPTXAsmPrinter::bufferLEByte(const Constant
*CPV
, int Bytes
,
1836 AggBuffer
*AggBuffer
) {
1837 const DataLayout
&DL
= getDataLayout();
1838 int AllocSize
= DL
.getTypeAllocSize(CPV
->getType());
1839 if (isa
<UndefValue
>(CPV
) || CPV
->isNullValue()) {
1840 // Non-zero Bytes indicates that we need to zero-fill everything. Otherwise,
1841 // only the space allocated by CPV.
1842 AggBuffer
->addZeros(Bytes
? Bytes
: AllocSize
);
1846 // Helper for filling AggBuffer with APInts.
1847 auto AddIntToBuffer
= [AggBuffer
, Bytes
](const APInt
&Val
) {
1848 size_t NumBytes
= (Val
.getBitWidth() + 7) / 8;
1849 SmallVector
<unsigned char, 16> Buf(NumBytes
);
1850 // `extractBitsAsZExtValue` does not allow the extraction of bits beyond the
1851 // input's bit width, and i1 arrays may not have a length that is a multuple
1852 // of 8. We handle the last byte separately, so we never request out of
1854 for (unsigned I
= 0; I
< NumBytes
- 1; ++I
) {
1855 Buf
[I
] = Val
.extractBitsAsZExtValue(8, I
* 8);
1857 size_t LastBytePosition
= (NumBytes
- 1) * 8;
1858 size_t LastByteBits
= Val
.getBitWidth() - LastBytePosition
;
1860 Val
.extractBitsAsZExtValue(LastByteBits
, LastBytePosition
);
1861 AggBuffer
->addBytes(Buf
.data(), NumBytes
, Bytes
);
1864 switch (CPV
->getType()->getTypeID()) {
1865 case Type::IntegerTyID
:
1866 if (const auto CI
= dyn_cast
<ConstantInt
>(CPV
)) {
1867 AddIntToBuffer(CI
->getValue());
1870 if (const auto *Cexpr
= dyn_cast
<ConstantExpr
>(CPV
)) {
1871 if (const auto *CI
=
1872 dyn_cast
<ConstantInt
>(ConstantFoldConstant(Cexpr
, DL
))) {
1873 AddIntToBuffer(CI
->getValue());
1876 if (Cexpr
->getOpcode() == Instruction::PtrToInt
) {
1877 Value
*V
= Cexpr
->getOperand(0)->stripPointerCasts();
1878 AggBuffer
->addSymbol(V
, Cexpr
->getOperand(0));
1879 AggBuffer
->addZeros(AllocSize
);
1883 llvm_unreachable("unsupported integer const type");
1886 case Type::HalfTyID
:
1887 case Type::BFloatTyID
:
1888 case Type::FloatTyID
:
1889 case Type::DoubleTyID
:
1890 AddIntToBuffer(cast
<ConstantFP
>(CPV
)->getValueAPF().bitcastToAPInt());
1893 case Type::PointerTyID
: {
1894 if (const GlobalValue
*GVar
= dyn_cast
<GlobalValue
>(CPV
)) {
1895 AggBuffer
->addSymbol(GVar
, GVar
);
1896 } else if (const ConstantExpr
*Cexpr
= dyn_cast
<ConstantExpr
>(CPV
)) {
1897 const Value
*v
= Cexpr
->stripPointerCasts();
1898 AggBuffer
->addSymbol(v
, Cexpr
);
1900 AggBuffer
->addZeros(AllocSize
);
1904 case Type::ArrayTyID
:
1905 case Type::FixedVectorTyID
:
1906 case Type::StructTyID
: {
1907 if (isa
<ConstantAggregate
>(CPV
) || isa
<ConstantDataSequential
>(CPV
)) {
1908 bufferAggregateConstant(CPV
, AggBuffer
);
1909 if (Bytes
> AllocSize
)
1910 AggBuffer
->addZeros(Bytes
- AllocSize
);
1911 } else if (isa
<ConstantAggregateZero
>(CPV
))
1912 AggBuffer
->addZeros(Bytes
);
1914 llvm_unreachable("Unexpected Constant type");
1919 llvm_unreachable("unsupported type");
1923 void NVPTXAsmPrinter::bufferAggregateConstant(const Constant
*CPV
,
1924 AggBuffer
*aggBuffer
) {
1925 const DataLayout
&DL
= getDataLayout();
1928 // Integers of arbitrary width
1929 if (const ConstantInt
*CI
= dyn_cast
<ConstantInt
>(CPV
)) {
1930 APInt Val
= CI
->getValue();
1931 for (unsigned I
= 0, E
= DL
.getTypeAllocSize(CPV
->getType()); I
< E
; ++I
) {
1932 uint8_t Byte
= Val
.getLoBits(8).getZExtValue();
1933 aggBuffer
->addBytes(&Byte
, 1, 1);
1940 if (isa
<ConstantArray
>(CPV
) || isa
<ConstantVector
>(CPV
)) {
1941 if (CPV
->getNumOperands())
1942 for (unsigned i
= 0, e
= CPV
->getNumOperands(); i
!= e
; ++i
)
1943 bufferLEByte(cast
<Constant
>(CPV
->getOperand(i
)), 0, aggBuffer
);
1947 if (const ConstantDataSequential
*CDS
=
1948 dyn_cast
<ConstantDataSequential
>(CPV
)) {
1949 if (CDS
->getNumElements())
1950 for (unsigned i
= 0; i
< CDS
->getNumElements(); ++i
)
1951 bufferLEByte(cast
<Constant
>(CDS
->getElementAsConstant(i
)), 0,
1956 if (isa
<ConstantStruct
>(CPV
)) {
1957 if (CPV
->getNumOperands()) {
1958 StructType
*ST
= cast
<StructType
>(CPV
->getType());
1959 for (unsigned i
= 0, e
= CPV
->getNumOperands(); i
!= e
; ++i
) {
1961 Bytes
= DL
.getStructLayout(ST
)->getElementOffset(0) +
1962 DL
.getTypeAllocSize(ST
) -
1963 DL
.getStructLayout(ST
)->getElementOffset(i
);
1965 Bytes
= DL
.getStructLayout(ST
)->getElementOffset(i
+ 1) -
1966 DL
.getStructLayout(ST
)->getElementOffset(i
);
1967 bufferLEByte(cast
<Constant
>(CPV
->getOperand(i
)), Bytes
, aggBuffer
);
1972 llvm_unreachable("unsupported constant type in printAggregateConstant()");
1975 /// lowerConstantForGV - Return an MCExpr for the given Constant. This is mostly
1976 /// a copy from AsmPrinter::lowerConstant, except customized to only handle
1977 /// expressions that are representable in PTX and create
1978 /// NVPTXGenericMCSymbolRefExpr nodes for addrspacecast instructions.
1980 NVPTXAsmPrinter::lowerConstantForGV(const Constant
*CV
, bool ProcessingGeneric
) {
1981 MCContext
&Ctx
= OutContext
;
1983 if (CV
->isNullValue() || isa
<UndefValue
>(CV
))
1984 return MCConstantExpr::create(0, Ctx
);
1986 if (const ConstantInt
*CI
= dyn_cast
<ConstantInt
>(CV
))
1987 return MCConstantExpr::create(CI
->getZExtValue(), Ctx
);
1989 if (const GlobalValue
*GV
= dyn_cast
<GlobalValue
>(CV
)) {
1990 const MCSymbolRefExpr
*Expr
=
1991 MCSymbolRefExpr::create(getSymbol(GV
), Ctx
);
1992 if (ProcessingGeneric
) {
1993 return NVPTXGenericMCSymbolRefExpr::create(Expr
, Ctx
);
1999 const ConstantExpr
*CE
= dyn_cast
<ConstantExpr
>(CV
);
2001 llvm_unreachable("Unknown constant value to lower!");
2004 switch (CE
->getOpcode()) {
2008 case Instruction::AddrSpaceCast
: {
2009 // Strip the addrspacecast and pass along the operand
2010 PointerType
*DstTy
= cast
<PointerType
>(CE
->getType());
2011 if (DstTy
->getAddressSpace() == 0)
2012 return lowerConstantForGV(cast
<const Constant
>(CE
->getOperand(0)), true);
2017 case Instruction::GetElementPtr
: {
2018 const DataLayout
&DL
= getDataLayout();
2020 // Generate a symbolic expression for the byte address
2021 APInt
OffsetAI(DL
.getPointerTypeSizeInBits(CE
->getType()), 0);
2022 cast
<GEPOperator
>(CE
)->accumulateConstantOffset(DL
, OffsetAI
);
2024 const MCExpr
*Base
= lowerConstantForGV(CE
->getOperand(0),
2029 int64_t Offset
= OffsetAI
.getSExtValue();
2030 return MCBinaryExpr::createAdd(Base
, MCConstantExpr::create(Offset
, Ctx
),
2034 case Instruction::Trunc
:
2035 // We emit the value and depend on the assembler to truncate the generated
2036 // expression properly. This is important for differences between
2037 // blockaddress labels. Since the two labels are in the same function, it
2038 // is reasonable to treat their delta as a 32-bit value.
2040 case Instruction::BitCast
:
2041 return lowerConstantForGV(CE
->getOperand(0), ProcessingGeneric
);
2043 case Instruction::IntToPtr
: {
2044 const DataLayout
&DL
= getDataLayout();
2046 // Handle casts to pointers by changing them into casts to the appropriate
2047 // integer type. This promotes constant folding and simplifies this code.
2048 Constant
*Op
= CE
->getOperand(0);
2049 Op
= ConstantFoldIntegerCast(Op
, DL
.getIntPtrType(CV
->getType()),
2050 /*IsSigned*/ false, DL
);
2052 return lowerConstantForGV(Op
, ProcessingGeneric
);
2057 case Instruction::PtrToInt
: {
2058 const DataLayout
&DL
= getDataLayout();
2060 // Support only foldable casts to/from pointers that can be eliminated by
2061 // changing the pointer to the appropriately sized integer type.
2062 Constant
*Op
= CE
->getOperand(0);
2063 Type
*Ty
= CE
->getType();
2065 const MCExpr
*OpExpr
= lowerConstantForGV(Op
, ProcessingGeneric
);
2067 // We can emit the pointer value into this slot if the slot is an
2068 // integer slot equal to the size of the pointer.
2069 if (DL
.getTypeAllocSize(Ty
) == DL
.getTypeAllocSize(Op
->getType()))
2072 // Otherwise the pointer is smaller than the resultant integer, mask off
2073 // the high bits so we are sure to get a proper truncation if the input is
2075 unsigned InBits
= DL
.getTypeAllocSizeInBits(Op
->getType());
2076 const MCExpr
*MaskExpr
= MCConstantExpr::create(~0ULL >> (64-InBits
), Ctx
);
2077 return MCBinaryExpr::createAnd(OpExpr
, MaskExpr
, Ctx
);
2080 // The MC library also has a right-shift operator, but it isn't consistently
2081 // signed or unsigned between different targets.
2082 case Instruction::Add
: {
2083 const MCExpr
*LHS
= lowerConstantForGV(CE
->getOperand(0), ProcessingGeneric
);
2084 const MCExpr
*RHS
= lowerConstantForGV(CE
->getOperand(1), ProcessingGeneric
);
2085 switch (CE
->getOpcode()) {
2086 default: llvm_unreachable("Unknown binary operator constant cast expr");
2087 case Instruction::Add
: return MCBinaryExpr::createAdd(LHS
, RHS
, Ctx
);
2092 // If the code isn't optimized, there may be outstanding folding
2093 // opportunities. Attempt to fold the expression using DataLayout as a
2094 // last resort before giving up.
2095 Constant
*C
= ConstantFoldConstant(CE
, getDataLayout());
2097 return lowerConstantForGV(C
, ProcessingGeneric
);
2099 // Otherwise report the problem to the user.
2101 raw_string_ostream
OS(S
);
2102 OS
<< "Unsupported expression in static initializer: ";
2103 CE
->printAsOperand(OS
, /*PrintType=*/false,
2104 !MF
? nullptr : MF
->getFunction().getParent());
2105 report_fatal_error(Twine(OS
.str()));
2108 // Copy of MCExpr::print customized for NVPTX
2109 void NVPTXAsmPrinter::printMCExpr(const MCExpr
&Expr
, raw_ostream
&OS
) {
2110 switch (Expr
.getKind()) {
2111 case MCExpr::Target
:
2112 return cast
<MCTargetExpr
>(&Expr
)->printImpl(OS
, MAI
);
2113 case MCExpr::Constant
:
2114 OS
<< cast
<MCConstantExpr
>(Expr
).getValue();
2117 case MCExpr::SymbolRef
: {
2118 const MCSymbolRefExpr
&SRE
= cast
<MCSymbolRefExpr
>(Expr
);
2119 const MCSymbol
&Sym
= SRE
.getSymbol();
2124 case MCExpr::Unary
: {
2125 const MCUnaryExpr
&UE
= cast
<MCUnaryExpr
>(Expr
);
2126 switch (UE
.getOpcode()) {
2127 case MCUnaryExpr::LNot
: OS
<< '!'; break;
2128 case MCUnaryExpr::Minus
: OS
<< '-'; break;
2129 case MCUnaryExpr::Not
: OS
<< '~'; break;
2130 case MCUnaryExpr::Plus
: OS
<< '+'; break;
2132 printMCExpr(*UE
.getSubExpr(), OS
);
2136 case MCExpr::Binary
: {
2137 const MCBinaryExpr
&BE
= cast
<MCBinaryExpr
>(Expr
);
2139 // Only print parens around the LHS if it is non-trivial.
2140 if (isa
<MCConstantExpr
>(BE
.getLHS()) || isa
<MCSymbolRefExpr
>(BE
.getLHS()) ||
2141 isa
<NVPTXGenericMCSymbolRefExpr
>(BE
.getLHS())) {
2142 printMCExpr(*BE
.getLHS(), OS
);
2145 printMCExpr(*BE
.getLHS(), OS
);
2149 switch (BE
.getOpcode()) {
2150 case MCBinaryExpr::Add
:
2151 // Print "X-42" instead of "X+-42".
2152 if (const MCConstantExpr
*RHSC
= dyn_cast
<MCConstantExpr
>(BE
.getRHS())) {
2153 if (RHSC
->getValue() < 0) {
2154 OS
<< RHSC
->getValue();
2161 default: llvm_unreachable("Unhandled binary operator");
2164 // Only print parens around the LHS if it is non-trivial.
2165 if (isa
<MCConstantExpr
>(BE
.getRHS()) || isa
<MCSymbolRefExpr
>(BE
.getRHS())) {
2166 printMCExpr(*BE
.getRHS(), OS
);
2169 printMCExpr(*BE
.getRHS(), OS
);
2176 llvm_unreachable("Invalid expression kind!");
2179 /// PrintAsmOperand - Print out an operand for an inline asm expression.
2181 bool NVPTXAsmPrinter::PrintAsmOperand(const MachineInstr
*MI
, unsigned OpNo
,
2182 const char *ExtraCode
, raw_ostream
&O
) {
2183 if (ExtraCode
&& ExtraCode
[0]) {
2184 if (ExtraCode
[1] != 0)
2185 return true; // Unknown modifier.
2187 switch (ExtraCode
[0]) {
2189 // See if this is a generic print operand
2190 return AsmPrinter::PrintAsmOperand(MI
, OpNo
, ExtraCode
, O
);
2196 printOperand(MI
, OpNo
, O
);
2201 bool NVPTXAsmPrinter::PrintAsmMemoryOperand(const MachineInstr
*MI
,
2203 const char *ExtraCode
,
2205 if (ExtraCode
&& ExtraCode
[0])
2206 return true; // Unknown modifier
2209 printMemOperand(MI
, OpNo
, O
);
2215 void NVPTXAsmPrinter::printOperand(const MachineInstr
*MI
, unsigned OpNum
,
2217 const MachineOperand
&MO
= MI
->getOperand(OpNum
);
2218 switch (MO
.getType()) {
2219 case MachineOperand::MO_Register
:
2220 if (MO
.getReg().isPhysical()) {
2221 if (MO
.getReg() == NVPTX::VRDepot
)
2222 O
<< DEPOTNAME
<< getFunctionNumber();
2224 O
<< NVPTXInstPrinter::getRegisterName(MO
.getReg());
2226 emitVirtualRegister(MO
.getReg(), O
);
2230 case MachineOperand::MO_Immediate
:
2234 case MachineOperand::MO_FPImmediate
:
2235 printFPConstant(MO
.getFPImm(), O
);
2238 case MachineOperand::MO_GlobalAddress
:
2239 PrintSymbolOperand(MO
, O
);
2242 case MachineOperand::MO_MachineBasicBlock
:
2243 MO
.getMBB()->getSymbol()->print(O
, MAI
);
2247 llvm_unreachable("Operand type not supported.");
2251 void NVPTXAsmPrinter::printMemOperand(const MachineInstr
*MI
, unsigned OpNum
,
2252 raw_ostream
&O
, const char *Modifier
) {
2253 printOperand(MI
, OpNum
, O
);
2255 if (Modifier
&& strcmp(Modifier
, "add") == 0) {
2257 printOperand(MI
, OpNum
+ 1, O
);
2259 if (MI
->getOperand(OpNum
+ 1).isImm() &&
2260 MI
->getOperand(OpNum
+ 1).getImm() == 0)
2261 return; // don't print ',0' or '+0'
2263 printOperand(MI
, OpNum
+ 1, O
);
2267 // Force static initialization.
2268 extern "C" LLVM_EXTERNAL_VISIBILITY
void LLVMInitializeNVPTXAsmPrinter() {
2269 RegisterAsmPrinter
<NVPTXAsmPrinter
> X(getTheNVPTXTarget32());
2270 RegisterAsmPrinter
<NVPTXAsmPrinter
> Y(getTheNVPTXTarget64());