1 //===-- NVPTXAsmPrinter.h - NVPTX LLVM assembly writer ----------*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file contains a printer that converts from our internal representation
10 // of machine-dependent LLVM code to NVPTX assembly language.
12 //===----------------------------------------------------------------------===//
14 #ifndef LLVM_LIB_TARGET_NVPTX_NVPTXASMPRINTER_H
15 #define LLVM_LIB_TARGET_NVPTX_NVPTXASMPRINTER_H
18 #include "NVPTXSubtarget.h"
19 #include "NVPTXTargetMachine.h"
20 #include "llvm/ADT/DenseMap.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/ADT/StringRef.h"
23 #include "llvm/CodeGen/AsmPrinter.h"
24 #include "llvm/CodeGen/MachineFunction.h"
25 #include "llvm/CodeGen/MachineLoopInfo.h"
26 #include "llvm/IR/Constants.h"
27 #include "llvm/IR/DebugLoc.h"
28 #include "llvm/IR/DerivedTypes.h"
29 #include "llvm/IR/Function.h"
30 #include "llvm/IR/GlobalValue.h"
31 #include "llvm/IR/Value.h"
32 #include "llvm/MC/MCExpr.h"
33 #include "llvm/MC/MCStreamer.h"
34 #include "llvm/MC/MCSymbol.h"
35 #include "llvm/Pass.h"
36 #include "llvm/Support/Casting.h"
37 #include "llvm/Support/Compiler.h"
38 #include "llvm/Support/ErrorHandling.h"
39 #include "llvm/Support/raw_ostream.h"
40 #include "llvm/Target/TargetMachine.h"
48 // The ptx syntax and format is very different from that usually seem in a .s
50 // therefore we are not able to use the MCAsmStreamer interface here.
52 // We are handcrafting the output method here.
54 // A better approach is to clone the MCAsmStreamer to a MCPTXAsmStreamer
55 // (subclass of MCStreamer).
61 class LLVM_LIBRARY_VISIBILITY NVPTXAsmPrinter
: public AsmPrinter
{
64 // Used to buffer the emitted string for initializing global
67 // Normally an aggregate (array, vector or structure) is emitted
68 // as a u8[]. However, if one element/field of the aggregate
69 // is a non-NULL address, then the aggregate is emitted as u32[]
72 // We first layout the aggregate in 'buffer' in bytes, except for
73 // those symbol addresses. For the i-th symbol address in the
74 //aggregate, its corresponding 4-byte or 8-byte elements in 'buffer'
75 // are filled with 0s. symbolPosInBuffer[i-1] records its position
76 // in 'buffer', and Symbols[i-1] records the Value*.
78 // Once we have this AggBuffer setup, we can choose how to print
81 unsigned numSymbols
; // number of symbol addresses
84 const unsigned size
; // size of the buffer in bytes
85 std::vector
<unsigned char> buffer
; // the buffer
86 SmallVector
<unsigned, 4> symbolPosInBuffer
;
87 SmallVector
<const Value
*, 4> Symbols
;
88 // SymbolsBeforeStripping[i] is the original form of Symbols[i] before
89 // stripping pointer casts, i.e.,
90 // Symbols[i] == SymbolsBeforeStripping[i]->stripPointerCasts().
92 // We need to keep these values because AggBuffer::print decides whether to
93 // emit a "generic()" cast for Symbols[i] depending on the address space of
94 // SymbolsBeforeStripping[i].
95 SmallVector
<const Value
*, 4> SymbolsBeforeStripping
;
102 AggBuffer(unsigned size
, raw_ostream
&O
, NVPTXAsmPrinter
&AP
)
103 : size(size
), buffer(size
), O(O
), AP(AP
) {
106 EmitGeneric
= AP
.EmitGeneric
;
109 unsigned addBytes(unsigned char *Ptr
, int Num
, int Bytes
) {
110 assert((curpos
+ Num
) <= size
);
111 assert((curpos
+ Bytes
) <= size
);
112 for (int i
= 0; i
< Num
; ++i
) {
113 buffer
[curpos
] = Ptr
[i
];
116 for (int i
= Num
; i
< Bytes
; ++i
) {
123 unsigned addZeros(int Num
) {
124 assert((curpos
+ Num
) <= size
);
125 for (int i
= 0; i
< Num
; ++i
) {
132 void addSymbol(const Value
*GVar
, const Value
*GVarBeforeStripping
) {
133 symbolPosInBuffer
.push_back(curpos
);
134 Symbols
.push_back(GVar
);
135 SymbolsBeforeStripping
.push_back(GVarBeforeStripping
);
140 if (numSymbols
== 0) {
141 // print out in bytes
142 for (unsigned i
= 0; i
< size
; i
++) {
145 O
<< (unsigned int) buffer
[i
];
148 // print out in 4-bytes or 8-bytes
149 unsigned int pos
= 0;
150 unsigned int nSym
= 0;
151 unsigned int nextSymbolPos
= symbolPosInBuffer
[nSym
];
152 unsigned int nBytes
= 4;
153 if (static_cast<const NVPTXTargetMachine
&>(AP
.TM
).is64Bit())
155 for (pos
= 0; pos
< size
; pos
+= nBytes
) {
158 if (pos
== nextSymbolPos
) {
159 const Value
*v
= Symbols
[nSym
];
160 const Value
*v0
= SymbolsBeforeStripping
[nSym
];
161 if (const GlobalValue
*GVar
= dyn_cast
<GlobalValue
>(v
)) {
162 MCSymbol
*Name
= AP
.getSymbol(GVar
);
163 PointerType
*PTy
= dyn_cast
<PointerType
>(v0
->getType());
164 bool IsNonGenericPointer
= false; // Is v0 a non-generic pointer?
165 if (PTy
&& PTy
->getAddressSpace() != 0) {
166 IsNonGenericPointer
= true;
168 if (EmitGeneric
&& !isa
<Function
>(v
) && !IsNonGenericPointer
) {
170 Name
->print(O
, AP
.MAI
);
173 Name
->print(O
, AP
.MAI
);
175 } else if (const ConstantExpr
*CExpr
= dyn_cast
<ConstantExpr
>(v0
)) {
177 AP
.lowerConstantForGV(cast
<Constant
>(CExpr
), false);
178 AP
.printMCExpr(*Expr
, O
);
180 llvm_unreachable("symbol type unknown");
182 if (nSym
>= numSymbols
)
183 nextSymbolPos
= size
+ 1;
185 nextSymbolPos
= symbolPosInBuffer
[nSym
];
186 } else if (nBytes
== 4)
187 O
<< *(unsigned int *)(&buffer
[pos
]);
189 O
<< *(unsigned long long *)(&buffer
[pos
]);
195 friend class AggBuffer
;
198 StringRef
getPassName() const override
{ return "NVPTX Assembly Printer"; }
201 std::string CurrentFnName
;
203 void emitStartOfAsmFile(Module
&M
) override
;
204 void emitBasicBlockStart(const MachineBasicBlock
&MBB
) override
;
205 void emitFunctionEntryLabel() override
;
206 void emitFunctionBodyStart() override
;
207 void emitFunctionBodyEnd() override
;
208 void emitImplicitDef(const MachineInstr
*MI
) const override
;
210 void emitInstruction(const MachineInstr
*) override
;
211 void lowerToMCInst(const MachineInstr
*MI
, MCInst
&OutMI
);
212 bool lowerOperand(const MachineOperand
&MO
, MCOperand
&MCOp
);
213 MCOperand
GetSymbolRef(const MCSymbol
*Symbol
);
214 unsigned encodeVirtualRegister(unsigned Reg
);
216 void printMemOperand(const MachineInstr
*MI
, int opNum
, raw_ostream
&O
,
217 const char *Modifier
= nullptr);
218 void printModuleLevelGV(const GlobalVariable
*GVar
, raw_ostream
&O
,
220 void printParamName(Function::const_arg_iterator I
, int paramIndex
,
222 void emitGlobals(const Module
&M
);
223 void emitHeader(Module
&M
, raw_ostream
&O
, const NVPTXSubtarget
&STI
);
224 void emitKernelFunctionDirectives(const Function
&F
, raw_ostream
&O
) const;
225 void emitVirtualRegister(unsigned int vr
, raw_ostream
&);
226 void emitFunctionParamList(const Function
*, raw_ostream
&O
);
227 void emitFunctionParamList(const MachineFunction
&MF
, raw_ostream
&O
);
228 void setAndEmitFunctionVirtualRegisters(const MachineFunction
&MF
);
229 void printReturnValStr(const Function
*, raw_ostream
&O
);
230 void printReturnValStr(const MachineFunction
&MF
, raw_ostream
&O
);
231 bool PrintAsmOperand(const MachineInstr
*MI
, unsigned OpNo
,
232 const char *ExtraCode
, raw_ostream
&) override
;
233 void printOperand(const MachineInstr
*MI
, int opNum
, raw_ostream
&O
);
234 bool PrintAsmMemoryOperand(const MachineInstr
*MI
, unsigned OpNo
,
235 const char *ExtraCode
, raw_ostream
&) override
;
237 const MCExpr
*lowerConstantForGV(const Constant
*CV
, bool ProcessingGeneric
);
238 void printMCExpr(const MCExpr
&Expr
, raw_ostream
&OS
);
241 bool doInitialization(Module
&M
) override
;
242 bool doFinalization(Module
&M
) override
;
247 // This is specific per MachineFunction.
248 const MachineRegisterInfo
*MRI
;
249 // The contents are specific for each
250 // MachineFunction. But the size of the
252 typedef DenseMap
<unsigned, unsigned> VRegMap
;
253 typedef DenseMap
<const TargetRegisterClass
*, VRegMap
> VRegRCMap
;
254 VRegRCMap VRegMapping
;
256 // List of variables demoted to a function scope.
257 std::map
<const Function
*, std::vector
<const GlobalVariable
*>> localDecls
;
259 void emitPTXGlobalVariable(const GlobalVariable
*GVar
, raw_ostream
&O
);
260 void emitPTXAddressSpace(unsigned int AddressSpace
, raw_ostream
&O
) const;
261 std::string
getPTXFundamentalTypeStr(Type
*Ty
, bool = true) const;
262 void printScalarConstant(const Constant
*CPV
, raw_ostream
&O
);
263 void printFPConstant(const ConstantFP
*Fp
, raw_ostream
&O
);
264 void bufferLEByte(const Constant
*CPV
, int Bytes
, AggBuffer
*aggBuffer
);
265 void bufferAggregateConstant(const Constant
*CV
, AggBuffer
*aggBuffer
);
267 void emitLinkageDirective(const GlobalValue
*V
, raw_ostream
&O
);
268 void emitDeclarations(const Module
&, raw_ostream
&O
);
269 void emitDeclaration(const Function
*, raw_ostream
&O
);
270 void emitDemotedVars(const Function
*, raw_ostream
&);
272 bool lowerImageHandleOperand(const MachineInstr
*MI
, unsigned OpNo
,
274 void lowerImageHandleSymbol(unsigned Index
, MCOperand
&MCOp
);
276 bool isLoopHeaderOfNoUnroll(const MachineBasicBlock
&MBB
) const;
278 // Used to control the need to emit .generic() in the initializer of
279 // module scope variables.
280 // Although ptx supports the hybrid mode like the following,
283 // .global .u32 addr[] = {a, generic(b)}
284 // we have difficulty representing the difference in the NVVM IR.
286 // Since the address value should always be generic in CUDA C and always
287 // be specific in OpenCL, we use this simple control here.
292 NVPTXAsmPrinter(TargetMachine
&TM
, std::unique_ptr
<MCStreamer
> Streamer
)
293 : AsmPrinter(TM
, std::move(Streamer
)),
294 EmitGeneric(static_cast<NVPTXTargetMachine
&>(TM
).getDrvInterface() ==
297 bool runOnMachineFunction(MachineFunction
&F
) override
;
299 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
300 AU
.addRequired
<MachineLoopInfo
>();
301 AsmPrinter::getAnalysisUsage(AU
);
304 std::string
getVirtualRegisterName(unsigned) const;
306 const MCSymbol
*getFunctionFrameSymbol() const override
;
309 } // end namespace llvm
311 #endif // LLVM_LIB_TARGET_NVPTX_NVPTXASMPRINTER_H