1 //===- GISelMITest.h --------------------------------------------*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
8 #ifndef LLVM_UNITTEST_CODEGEN_GLOBALISEL_GISELMI_H
9 #define LLVM_UNITTEST_CODEGEN_GLOBALISEL_GISELMI_H
11 #include "llvm/CodeGen/GlobalISel/GISelChangeObserver.h"
12 #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
13 #include "llvm/CodeGen/GlobalISel/LegalizerInfo.h"
14 #include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
15 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
16 #include "llvm/CodeGen/GlobalISel/Utils.h"
17 #include "llvm/CodeGen/MIRParser/MIRParser.h"
18 #include "llvm/CodeGen/MachineFunction.h"
19 #include "llvm/CodeGen/MachineModuleInfo.h"
20 #include "llvm/CodeGen/TargetFrameLowering.h"
21 #include "llvm/CodeGen/TargetInstrInfo.h"
22 #include "llvm/CodeGen/TargetLowering.h"
23 #include "llvm/CodeGen/TargetSubtargetInfo.h"
24 #include "llvm/Support/FileCheck.h"
25 #include "llvm/Support/SourceMgr.h"
26 #include "llvm/Support/TargetRegistry.h"
27 #include "llvm/Support/TargetSelect.h"
28 #include "llvm/Target/TargetMachine.h"
29 #include "llvm/Target/TargetOptions.h"
30 #include "gtest/gtest.h"
33 using namespace MIPatternMatch
;
35 static inline void initLLVM() {
36 InitializeAllTargets();
37 InitializeAllTargetMCs();
38 InitializeAllAsmPrinters();
39 InitializeAllAsmParsers();
41 PassRegistry
*Registry
= PassRegistry::getPassRegistry();
42 initializeCore(*Registry
);
43 initializeCodeGen(*Registry
);
46 // Define a printers to help debugging when things go wrong.
49 operator<<(std::ostream
&OS
, const LLT Ty
);
52 operator<<(std::ostream
&OS
, const MachineFunction
&MF
);
55 /// Create a TargetMachine. As we lack a dedicated always available target for
56 /// unittests, we go for "AArch64".
57 static std::unique_ptr
<LLVMTargetMachine
> createTargetMachine() {
58 Triple
TargetTriple("aarch64--");
60 const Target
*T
= TargetRegistry::lookupTarget("", TargetTriple
, Error
);
64 TargetOptions Options
;
65 return std::unique_ptr
<LLVMTargetMachine
>(
66 static_cast<LLVMTargetMachine
*>(T
->createTargetMachine(
67 "AArch64", "", "", Options
, None
, None
, CodeGenOpt::Aggressive
)));
70 static std::unique_ptr
<Module
> parseMIR(LLVMContext
&Context
,
71 std::unique_ptr
<MIRParser
> &MIR
,
72 const TargetMachine
&TM
,
73 StringRef MIRCode
, const char *FuncName
,
74 MachineModuleInfo
&MMI
) {
75 SMDiagnostic Diagnostic
;
76 std::unique_ptr
<MemoryBuffer
> MBuffer
= MemoryBuffer::getMemBuffer(MIRCode
);
77 MIR
= createMIRParser(std::move(MBuffer
), Context
);
81 std::unique_ptr
<Module
> M
= MIR
->parseIRModule();
85 M
->setDataLayout(TM
.createDataLayout());
87 if (MIR
->parseMachineFunctions(*M
, MMI
))
93 static std::pair
<std::unique_ptr
<Module
>, std::unique_ptr
<MachineModuleInfo
>>
94 createDummyModule(LLVMContext
&Context
, const LLVMTargetMachine
&TM
,
97 StringRef MIRString
= (Twine(R
"MIR(
102 - { id: 0, class: _ }
103 - { id: 1, class: _ }
104 - { id: 2, class: _ }
105 - { id: 3, class: _ }
111 )MIR") + Twine(MIRFunc
) + Twine("...\n"))
112 .toNullTerminatedStringRef(S
);
113 std::unique_ptr
<MIRParser
> MIR
;
114 auto MMI
= std::make_unique
<MachineModuleInfo
>(&TM
);
115 std::unique_ptr
<Module
> M
=
116 parseMIR(Context
, MIR
, TM
, MIRString
, "func", *MMI
);
117 return make_pair(std::move(M
), std::move(MMI
));
120 static MachineFunction
*getMFFromMMI(const Module
*M
,
121 const MachineModuleInfo
*MMI
) {
122 Function
*F
= M
->getFunction("func");
123 auto *MF
= MMI
->getMachineFunction(*F
);
127 static void collectCopies(SmallVectorImpl
<Register
> &Copies
,
128 MachineFunction
*MF
) {
129 for (auto &MBB
: *MF
)
130 for (MachineInstr
&MI
: MBB
) {
131 if (MI
.getOpcode() == TargetOpcode::COPY
)
132 Copies
.push_back(MI
.getOperand(0).getReg());
136 class GISelMITest
: public ::testing::Test
{
138 GISelMITest() : ::testing::Test() {}
139 void setUp(StringRef ExtraAssembly
= "") {
140 TM
= createTargetMachine();
143 ModuleMMIPair
= createDummyModule(Context
, *TM
, ExtraAssembly
);
144 MF
= getMFFromMMI(ModuleMMIPair
.first
.get(), ModuleMMIPair
.second
.get());
145 collectCopies(Copies
, MF
);
146 EntryMBB
= &*MF
->begin();
148 MRI
= &MF
->getRegInfo();
149 B
.setInsertPt(*EntryMBB
, EntryMBB
->end());
152 std::unique_ptr
<LLVMTargetMachine
> TM
;
154 std::pair
<std::unique_ptr
<Module
>, std::unique_ptr
<MachineModuleInfo
>>
156 SmallVector
<Register
, 4> Copies
;
157 MachineBasicBlock
*EntryMBB
;
159 MachineRegisterInfo
*MRI
;
162 #define DefineLegalizerInfo(Name, SettingUpActionsBlock) \
163 class Name##Info : public LegalizerInfo { \
165 Name##Info(const TargetSubtargetInfo &ST) { \
166 using namespace TargetOpcode; \
167 const LLT s8 = LLT::scalar(8); \
169 const LLT s16 = LLT::scalar(16); \
171 const LLT s32 = LLT::scalar(32); \
173 const LLT s64 = LLT::scalar(64); \
176 SettingUpActionsBlock while (0); \
178 verify(*ST.getInstrInfo()); \
182 static inline bool CheckMachineFunction(const MachineFunction
&MF
,
183 StringRef CheckStr
) {
184 SmallString
<512> Msg
;
185 raw_svector_ostream
OS(Msg
);
187 auto OutputBuf
= MemoryBuffer::getMemBuffer(Msg
, "Output", false);
188 auto CheckBuf
= MemoryBuffer::getMemBuffer(CheckStr
, "");
189 SmallString
<4096> CheckFileBuffer
;
190 FileCheckRequest Req
;
192 StringRef CheckFileText
=
193 FC
.CanonicalizeFile(*CheckBuf
.get(), CheckFileBuffer
);
195 SM
.AddNewSourceBuffer(MemoryBuffer::getMemBuffer(CheckFileText
, "CheckFile"),
197 Regex PrefixRE
= FC
.buildCheckPrefixRegex();
198 std::vector
<FileCheckString
> CheckStrings
;
199 if (FC
.ReadCheckFile(SM
, CheckFileText
, PrefixRE
, CheckStrings
))
202 auto OutBuffer
= OutputBuf
->getBuffer();
203 SM
.AddNewSourceBuffer(std::move(OutputBuf
), SMLoc());
204 return FC
.CheckInput(SM
, OutBuffer
, CheckStrings
);