Run DCE after a LoopFlatten test to reduce spurious output [nfc]
[llvm-project.git] / llvm / unittests / CodeGen / GlobalISel / GISelMITest.h
blob27d599671db6ddf51e2172652283b64e1a6a9638
1 //===- GISelMITest.h --------------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 #ifndef LLVM_UNITTEST_CODEGEN_GLOBALISEL_GISELMI_H
9 #define LLVM_UNITTEST_CODEGEN_GLOBALISEL_GISELMI_H
11 #include "llvm/CodeGen/GlobalISel/GISelChangeObserver.h"
12 #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
13 #include "llvm/CodeGen/GlobalISel/LegalizerInfo.h"
14 #include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
15 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
16 #include "llvm/CodeGen/GlobalISel/Utils.h"
17 #include "llvm/CodeGen/MIRParser/MIRParser.h"
18 #include "llvm/CodeGen/MachineFunction.h"
19 #include "llvm/CodeGen/MachineModuleInfo.h"
20 #include "llvm/CodeGen/TargetFrameLowering.h"
21 #include "llvm/CodeGen/TargetInstrInfo.h"
22 #include "llvm/CodeGen/TargetLowering.h"
23 #include "llvm/CodeGen/TargetSubtargetInfo.h"
24 #include "llvm/FileCheck/FileCheck.h"
25 #include "llvm/InitializePasses.h"
26 #include "llvm/MC/TargetRegistry.h"
27 #include "llvm/Support/SourceMgr.h"
28 #include "llvm/Support/TargetSelect.h"
29 #include "llvm/Target/TargetMachine.h"
30 #include "llvm/Target/TargetOptions.h"
31 #include "gtest/gtest.h"
33 using namespace llvm;
34 using namespace MIPatternMatch;
36 static inline void initLLVM() {
37 InitializeAllTargets();
38 InitializeAllTargetMCs();
39 InitializeAllAsmPrinters();
40 InitializeAllAsmParsers();
42 PassRegistry *Registry = PassRegistry::getPassRegistry();
43 initializeCore(*Registry);
44 initializeCodeGen(*Registry);
47 // Define a printers to help debugging when things go wrong.
48 namespace llvm {
49 std::ostream &
50 operator<<(std::ostream &OS, const LLT Ty);
52 std::ostream &
53 operator<<(std::ostream &OS, const MachineFunction &MF);
56 static std::unique_ptr<Module> parseMIR(LLVMContext &Context,
57 std::unique_ptr<MIRParser> &MIR,
58 const TargetMachine &TM,
59 StringRef MIRCode, const char *FuncName,
60 MachineModuleInfo &MMI) {
61 SMDiagnostic Diagnostic;
62 std::unique_ptr<MemoryBuffer> MBuffer = MemoryBuffer::getMemBuffer(MIRCode);
63 MIR = createMIRParser(std::move(MBuffer), Context);
64 if (!MIR)
65 return nullptr;
67 std::unique_ptr<Module> M = MIR->parseIRModule();
68 if (!M)
69 return nullptr;
71 M->setDataLayout(TM.createDataLayout());
73 if (MIR->parseMachineFunctions(*M, MMI))
74 return nullptr;
76 return M;
78 static std::pair<std::unique_ptr<Module>, std::unique_ptr<MachineModuleInfo>>
79 createDummyModule(LLVMContext &Context, const LLVMTargetMachine &TM,
80 StringRef MIRString, const char *FuncName) {
81 std::unique_ptr<MIRParser> MIR;
82 auto MMI = std::make_unique<MachineModuleInfo>(&TM);
83 std::unique_ptr<Module> M =
84 parseMIR(Context, MIR, TM, MIRString, FuncName, *MMI);
85 return make_pair(std::move(M), std::move(MMI));
88 static MachineFunction *getMFFromMMI(const Module *M,
89 const MachineModuleInfo *MMI) {
90 Function *F = M->getFunction("func");
91 auto *MF = MMI->getMachineFunction(*F);
92 return MF;
95 static void collectCopies(SmallVectorImpl<Register> &Copies,
96 MachineFunction *MF) {
97 for (auto &MBB : *MF)
98 for (MachineInstr &MI : MBB) {
99 if (MI.getOpcode() == TargetOpcode::COPY)
100 Copies.push_back(MI.getOperand(0).getReg());
104 class GISelMITest : public ::testing::Test {
105 protected:
106 GISelMITest() : ::testing::Test() {}
108 /// Prepare a target specific LLVMTargetMachine.
109 virtual std::unique_ptr<LLVMTargetMachine> createTargetMachine() const = 0;
111 /// Get the stub sample MIR test function.
112 virtual void getTargetTestModuleString(SmallString<512> &S,
113 StringRef MIRFunc) const = 0;
115 void setUp(StringRef ExtraAssembly = "") {
116 TM = createTargetMachine();
117 if (!TM)
118 return;
120 SmallString<512> MIRString;
121 getTargetTestModuleString(MIRString, ExtraAssembly);
123 ModuleMMIPair = createDummyModule(Context, *TM, MIRString, "func");
124 MF = getMFFromMMI(ModuleMMIPair.first.get(), ModuleMMIPair.second.get());
125 collectCopies(Copies, MF);
126 EntryMBB = &*MF->begin();
127 B.setMF(*MF);
128 MRI = &MF->getRegInfo();
129 B.setInsertPt(*EntryMBB, EntryMBB->end());
132 LLVMContext Context;
133 std::unique_ptr<LLVMTargetMachine> TM;
134 MachineFunction *MF;
135 std::pair<std::unique_ptr<Module>, std::unique_ptr<MachineModuleInfo>>
136 ModuleMMIPair;
137 SmallVector<Register, 4> Copies;
138 MachineBasicBlock *EntryMBB;
139 MachineIRBuilder B;
140 MachineRegisterInfo *MRI;
143 class AArch64GISelMITest : public GISelMITest {
144 std::unique_ptr<LLVMTargetMachine> createTargetMachine() const override;
145 void getTargetTestModuleString(SmallString<512> &S,
146 StringRef MIRFunc) const override;
149 class AMDGPUGISelMITest : public GISelMITest {
150 std::unique_ptr<LLVMTargetMachine> createTargetMachine() const override;
151 void getTargetTestModuleString(SmallString<512> &S,
152 StringRef MIRFunc) const override;
155 #define DefineLegalizerInfo(Name, SettingUpActionsBlock) \
156 class Name##Info : public LegalizerInfo { \
157 public: \
158 Name##Info(const TargetSubtargetInfo &ST) { \
159 using namespace TargetOpcode; \
160 const LLT s8 = LLT::scalar(8); \
161 (void)s8; \
162 const LLT s16 = LLT::scalar(16); \
163 (void)s16; \
164 const LLT s32 = LLT::scalar(32); \
165 (void)s32; \
166 const LLT s64 = LLT::scalar(64); \
167 (void)s64; \
168 const LLT s128 = LLT::scalar(128); \
169 (void)s128; \
170 do \
171 SettingUpActionsBlock while (0); \
172 getLegacyLegalizerInfo().computeTables(); \
173 verify(*ST.getInstrInfo()); \
177 static inline bool CheckMachineFunction(const MachineFunction &MF,
178 StringRef CheckStr) {
179 SmallString<512> Msg;
180 raw_svector_ostream OS(Msg);
181 MF.print(OS);
182 auto OutputBuf = MemoryBuffer::getMemBuffer(Msg, "Output", false);
183 auto CheckBuf = MemoryBuffer::getMemBuffer(CheckStr, "");
184 SmallString<4096> CheckFileBuffer;
185 FileCheckRequest Req;
186 FileCheck FC(Req);
187 StringRef CheckFileText =
188 FC.CanonicalizeFile(*CheckBuf.get(), CheckFileBuffer);
189 SourceMgr SM;
190 SM.AddNewSourceBuffer(MemoryBuffer::getMemBuffer(CheckFileText, "CheckFile"),
191 SMLoc());
192 Regex PrefixRE = FC.buildCheckPrefixRegex();
193 if (FC.readCheckFile(SM, CheckFileText, PrefixRE))
194 return false;
196 auto OutBuffer = OutputBuf->getBuffer();
197 SM.AddNewSourceBuffer(std::move(OutputBuf), SMLoc());
198 return FC.checkInput(SM, OutBuffer);
200 #endif