Fix test failures introduced by PR #113697 (#116941)
[llvm-project.git] / llvm / unittests / CodeGen / GlobalISel / GISelMITest.h
blob4b82f572150e50fadc0463388a1b463721426262
1 //===- GISelMITest.h --------------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 #ifndef LLVM_UNITTEST_CODEGEN_GLOBALISEL_GISELMI_H
9 #define LLVM_UNITTEST_CODEGEN_GLOBALISEL_GISELMI_H
11 #include "llvm/CodeGen/GlobalISel/GISelChangeObserver.h"
12 #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
13 #include "llvm/CodeGen/GlobalISel/LegalizerInfo.h"
14 #include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
15 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
16 #include "llvm/CodeGen/GlobalISel/Utils.h"
17 #include "llvm/CodeGen/MIRParser/MIRParser.h"
18 #include "llvm/CodeGen/MachineFunction.h"
19 #include "llvm/CodeGen/MachineModuleInfo.h"
20 #include "llvm/CodeGen/TargetFrameLowering.h"
21 #include "llvm/CodeGen/TargetInstrInfo.h"
22 #include "llvm/CodeGen/TargetLowering.h"
23 #include "llvm/CodeGen/TargetSubtargetInfo.h"
24 #include "llvm/FileCheck/FileCheck.h"
25 #include "llvm/InitializePasses.h"
26 #include "llvm/MC/TargetRegistry.h"
27 #include "llvm/Support/SourceMgr.h"
28 #include "llvm/Support/TargetSelect.h"
29 #include "llvm/Target/TargetMachine.h"
30 #include "llvm/Target/TargetOptions.h"
31 #include "gtest/gtest.h"
33 using namespace llvm;
34 using namespace MIPatternMatch;
36 static inline void initLLVM() {
37 InitializeAllTargets();
38 InitializeAllTargetMCs();
39 InitializeAllAsmPrinters();
40 InitializeAllAsmParsers();
42 PassRegistry *Registry = PassRegistry::getPassRegistry();
43 initializeCore(*Registry);
44 initializeCodeGen(*Registry);
47 // Define a printers to help debugging when things go wrong.
48 namespace llvm {
49 std::ostream &
50 operator<<(std::ostream &OS, const LLT Ty);
52 std::ostream &
53 operator<<(std::ostream &OS, const MachineFunction &MF);
56 static std::unique_ptr<Module>
57 parseMIR(LLVMContext &Context, std::unique_ptr<MIRParser> &MIR,
58 const TargetMachine &TM, StringRef MIRCode, MachineModuleInfo &MMI) {
59 SMDiagnostic Diagnostic;
60 std::unique_ptr<MemoryBuffer> MBuffer = MemoryBuffer::getMemBuffer(MIRCode);
61 MIR = createMIRParser(std::move(MBuffer), Context);
62 if (!MIR)
63 return nullptr;
65 std::unique_ptr<Module> M = MIR->parseIRModule();
66 if (!M)
67 return nullptr;
69 M->setDataLayout(TM.createDataLayout());
71 if (MIR->parseMachineFunctions(*M, MMI))
72 return nullptr;
74 return M;
76 static std::pair<std::unique_ptr<Module>, std::unique_ptr<MachineModuleInfo>>
77 createDummyModule(LLVMContext &Context, const TargetMachine &TM,
78 StringRef MIRString, const char *FuncName) {
79 std::unique_ptr<MIRParser> MIR;
80 auto MMI = std::make_unique<MachineModuleInfo>(&TM);
81 std::unique_ptr<Module> M = parseMIR(Context, MIR, TM, MIRString, *MMI);
82 return make_pair(std::move(M), std::move(MMI));
85 static MachineFunction *getMFFromMMI(const Module *M,
86 const MachineModuleInfo *MMI) {
87 Function *F = M->getFunction("func");
88 auto *MF = MMI->getMachineFunction(*F);
89 return MF;
92 static void collectCopies(SmallVectorImpl<Register> &Copies,
93 MachineFunction *MF) {
94 for (auto &MBB : *MF)
95 for (MachineInstr &MI : MBB) {
96 if (MI.getOpcode() == TargetOpcode::COPY)
97 Copies.push_back(MI.getOperand(0).getReg());
101 class GISelMITest : public ::testing::Test {
102 protected:
103 GISelMITest() : ::testing::Test() {}
105 /// Prepare a target specific TargetMachine.
106 virtual std::unique_ptr<TargetMachine> createTargetMachine() const = 0;
108 /// Get the stub sample MIR test function.
109 virtual void getTargetTestModuleString(SmallString<512> &S,
110 StringRef MIRFunc) const = 0;
112 void setUp(StringRef ExtraAssembly = "") {
113 TM = createTargetMachine();
114 if (!TM)
115 return;
117 SmallString<512> MIRString;
118 getTargetTestModuleString(MIRString, ExtraAssembly);
120 ModuleMMIPair = createDummyModule(Context, *TM, MIRString, "func");
121 MF = getMFFromMMI(ModuleMMIPair.first.get(), ModuleMMIPair.second.get());
122 collectCopies(Copies, MF);
123 EntryMBB = &*MF->begin();
124 B.setMF(*MF);
125 MRI = &MF->getRegInfo();
126 B.setInsertPt(*EntryMBB, EntryMBB->end());
129 LLVMContext Context;
130 std::unique_ptr<TargetMachine> TM;
131 MachineFunction *MF;
132 std::pair<std::unique_ptr<Module>, std::unique_ptr<MachineModuleInfo>>
133 ModuleMMIPair;
134 SmallVector<Register, 4> Copies;
135 MachineBasicBlock *EntryMBB;
136 MachineIRBuilder B;
137 MachineRegisterInfo *MRI;
140 class AArch64GISelMITest : public GISelMITest {
141 std::unique_ptr<TargetMachine> createTargetMachine() const override;
142 void getTargetTestModuleString(SmallString<512> &S,
143 StringRef MIRFunc) const override;
146 class AMDGPUGISelMITest : public GISelMITest {
147 std::unique_ptr<TargetMachine> createTargetMachine() const override;
148 void getTargetTestModuleString(SmallString<512> &S,
149 StringRef MIRFunc) const override;
152 #define DefineLegalizerInfo(Name, SettingUpActionsBlock) \
153 class Name##Info : public LegalizerInfo { \
154 public: \
155 Name##Info(const TargetSubtargetInfo &ST) { \
156 using namespace TargetOpcode; \
157 const LLT s8 = LLT::scalar(8); \
158 (void)s8; \
159 const LLT s16 = LLT::scalar(16); \
160 (void)s16; \
161 const LLT s32 = LLT::scalar(32); \
162 (void)s32; \
163 const LLT s64 = LLT::scalar(64); \
164 (void)s64; \
165 const LLT s128 = LLT::scalar(128); \
166 (void)s128; \
167 do \
168 SettingUpActionsBlock while (0); \
169 getLegacyLegalizerInfo().computeTables(); \
170 verify(*ST.getInstrInfo()); \
174 static inline bool CheckMachineFunction(const MachineFunction &MF,
175 StringRef CheckStr) {
176 SmallString<512> Msg;
177 raw_svector_ostream OS(Msg);
178 MF.print(OS);
179 auto OutputBuf = MemoryBuffer::getMemBuffer(Msg, "Output", false);
180 auto CheckBuf = MemoryBuffer::getMemBuffer(CheckStr, "");
181 SmallString<4096> CheckFileBuffer;
182 FileCheckRequest Req;
183 FileCheck FC(Req);
184 StringRef CheckFileText = FC.CanonicalizeFile(*CheckBuf, CheckFileBuffer);
185 SourceMgr SM;
186 SM.AddNewSourceBuffer(MemoryBuffer::getMemBuffer(CheckFileText, "CheckFile"),
187 SMLoc());
188 if (FC.readCheckFile(SM, CheckFileText))
189 return false;
191 auto OutBuffer = OutputBuf->getBuffer();
192 SM.AddNewSourceBuffer(std::move(OutputBuf), SMLoc());
193 return FC.checkInput(SM, OutBuffer);
195 #endif