1 //===-- NVPTXAsmPrinter.cpp - NVPTX LLVM assembly writer ------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file contains a printer that converts from our internal representation
10 // of machine-dependent LLVM code to NVPTX assembly language.
12 //===----------------------------------------------------------------------===//
14 #include "NVPTXAsmPrinter.h"
15 #include "MCTargetDesc/NVPTXBaseInfo.h"
16 #include "MCTargetDesc/NVPTXInstPrinter.h"
17 #include "MCTargetDesc/NVPTXMCAsmInfo.h"
18 #include "MCTargetDesc/NVPTXTargetStreamer.h"
20 #include "NVPTXMCExpr.h"
21 #include "NVPTXMachineFunctionInfo.h"
22 #include "NVPTXRegisterInfo.h"
23 #include "NVPTXSubtarget.h"
24 #include "NVPTXTargetMachine.h"
25 #include "NVPTXUtilities.h"
26 #include "TargetInfo/NVPTXTargetInfo.h"
27 #include "cl_common_defines.h"
28 #include "llvm/ADT/APFloat.h"
29 #include "llvm/ADT/APInt.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/DenseSet.h"
32 #include "llvm/ADT/SmallString.h"
33 #include "llvm/ADT/SmallVector.h"
34 #include "llvm/ADT/StringExtras.h"
35 #include "llvm/ADT/StringRef.h"
36 #include "llvm/ADT/Twine.h"
37 #include "llvm/Analysis/ConstantFolding.h"
38 #include "llvm/CodeGen/Analysis.h"
39 #include "llvm/CodeGen/MachineBasicBlock.h"
40 #include "llvm/CodeGen/MachineFrameInfo.h"
41 #include "llvm/CodeGen/MachineFunction.h"
42 #include "llvm/CodeGen/MachineInstr.h"
43 #include "llvm/CodeGen/MachineLoopInfo.h"
44 #include "llvm/CodeGen/MachineModuleInfo.h"
45 #include "llvm/CodeGen/MachineOperand.h"
46 #include "llvm/CodeGen/MachineRegisterInfo.h"
47 #include "llvm/CodeGen/MachineValueType.h"
48 #include "llvm/CodeGen/TargetRegisterInfo.h"
49 #include "llvm/CodeGen/ValueTypes.h"
50 #include "llvm/IR/Attributes.h"
51 #include "llvm/IR/BasicBlock.h"
52 #include "llvm/IR/Constant.h"
53 #include "llvm/IR/Constants.h"
54 #include "llvm/IR/DataLayout.h"
55 #include "llvm/IR/DebugInfo.h"
56 #include "llvm/IR/DebugInfoMetadata.h"
57 #include "llvm/IR/DebugLoc.h"
58 #include "llvm/IR/DerivedTypes.h"
59 #include "llvm/IR/Function.h"
60 #include "llvm/IR/GlobalValue.h"
61 #include "llvm/IR/GlobalVariable.h"
62 #include "llvm/IR/Instruction.h"
63 #include "llvm/IR/LLVMContext.h"
64 #include "llvm/IR/Module.h"
65 #include "llvm/IR/Operator.h"
66 #include "llvm/IR/Type.h"
67 #include "llvm/IR/User.h"
68 #include "llvm/MC/MCExpr.h"
69 #include "llvm/MC/MCInst.h"
70 #include "llvm/MC/MCInstrDesc.h"
71 #include "llvm/MC/MCStreamer.h"
72 #include "llvm/MC/MCSymbol.h"
73 #include "llvm/MC/TargetRegistry.h"
74 #include "llvm/Support/Casting.h"
75 #include "llvm/Support/CommandLine.h"
76 #include "llvm/Support/Endian.h"
77 #include "llvm/Support/ErrorHandling.h"
78 #include "llvm/Support/NativeFormatting.h"
79 #include "llvm/Support/Path.h"
80 #include "llvm/Support/raw_ostream.h"
81 #include "llvm/Target/TargetLoweringObjectFile.h"
82 #include "llvm/Target/TargetMachine.h"
83 #include "llvm/TargetParser/Triple.h"
84 #include "llvm/Transforms/Utils/UnrollLoop.h"
96 LowerCtorDtor("nvptx-lower-global-ctor-dtor",
97 cl::desc("Lower GPU ctor / dtors to globals on the device."),
98 cl::init(false), cl::Hidden
);
100 #define DEPOTNAME "__local_depot"
102 /// DiscoverDependentGlobals - Return a set of GlobalVariables on which \p V
105 DiscoverDependentGlobals(const Value
*V
,
106 DenseSet
<const GlobalVariable
*> &Globals
) {
107 if (const GlobalVariable
*GV
= dyn_cast
<GlobalVariable
>(V
))
110 if (const User
*U
= dyn_cast
<User
>(V
)) {
111 for (unsigned i
= 0, e
= U
->getNumOperands(); i
!= e
; ++i
) {
112 DiscoverDependentGlobals(U
->getOperand(i
), Globals
);
118 /// VisitGlobalVariableForEmission - Add \p GV to the list of GlobalVariable
119 /// instances to be emitted, but only after any dependents have been added
122 VisitGlobalVariableForEmission(const GlobalVariable
*GV
,
123 SmallVectorImpl
<const GlobalVariable
*> &Order
,
124 DenseSet
<const GlobalVariable
*> &Visited
,
125 DenseSet
<const GlobalVariable
*> &Visiting
) {
126 // Have we already visited this one?
127 if (Visited
.count(GV
))
130 // Do we have a circular dependency?
131 if (!Visiting
.insert(GV
).second
)
132 report_fatal_error("Circular dependency found in global variable set");
134 // Make sure we visit all dependents first
135 DenseSet
<const GlobalVariable
*> Others
;
136 for (unsigned i
= 0, e
= GV
->getNumOperands(); i
!= e
; ++i
)
137 DiscoverDependentGlobals(GV
->getOperand(i
), Others
);
139 for (const GlobalVariable
*GV
: Others
)
140 VisitGlobalVariableForEmission(GV
, Order
, Visited
, Visiting
);
142 // Now we can visit ourself
148 void NVPTXAsmPrinter::emitInstruction(const MachineInstr
*MI
) {
149 NVPTX_MC::verifyInstructionPredicates(MI
->getOpcode(),
150 getSubtargetInfo().getFeatureBits());
153 lowerToMCInst(MI
, Inst
);
154 EmitToStreamer(*OutStreamer
, Inst
);
157 // Handle symbol backtracking for targets that do not support image handles
158 bool NVPTXAsmPrinter::lowerImageHandleOperand(const MachineInstr
*MI
,
159 unsigned OpNo
, MCOperand
&MCOp
) {
160 const MachineOperand
&MO
= MI
->getOperand(OpNo
);
161 const MCInstrDesc
&MCID
= MI
->getDesc();
163 if (MCID
.TSFlags
& NVPTXII::IsTexFlag
) {
164 // This is a texture fetch, so operand 4 is a texref and operand 5 is
166 if (OpNo
== 4 && MO
.isImm()) {
167 lowerImageHandleSymbol(MO
.getImm(), MCOp
);
170 if (OpNo
== 5 && MO
.isImm() && !(MCID
.TSFlags
& NVPTXII::IsTexModeUnifiedFlag
)) {
171 lowerImageHandleSymbol(MO
.getImm(), MCOp
);
176 } else if (MCID
.TSFlags
& NVPTXII::IsSuldMask
) {
178 1 << (((MCID
.TSFlags
& NVPTXII::IsSuldMask
) >> NVPTXII::IsSuldShift
) - 1);
180 // For a surface load of vector size N, the Nth operand will be the surfref
181 if (OpNo
== VecSize
&& MO
.isImm()) {
182 lowerImageHandleSymbol(MO
.getImm(), MCOp
);
187 } else if (MCID
.TSFlags
& NVPTXII::IsSustFlag
) {
188 // This is a surface store, so operand 0 is a surfref
189 if (OpNo
== 0 && MO
.isImm()) {
190 lowerImageHandleSymbol(MO
.getImm(), MCOp
);
195 } else if (MCID
.TSFlags
& NVPTXII::IsSurfTexQueryFlag
) {
196 // This is a query, so operand 1 is a surfref/texref
197 if (OpNo
== 1 && MO
.isImm()) {
198 lowerImageHandleSymbol(MO
.getImm(), MCOp
);
208 void NVPTXAsmPrinter::lowerImageHandleSymbol(unsigned Index
, MCOperand
&MCOp
) {
210 LLVMTargetMachine
&TM
= const_cast<LLVMTargetMachine
&>(MF
->getTarget());
211 NVPTXTargetMachine
&nvTM
= static_cast<NVPTXTargetMachine
&>(TM
);
212 const NVPTXMachineFunctionInfo
*MFI
= MF
->getInfo
<NVPTXMachineFunctionInfo
>();
213 const char *Sym
= MFI
->getImageHandleSymbol(Index
);
214 StringRef SymName
= nvTM
.getStrPool().save(Sym
);
215 MCOp
= GetSymbolRef(OutContext
.getOrCreateSymbol(SymName
));
218 void NVPTXAsmPrinter::lowerToMCInst(const MachineInstr
*MI
, MCInst
&OutMI
) {
219 OutMI
.setOpcode(MI
->getOpcode());
220 // Special: Do not mangle symbol operand of CALL_PROTOTYPE
221 if (MI
->getOpcode() == NVPTX::CALL_PROTOTYPE
) {
222 const MachineOperand
&MO
= MI
->getOperand(0);
223 OutMI
.addOperand(GetSymbolRef(
224 OutContext
.getOrCreateSymbol(Twine(MO
.getSymbolName()))));
228 const NVPTXSubtarget
&STI
= MI
->getMF()->getSubtarget
<NVPTXSubtarget
>();
229 for (unsigned i
= 0, e
= MI
->getNumOperands(); i
!= e
; ++i
) {
230 const MachineOperand
&MO
= MI
->getOperand(i
);
233 if (!STI
.hasImageHandles()) {
234 if (lowerImageHandleOperand(MI
, i
, MCOp
)) {
235 OutMI
.addOperand(MCOp
);
240 if (lowerOperand(MO
, MCOp
))
241 OutMI
.addOperand(MCOp
);
245 bool NVPTXAsmPrinter::lowerOperand(const MachineOperand
&MO
,
247 switch (MO
.getType()) {
248 default: llvm_unreachable("unknown operand type");
249 case MachineOperand::MO_Register
:
250 MCOp
= MCOperand::createReg(encodeVirtualRegister(MO
.getReg()));
252 case MachineOperand::MO_Immediate
:
253 MCOp
= MCOperand::createImm(MO
.getImm());
255 case MachineOperand::MO_MachineBasicBlock
:
256 MCOp
= MCOperand::createExpr(MCSymbolRefExpr::create(
257 MO
.getMBB()->getSymbol(), OutContext
));
259 case MachineOperand::MO_ExternalSymbol
:
260 MCOp
= GetSymbolRef(GetExternalSymbolSymbol(MO
.getSymbolName()));
262 case MachineOperand::MO_GlobalAddress
:
263 MCOp
= GetSymbolRef(getSymbol(MO
.getGlobal()));
265 case MachineOperand::MO_FPImmediate
: {
266 const ConstantFP
*Cnt
= MO
.getFPImm();
267 const APFloat
&Val
= Cnt
->getValueAPF();
269 switch (Cnt
->getType()->getTypeID()) {
270 default: report_fatal_error("Unsupported FP type"); break;
272 MCOp
= MCOperand::createExpr(
273 NVPTXFloatMCExpr::createConstantFPHalf(Val
, OutContext
));
275 case Type::BFloatTyID
:
276 MCOp
= MCOperand::createExpr(
277 NVPTXFloatMCExpr::createConstantBFPHalf(Val
, OutContext
));
279 case Type::FloatTyID
:
280 MCOp
= MCOperand::createExpr(
281 NVPTXFloatMCExpr::createConstantFPSingle(Val
, OutContext
));
283 case Type::DoubleTyID
:
284 MCOp
= MCOperand::createExpr(
285 NVPTXFloatMCExpr::createConstantFPDouble(Val
, OutContext
));
294 unsigned NVPTXAsmPrinter::encodeVirtualRegister(unsigned Reg
) {
295 if (Register::isVirtualRegister(Reg
)) {
296 const TargetRegisterClass
*RC
= MRI
->getRegClass(Reg
);
298 DenseMap
<unsigned, unsigned> &RegMap
= VRegMapping
[RC
];
299 unsigned RegNum
= RegMap
[Reg
];
301 // Encode the register class in the upper 4 bits
302 // Must be kept in sync with NVPTXInstPrinter::printRegName
304 if (RC
== &NVPTX::Int1RegsRegClass
) {
306 } else if (RC
== &NVPTX::Int16RegsRegClass
) {
308 } else if (RC
== &NVPTX::Int32RegsRegClass
) {
310 } else if (RC
== &NVPTX::Int64RegsRegClass
) {
312 } else if (RC
== &NVPTX::Float32RegsRegClass
) {
314 } else if (RC
== &NVPTX::Float64RegsRegClass
) {
317 report_fatal_error("Bad register class");
320 // Insert the vreg number
321 Ret
|= (RegNum
& 0x0FFFFFFF);
324 // Some special-use registers are actually physical registers.
325 // Encode this as the register class ID of 0 and the real register ID.
326 return Reg
& 0x0FFFFFFF;
330 MCOperand
NVPTXAsmPrinter::GetSymbolRef(const MCSymbol
*Symbol
) {
332 Expr
= MCSymbolRefExpr::create(Symbol
, MCSymbolRefExpr::VK_None
,
334 return MCOperand::createExpr(Expr
);
337 static bool ShouldPassAsArray(Type
*Ty
) {
338 return Ty
->isAggregateType() || Ty
->isVectorTy() || Ty
->isIntegerTy(128) ||
339 Ty
->isHalfTy() || Ty
->isBFloatTy();
342 void NVPTXAsmPrinter::printReturnValStr(const Function
*F
, raw_ostream
&O
) {
343 const DataLayout
&DL
= getDataLayout();
344 const NVPTXSubtarget
&STI
= TM
.getSubtarget
<NVPTXSubtarget
>(*F
);
345 const auto *TLI
= cast
<NVPTXTargetLowering
>(STI
.getTargetLowering());
347 Type
*Ty
= F
->getReturnType();
349 bool isABI
= (STI
.getSmVersion() >= 20);
351 if (Ty
->getTypeID() == Type::VoidTyID
)
356 if ((Ty
->isFloatingPointTy() || Ty
->isIntegerTy()) &&
357 !ShouldPassAsArray(Ty
)) {
359 if (auto *ITy
= dyn_cast
<IntegerType
>(Ty
)) {
360 size
= ITy
->getBitWidth();
362 assert(Ty
->isFloatingPointTy() && "Floating point type expected here");
363 size
= Ty
->getPrimitiveSizeInBits();
365 size
= promoteScalarArgumentSize(size
);
366 O
<< ".param .b" << size
<< " func_retval0";
367 } else if (isa
<PointerType
>(Ty
)) {
368 O
<< ".param .b" << TLI
->getPointerTy(DL
).getSizeInBits()
370 } else if (ShouldPassAsArray(Ty
)) {
371 unsigned totalsz
= DL
.getTypeAllocSize(Ty
);
372 unsigned retAlignment
= 0;
373 if (!getAlign(*F
, 0, retAlignment
))
374 retAlignment
= TLI
->getFunctionParamOptimizedAlign(F
, Ty
, DL
).value();
375 O
<< ".param .align " << retAlignment
<< " .b8 func_retval0[" << totalsz
378 llvm_unreachable("Unknown return type");
380 SmallVector
<EVT
, 16> vtparts
;
381 ComputeValueVTs(*TLI
, DL
, Ty
, vtparts
);
383 for (unsigned i
= 0, e
= vtparts
.size(); i
!= e
; ++i
) {
385 EVT elemtype
= vtparts
[i
];
386 if (vtparts
[i
].isVector()) {
387 elems
= vtparts
[i
].getVectorNumElements();
388 elemtype
= vtparts
[i
].getVectorElementType();
391 for (unsigned j
= 0, je
= elems
; j
!= je
; ++j
) {
392 unsigned sz
= elemtype
.getSizeInBits();
393 if (elemtype
.isInteger())
394 sz
= promoteScalarArgumentSize(sz
);
395 O
<< ".reg .b" << sz
<< " func_retval" << idx
;
407 void NVPTXAsmPrinter::printReturnValStr(const MachineFunction
&MF
,
409 const Function
&F
= MF
.getFunction();
410 printReturnValStr(&F
, O
);
413 // Return true if MBB is the header of a loop marked with
414 // llvm.loop.unroll.disable or llvm.loop.unroll.count=1.
415 bool NVPTXAsmPrinter::isLoopHeaderOfNoUnroll(
416 const MachineBasicBlock
&MBB
) const {
417 MachineLoopInfo
&LI
= getAnalysis
<MachineLoopInfo
>();
418 // We insert .pragma "nounroll" only to the loop header.
419 if (!LI
.isLoopHeader(&MBB
))
422 // llvm.loop.unroll.disable is marked on the back edges of a loop. Therefore,
423 // we iterate through each back edge of the loop with header MBB, and check
424 // whether its metadata contains llvm.loop.unroll.disable.
425 for (const MachineBasicBlock
*PMBB
: MBB
.predecessors()) {
426 if (LI
.getLoopFor(PMBB
) != LI
.getLoopFor(&MBB
)) {
427 // Edges from other loops to MBB are not back edges.
430 if (const BasicBlock
*PBB
= PMBB
->getBasicBlock()) {
432 PBB
->getTerminator()->getMetadata(LLVMContext::MD_loop
)) {
433 if (GetUnrollMetadata(LoopID
, "llvm.loop.unroll.disable"))
435 if (MDNode
*UnrollCountMD
=
436 GetUnrollMetadata(LoopID
, "llvm.loop.unroll.count")) {
437 if (mdconst::extract
<ConstantInt
>(UnrollCountMD
->getOperand(1))
447 void NVPTXAsmPrinter::emitBasicBlockStart(const MachineBasicBlock
&MBB
) {
448 AsmPrinter::emitBasicBlockStart(MBB
);
449 if (isLoopHeaderOfNoUnroll(MBB
))
450 OutStreamer
->emitRawText(StringRef("\t.pragma \"nounroll\";\n"));
453 void NVPTXAsmPrinter::emitFunctionEntryLabel() {
454 SmallString
<128> Str
;
455 raw_svector_ostream
O(Str
);
457 if (!GlobalsEmitted
) {
458 emitGlobals(*MF
->getFunction().getParent());
459 GlobalsEmitted
= true;
463 MRI
= &MF
->getRegInfo();
464 F
= &MF
->getFunction();
465 emitLinkageDirective(F
, O
);
466 if (isKernelFunction(*F
))
470 printReturnValStr(*MF
, O
);
473 CurrentFnSym
->print(O
, MAI
);
475 emitFunctionParamList(F
, O
);
478 if (isKernelFunction(*F
))
479 emitKernelFunctionDirectives(*F
, O
);
481 if (shouldEmitPTXNoReturn(F
, TM
))
484 OutStreamer
->emitRawText(O
.str());
487 // Emit open brace for function body.
488 OutStreamer
->emitRawText(StringRef("{\n"));
489 setAndEmitFunctionVirtualRegisters(*MF
);
490 // Emit initial .loc debug directive for correct relocation symbol data.
491 if (MMI
&& MMI
->hasDebugInfo())
492 emitInitialRawDwarfLocDirective(*MF
);
495 bool NVPTXAsmPrinter::runOnMachineFunction(MachineFunction
&F
) {
496 bool Result
= AsmPrinter::runOnMachineFunction(F
);
497 // Emit closing brace for the body of function F.
498 // The closing brace must be emitted here because we need to emit additional
499 // debug labels/data after the last basic block.
500 // We need to emit the closing brace here because we don't have function that
501 // finished emission of the function body.
502 OutStreamer
->emitRawText(StringRef("}\n"));
506 void NVPTXAsmPrinter::emitFunctionBodyStart() {
507 SmallString
<128> Str
;
508 raw_svector_ostream
O(Str
);
509 emitDemotedVars(&MF
->getFunction(), O
);
510 OutStreamer
->emitRawText(O
.str());
513 void NVPTXAsmPrinter::emitFunctionBodyEnd() {
517 const MCSymbol
*NVPTXAsmPrinter::getFunctionFrameSymbol() const {
518 SmallString
<128> Str
;
519 raw_svector_ostream(Str
) << DEPOTNAME
<< getFunctionNumber();
520 return OutContext
.getOrCreateSymbol(Str
);
523 void NVPTXAsmPrinter::emitImplicitDef(const MachineInstr
*MI
) const {
524 Register RegNo
= MI
->getOperand(0).getReg();
525 if (RegNo
.isVirtual()) {
526 OutStreamer
->AddComment(Twine("implicit-def: ") +
527 getVirtualRegisterName(RegNo
));
529 const NVPTXSubtarget
&STI
= MI
->getMF()->getSubtarget
<NVPTXSubtarget
>();
530 OutStreamer
->AddComment(Twine("implicit-def: ") +
531 STI
.getRegisterInfo()->getName(RegNo
));
533 OutStreamer
->addBlankLine();
536 void NVPTXAsmPrinter::emitKernelFunctionDirectives(const Function
&F
,
537 raw_ostream
&O
) const {
538 // If the NVVM IR has some of reqntid* specified, then output
539 // the reqntid directive, and set the unspecified ones to 1.
540 // If none of Reqntid* is specified, don't output reqntid directive.
541 unsigned Reqntidx
, Reqntidy
, Reqntidz
;
542 Reqntidx
= Reqntidy
= Reqntidz
= 1;
543 bool ReqSpecified
= false;
544 ReqSpecified
|= getReqNTIDx(F
, Reqntidx
);
545 ReqSpecified
|= getReqNTIDy(F
, Reqntidy
);
546 ReqSpecified
|= getReqNTIDz(F
, Reqntidz
);
549 O
<< ".reqntid " << Reqntidx
<< ", " << Reqntidy
<< ", " << Reqntidz
552 // If the NVVM IR has some of maxntid* specified, then output
553 // the maxntid directive, and set the unspecified ones to 1.
554 // If none of maxntid* is specified, don't output maxntid directive.
555 unsigned Maxntidx
, Maxntidy
, Maxntidz
;
556 Maxntidx
= Maxntidy
= Maxntidz
= 1;
557 bool MaxSpecified
= false;
558 MaxSpecified
|= getMaxNTIDx(F
, Maxntidx
);
559 MaxSpecified
|= getMaxNTIDy(F
, Maxntidy
);
560 MaxSpecified
|= getMaxNTIDz(F
, Maxntidz
);
563 O
<< ".maxntid " << Maxntidx
<< ", " << Maxntidy
<< ", " << Maxntidz
567 if (getMinCTASm(F
, Mincta
))
568 O
<< ".minnctapersm " << Mincta
<< "\n";
570 unsigned Maxnreg
= 0;
571 if (getMaxNReg(F
, Maxnreg
))
572 O
<< ".maxnreg " << Maxnreg
<< "\n";
574 // .maxclusterrank directive requires SM_90 or higher, make sure that we
575 // filter it out for lower SM versions, as it causes a hard ptxas crash.
576 const NVPTXTargetMachine
&NTM
= static_cast<const NVPTXTargetMachine
&>(TM
);
577 const auto *STI
= static_cast<const NVPTXSubtarget
*>(NTM
.getSubtargetImpl());
578 unsigned Maxclusterrank
= 0;
579 if (getMaxClusterRank(F
, Maxclusterrank
) && STI
->getSmVersion() >= 90)
580 O
<< ".maxclusterrank " << Maxclusterrank
<< "\n";
583 std::string
NVPTXAsmPrinter::getVirtualRegisterName(unsigned Reg
) const {
584 const TargetRegisterClass
*RC
= MRI
->getRegClass(Reg
);
587 raw_string_ostream
NameStr(Name
);
589 VRegRCMap::const_iterator I
= VRegMapping
.find(RC
);
590 assert(I
!= VRegMapping
.end() && "Bad register class");
591 const DenseMap
<unsigned, unsigned> &RegMap
= I
->second
;
593 VRegMap::const_iterator VI
= RegMap
.find(Reg
);
594 assert(VI
!= RegMap
.end() && "Bad virtual register");
595 unsigned MappedVR
= VI
->second
;
597 NameStr
<< getNVPTXRegClassStr(RC
) << MappedVR
;
603 void NVPTXAsmPrinter::emitVirtualRegister(unsigned int vr
,
605 O
<< getVirtualRegisterName(vr
);
608 void NVPTXAsmPrinter::emitDeclaration(const Function
*F
, raw_ostream
&O
) {
609 emitLinkageDirective(F
, O
);
610 if (isKernelFunction(*F
))
614 printReturnValStr(F
, O
);
615 getSymbol(F
)->print(O
, MAI
);
617 emitFunctionParamList(F
, O
);
619 if (shouldEmitPTXNoReturn(F
, TM
))
624 static bool usedInGlobalVarDef(const Constant
*C
) {
628 if (const GlobalVariable
*GV
= dyn_cast
<GlobalVariable
>(C
)) {
629 return GV
->getName() != "llvm.used";
632 for (const User
*U
: C
->users())
633 if (const Constant
*C
= dyn_cast
<Constant
>(U
))
634 if (usedInGlobalVarDef(C
))
640 static bool usedInOneFunc(const User
*U
, Function
const *&oneFunc
) {
641 if (const GlobalVariable
*othergv
= dyn_cast
<GlobalVariable
>(U
)) {
642 if (othergv
->getName() == "llvm.used")
646 if (const Instruction
*instr
= dyn_cast
<Instruction
>(U
)) {
647 if (instr
->getParent() && instr
->getParent()->getParent()) {
648 const Function
*curFunc
= instr
->getParent()->getParent();
649 if (oneFunc
&& (curFunc
!= oneFunc
))
657 for (const User
*UU
: U
->users())
658 if (!usedInOneFunc(UU
, oneFunc
))
664 /* Find out if a global variable can be demoted to local scope.
665 * Currently, this is valid for CUDA shared variables, which have local
666 * scope and global lifetime. So the conditions to check are :
667 * 1. Is the global variable in shared address space?
668 * 2. Does it have local linkage?
669 * 3. Is the global variable referenced only in one function?
671 static bool canDemoteGlobalVar(const GlobalVariable
*gv
, Function
const *&f
) {
672 if (!gv
->hasLocalLinkage())
674 PointerType
*Pty
= gv
->getType();
675 if (Pty
->getAddressSpace() != ADDRESS_SPACE_SHARED
)
678 const Function
*oneFunc
= nullptr;
680 bool flag
= usedInOneFunc(gv
, oneFunc
);
689 static bool useFuncSeen(const Constant
*C
,
690 DenseMap
<const Function
*, bool> &seenMap
) {
691 for (const User
*U
: C
->users()) {
692 if (const Constant
*cu
= dyn_cast
<Constant
>(U
)) {
693 if (useFuncSeen(cu
, seenMap
))
695 } else if (const Instruction
*I
= dyn_cast
<Instruction
>(U
)) {
696 const BasicBlock
*bb
= I
->getParent();
699 const Function
*caller
= bb
->getParent();
702 if (seenMap
.contains(caller
))
709 void NVPTXAsmPrinter::emitDeclarations(const Module
&M
, raw_ostream
&O
) {
710 DenseMap
<const Function
*, bool> seenMap
;
711 for (const Function
&F
: M
) {
712 if (F
.getAttributes().hasFnAttr("nvptx-libcall-callee")) {
713 emitDeclaration(&F
, O
);
717 if (F
.isDeclaration()) {
720 if (F
.getIntrinsicID())
722 emitDeclaration(&F
, O
);
725 for (const User
*U
: F
.users()) {
726 if (const Constant
*C
= dyn_cast
<Constant
>(U
)) {
727 if (usedInGlobalVarDef(C
)) {
728 // The use is in the initialization of a global variable
729 // that is a function pointer, so print a declaration
730 // for the original function
731 emitDeclaration(&F
, O
);
734 // Emit a declaration of this function if the function that
735 // uses this constant expr has already been seen.
736 if (useFuncSeen(C
, seenMap
)) {
737 emitDeclaration(&F
, O
);
742 if (!isa
<Instruction
>(U
))
744 const Instruction
*instr
= cast
<Instruction
>(U
);
745 const BasicBlock
*bb
= instr
->getParent();
748 const Function
*caller
= bb
->getParent();
752 // If a caller has already been seen, then the caller is
753 // appearing in the module before the callee. so print out
754 // a declaration for the callee.
755 if (seenMap
.contains(caller
)) {
756 emitDeclaration(&F
, O
);
764 static bool isEmptyXXStructor(GlobalVariable
*GV
) {
765 if (!GV
) return true;
766 const ConstantArray
*InitList
= dyn_cast
<ConstantArray
>(GV
->getInitializer());
767 if (!InitList
) return true; // Not an array; we don't know how to parse.
768 return InitList
->getNumOperands() == 0;
771 void NVPTXAsmPrinter::emitStartOfAsmFile(Module
&M
) {
772 // Construct a default subtarget off of the TargetMachine defaults. The
773 // rest of NVPTX isn't friendly to change subtargets per function and
774 // so the default TargetMachine will have all of the options.
775 const NVPTXTargetMachine
&NTM
= static_cast<const NVPTXTargetMachine
&>(TM
);
776 const auto* STI
= static_cast<const NVPTXSubtarget
*>(NTM
.getSubtargetImpl());
777 SmallString
<128> Str1
;
778 raw_svector_ostream
OS1(Str1
);
780 // Emit header before any dwarf directives are emitted below.
781 emitHeader(M
, OS1
, *STI
);
782 OutStreamer
->emitRawText(OS1
.str());
785 bool NVPTXAsmPrinter::doInitialization(Module
&M
) {
786 const NVPTXTargetMachine
&NTM
= static_cast<const NVPTXTargetMachine
&>(TM
);
787 const NVPTXSubtarget
&STI
=
788 *static_cast<const NVPTXSubtarget
*>(NTM
.getSubtargetImpl());
789 if (M
.alias_size() && (STI
.getPTXVersion() < 63 || STI
.getSmVersion() < 30))
790 report_fatal_error(".alias requires PTX version >= 6.3 and sm_30");
792 if (!isEmptyXXStructor(M
.getNamedGlobal("llvm.global_ctors")) &&
795 "Module has a nontrivial global ctor, which NVPTX does not support.");
796 return true; // error
798 if (!isEmptyXXStructor(M
.getNamedGlobal("llvm.global_dtors")) &&
801 "Module has a nontrivial global dtor, which NVPTX does not support.");
802 return true; // error
805 // We need to call the parent's one explicitly.
806 bool Result
= AsmPrinter::doInitialization(M
);
808 GlobalsEmitted
= false;
813 void NVPTXAsmPrinter::emitGlobals(const Module
&M
) {
814 SmallString
<128> Str2
;
815 raw_svector_ostream
OS2(Str2
);
817 emitDeclarations(M
, OS2
);
819 // As ptxas does not support forward references of globals, we need to first
820 // sort the list of module-level globals in def-use order. We visit each
821 // global variable in order, and ensure that we emit it *after* its dependent
822 // globals. We use a little extra memory maintaining both a set and a list to
823 // have fast searches while maintaining a strict ordering.
824 SmallVector
<const GlobalVariable
*, 8> Globals
;
825 DenseSet
<const GlobalVariable
*> GVVisited
;
826 DenseSet
<const GlobalVariable
*> GVVisiting
;
828 // Visit each global variable, in order
829 for (const GlobalVariable
&I
: M
.globals())
830 VisitGlobalVariableForEmission(&I
, Globals
, GVVisited
, GVVisiting
);
832 assert(GVVisited
.size() == M
.global_size() && "Missed a global variable");
833 assert(GVVisiting
.size() == 0 && "Did not fully process a global variable");
835 const NVPTXTargetMachine
&NTM
= static_cast<const NVPTXTargetMachine
&>(TM
);
836 const NVPTXSubtarget
&STI
=
837 *static_cast<const NVPTXSubtarget
*>(NTM
.getSubtargetImpl());
839 // Print out module-level global variables in proper order
840 for (unsigned i
= 0, e
= Globals
.size(); i
!= e
; ++i
)
841 printModuleLevelGV(Globals
[i
], OS2
, /*processDemoted=*/false, STI
);
845 OutStreamer
->emitRawText(OS2
.str());
848 void NVPTXAsmPrinter::emitGlobalAlias(const Module
&M
, const GlobalAlias
&GA
) {
849 SmallString
<128> Str
;
850 raw_svector_ostream
OS(Str
);
852 MCSymbol
*Name
= getSymbol(&GA
);
853 const Function
*F
= dyn_cast
<Function
>(GA
.getAliasee());
854 if (!F
|| isKernelFunction(*F
))
855 report_fatal_error("NVPTX aliasee must be a non-kernel function");
857 if (GA
.hasLinkOnceLinkage() || GA
.hasWeakLinkage() ||
858 GA
.hasAvailableExternallyLinkage() || GA
.hasCommonLinkage())
859 report_fatal_error("NVPTX aliasee must not be '.weak'");
862 emitLinkageDirective(F
, OS
);
864 printReturnValStr(F
, OS
);
865 OS
<< Name
->getName();
866 emitFunctionParamList(F
, OS
);
867 if (shouldEmitPTXNoReturn(F
, TM
))
871 OS
<< ".alias " << Name
->getName() << ", " << F
->getName() << ";\n";
873 OutStreamer
->emitRawText(OS
.str());
876 void NVPTXAsmPrinter::emitHeader(Module
&M
, raw_ostream
&O
,
877 const NVPTXSubtarget
&STI
) {
879 O
<< "// Generated by LLVM NVPTX Back-End\n";
883 unsigned PTXVersion
= STI
.getPTXVersion();
884 O
<< ".version " << (PTXVersion
/ 10) << "." << (PTXVersion
% 10) << "\n";
887 O
<< STI
.getTargetName();
889 const NVPTXTargetMachine
&NTM
= static_cast<const NVPTXTargetMachine
&>(TM
);
890 if (NTM
.getDrvInterface() == NVPTX::NVCL
)
891 O
<< ", texmode_independent";
893 bool HasFullDebugInfo
= false;
894 for (DICompileUnit
*CU
: M
.debug_compile_units()) {
895 switch(CU
->getEmissionKind()) {
896 case DICompileUnit::NoDebug
:
897 case DICompileUnit::DebugDirectivesOnly
:
899 case DICompileUnit::LineTablesOnly
:
900 case DICompileUnit::FullDebug
:
901 HasFullDebugInfo
= true;
904 if (HasFullDebugInfo
)
907 if (MMI
&& MMI
->hasDebugInfo() && HasFullDebugInfo
)
912 O
<< ".address_size ";
922 bool NVPTXAsmPrinter::doFinalization(Module
&M
) {
923 bool HasDebugInfo
= MMI
&& MMI
->hasDebugInfo();
925 // If we did not emit any functions, then the global declarations have not
927 if (!GlobalsEmitted
) {
929 GlobalsEmitted
= true;
932 // If we have any aliases we emit them at the end.
933 SmallVector
<GlobalAlias
*> AliasesToRemove
;
934 for (GlobalAlias
&Alias
: M
.aliases()) {
935 emitGlobalAlias(M
, Alias
);
936 AliasesToRemove
.push_back(&Alias
);
939 for (GlobalAlias
*A
: AliasesToRemove
)
940 A
->eraseFromParent();
942 // call doFinalization
943 bool ret
= AsmPrinter::doFinalization(M
);
945 clearAnnotationCache(&M
);
948 static_cast<NVPTXTargetStreamer
*>(OutStreamer
->getTargetStreamer());
949 // Close the last emitted section
951 TS
->closeLastSection();
952 // Emit empty .debug_loc section for better support of the empty files.
953 OutStreamer
->emitRawText("\t.section\t.debug_loc\t{\t}");
956 // Output last DWARF .file directives, if any.
957 TS
->outputDwarfFileDirectives();
962 // This function emits appropriate linkage directives for
963 // functions and global variables.
965 // extern function declaration -> .extern
966 // extern function definition -> .visible
967 // external global variable with init -> .visible
968 // external without init -> .extern
969 // appending -> not allowed, assert.
970 // for any linkage other than
971 // internal, private, linker_private,
972 // linker_private_weak, linker_private_weak_def_auto,
975 void NVPTXAsmPrinter::emitLinkageDirective(const GlobalValue
*V
,
977 if (static_cast<NVPTXTargetMachine
&>(TM
).getDrvInterface() == NVPTX::CUDA
) {
978 if (V
->hasExternalLinkage()) {
979 if (isa
<GlobalVariable
>(V
)) {
980 const GlobalVariable
*GVar
= cast
<GlobalVariable
>(V
);
982 if (GVar
->hasInitializer())
987 } else if (V
->isDeclaration())
991 } else if (V
->hasAppendingLinkage()) {
993 msg
.append("Error: ");
994 msg
.append("Symbol ");
996 msg
.append(std::string(V
->getName()));
997 msg
.append("has unsupported appending linkage type");
998 llvm_unreachable(msg
.c_str());
999 } else if (!V
->hasInternalLinkage() &&
1000 !V
->hasPrivateLinkage()) {
1006 void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable
*GVar
,
1007 raw_ostream
&O
, bool processDemoted
,
1008 const NVPTXSubtarget
&STI
) {
1010 if (GVar
->hasSection()) {
1011 if (GVar
->getSection() == "llvm.metadata")
1015 // Skip LLVM intrinsic global variables
1016 if (GVar
->getName().startswith("llvm.") ||
1017 GVar
->getName().startswith("nvvm."))
1020 const DataLayout
&DL
= getDataLayout();
1022 // GlobalVariables are always constant pointers themselves.
1023 PointerType
*PTy
= GVar
->getType();
1024 Type
*ETy
= GVar
->getValueType();
1026 if (GVar
->hasExternalLinkage()) {
1027 if (GVar
->hasInitializer())
1031 } else if (GVar
->hasLinkOnceLinkage() || GVar
->hasWeakLinkage() ||
1032 GVar
->hasAvailableExternallyLinkage() ||
1033 GVar
->hasCommonLinkage()) {
1037 if (isTexture(*GVar
)) {
1038 O
<< ".global .texref " << getTextureName(*GVar
) << ";\n";
1042 if (isSurface(*GVar
)) {
1043 O
<< ".global .surfref " << getSurfaceName(*GVar
) << ";\n";
1047 if (GVar
->isDeclaration()) {
1048 // (extern) declarations, no definition or initializer
1049 // Currently the only known declaration is for an automatic __local
1050 // (.shared) promoted to global.
1051 emitPTXGlobalVariable(GVar
, O
, STI
);
1056 if (isSampler(*GVar
)) {
1057 O
<< ".global .samplerref " << getSamplerName(*GVar
);
1059 const Constant
*Initializer
= nullptr;
1060 if (GVar
->hasInitializer())
1061 Initializer
= GVar
->getInitializer();
1062 const ConstantInt
*CI
= nullptr;
1064 CI
= dyn_cast
<ConstantInt
>(Initializer
);
1066 unsigned sample
= CI
->getZExtValue();
1071 addr
= ((sample
& __CLK_ADDRESS_MASK
) >> __CLK_ADDRESS_BASE
);
1073 O
<< "addr_mode_" << i
<< " = ";
1079 O
<< "clamp_to_border";
1082 O
<< "clamp_to_edge";
1093 O
<< "filter_mode = ";
1094 switch ((sample
& __CLK_FILTER_MASK
) >> __CLK_FILTER_BASE
) {
1102 llvm_unreachable("Anisotropic filtering is not supported");
1107 if (!((sample
& __CLK_NORMALIZED_MASK
) >> __CLK_NORMALIZED_BASE
)) {
1108 O
<< ", force_unnormalized_coords = 1";
1117 if (GVar
->hasPrivateLinkage()) {
1118 if (strncmp(GVar
->getName().data(), "unrollpragma", 12) == 0)
1121 // FIXME - need better way (e.g. Metadata) to avoid generating this global
1122 if (strncmp(GVar
->getName().data(), "filename", 8) == 0)
1124 if (GVar
->use_empty())
1128 const Function
*demotedFunc
= nullptr;
1129 if (!processDemoted
&& canDemoteGlobalVar(GVar
, demotedFunc
)) {
1130 O
<< "// " << GVar
->getName() << " has been demoted\n";
1131 if (localDecls
.find(demotedFunc
) != localDecls
.end())
1132 localDecls
[demotedFunc
].push_back(GVar
);
1134 std::vector
<const GlobalVariable
*> temp
;
1135 temp
.push_back(GVar
);
1136 localDecls
[demotedFunc
] = temp
;
1142 emitPTXAddressSpace(PTy
->getAddressSpace(), O
);
1144 if (isManaged(*GVar
)) {
1145 if (STI
.getPTXVersion() < 40 || STI
.getSmVersion() < 30) {
1147 ".attribute(.managed) requires PTX version >= 4.0 and sm_30");
1149 O
<< " .attribute(.managed)";
1152 if (MaybeAlign A
= GVar
->getAlign())
1153 O
<< " .align " << A
->value();
1155 O
<< " .align " << (int)DL
.getPrefTypeAlign(ETy
).value();
1157 if (ETy
->isFloatingPointTy() || ETy
->isPointerTy() ||
1158 (ETy
->isIntegerTy() && ETy
->getScalarSizeInBits() <= 64)) {
1160 // Special case: ABI requires that we use .u8 for predicates
1161 if (ETy
->isIntegerTy(1))
1164 O
<< getPTXFundamentalTypeStr(ETy
, false);
1166 getSymbol(GVar
)->print(O
, MAI
);
1168 // Ptx allows variable initilization only for constant and global state
1170 if (GVar
->hasInitializer()) {
1171 if ((PTy
->getAddressSpace() == ADDRESS_SPACE_GLOBAL
) ||
1172 (PTy
->getAddressSpace() == ADDRESS_SPACE_CONST
)) {
1173 const Constant
*Initializer
= GVar
->getInitializer();
1174 // 'undef' is treated as there is no value specified.
1175 if (!Initializer
->isNullValue() && !isa
<UndefValue
>(Initializer
)) {
1177 printScalarConstant(Initializer
, O
);
1180 // The frontend adds zero-initializer to device and constant variables
1181 // that don't have an initial value, and UndefValue to shared
1182 // variables, so skip warning for this case.
1183 if (!GVar
->getInitializer()->isNullValue() &&
1184 !isa
<UndefValue
>(GVar
->getInitializer())) {
1185 report_fatal_error("initial value of '" + GVar
->getName() +
1186 "' is not allowed in addrspace(" +
1187 Twine(PTy
->getAddressSpace()) + ")");
1192 uint64_t ElementSize
= 0;
1194 // Although PTX has direct support for struct type and array type and
1195 // LLVM IR is very similar to PTX, the LLVM CodeGen does not support for
1196 // targets that support these high level field accesses. Structs, arrays
1197 // and vectors are lowered into arrays of bytes.
1198 switch (ETy
->getTypeID()) {
1199 case Type::IntegerTyID
: // Integers larger than 64 bits
1200 case Type::StructTyID
:
1201 case Type::ArrayTyID
:
1202 case Type::FixedVectorTyID
:
1203 ElementSize
= DL
.getTypeStoreSize(ETy
);
1204 // Ptx allows variable initilization only for constant and
1205 // global state spaces.
1206 if (((PTy
->getAddressSpace() == ADDRESS_SPACE_GLOBAL
) ||
1207 (PTy
->getAddressSpace() == ADDRESS_SPACE_CONST
)) &&
1208 GVar
->hasInitializer()) {
1209 const Constant
*Initializer
= GVar
->getInitializer();
1210 if (!isa
<UndefValue
>(Initializer
) && !Initializer
->isNullValue()) {
1211 AggBuffer
aggBuffer(ElementSize
, *this);
1212 bufferAggregateConstant(Initializer
, &aggBuffer
);
1213 if (aggBuffer
.numSymbols()) {
1214 unsigned int ptrSize
= MAI
->getCodePointerSize();
1215 if (ElementSize
% ptrSize
||
1216 !aggBuffer
.allSymbolsAligned(ptrSize
)) {
1217 // Print in bytes and use the mask() operator for pointers.
1218 if (!STI
.hasMaskOperator())
1220 "initialized packed aggregate with pointers '" +
1222 "' requires at least PTX ISA version 7.1");
1224 getSymbol(GVar
)->print(O
, MAI
);
1225 O
<< "[" << ElementSize
<< "] = {";
1226 aggBuffer
.printBytes(O
);
1229 O
<< " .u" << ptrSize
* 8 << " ";
1230 getSymbol(GVar
)->print(O
, MAI
);
1231 O
<< "[" << ElementSize
/ ptrSize
<< "] = {";
1232 aggBuffer
.printWords(O
);
1237 getSymbol(GVar
)->print(O
, MAI
);
1238 O
<< "[" << ElementSize
<< "] = {";
1239 aggBuffer
.printBytes(O
);
1244 getSymbol(GVar
)->print(O
, MAI
);
1253 getSymbol(GVar
)->print(O
, MAI
);
1262 llvm_unreachable("type not supported yet");
1268 void NVPTXAsmPrinter::AggBuffer::printSymbol(unsigned nSym
, raw_ostream
&os
) {
1269 const Value
*v
= Symbols
[nSym
];
1270 const Value
*v0
= SymbolsBeforeStripping
[nSym
];
1271 if (const GlobalValue
*GVar
= dyn_cast
<GlobalValue
>(v
)) {
1272 MCSymbol
*Name
= AP
.getSymbol(GVar
);
1273 PointerType
*PTy
= dyn_cast
<PointerType
>(v0
->getType());
1274 // Is v0 a generic pointer?
1275 bool isGenericPointer
= PTy
&& PTy
->getAddressSpace() == 0;
1276 if (EmitGeneric
&& isGenericPointer
&& !isa
<Function
>(v
)) {
1278 Name
->print(os
, AP
.MAI
);
1281 Name
->print(os
, AP
.MAI
);
1283 } else if (const ConstantExpr
*CExpr
= dyn_cast
<ConstantExpr
>(v0
)) {
1284 const MCExpr
*Expr
= AP
.lowerConstantForGV(cast
<Constant
>(CExpr
), false);
1285 AP
.printMCExpr(*Expr
, os
);
1287 llvm_unreachable("symbol type unknown");
1290 void NVPTXAsmPrinter::AggBuffer::printBytes(raw_ostream
&os
) {
1291 unsigned int ptrSize
= AP
.MAI
->getCodePointerSize();
1292 symbolPosInBuffer
.push_back(size
);
1293 unsigned int nSym
= 0;
1294 unsigned int nextSymbolPos
= symbolPosInBuffer
[nSym
];
1295 for (unsigned int pos
= 0; pos
< size
;) {
1298 if (pos
!= nextSymbolPos
) {
1299 os
<< (unsigned int)buffer
[pos
];
1303 // Generate a per-byte mask() operator for the symbol, which looks like:
1304 // .global .u8 addr[] = {0xFF(foo), 0xFF00(foo), 0xFF0000(foo), ...};
1305 // See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#initializers
1306 std::string symText
;
1307 llvm::raw_string_ostream
oss(symText
);
1308 printSymbol(nSym
, oss
);
1309 for (unsigned i
= 0; i
< ptrSize
; ++i
) {
1312 llvm::write_hex(os
, 0xFFULL
<< i
* 8, HexPrintStyle::PrefixUpper
);
1313 os
<< "(" << symText
<< ")";
1316 nextSymbolPos
= symbolPosInBuffer
[++nSym
];
1317 assert(nextSymbolPos
>= pos
);
1321 void NVPTXAsmPrinter::AggBuffer::printWords(raw_ostream
&os
) {
1322 unsigned int ptrSize
= AP
.MAI
->getCodePointerSize();
1323 symbolPosInBuffer
.push_back(size
);
1324 unsigned int nSym
= 0;
1325 unsigned int nextSymbolPos
= symbolPosInBuffer
[nSym
];
1326 assert(nextSymbolPos
% ptrSize
== 0);
1327 for (unsigned int pos
= 0; pos
< size
; pos
+= ptrSize
) {
1330 if (pos
== nextSymbolPos
) {
1331 printSymbol(nSym
, os
);
1332 nextSymbolPos
= symbolPosInBuffer
[++nSym
];
1333 assert(nextSymbolPos
% ptrSize
== 0);
1334 assert(nextSymbolPos
>= pos
+ ptrSize
);
1335 } else if (ptrSize
== 4)
1336 os
<< support::endian::read32le(&buffer
[pos
]);
1338 os
<< support::endian::read64le(&buffer
[pos
]);
1342 void NVPTXAsmPrinter::emitDemotedVars(const Function
*f
, raw_ostream
&O
) {
1343 if (localDecls
.find(f
) == localDecls
.end())
1346 std::vector
<const GlobalVariable
*> &gvars
= localDecls
[f
];
1348 const NVPTXTargetMachine
&NTM
= static_cast<const NVPTXTargetMachine
&>(TM
);
1349 const NVPTXSubtarget
&STI
=
1350 *static_cast<const NVPTXSubtarget
*>(NTM
.getSubtargetImpl());
1352 for (const GlobalVariable
*GV
: gvars
) {
1353 O
<< "\t// demoted variable\n\t";
1354 printModuleLevelGV(GV
, O
, /*processDemoted=*/true, STI
);
1358 void NVPTXAsmPrinter::emitPTXAddressSpace(unsigned int AddressSpace
,
1359 raw_ostream
&O
) const {
1360 switch (AddressSpace
) {
1361 case ADDRESS_SPACE_LOCAL
:
1364 case ADDRESS_SPACE_GLOBAL
:
1367 case ADDRESS_SPACE_CONST
:
1370 case ADDRESS_SPACE_SHARED
:
1374 report_fatal_error("Bad address space found while emitting PTX: " +
1375 llvm::Twine(AddressSpace
));
1381 NVPTXAsmPrinter::getPTXFundamentalTypeStr(Type
*Ty
, bool useB4PTR
) const {
1382 switch (Ty
->getTypeID()) {
1383 case Type::IntegerTyID
: {
1384 unsigned NumBits
= cast
<IntegerType
>(Ty
)->getBitWidth();
1387 else if (NumBits
<= 64) {
1388 std::string name
= "u";
1389 return name
+ utostr(NumBits
);
1391 llvm_unreachable("Integer too large");
1396 case Type::BFloatTyID
:
1397 case Type::HalfTyID
:
1398 // fp16 and bf16 are stored as .b16 for compatibility with pre-sm_53
1401 case Type::FloatTyID
:
1403 case Type::DoubleTyID
:
1405 case Type::PointerTyID
: {
1406 unsigned PtrSize
= TM
.getPointerSizeInBits(Ty
->getPointerAddressSpace());
1407 assert((PtrSize
== 64 || PtrSize
== 32) && "Unexpected pointer size");
1422 llvm_unreachable("unexpected type");
1425 void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable
*GVar
,
1427 const NVPTXSubtarget
&STI
) {
1428 const DataLayout
&DL
= getDataLayout();
1430 // GlobalVariables are always constant pointers themselves.
1431 Type
*ETy
= GVar
->getValueType();
1434 emitPTXAddressSpace(GVar
->getType()->getAddressSpace(), O
);
1435 if (isManaged(*GVar
)) {
1436 if (STI
.getPTXVersion() < 40 || STI
.getSmVersion() < 30) {
1438 ".attribute(.managed) requires PTX version >= 4.0 and sm_30");
1440 O
<< " .attribute(.managed)";
1442 if (MaybeAlign A
= GVar
->getAlign())
1443 O
<< " .align " << A
->value();
1445 O
<< " .align " << (int)DL
.getPrefTypeAlign(ETy
).value();
1447 // Special case for i128
1448 if (ETy
->isIntegerTy(128)) {
1450 getSymbol(GVar
)->print(O
, MAI
);
1455 if (ETy
->isFloatingPointTy() || ETy
->isIntOrPtrTy()) {
1457 O
<< getPTXFundamentalTypeStr(ETy
);
1459 getSymbol(GVar
)->print(O
, MAI
);
1463 int64_t ElementSize
= 0;
1465 // Although PTX has direct support for struct type and array type and LLVM IR
1466 // is very similar to PTX, the LLVM CodeGen does not support for targets that
1467 // support these high level field accesses. Structs and arrays are lowered
1468 // into arrays of bytes.
1469 switch (ETy
->getTypeID()) {
1470 case Type::StructTyID
:
1471 case Type::ArrayTyID
:
1472 case Type::FixedVectorTyID
:
1473 ElementSize
= DL
.getTypeStoreSize(ETy
);
1475 getSymbol(GVar
)->print(O
, MAI
);
1483 llvm_unreachable("type not supported yet");
1487 void NVPTXAsmPrinter::emitFunctionParamList(const Function
*F
, raw_ostream
&O
) {
1488 const DataLayout
&DL
= getDataLayout();
1489 const AttributeList
&PAL
= F
->getAttributes();
1490 const NVPTXSubtarget
&STI
= TM
.getSubtarget
<NVPTXSubtarget
>(*F
);
1491 const auto *TLI
= cast
<NVPTXTargetLowering
>(STI
.getTargetLowering());
1493 Function::const_arg_iterator I
, E
;
1494 unsigned paramIndex
= 0;
1496 bool isKernelFunc
= isKernelFunction(*F
);
1497 bool isABI
= (STI
.getSmVersion() >= 20);
1498 bool hasImageHandles
= STI
.hasImageHandles();
1500 if (F
->arg_empty() && !F
->isVarArg()) {
1507 for (I
= F
->arg_begin(), E
= F
->arg_end(); I
!= E
; ++I
, paramIndex
++) {
1508 Type
*Ty
= I
->getType();
1515 // Handle image/sampler parameters
1516 if (isKernelFunction(*F
)) {
1517 if (isSampler(*I
) || isImage(*I
)) {
1519 std::string sname
= std::string(I
->getName());
1520 if (isImageWriteOnly(*I
) || isImageReadWrite(*I
)) {
1521 if (hasImageHandles
)
1522 O
<< "\t.param .u64 .ptr .surfref ";
1524 O
<< "\t.param .surfref ";
1525 O
<< TLI
->getParamName(F
, paramIndex
);
1527 else { // Default image is read_only
1528 if (hasImageHandles
)
1529 O
<< "\t.param .u64 .ptr .texref ";
1531 O
<< "\t.param .texref ";
1532 O
<< TLI
->getParamName(F
, paramIndex
);
1535 if (hasImageHandles
)
1536 O
<< "\t.param .u64 .ptr .samplerref ";
1538 O
<< "\t.param .samplerref ";
1539 O
<< TLI
->getParamName(F
, paramIndex
);
1545 auto getOptimalAlignForParam
= [TLI
, &DL
, &PAL
, F
,
1546 paramIndex
](Type
*Ty
) -> Align
{
1547 Align TypeAlign
= TLI
->getFunctionParamOptimizedAlign(F
, Ty
, DL
);
1548 MaybeAlign ParamAlign
= PAL
.getParamAlignment(paramIndex
);
1549 return std::max(TypeAlign
, ParamAlign
.valueOrOne());
1552 if (!PAL
.hasParamAttr(paramIndex
, Attribute::ByVal
)) {
1553 if (ShouldPassAsArray(Ty
)) {
1554 // Just print .param .align <a> .b8 .param[size];
1555 // <a> = optimal alignment for the element type; always multiple of
1556 // PAL.getParamAlignment
1557 // size = typeallocsize of element type
1558 Align OptimalAlign
= getOptimalAlignForParam(Ty
);
1560 O
<< "\t.param .align " << OptimalAlign
.value() << " .b8 ";
1561 O
<< TLI
->getParamName(F
, paramIndex
);
1562 O
<< "[" << DL
.getTypeAllocSize(Ty
) << "]";
1567 auto *PTy
= dyn_cast
<PointerType
>(Ty
);
1568 unsigned PTySizeInBits
= 0;
1571 TLI
->getPointerTy(DL
, PTy
->getAddressSpace()).getSizeInBits();
1572 assert(PTySizeInBits
&& "Invalid pointer size");
1577 // Special handling for pointer arguments to kernel
1578 O
<< "\t.param .u" << PTySizeInBits
<< " ";
1580 if (static_cast<NVPTXTargetMachine
&>(TM
).getDrvInterface() !=
1582 int addrSpace
= PTy
->getAddressSpace();
1583 switch (addrSpace
) {
1587 case ADDRESS_SPACE_CONST
:
1588 O
<< ".ptr .const ";
1590 case ADDRESS_SPACE_SHARED
:
1591 O
<< ".ptr .shared ";
1593 case ADDRESS_SPACE_GLOBAL
:
1594 O
<< ".ptr .global ";
1597 Align ParamAlign
= I
->getParamAlign().valueOrOne();
1598 O
<< ".align " << ParamAlign
.value() << " ";
1600 O
<< TLI
->getParamName(F
, paramIndex
);
1604 // non-pointer scalar to kernel func
1606 // Special case: predicate operands become .u8 types
1607 if (Ty
->isIntegerTy(1))
1610 O
<< getPTXFundamentalTypeStr(Ty
);
1612 O
<< TLI
->getParamName(F
, paramIndex
);
1615 // Non-kernel function, just print .param .b<size> for ABI
1616 // and .reg .b<size> for non-ABI
1618 if (isa
<IntegerType
>(Ty
)) {
1619 sz
= cast
<IntegerType
>(Ty
)->getBitWidth();
1620 sz
= promoteScalarArgumentSize(sz
);
1622 assert(PTySizeInBits
&& "Invalid pointer size");
1625 sz
= Ty
->getPrimitiveSizeInBits();
1627 O
<< "\t.param .b" << sz
<< " ";
1629 O
<< "\t.reg .b" << sz
<< " ";
1630 O
<< TLI
->getParamName(F
, paramIndex
);
1634 // param has byVal attribute.
1635 Type
*ETy
= PAL
.getParamByValType(paramIndex
);
1636 assert(ETy
&& "Param should have byval type");
1638 if (isABI
|| isKernelFunc
) {
1639 // Just print .param .align <a> .b8 .param[size];
1640 // <a> = optimal alignment for the element type; always multiple of
1641 // PAL.getParamAlignment
1642 // size = typeallocsize of element type
1643 Align OptimalAlign
=
1645 ? getOptimalAlignForParam(ETy
)
1646 : TLI
->getFunctionByValParamAlign(
1647 F
, ETy
, PAL
.getParamAlignment(paramIndex
).valueOrOne(), DL
);
1649 unsigned sz
= DL
.getTypeAllocSize(ETy
);
1650 O
<< "\t.param .align " << OptimalAlign
.value() << " .b8 ";
1651 O
<< TLI
->getParamName(F
, paramIndex
);
1652 O
<< "[" << sz
<< "]";
1655 // Split the ETy into constituent parts and
1656 // print .param .b<size> <name> for each part.
1657 // Further, if a part is vector, print the above for
1658 // each vector element.
1659 SmallVector
<EVT
, 16> vtparts
;
1660 ComputeValueVTs(*TLI
, DL
, ETy
, vtparts
);
1661 for (unsigned i
= 0, e
= vtparts
.size(); i
!= e
; ++i
) {
1663 EVT elemtype
= vtparts
[i
];
1664 if (vtparts
[i
].isVector()) {
1665 elems
= vtparts
[i
].getVectorNumElements();
1666 elemtype
= vtparts
[i
].getVectorElementType();
1669 for (unsigned j
= 0, je
= elems
; j
!= je
; ++j
) {
1670 unsigned sz
= elemtype
.getSizeInBits();
1671 if (elemtype
.isInteger())
1672 sz
= promoteScalarArgumentSize(sz
);
1673 O
<< "\t.reg .b" << sz
<< " ";
1674 O
<< TLI
->getParamName(F
, paramIndex
);
1687 if (F
->isVarArg()) {
1690 O
<< "\t.param .align " << STI
.getMaxRequiredAlignment();
1692 O
<< TLI
->getParamName(F
, /* vararg */ -1) << "[]";
1698 void NVPTXAsmPrinter::setAndEmitFunctionVirtualRegisters(
1699 const MachineFunction
&MF
) {
1700 SmallString
<128> Str
;
1701 raw_svector_ostream
O(Str
);
1703 // Map the global virtual register number to a register class specific
1704 // virtual register number starting from 1 with that class.
1705 const TargetRegisterInfo
*TRI
= MF
.getSubtarget().getRegisterInfo();
1706 //unsigned numRegClasses = TRI->getNumRegClasses();
1708 // Emit the Fake Stack Object
1709 const MachineFrameInfo
&MFI
= MF
.getFrameInfo();
1710 int NumBytes
= (int) MFI
.getStackSize();
1712 O
<< "\t.local .align " << MFI
.getMaxAlign().value() << " .b8 \t"
1713 << DEPOTNAME
<< getFunctionNumber() << "[" << NumBytes
<< "];\n";
1714 if (static_cast<const NVPTXTargetMachine
&>(MF
.getTarget()).is64Bit()) {
1715 O
<< "\t.reg .b64 \t%SP;\n";
1716 O
<< "\t.reg .b64 \t%SPL;\n";
1718 O
<< "\t.reg .b32 \t%SP;\n";
1719 O
<< "\t.reg .b32 \t%SPL;\n";
1723 // Go through all virtual registers to establish the mapping between the
1725 // register number and the per class virtual register number.
1726 // We use the per class virtual register number in the ptx output.
1727 unsigned int numVRs
= MRI
->getNumVirtRegs();
1728 for (unsigned i
= 0; i
< numVRs
; i
++) {
1729 Register vr
= Register::index2VirtReg(i
);
1730 const TargetRegisterClass
*RC
= MRI
->getRegClass(vr
);
1731 DenseMap
<unsigned, unsigned> ®map
= VRegMapping
[RC
];
1732 int n
= regmap
.size();
1733 regmap
.insert(std::make_pair(vr
, n
+ 1));
1736 // Emit register declarations
1737 // @TODO: Extract out the real register usage
1738 // O << "\t.reg .pred %p<" << NVPTXNumRegisters << ">;\n";
1739 // O << "\t.reg .s16 %rc<" << NVPTXNumRegisters << ">;\n";
1740 // O << "\t.reg .s16 %rs<" << NVPTXNumRegisters << ">;\n";
1741 // O << "\t.reg .s32 %r<" << NVPTXNumRegisters << ">;\n";
1742 // O << "\t.reg .s64 %rd<" << NVPTXNumRegisters << ">;\n";
1743 // O << "\t.reg .f32 %f<" << NVPTXNumRegisters << ">;\n";
1744 // O << "\t.reg .f64 %fd<" << NVPTXNumRegisters << ">;\n";
1746 // Emit declaration of the virtual registers or 'physical' registers for
1747 // each register class
1748 for (unsigned i
=0; i
< TRI
->getNumRegClasses(); i
++) {
1749 const TargetRegisterClass
*RC
= TRI
->getRegClass(i
);
1750 DenseMap
<unsigned, unsigned> ®map
= VRegMapping
[RC
];
1751 std::string rcname
= getNVPTXRegClassName(RC
);
1752 std::string rcStr
= getNVPTXRegClassStr(RC
);
1753 int n
= regmap
.size();
1755 // Only declare those registers that may be used.
1757 O
<< "\t.reg " << rcname
<< " \t" << rcStr
<< "<" << (n
+1)
1762 OutStreamer
->emitRawText(O
.str());
1765 void NVPTXAsmPrinter::printFPConstant(const ConstantFP
*Fp
, raw_ostream
&O
) {
1766 APFloat APF
= APFloat(Fp
->getValueAPF()); // make a copy
1768 unsigned int numHex
;
1771 if (Fp
->getType()->getTypeID() == Type::FloatTyID
) {
1774 APF
.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven
, &ignored
);
1775 } else if (Fp
->getType()->getTypeID() == Type::DoubleTyID
) {
1778 APF
.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven
, &ignored
);
1780 llvm_unreachable("unsupported fp type");
1782 APInt API
= APF
.bitcastToAPInt();
1783 O
<< lead
<< format_hex_no_prefix(API
.getZExtValue(), numHex
, /*Upper=*/true);
1786 void NVPTXAsmPrinter::printScalarConstant(const Constant
*CPV
, raw_ostream
&O
) {
1787 if (const ConstantInt
*CI
= dyn_cast
<ConstantInt
>(CPV
)) {
1788 O
<< CI
->getValue();
1791 if (const ConstantFP
*CFP
= dyn_cast
<ConstantFP
>(CPV
)) {
1792 printFPConstant(CFP
, O
);
1795 if (isa
<ConstantPointerNull
>(CPV
)) {
1799 if (const GlobalValue
*GVar
= dyn_cast
<GlobalValue
>(CPV
)) {
1800 bool IsNonGenericPointer
= false;
1801 if (GVar
->getType()->getAddressSpace() != 0) {
1802 IsNonGenericPointer
= true;
1804 if (EmitGeneric
&& !isa
<Function
>(CPV
) && !IsNonGenericPointer
) {
1806 getSymbol(GVar
)->print(O
, MAI
);
1809 getSymbol(GVar
)->print(O
, MAI
);
1813 if (const ConstantExpr
*Cexpr
= dyn_cast
<ConstantExpr
>(CPV
)) {
1814 const MCExpr
*E
= lowerConstantForGV(cast
<Constant
>(Cexpr
), false);
1818 llvm_unreachable("Not scalar type found in printScalarConstant()");
1821 void NVPTXAsmPrinter::bufferLEByte(const Constant
*CPV
, int Bytes
,
1822 AggBuffer
*AggBuffer
) {
1823 const DataLayout
&DL
= getDataLayout();
1824 int AllocSize
= DL
.getTypeAllocSize(CPV
->getType());
1825 if (isa
<UndefValue
>(CPV
) || CPV
->isNullValue()) {
1826 // Non-zero Bytes indicates that we need to zero-fill everything. Otherwise,
1827 // only the space allocated by CPV.
1828 AggBuffer
->addZeros(Bytes
? Bytes
: AllocSize
);
1832 // Helper for filling AggBuffer with APInts.
1833 auto AddIntToBuffer
= [AggBuffer
, Bytes
](const APInt
&Val
) {
1834 size_t NumBytes
= (Val
.getBitWidth() + 7) / 8;
1835 SmallVector
<unsigned char, 16> Buf(NumBytes
);
1836 for (unsigned I
= 0; I
< NumBytes
; ++I
) {
1837 Buf
[I
] = Val
.extractBitsAsZExtValue(8, I
* 8);
1839 AggBuffer
->addBytes(Buf
.data(), NumBytes
, Bytes
);
1842 switch (CPV
->getType()->getTypeID()) {
1843 case Type::IntegerTyID
:
1844 if (const auto CI
= dyn_cast
<ConstantInt
>(CPV
)) {
1845 AddIntToBuffer(CI
->getValue());
1848 if (const auto *Cexpr
= dyn_cast
<ConstantExpr
>(CPV
)) {
1849 if (const auto *CI
=
1850 dyn_cast
<ConstantInt
>(ConstantFoldConstant(Cexpr
, DL
))) {
1851 AddIntToBuffer(CI
->getValue());
1854 if (Cexpr
->getOpcode() == Instruction::PtrToInt
) {
1855 Value
*V
= Cexpr
->getOperand(0)->stripPointerCasts();
1856 AggBuffer
->addSymbol(V
, Cexpr
->getOperand(0));
1857 AggBuffer
->addZeros(AllocSize
);
1861 llvm_unreachable("unsupported integer const type");
1864 case Type::HalfTyID
:
1865 case Type::BFloatTyID
:
1866 case Type::FloatTyID
:
1867 case Type::DoubleTyID
:
1868 AddIntToBuffer(cast
<ConstantFP
>(CPV
)->getValueAPF().bitcastToAPInt());
1871 case Type::PointerTyID
: {
1872 if (const GlobalValue
*GVar
= dyn_cast
<GlobalValue
>(CPV
)) {
1873 AggBuffer
->addSymbol(GVar
, GVar
);
1874 } else if (const ConstantExpr
*Cexpr
= dyn_cast
<ConstantExpr
>(CPV
)) {
1875 const Value
*v
= Cexpr
->stripPointerCasts();
1876 AggBuffer
->addSymbol(v
, Cexpr
);
1878 AggBuffer
->addZeros(AllocSize
);
1882 case Type::ArrayTyID
:
1883 case Type::FixedVectorTyID
:
1884 case Type::StructTyID
: {
1885 if (isa
<ConstantAggregate
>(CPV
) || isa
<ConstantDataSequential
>(CPV
)) {
1886 bufferAggregateConstant(CPV
, AggBuffer
);
1887 if (Bytes
> AllocSize
)
1888 AggBuffer
->addZeros(Bytes
- AllocSize
);
1889 } else if (isa
<ConstantAggregateZero
>(CPV
))
1890 AggBuffer
->addZeros(Bytes
);
1892 llvm_unreachable("Unexpected Constant type");
1897 llvm_unreachable("unsupported type");
1901 void NVPTXAsmPrinter::bufferAggregateConstant(const Constant
*CPV
,
1902 AggBuffer
*aggBuffer
) {
1903 const DataLayout
&DL
= getDataLayout();
1906 // Integers of arbitrary width
1907 if (const ConstantInt
*CI
= dyn_cast
<ConstantInt
>(CPV
)) {
1908 APInt Val
= CI
->getValue();
1909 for (unsigned I
= 0, E
= DL
.getTypeAllocSize(CPV
->getType()); I
< E
; ++I
) {
1910 uint8_t Byte
= Val
.getLoBits(8).getZExtValue();
1911 aggBuffer
->addBytes(&Byte
, 1, 1);
1918 if (isa
<ConstantArray
>(CPV
) || isa
<ConstantVector
>(CPV
)) {
1919 if (CPV
->getNumOperands())
1920 for (unsigned i
= 0, e
= CPV
->getNumOperands(); i
!= e
; ++i
)
1921 bufferLEByte(cast
<Constant
>(CPV
->getOperand(i
)), 0, aggBuffer
);
1925 if (const ConstantDataSequential
*CDS
=
1926 dyn_cast
<ConstantDataSequential
>(CPV
)) {
1927 if (CDS
->getNumElements())
1928 for (unsigned i
= 0; i
< CDS
->getNumElements(); ++i
)
1929 bufferLEByte(cast
<Constant
>(CDS
->getElementAsConstant(i
)), 0,
1934 if (isa
<ConstantStruct
>(CPV
)) {
1935 if (CPV
->getNumOperands()) {
1936 StructType
*ST
= cast
<StructType
>(CPV
->getType());
1937 for (unsigned i
= 0, e
= CPV
->getNumOperands(); i
!= e
; ++i
) {
1939 Bytes
= DL
.getStructLayout(ST
)->getElementOffset(0) +
1940 DL
.getTypeAllocSize(ST
) -
1941 DL
.getStructLayout(ST
)->getElementOffset(i
);
1943 Bytes
= DL
.getStructLayout(ST
)->getElementOffset(i
+ 1) -
1944 DL
.getStructLayout(ST
)->getElementOffset(i
);
1945 bufferLEByte(cast
<Constant
>(CPV
->getOperand(i
)), Bytes
, aggBuffer
);
1950 llvm_unreachable("unsupported constant type in printAggregateConstant()");
1953 /// lowerConstantForGV - Return an MCExpr for the given Constant. This is mostly
1954 /// a copy from AsmPrinter::lowerConstant, except customized to only handle
1955 /// expressions that are representable in PTX and create
1956 /// NVPTXGenericMCSymbolRefExpr nodes for addrspacecast instructions.
1958 NVPTXAsmPrinter::lowerConstantForGV(const Constant
*CV
, bool ProcessingGeneric
) {
1959 MCContext
&Ctx
= OutContext
;
1961 if (CV
->isNullValue() || isa
<UndefValue
>(CV
))
1962 return MCConstantExpr::create(0, Ctx
);
1964 if (const ConstantInt
*CI
= dyn_cast
<ConstantInt
>(CV
))
1965 return MCConstantExpr::create(CI
->getZExtValue(), Ctx
);
1967 if (const GlobalValue
*GV
= dyn_cast
<GlobalValue
>(CV
)) {
1968 const MCSymbolRefExpr
*Expr
=
1969 MCSymbolRefExpr::create(getSymbol(GV
), Ctx
);
1970 if (ProcessingGeneric
) {
1971 return NVPTXGenericMCSymbolRefExpr::create(Expr
, Ctx
);
1977 const ConstantExpr
*CE
= dyn_cast
<ConstantExpr
>(CV
);
1979 llvm_unreachable("Unknown constant value to lower!");
1982 switch (CE
->getOpcode()) {
1986 case Instruction::AddrSpaceCast
: {
1987 // Strip the addrspacecast and pass along the operand
1988 PointerType
*DstTy
= cast
<PointerType
>(CE
->getType());
1989 if (DstTy
->getAddressSpace() == 0)
1990 return lowerConstantForGV(cast
<const Constant
>(CE
->getOperand(0)), true);
1995 case Instruction::GetElementPtr
: {
1996 const DataLayout
&DL
= getDataLayout();
1998 // Generate a symbolic expression for the byte address
1999 APInt
OffsetAI(DL
.getPointerTypeSizeInBits(CE
->getType()), 0);
2000 cast
<GEPOperator
>(CE
)->accumulateConstantOffset(DL
, OffsetAI
);
2002 const MCExpr
*Base
= lowerConstantForGV(CE
->getOperand(0),
2007 int64_t Offset
= OffsetAI
.getSExtValue();
2008 return MCBinaryExpr::createAdd(Base
, MCConstantExpr::create(Offset
, Ctx
),
2012 case Instruction::Trunc
:
2013 // We emit the value and depend on the assembler to truncate the generated
2014 // expression properly. This is important for differences between
2015 // blockaddress labels. Since the two labels are in the same function, it
2016 // is reasonable to treat their delta as a 32-bit value.
2018 case Instruction::BitCast
:
2019 return lowerConstantForGV(CE
->getOperand(0), ProcessingGeneric
);
2021 case Instruction::IntToPtr
: {
2022 const DataLayout
&DL
= getDataLayout();
2024 // Handle casts to pointers by changing them into casts to the appropriate
2025 // integer type. This promotes constant folding and simplifies this code.
2026 Constant
*Op
= CE
->getOperand(0);
2027 Op
= ConstantFoldIntegerCast(Op
, DL
.getIntPtrType(CV
->getType()),
2028 /*IsSigned*/ false, DL
);
2030 return lowerConstantForGV(Op
, ProcessingGeneric
);
2035 case Instruction::PtrToInt
: {
2036 const DataLayout
&DL
= getDataLayout();
2038 // Support only foldable casts to/from pointers that can be eliminated by
2039 // changing the pointer to the appropriately sized integer type.
2040 Constant
*Op
= CE
->getOperand(0);
2041 Type
*Ty
= CE
->getType();
2043 const MCExpr
*OpExpr
= lowerConstantForGV(Op
, ProcessingGeneric
);
2045 // We can emit the pointer value into this slot if the slot is an
2046 // integer slot equal to the size of the pointer.
2047 if (DL
.getTypeAllocSize(Ty
) == DL
.getTypeAllocSize(Op
->getType()))
2050 // Otherwise the pointer is smaller than the resultant integer, mask off
2051 // the high bits so we are sure to get a proper truncation if the input is
2053 unsigned InBits
= DL
.getTypeAllocSizeInBits(Op
->getType());
2054 const MCExpr
*MaskExpr
= MCConstantExpr::create(~0ULL >> (64-InBits
), Ctx
);
2055 return MCBinaryExpr::createAnd(OpExpr
, MaskExpr
, Ctx
);
2058 // The MC library also has a right-shift operator, but it isn't consistently
2059 // signed or unsigned between different targets.
2060 case Instruction::Add
: {
2061 const MCExpr
*LHS
= lowerConstantForGV(CE
->getOperand(0), ProcessingGeneric
);
2062 const MCExpr
*RHS
= lowerConstantForGV(CE
->getOperand(1), ProcessingGeneric
);
2063 switch (CE
->getOpcode()) {
2064 default: llvm_unreachable("Unknown binary operator constant cast expr");
2065 case Instruction::Add
: return MCBinaryExpr::createAdd(LHS
, RHS
, Ctx
);
2070 // If the code isn't optimized, there may be outstanding folding
2071 // opportunities. Attempt to fold the expression using DataLayout as a
2072 // last resort before giving up.
2073 Constant
*C
= ConstantFoldConstant(CE
, getDataLayout());
2075 return lowerConstantForGV(C
, ProcessingGeneric
);
2077 // Otherwise report the problem to the user.
2079 raw_string_ostream
OS(S
);
2080 OS
<< "Unsupported expression in static initializer: ";
2081 CE
->printAsOperand(OS
, /*PrintType=*/false,
2082 !MF
? nullptr : MF
->getFunction().getParent());
2083 report_fatal_error(Twine(OS
.str()));
2086 // Copy of MCExpr::print customized for NVPTX
2087 void NVPTXAsmPrinter::printMCExpr(const MCExpr
&Expr
, raw_ostream
&OS
) {
2088 switch (Expr
.getKind()) {
2089 case MCExpr::Target
:
2090 return cast
<MCTargetExpr
>(&Expr
)->printImpl(OS
, MAI
);
2091 case MCExpr::Constant
:
2092 OS
<< cast
<MCConstantExpr
>(Expr
).getValue();
2095 case MCExpr::SymbolRef
: {
2096 const MCSymbolRefExpr
&SRE
= cast
<MCSymbolRefExpr
>(Expr
);
2097 const MCSymbol
&Sym
= SRE
.getSymbol();
2102 case MCExpr::Unary
: {
2103 const MCUnaryExpr
&UE
= cast
<MCUnaryExpr
>(Expr
);
2104 switch (UE
.getOpcode()) {
2105 case MCUnaryExpr::LNot
: OS
<< '!'; break;
2106 case MCUnaryExpr::Minus
: OS
<< '-'; break;
2107 case MCUnaryExpr::Not
: OS
<< '~'; break;
2108 case MCUnaryExpr::Plus
: OS
<< '+'; break;
2110 printMCExpr(*UE
.getSubExpr(), OS
);
2114 case MCExpr::Binary
: {
2115 const MCBinaryExpr
&BE
= cast
<MCBinaryExpr
>(Expr
);
2117 // Only print parens around the LHS if it is non-trivial.
2118 if (isa
<MCConstantExpr
>(BE
.getLHS()) || isa
<MCSymbolRefExpr
>(BE
.getLHS()) ||
2119 isa
<NVPTXGenericMCSymbolRefExpr
>(BE
.getLHS())) {
2120 printMCExpr(*BE
.getLHS(), OS
);
2123 printMCExpr(*BE
.getLHS(), OS
);
2127 switch (BE
.getOpcode()) {
2128 case MCBinaryExpr::Add
:
2129 // Print "X-42" instead of "X+-42".
2130 if (const MCConstantExpr
*RHSC
= dyn_cast
<MCConstantExpr
>(BE
.getRHS())) {
2131 if (RHSC
->getValue() < 0) {
2132 OS
<< RHSC
->getValue();
2139 default: llvm_unreachable("Unhandled binary operator");
2142 // Only print parens around the LHS if it is non-trivial.
2143 if (isa
<MCConstantExpr
>(BE
.getRHS()) || isa
<MCSymbolRefExpr
>(BE
.getRHS())) {
2144 printMCExpr(*BE
.getRHS(), OS
);
2147 printMCExpr(*BE
.getRHS(), OS
);
2154 llvm_unreachable("Invalid expression kind!");
2157 /// PrintAsmOperand - Print out an operand for an inline asm expression.
2159 bool NVPTXAsmPrinter::PrintAsmOperand(const MachineInstr
*MI
, unsigned OpNo
,
2160 const char *ExtraCode
, raw_ostream
&O
) {
2161 if (ExtraCode
&& ExtraCode
[0]) {
2162 if (ExtraCode
[1] != 0)
2163 return true; // Unknown modifier.
2165 switch (ExtraCode
[0]) {
2167 // See if this is a generic print operand
2168 return AsmPrinter::PrintAsmOperand(MI
, OpNo
, ExtraCode
, O
);
2174 printOperand(MI
, OpNo
, O
);
2179 bool NVPTXAsmPrinter::PrintAsmMemoryOperand(const MachineInstr
*MI
,
2181 const char *ExtraCode
,
2183 if (ExtraCode
&& ExtraCode
[0])
2184 return true; // Unknown modifier
2187 printMemOperand(MI
, OpNo
, O
);
2193 void NVPTXAsmPrinter::printOperand(const MachineInstr
*MI
, unsigned OpNum
,
2195 const MachineOperand
&MO
= MI
->getOperand(OpNum
);
2196 switch (MO
.getType()) {
2197 case MachineOperand::MO_Register
:
2198 if (MO
.getReg().isPhysical()) {
2199 if (MO
.getReg() == NVPTX::VRDepot
)
2200 O
<< DEPOTNAME
<< getFunctionNumber();
2202 O
<< NVPTXInstPrinter::getRegisterName(MO
.getReg());
2204 emitVirtualRegister(MO
.getReg(), O
);
2208 case MachineOperand::MO_Immediate
:
2212 case MachineOperand::MO_FPImmediate
:
2213 printFPConstant(MO
.getFPImm(), O
);
2216 case MachineOperand::MO_GlobalAddress
:
2217 PrintSymbolOperand(MO
, O
);
2220 case MachineOperand::MO_MachineBasicBlock
:
2221 MO
.getMBB()->getSymbol()->print(O
, MAI
);
2225 llvm_unreachable("Operand type not supported.");
2229 void NVPTXAsmPrinter::printMemOperand(const MachineInstr
*MI
, unsigned OpNum
,
2230 raw_ostream
&O
, const char *Modifier
) {
2231 printOperand(MI
, OpNum
, O
);
2233 if (Modifier
&& strcmp(Modifier
, "add") == 0) {
2235 printOperand(MI
, OpNum
+ 1, O
);
2237 if (MI
->getOperand(OpNum
+ 1).isImm() &&
2238 MI
->getOperand(OpNum
+ 1).getImm() == 0)
2239 return; // don't print ',0' or '+0'
2241 printOperand(MI
, OpNum
+ 1, O
);
2245 // Force static initialization.
2246 extern "C" LLVM_EXTERNAL_VISIBILITY
void LLVMInitializeNVPTXAsmPrinter() {
2247 RegisterAsmPrinter
<NVPTXAsmPrinter
> X(getTheNVPTXTarget32());
2248 RegisterAsmPrinter
<NVPTXAsmPrinter
> Y(getTheNVPTXTarget64());