1 //===-- NVPTXAsmPrinter.h - NVPTX LLVM assembly writer ----------*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file contains a printer that converts from our internal representation
10 // of machine-dependent LLVM code to NVPTX assembly language.
12 //===----------------------------------------------------------------------===//
14 #ifndef LLVM_LIB_TARGET_NVPTX_NVPTXASMPRINTER_H
15 #define LLVM_LIB_TARGET_NVPTX_NVPTXASMPRINTER_H
18 #include "NVPTXSubtarget.h"
19 #include "NVPTXTargetMachine.h"
20 #include "llvm/ADT/DenseMap.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/ADT/StringRef.h"
23 #include "llvm/CodeGen/AsmPrinter.h"
24 #include "llvm/CodeGen/MachineFunction.h"
25 #include "llvm/CodeGen/MachineLoopInfo.h"
26 #include "llvm/IR/Constants.h"
27 #include "llvm/IR/DebugLoc.h"
28 #include "llvm/IR/DerivedTypes.h"
29 #include "llvm/IR/Function.h"
30 #include "llvm/IR/GlobalValue.h"
31 #include "llvm/IR/Value.h"
32 #include "llvm/MC/MCExpr.h"
33 #include "llvm/MC/MCStreamer.h"
34 #include "llvm/MC/MCSymbol.h"
35 #include "llvm/PassAnalysisSupport.h"
36 #include "llvm/Support/Casting.h"
37 #include "llvm/Support/Compiler.h"
38 #include "llvm/Support/ErrorHandling.h"
39 #include "llvm/Support/raw_ostream.h"
40 #include "llvm/Target/TargetMachine.h"
48 // The ptx syntax and format is very different from that usually seem in a .s
50 // therefore we are not able to use the MCAsmStreamer interface here.
52 // We are handcrafting the output method here.
54 // A better approach is to clone the MCAsmStreamer to a MCPTXAsmStreamer
55 // (subclass of MCStreamer).
61 class LLVM_LIBRARY_VISIBILITY NVPTXAsmPrinter
: public AsmPrinter
{
64 // Used to buffer the emitted string for initializing global
67 // Normally an aggregate (array, vector or structure) is emitted
68 // as a u8[]. However, if one element/field of the aggregate
69 // is a non-NULL address, then the aggregate is emitted as u32[]
72 // We first layout the aggregate in 'buffer' in bytes, except for
73 // those symbol addresses. For the i-th symbol address in the
74 //aggregate, its corresponding 4-byte or 8-byte elements in 'buffer'
75 // are filled with 0s. symbolPosInBuffer[i-1] records its position
76 // in 'buffer', and Symbols[i-1] records the Value*.
78 // Once we have this AggBuffer setup, we can choose how to print
81 unsigned numSymbols
; // number of symbol addresses
84 const unsigned size
; // size of the buffer in bytes
85 std::vector
<unsigned char> buffer
; // the buffer
86 SmallVector
<unsigned, 4> symbolPosInBuffer
;
87 SmallVector
<const Value
*, 4> Symbols
;
88 // SymbolsBeforeStripping[i] is the original form of Symbols[i] before
89 // stripping pointer casts, i.e.,
90 // Symbols[i] == SymbolsBeforeStripping[i]->stripPointerCasts().
92 // We need to keep these values because AggBuffer::print decides whether to
93 // emit a "generic()" cast for Symbols[i] depending on the address space of
94 // SymbolsBeforeStripping[i].
95 SmallVector
<const Value
*, 4> SymbolsBeforeStripping
;
102 AggBuffer(unsigned size
, raw_ostream
&O
, NVPTXAsmPrinter
&AP
)
103 : size(size
), buffer(size
), O(O
), AP(AP
) {
106 EmitGeneric
= AP
.EmitGeneric
;
109 unsigned addBytes(unsigned char *Ptr
, int Num
, int Bytes
) {
110 assert((curpos
+ Num
) <= size
);
111 assert((curpos
+ Bytes
) <= size
);
112 for (int i
= 0; i
< Num
; ++i
) {
113 buffer
[curpos
] = Ptr
[i
];
116 for (int i
= Num
; i
< Bytes
; ++i
) {
123 unsigned addZeros(int Num
) {
124 assert((curpos
+ Num
) <= size
);
125 for (int i
= 0; i
< Num
; ++i
) {
132 void addSymbol(const Value
*GVar
, const Value
*GVarBeforeStripping
) {
133 symbolPosInBuffer
.push_back(curpos
);
134 Symbols
.push_back(GVar
);
135 SymbolsBeforeStripping
.push_back(GVarBeforeStripping
);
140 if (numSymbols
== 0) {
141 // print out in bytes
142 for (unsigned i
= 0; i
< size
; i
++) {
145 O
<< (unsigned int) buffer
[i
];
148 // print out in 4-bytes or 8-bytes
149 unsigned int pos
= 0;
150 unsigned int nSym
= 0;
151 unsigned int nextSymbolPos
= symbolPosInBuffer
[nSym
];
152 unsigned int nBytes
= 4;
153 if (static_cast<const NVPTXTargetMachine
&>(AP
.TM
).is64Bit())
155 for (pos
= 0; pos
< size
; pos
+= nBytes
) {
158 if (pos
== nextSymbolPos
) {
159 const Value
*v
= Symbols
[nSym
];
160 const Value
*v0
= SymbolsBeforeStripping
[nSym
];
161 if (const GlobalValue
*GVar
= dyn_cast
<GlobalValue
>(v
)) {
162 MCSymbol
*Name
= AP
.getSymbol(GVar
);
163 PointerType
*PTy
= dyn_cast
<PointerType
>(v0
->getType());
164 bool IsNonGenericPointer
= false; // Is v0 a non-generic pointer?
165 if (PTy
&& PTy
->getAddressSpace() != 0) {
166 IsNonGenericPointer
= true;
168 if (EmitGeneric
&& !isa
<Function
>(v
) && !IsNonGenericPointer
) {
170 Name
->print(O
, AP
.MAI
);
173 Name
->print(O
, AP
.MAI
);
175 } else if (const ConstantExpr
*CExpr
= dyn_cast
<ConstantExpr
>(v0
)) {
177 AP
.lowerConstantForGV(cast
<Constant
>(CExpr
), false);
178 AP
.printMCExpr(*Expr
, O
);
180 llvm_unreachable("symbol type unknown");
182 if (nSym
>= numSymbols
)
183 nextSymbolPos
= size
+ 1;
185 nextSymbolPos
= symbolPosInBuffer
[nSym
];
186 } else if (nBytes
== 4)
187 O
<< *(unsigned int *)(&buffer
[pos
]);
189 O
<< *(unsigned long long *)(&buffer
[pos
]);
195 friend class AggBuffer
;
198 StringRef
getPassName() const override
{ return "NVPTX Assembly Printer"; }
201 std::string CurrentFnName
;
203 void EmitBasicBlockStart(const MachineBasicBlock
&MBB
) override
;
204 void EmitFunctionEntryLabel() override
;
205 void EmitFunctionBodyStart() override
;
206 void EmitFunctionBodyEnd() override
;
207 void emitImplicitDef(const MachineInstr
*MI
) const override
;
209 void EmitInstruction(const MachineInstr
*) override
;
210 void lowerToMCInst(const MachineInstr
*MI
, MCInst
&OutMI
);
211 bool lowerOperand(const MachineOperand
&MO
, MCOperand
&MCOp
);
212 MCOperand
GetSymbolRef(const MCSymbol
*Symbol
);
213 unsigned encodeVirtualRegister(unsigned Reg
);
215 void printMemOperand(const MachineInstr
*MI
, int opNum
, raw_ostream
&O
,
216 const char *Modifier
= nullptr);
217 void printModuleLevelGV(const GlobalVariable
*GVar
, raw_ostream
&O
,
219 void printParamName(Function::const_arg_iterator I
, int paramIndex
,
221 void emitGlobals(const Module
&M
);
222 void emitHeader(Module
&M
, raw_ostream
&O
, const NVPTXSubtarget
&STI
);
223 void emitKernelFunctionDirectives(const Function
&F
, raw_ostream
&O
) const;
224 void emitVirtualRegister(unsigned int vr
, raw_ostream
&);
225 void emitFunctionParamList(const Function
*, raw_ostream
&O
);
226 void emitFunctionParamList(const MachineFunction
&MF
, raw_ostream
&O
);
227 void setAndEmitFunctionVirtualRegisters(const MachineFunction
&MF
);
228 void printReturnValStr(const Function
*, raw_ostream
&O
);
229 void printReturnValStr(const MachineFunction
&MF
, raw_ostream
&O
);
230 bool PrintAsmOperand(const MachineInstr
*MI
, unsigned OpNo
,
231 const char *ExtraCode
, raw_ostream
&) override
;
232 void printOperand(const MachineInstr
*MI
, int opNum
, raw_ostream
&O
);
233 bool PrintAsmMemoryOperand(const MachineInstr
*MI
, unsigned OpNo
,
234 const char *ExtraCode
, raw_ostream
&) override
;
236 const MCExpr
*lowerConstantForGV(const Constant
*CV
, bool ProcessingGeneric
);
237 void printMCExpr(const MCExpr
&Expr
, raw_ostream
&OS
);
240 bool doInitialization(Module
&M
) override
;
241 bool doFinalization(Module
&M
) override
;
246 // This is specific per MachineFunction.
247 const MachineRegisterInfo
*MRI
;
248 // The contents are specific for each
249 // MachineFunction. But the size of the
251 typedef DenseMap
<unsigned, unsigned> VRegMap
;
252 typedef DenseMap
<const TargetRegisterClass
*, VRegMap
> VRegRCMap
;
253 VRegRCMap VRegMapping
;
255 // List of variables demoted to a function scope.
256 std::map
<const Function
*, std::vector
<const GlobalVariable
*>> localDecls
;
258 void emitPTXGlobalVariable(const GlobalVariable
*GVar
, raw_ostream
&O
);
259 void emitPTXAddressSpace(unsigned int AddressSpace
, raw_ostream
&O
) const;
260 std::string
getPTXFundamentalTypeStr(Type
*Ty
, bool = true) const;
261 void printScalarConstant(const Constant
*CPV
, raw_ostream
&O
);
262 void printFPConstant(const ConstantFP
*Fp
, raw_ostream
&O
);
263 void bufferLEByte(const Constant
*CPV
, int Bytes
, AggBuffer
*aggBuffer
);
264 void bufferAggregateConstant(const Constant
*CV
, AggBuffer
*aggBuffer
);
266 void emitLinkageDirective(const GlobalValue
*V
, raw_ostream
&O
);
267 void emitDeclarations(const Module
&, raw_ostream
&O
);
268 void emitDeclaration(const Function
*, raw_ostream
&O
);
269 void emitDemotedVars(const Function
*, raw_ostream
&);
271 bool lowerImageHandleOperand(const MachineInstr
*MI
, unsigned OpNo
,
273 void lowerImageHandleSymbol(unsigned Index
, MCOperand
&MCOp
);
275 bool isLoopHeaderOfNoUnroll(const MachineBasicBlock
&MBB
) const;
277 // Used to control the need to emit .generic() in the initializer of
278 // module scope variables.
279 // Although ptx supports the hybrid mode like the following,
282 // .global .u32 addr[] = {a, generic(b)}
283 // we have difficulty representing the difference in the NVVM IR.
285 // Since the address value should always be generic in CUDA C and always
286 // be specific in OpenCL, we use this simple control here.
291 NVPTXAsmPrinter(TargetMachine
&TM
, std::unique_ptr
<MCStreamer
> Streamer
)
292 : AsmPrinter(TM
, std::move(Streamer
)),
293 EmitGeneric(static_cast<NVPTXTargetMachine
&>(TM
).getDrvInterface() ==
296 bool runOnMachineFunction(MachineFunction
&F
) override
;
298 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
299 AU
.addRequired
<MachineLoopInfo
>();
300 AsmPrinter::getAnalysisUsage(AU
);
303 std::string
getVirtualRegisterName(unsigned) const;
305 const MCSymbol
*getFunctionFrameSymbol() const override
;
308 } // end namespace llvm
310 #endif // LLVM_LIB_TARGET_NVPTX_NVPTXASMPRINTER_H