Run DCE after a LoopFlatten test to reduce spurious output [nfc]
[llvm-project.git] / llvm / utils / TableGen / VarLenCodeEmitterGen.cpp
blobbfb7e5c333170c2ae7189c181f7c02fa1e630969
1 //===- VarLenCodeEmitterGen.cpp - CEG for variable-length insts -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // The CodeEmitterGen component for variable-length instructions.
11 // The basic CodeEmitterGen is almost exclusively designed for fixed-
12 // length instructions. A good analogy for its encoding scheme is how printf
13 // works: The (immutable) formatting string represent the fixed values in the
14 // encoded instruction. Placeholders (i.e. %something), on the other hand,
15 // represent encoding for instruction operands.
16 // ```
17 // printf("1101 %src 1001 %dst", <encoded value for operand `src`>,
18 // <encoded value for operand `dst`>);
19 // ```
20 // VarLenCodeEmitterGen in this file provides an alternative encoding scheme
21 // that works more like a C++ stream operator:
22 // ```
23 // OS << 0b1101;
24 // if (Cond)
25 // OS << OperandEncoding0;
26 // OS << 0b1001 << OperandEncoding1;
27 // ```
28 // You are free to concatenate arbitrary types (and sizes) of encoding
29 // fragments on any bit position, bringing more flexibilities on defining
30 // encoding for variable-length instructions.
32 // In a more specific way, instruction encoding is represented by a DAG type
33 // `Inst` field. Here is an example:
34 // ```
35 // dag Inst = (descend 0b1101, (operand "$src", 4), 0b1001,
36 // (operand "$dst", 4));
37 // ```
38 // It represents the following instruction encoding:
39 // ```
40 // MSB LSB
41 // 1101<encoding for operand src>1001<encoding for operand dst>
42 // ```
43 // For more details about DAG operators in the above snippet, please
44 // refer to \file include/llvm/Target/Target.td.
46 // VarLenCodeEmitter will convert the above DAG into the same helper function
47 // generated by CodeEmitter, `MCCodeEmitter::getBinaryCodeForInstr` (except
48 // for few details).
50 //===----------------------------------------------------------------------===//
52 #include "VarLenCodeEmitterGen.h"
53 #include "CodeGenHwModes.h"
54 #include "CodeGenInstruction.h"
55 #include "CodeGenTarget.h"
56 #include "InfoByHwMode.h"
57 #include "llvm/ADT/ArrayRef.h"
58 #include "llvm/ADT/DenseMap.h"
59 #include "llvm/Support/raw_ostream.h"
60 #include "llvm/TableGen/Error.h"
61 #include "llvm/TableGen/Record.h"
63 #include <algorithm>
65 using namespace llvm;
67 namespace {
69 class VarLenCodeEmitterGen {
70 RecordKeeper &Records;
72 // Representaton of alternative encodings used for HwModes.
73 using AltEncodingTy = int;
74 // Mode identifier when only one encoding is defined.
75 const AltEncodingTy Universal = -1;
76 // The set of alternative instruction encodings with a descriptive
77 // name suffix to improve readability of the generated code.
78 std::map<AltEncodingTy, std::string> Modes;
80 DenseMap<Record *, DenseMap<AltEncodingTy, VarLenInst>> VarLenInsts;
82 // Emit based values (i.e. fixed bits in the encoded instructions)
83 void emitInstructionBaseValues(
84 raw_ostream &OS,
85 ArrayRef<const CodeGenInstruction *> NumberedInstructions,
86 CodeGenTarget &Target, AltEncodingTy Mode);
88 std::string getInstructionCases(Record *R, CodeGenTarget &Target);
89 std::string getInstructionCaseForEncoding(Record *R, AltEncodingTy Mode,
90 const VarLenInst &VLI,
91 CodeGenTarget &Target, int I);
93 public:
94 explicit VarLenCodeEmitterGen(RecordKeeper &R) : Records(R) {}
96 void run(raw_ostream &OS);
98 } // end anonymous namespace
100 // Get the name of custom encoder or decoder, if there is any.
101 // Returns `{encoder name, decoder name}`.
102 static std::pair<StringRef, StringRef> getCustomCoders(ArrayRef<Init *> Args) {
103 std::pair<StringRef, StringRef> Result;
104 for (const auto *Arg : Args) {
105 const auto *DI = dyn_cast<DagInit>(Arg);
106 if (!DI)
107 continue;
108 const Init *Op = DI->getOperator();
109 if (!isa<DefInit>(Op))
110 continue;
111 // syntax: `(<encoder | decoder> "function name")`
112 StringRef OpName = cast<DefInit>(Op)->getDef()->getName();
113 if (OpName != "encoder" && OpName != "decoder")
114 continue;
115 if (!DI->getNumArgs() || !isa<StringInit>(DI->getArg(0)))
116 PrintFatalError("expected '" + OpName +
117 "' directive to be followed by a custom function name.");
118 StringRef FuncName = cast<StringInit>(DI->getArg(0))->getValue();
119 if (OpName == "encoder")
120 Result.first = FuncName;
121 else
122 Result.second = FuncName;
124 return Result;
127 VarLenInst::VarLenInst(const DagInit *DI, const RecordVal *TheDef)
128 : TheDef(TheDef), NumBits(0U), HasDynamicSegment(false) {
129 buildRec(DI);
130 for (const auto &S : Segments)
131 NumBits += S.BitWidth;
134 void VarLenInst::buildRec(const DagInit *DI) {
135 assert(TheDef && "The def record is nullptr ?");
137 std::string Op = DI->getOperator()->getAsString();
139 if (Op == "ascend" || Op == "descend") {
140 bool Reverse = Op == "descend";
141 int i = Reverse ? DI->getNumArgs() - 1 : 0;
142 int e = Reverse ? -1 : DI->getNumArgs();
143 int s = Reverse ? -1 : 1;
144 for (; i != e; i += s) {
145 const Init *Arg = DI->getArg(i);
146 if (const auto *BI = dyn_cast<BitsInit>(Arg)) {
147 if (!BI->isComplete())
148 PrintFatalError(TheDef->getLoc(),
149 "Expecting complete bits init in `" + Op + "`");
150 Segments.push_back({BI->getNumBits(), BI});
151 } else if (const auto *BI = dyn_cast<BitInit>(Arg)) {
152 if (!BI->isConcrete())
153 PrintFatalError(TheDef->getLoc(),
154 "Expecting concrete bit init in `" + Op + "`");
155 Segments.push_back({1, BI});
156 } else if (const auto *SubDI = dyn_cast<DagInit>(Arg)) {
157 buildRec(SubDI);
158 } else {
159 PrintFatalError(TheDef->getLoc(), "Unrecognized type of argument in `" +
160 Op + "`: " + Arg->getAsString());
163 } else if (Op == "operand") {
164 // (operand <operand name>, <# of bits>,
165 // [(encoder <custom encoder>)][, (decoder <custom decoder>)])
166 if (DI->getNumArgs() < 2)
167 PrintFatalError(TheDef->getLoc(),
168 "Expecting at least 2 arguments for `operand`");
169 HasDynamicSegment = true;
170 const Init *OperandName = DI->getArg(0), *NumBits = DI->getArg(1);
171 if (!isa<StringInit>(OperandName) || !isa<IntInit>(NumBits))
172 PrintFatalError(TheDef->getLoc(), "Invalid argument types for `operand`");
174 auto NumBitsVal = cast<IntInit>(NumBits)->getValue();
175 if (NumBitsVal <= 0)
176 PrintFatalError(TheDef->getLoc(), "Invalid number of bits for `operand`");
178 auto [CustomEncoder, CustomDecoder] =
179 getCustomCoders(DI->getArgs().slice(2));
180 Segments.push_back({static_cast<unsigned>(NumBitsVal), OperandName,
181 CustomEncoder, CustomDecoder});
182 } else if (Op == "slice") {
183 // (slice <operand name>, <high / low bit>, <low / high bit>,
184 // [(encoder <custom encoder>)][, (decoder <custom decoder>)])
185 if (DI->getNumArgs() < 3)
186 PrintFatalError(TheDef->getLoc(),
187 "Expecting at least 3 arguments for `slice`");
188 HasDynamicSegment = true;
189 Init *OperandName = DI->getArg(0), *HiBit = DI->getArg(1),
190 *LoBit = DI->getArg(2);
191 if (!isa<StringInit>(OperandName) || !isa<IntInit>(HiBit) ||
192 !isa<IntInit>(LoBit))
193 PrintFatalError(TheDef->getLoc(), "Invalid argument types for `slice`");
195 auto HiBitVal = cast<IntInit>(HiBit)->getValue(),
196 LoBitVal = cast<IntInit>(LoBit)->getValue();
197 if (HiBitVal < 0 || LoBitVal < 0)
198 PrintFatalError(TheDef->getLoc(), "Invalid bit range for `slice`");
199 bool NeedSwap = false;
200 unsigned NumBits = 0U;
201 if (HiBitVal < LoBitVal) {
202 NeedSwap = true;
203 NumBits = static_cast<unsigned>(LoBitVal - HiBitVal + 1);
204 } else {
205 NumBits = static_cast<unsigned>(HiBitVal - LoBitVal + 1);
208 auto [CustomEncoder, CustomDecoder] =
209 getCustomCoders(DI->getArgs().slice(3));
211 if (NeedSwap) {
212 // Normalization: Hi bit should always be the second argument.
213 Init *const NewArgs[] = {OperandName, LoBit, HiBit};
214 Segments.push_back({NumBits,
215 DagInit::get(DI->getOperator(), nullptr, NewArgs, {}),
216 CustomEncoder, CustomDecoder});
217 } else {
218 Segments.push_back({NumBits, DI, CustomEncoder, CustomDecoder});
223 void VarLenCodeEmitterGen::run(raw_ostream &OS) {
224 CodeGenTarget Target(Records);
225 auto Insts = Records.getAllDerivedDefinitions("Instruction");
227 auto NumberedInstructions = Target.getInstructionsByEnumValue();
229 for (const CodeGenInstruction *CGI : NumberedInstructions) {
230 Record *R = CGI->TheDef;
231 // Create the corresponding VarLenInst instance.
232 if (R->getValueAsString("Namespace") == "TargetOpcode" ||
233 R->getValueAsBit("isPseudo"))
234 continue;
236 // Setup alternative encodings according to HwModes
237 if (const RecordVal *RV = R->getValue("EncodingInfos")) {
238 if (auto *DI = dyn_cast_or_null<DefInit>(RV->getValue())) {
239 const CodeGenHwModes &HWM = Target.getHwModes();
240 EncodingInfoByHwMode EBM(DI->getDef(), HWM);
241 for (auto &KV : EBM) {
242 AltEncodingTy Mode = KV.first;
243 Modes.insert({Mode, "_" + HWM.getMode(Mode).Name.str()});
244 Record *EncodingDef = KV.second;
245 RecordVal *RV = EncodingDef->getValue("Inst");
246 DagInit *DI = cast<DagInit>(RV->getValue());
247 VarLenInsts[R].insert({Mode, VarLenInst(DI, RV)});
249 continue;
252 RecordVal *RV = R->getValue("Inst");
253 DagInit *DI = cast<DagInit>(RV->getValue());
254 VarLenInsts[R].insert({Universal, VarLenInst(DI, RV)});
257 if (Modes.empty())
258 Modes.insert({Universal, ""}); // Base case, skip suffix.
260 // Emit function declaration
261 OS << "void " << Target.getName()
262 << "MCCodeEmitter::getBinaryCodeForInstr(const MCInst &MI,\n"
263 << " SmallVectorImpl<MCFixup> &Fixups,\n"
264 << " APInt &Inst,\n"
265 << " APInt &Scratch,\n"
266 << " const MCSubtargetInfo &STI) const {\n";
268 // Emit instruction base values
269 for (const auto &Mode : Modes)
270 emitInstructionBaseValues(OS, NumberedInstructions, Target, Mode.first);
272 if (Modes.size() > 1) {
273 OS << " unsigned Mode = STI.getHwMode();\n";
276 for (const auto &Mode : Modes) {
277 // Emit helper function to retrieve base values.
278 OS << " auto getInstBits" << Mode.second
279 << " = [&](unsigned Opcode) -> APInt {\n"
280 << " unsigned NumBits = Index" << Mode.second << "[Opcode][0];\n"
281 << " if (!NumBits)\n"
282 << " return APInt::getZeroWidth();\n"
283 << " unsigned Idx = Index" << Mode.second << "[Opcode][1];\n"
284 << " ArrayRef<uint64_t> Data(&InstBits" << Mode.second << "[Idx], "
285 << "APInt::getNumWords(NumBits));\n"
286 << " return APInt(NumBits, Data);\n"
287 << " };\n";
290 // Map to accumulate all the cases.
291 std::map<std::string, std::vector<std::string>> CaseMap;
293 // Construct all cases statement for each opcode
294 for (Record *R : Insts) {
295 if (R->getValueAsString("Namespace") == "TargetOpcode" ||
296 R->getValueAsBit("isPseudo"))
297 continue;
298 std::string InstName =
299 (R->getValueAsString("Namespace") + "::" + R->getName()).str();
300 std::string Case = getInstructionCases(R, Target);
302 CaseMap[Case].push_back(std::move(InstName));
305 // Emit initial function code
306 OS << " const unsigned opcode = MI.getOpcode();\n"
307 << " switch (opcode) {\n";
309 // Emit each case statement
310 for (const auto &C : CaseMap) {
311 const std::string &Case = C.first;
312 const auto &InstList = C.second;
314 ListSeparator LS("\n");
315 for (const auto &InstName : InstList)
316 OS << LS << " case " << InstName << ":";
318 OS << " {\n";
319 OS << Case;
320 OS << " break;\n"
321 << " }\n";
323 // Default case: unhandled opcode
324 OS << " default:\n"
325 << " std::string msg;\n"
326 << " raw_string_ostream Msg(msg);\n"
327 << " Msg << \"Not supported instr: \" << MI;\n"
328 << " report_fatal_error(Msg.str().c_str());\n"
329 << " }\n";
330 OS << "}\n\n";
333 static void emitInstBits(raw_ostream &IS, raw_ostream &SS, const APInt &Bits,
334 unsigned &Index) {
335 if (!Bits.getNumWords()) {
336 IS.indent(4) << "{/*NumBits*/0, /*Index*/0},";
337 return;
340 IS.indent(4) << "{/*NumBits*/" << Bits.getBitWidth() << ", "
341 << "/*Index*/" << Index << "},";
343 SS.indent(4);
344 for (unsigned I = 0; I < Bits.getNumWords(); ++I, ++Index)
345 SS << "UINT64_C(" << utostr(Bits.getRawData()[I]) << "),";
348 void VarLenCodeEmitterGen::emitInstructionBaseValues(
349 raw_ostream &OS, ArrayRef<const CodeGenInstruction *> NumberedInstructions,
350 CodeGenTarget &Target, AltEncodingTy Mode) {
351 std::string IndexArray, StorageArray;
352 raw_string_ostream IS(IndexArray), SS(StorageArray);
354 IS << " static const unsigned Index" << Modes[Mode] << "[][2] = {\n";
355 SS << " static const uint64_t InstBits" << Modes[Mode] << "[] = {\n";
357 unsigned NumFixedValueWords = 0U;
358 for (const CodeGenInstruction *CGI : NumberedInstructions) {
359 Record *R = CGI->TheDef;
361 if (R->getValueAsString("Namespace") == "TargetOpcode" ||
362 R->getValueAsBit("isPseudo")) {
363 IS.indent(4) << "{/*NumBits*/0, /*Index*/0},\n";
364 continue;
367 const auto InstIt = VarLenInsts.find(R);
368 if (InstIt == VarLenInsts.end())
369 PrintFatalError(R, "VarLenInst not found for this record");
370 auto ModeIt = InstIt->second.find(Mode);
371 if (ModeIt == InstIt->second.end())
372 ModeIt = InstIt->second.find(Universal);
373 if (ModeIt == InstIt->second.end()) {
374 IS.indent(4) << "{/*NumBits*/0, /*Index*/0},\t"
375 << "// " << R->getName() << " no encoding\n";
376 continue;
378 const VarLenInst &VLI = ModeIt->second;
379 unsigned i = 0U, BitWidth = VLI.size();
381 // Start by filling in fixed values.
382 APInt Value(BitWidth, 0);
383 auto SI = VLI.begin(), SE = VLI.end();
384 // Scan through all the segments that have fixed-bits values.
385 while (i < BitWidth && SI != SE) {
386 unsigned SegmentNumBits = SI->BitWidth;
387 if (const auto *BI = dyn_cast<BitsInit>(SI->Value)) {
388 for (unsigned Idx = 0U; Idx != SegmentNumBits; ++Idx) {
389 auto *B = cast<BitInit>(BI->getBit(Idx));
390 Value.setBitVal(i + Idx, B->getValue());
393 if (const auto *BI = dyn_cast<BitInit>(SI->Value))
394 Value.setBitVal(i, BI->getValue());
396 i += SegmentNumBits;
397 ++SI;
400 emitInstBits(IS, SS, Value, NumFixedValueWords);
401 IS << '\t' << "// " << R->getName() << "\n";
402 if (Value.getNumWords())
403 SS << '\t' << "// " << R->getName() << "\n";
405 IS.indent(4) << "{/*NumBits*/0, /*Index*/0}\n };\n";
406 SS.indent(4) << "UINT64_C(0)\n };\n";
408 OS << IS.str() << SS.str();
411 std::string VarLenCodeEmitterGen::getInstructionCases(Record *R,
412 CodeGenTarget &Target) {
413 auto It = VarLenInsts.find(R);
414 if (It == VarLenInsts.end())
415 PrintFatalError(R, "Parsed encoding record not found");
416 const auto &Map = It->second;
418 // Is this instructions encoding universal (same for all modes)?
419 // Allways true if there is only one mode.
420 if (Map.size() == 1 && Map.begin()->first == Universal) {
421 // Universal, just pick the first mode.
422 AltEncodingTy Mode = Modes.begin()->first;
423 const auto &Encoding = Map.begin()->second;
424 return getInstructionCaseForEncoding(R, Mode, Encoding, Target, 6);
427 std::string Case;
428 Case += " switch (Mode) {\n";
429 Case += " default: llvm_unreachable(\"Unhandled Mode\");\n";
430 for (const auto &Mode : Modes) {
431 Case += " case " + itostr(Mode.first) + ": {\n";
432 const auto &It = Map.find(Mode.first);
433 if (It == Map.end()) {
434 Case +=
435 " llvm_unreachable(\"Undefined encoding in this mode\");\n";
436 } else {
437 Case +=
438 getInstructionCaseForEncoding(R, It->first, It->second, Target, 8);
440 Case += " break;\n";
441 Case += " }\n";
443 Case += " }\n";
444 return Case;
447 std::string VarLenCodeEmitterGen::getInstructionCaseForEncoding(
448 Record *R, AltEncodingTy Mode, const VarLenInst &VLI, CodeGenTarget &Target,
449 int I) {
451 CodeGenInstruction &CGI = Target.getInstruction(R);
453 std::string Case;
454 raw_string_ostream SS(Case);
455 // Populate based value.
456 SS.indent(I) << "Inst = getInstBits" << Modes[Mode] << "(opcode);\n";
458 // Process each segment in VLI.
459 size_t Offset = 0U;
460 unsigned HighScratchAccess = 0U;
461 for (const auto &ES : VLI) {
462 unsigned NumBits = ES.BitWidth;
463 const Init *Val = ES.Value;
464 // If it's a StringInit or DagInit, it's a reference to an operand
465 // or part of an operand.
466 if (isa<StringInit>(Val) || isa<DagInit>(Val)) {
467 StringRef OperandName;
468 unsigned LoBit = 0U;
469 if (const auto *SV = dyn_cast<StringInit>(Val)) {
470 OperandName = SV->getValue();
471 } else {
472 // Normalized: (slice <operand name>, <high bit>, <low bit>)
473 const auto *DV = cast<DagInit>(Val);
474 OperandName = cast<StringInit>(DV->getArg(0))->getValue();
475 LoBit = static_cast<unsigned>(cast<IntInit>(DV->getArg(2))->getValue());
478 auto OpIdx = CGI.Operands.ParseOperandName(OperandName);
479 unsigned FlatOpIdx = CGI.Operands.getFlattenedOperandNumber(OpIdx);
480 StringRef CustomEncoder =
481 CGI.Operands[OpIdx.first].EncoderMethodNames[OpIdx.second];
482 if (ES.CustomEncoder.size())
483 CustomEncoder = ES.CustomEncoder;
485 SS.indent(I) << "Scratch.clearAllBits();\n";
486 SS.indent(I) << "// op: " << OperandName.drop_front(1) << "\n";
487 if (CustomEncoder.empty())
488 SS.indent(I) << "getMachineOpValue(MI, MI.getOperand("
489 << utostr(FlatOpIdx) << ")";
490 else
491 SS.indent(I) << CustomEncoder << "(MI, /*OpIdx=*/" << utostr(FlatOpIdx);
493 SS << ", /*Pos=*/" << utostr(Offset) << ", Scratch, Fixups, STI);\n";
495 SS.indent(I) << "Inst.insertBits("
496 << "Scratch.extractBits(" << utostr(NumBits) << ", "
497 << utostr(LoBit) << ")"
498 << ", " << Offset << ");\n";
500 HighScratchAccess = std::max(HighScratchAccess, NumBits + LoBit);
502 Offset += NumBits;
505 StringRef PostEmitter = R->getValueAsString("PostEncoderMethod");
506 if (!PostEmitter.empty())
507 SS.indent(I) << "Inst = " << PostEmitter << "(MI, Inst, STI);\n";
509 // Resize the scratch buffer if it's to small.
510 std::string ScratchResizeStr;
511 if (VLI.size() && !VLI.isFixedValueOnly()) {
512 raw_string_ostream RS(ScratchResizeStr);
513 RS.indent(I) << "if (Scratch.getBitWidth() < " << HighScratchAccess
514 << ") { Scratch = Scratch.zext(" << HighScratchAccess
515 << "); }\n";
518 return ScratchResizeStr + Case;
521 namespace llvm {
523 void emitVarLenCodeEmitter(RecordKeeper &R, raw_ostream &OS) {
524 VarLenCodeEmitterGen(R).run(OS);
527 } // end namespace llvm