1 //===- PatternMatchTest.cpp -----------------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "llvm/CodeGen/GlobalISel/ConstantFoldingMIRBuilder.h"
10 #include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
11 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
12 #include "llvm/CodeGen/GlobalISel/Utils.h"
13 #include "llvm/CodeGen/MIRParser/MIRParser.h"
14 #include "llvm/CodeGen/MachineFunction.h"
15 #include "llvm/CodeGen/MachineModuleInfo.h"
16 #include "llvm/CodeGen/TargetFrameLowering.h"
17 #include "llvm/CodeGen/TargetInstrInfo.h"
18 #include "llvm/CodeGen/TargetLowering.h"
19 #include "llvm/CodeGen/TargetSubtargetInfo.h"
20 #include "llvm/Support/SourceMgr.h"
21 #include "llvm/Support/TargetRegistry.h"
22 #include "llvm/Support/TargetSelect.h"
23 #include "llvm/Target/TargetMachine.h"
24 #include "llvm/Target/TargetOptions.h"
25 #include "gtest/gtest.h"
28 using namespace MIPatternMatch
;
33 InitializeAllTargets();
34 InitializeAllTargetMCs();
35 InitializeAllAsmPrinters();
36 InitializeAllAsmParsers();
38 PassRegistry
*Registry
= PassRegistry::getPassRegistry();
39 initializeCore(*Registry
);
40 initializeCodeGen(*Registry
);
43 /// Create a TargetMachine. As we lack a dedicated always available target for
44 /// unittests, we go for "AArch64".
45 std::unique_ptr
<LLVMTargetMachine
> createTargetMachine() {
46 Triple
TargetTriple("aarch64--");
48 const Target
*T
= TargetRegistry::lookupTarget("", TargetTriple
, Error
);
52 TargetOptions Options
;
53 return std::unique_ptr
<LLVMTargetMachine
>(static_cast<LLVMTargetMachine
*>(
54 T
->createTargetMachine("AArch64", "", "", Options
, None
, None
,
55 CodeGenOpt::Aggressive
)));
58 std::unique_ptr
<Module
> parseMIR(LLVMContext
&Context
,
59 std::unique_ptr
<MIRParser
> &MIR
,
60 const TargetMachine
&TM
, StringRef MIRCode
,
61 const char *FuncName
, MachineModuleInfo
&MMI
) {
62 SMDiagnostic Diagnostic
;
63 std::unique_ptr
<MemoryBuffer
> MBuffer
= MemoryBuffer::getMemBuffer(MIRCode
);
64 MIR
= createMIRParser(std::move(MBuffer
), Context
);
68 std::unique_ptr
<Module
> M
= MIR
->parseIRModule();
72 M
->setDataLayout(TM
.createDataLayout());
74 if (MIR
->parseMachineFunctions(*M
, MMI
))
80 std::pair
<std::unique_ptr
<Module
>, std::unique_ptr
<MachineModuleInfo
>>
81 createDummyModule(LLVMContext
&Context
, const LLVMTargetMachine
&TM
,
84 StringRef MIRString
= (Twine(R
"MIR(
98 )MIR") + Twine(MIRFunc
) + Twine("...\n"))
99 .toNullTerminatedStringRef(S
);
100 std::unique_ptr
<MIRParser
> MIR
;
101 auto MMI
= make_unique
<MachineModuleInfo
>(&TM
);
102 std::unique_ptr
<Module
> M
=
103 parseMIR(Context
, MIR
, TM
, MIRString
, "func", *MMI
);
104 return make_pair(std::move(M
), std::move(MMI
));
107 static MachineFunction
*getMFFromMMI(const Module
*M
,
108 const MachineModuleInfo
*MMI
) {
109 Function
*F
= M
->getFunction("func");
110 auto *MF
= MMI
->getMachineFunction(*F
);
114 static void collectCopies(SmallVectorImpl
<unsigned> &Copies
,
115 MachineFunction
*MF
) {
116 for (auto &MBB
: *MF
)
117 for (MachineInstr
&MI
: MBB
) {
118 if (MI
.getOpcode() == TargetOpcode::COPY
)
119 Copies
.push_back(MI
.getOperand(0).getReg());
123 TEST(PatternMatchInstr
, MatchIntConstant
) {
125 std::unique_ptr
<LLVMTargetMachine
> TM
= createTargetMachine();
128 auto ModuleMMIPair
= createDummyModule(Context
, *TM
, "");
129 MachineFunction
*MF
=
130 getMFFromMMI(ModuleMMIPair
.first
.get(), ModuleMMIPair
.second
.get());
131 SmallVector
<unsigned, 4> Copies
;
132 collectCopies(Copies
, MF
);
133 MachineBasicBlock
*EntryMBB
= &*MF
->begin();
134 MachineIRBuilder
B(*MF
);
135 MachineRegisterInfo
&MRI
= MF
->getRegInfo();
136 B
.setInsertPt(*EntryMBB
, EntryMBB
->end());
137 auto MIBCst
= B
.buildConstant(LLT::scalar(64), 42);
139 bool match
= mi_match(MIBCst
->getOperand(0).getReg(), MRI
, m_ICst(Cst
));
144 TEST(PatternMatchInstr
, MatchBinaryOp
) {
146 std::unique_ptr
<LLVMTargetMachine
> TM
= createTargetMachine();
149 auto ModuleMMIPair
= createDummyModule(Context
, *TM
, "");
150 MachineFunction
*MF
=
151 getMFFromMMI(ModuleMMIPair
.first
.get(), ModuleMMIPair
.second
.get());
152 SmallVector
<unsigned, 4> Copies
;
153 collectCopies(Copies
, MF
);
154 MachineBasicBlock
*EntryMBB
= &*MF
->begin();
155 MachineIRBuilder
B(*MF
);
156 MachineRegisterInfo
&MRI
= MF
->getRegInfo();
157 B
.setInsertPt(*EntryMBB
, EntryMBB
->end());
158 LLT s64
= LLT::scalar(64);
159 auto MIBAdd
= B
.buildAdd(s64
, Copies
[0], Copies
[1]);
160 // Test case for no bind.
162 mi_match(MIBAdd
->getOperand(0).getReg(), MRI
, m_GAdd(m_Reg(), m_Reg()));
164 unsigned Src0
, Src1
, Src2
;
165 match
= mi_match(MIBAdd
->getOperand(0).getReg(), MRI
,
166 m_GAdd(m_Reg(Src0
), m_Reg(Src1
)));
168 EXPECT_EQ(Src0
, Copies
[0]);
169 EXPECT_EQ(Src1
, Copies
[1]);
171 // Build MUL(ADD %0, %1), %2
172 auto MIBMul
= B
.buildMul(s64
, MIBAdd
, Copies
[2]);
175 match
= mi_match(MIBMul
->getOperand(0).getReg(), MRI
,
176 m_GMul(m_Reg(Src0
), m_Reg(Src1
)));
178 EXPECT_EQ(Src0
, MIBAdd
->getOperand(0).getReg());
179 EXPECT_EQ(Src1
, Copies
[2]);
181 // Try to match MUL(ADD)
182 match
= mi_match(MIBMul
->getOperand(0).getReg(), MRI
,
183 m_GMul(m_GAdd(m_Reg(Src0
), m_Reg(Src1
)), m_Reg(Src2
)));
185 EXPECT_EQ(Src0
, Copies
[0]);
186 EXPECT_EQ(Src1
, Copies
[1]);
187 EXPECT_EQ(Src2
, Copies
[2]);
189 // Test Commutativity.
190 auto MIBMul2
= B
.buildMul(s64
, Copies
[0], B
.buildConstant(s64
, 42));
191 // Try to match MUL(Cst, Reg) on src of MUL(Reg, Cst) to validate
194 match
= mi_match(MIBMul2
->getOperand(0).getReg(), MRI
,
195 m_GMul(m_ICst(Cst
), m_Reg(Src0
)));
198 EXPECT_EQ(Src0
, Copies
[0]);
200 // Make sure commutative doesn't work with something like SUB.
201 auto MIBSub
= B
.buildSub(s64
, Copies
[0], B
.buildConstant(s64
, 42));
202 match
= mi_match(MIBSub
->getOperand(0).getReg(), MRI
,
203 m_GSub(m_ICst(Cst
), m_Reg(Src0
)));
206 auto MIBFMul
= B
.buildInstr(TargetOpcode::G_FMUL
, {s64
},
207 {Copies
[0], B
.buildConstant(s64
, 42)});
208 // Match and test commutativity for FMUL.
209 match
= mi_match(MIBFMul
->getOperand(0).getReg(), MRI
,
210 m_GFMul(m_ICst(Cst
), m_Reg(Src0
)));
213 EXPECT_EQ(Src0
, Copies
[0]);
216 auto MIBFSub
= B
.buildInstr(TargetOpcode::G_FSUB
, {s64
},
217 {Copies
[0], B
.buildConstant(s64
, 42)});
218 match
= mi_match(MIBFSub
->getOperand(0).getReg(), MRI
,
219 m_GFSub(m_Reg(Src0
), m_Reg()));
221 EXPECT_EQ(Src0
, Copies
[0]);
224 auto MIBAnd
= B
.buildAnd(s64
, Copies
[0], Copies
[1]);
226 match
= mi_match(MIBAnd
->getOperand(0).getReg(), MRI
,
227 m_GAnd(m_Reg(Src0
), m_Reg(Src1
)));
229 EXPECT_EQ(Src0
, Copies
[0]);
230 EXPECT_EQ(Src1
, Copies
[1]);
233 auto MIBOr
= B
.buildOr(s64
, Copies
[0], Copies
[1]);
235 match
= mi_match(MIBOr
->getOperand(0).getReg(), MRI
,
236 m_GOr(m_Reg(Src0
), m_Reg(Src1
)));
238 EXPECT_EQ(Src0
, Copies
[0]);
239 EXPECT_EQ(Src1
, Copies
[1]);
241 // Try to use the FoldableInstructionsBuilder to build binary ops.
242 ConstantFoldingMIRBuilder
CFB(B
.getState());
243 LLT s32
= LLT::scalar(32);
245 CFB
.buildAdd(s32
, CFB
.buildConstant(s32
, 0), CFB
.buildConstant(s32
, 1));
246 // This should be a constant now.
247 match
= mi_match(MIBCAdd
->getOperand(0).getReg(), MRI
, m_ICst(Cst
));
251 CFB
.buildInstr(TargetOpcode::G_ADD
, {s32
},
252 {CFB
.buildConstant(s32
, 0), CFB
.buildConstant(s32
, 1)});
253 // This should be a constant now.
254 match
= mi_match(MIBCAdd1
->getOperand(0).getReg(), MRI
, m_ICst(Cst
));
258 // Try one of the other constructors of MachineIRBuilder to make sure it's
260 ConstantFoldingMIRBuilder
CFB1(*MF
);
261 CFB1
.setInsertPt(*EntryMBB
, EntryMBB
->end());
263 CFB1
.buildInstr(TargetOpcode::G_SUB
, {s32
},
264 {CFB1
.buildConstant(s32
, 1), CFB1
.buildConstant(s32
, 1)});
265 // This should be a constant now.
266 match
= mi_match(MIBCSub
->getOperand(0).getReg(), MRI
, m_ICst(Cst
));
271 TEST(PatternMatchInstr
, MatchFPUnaryOp
) {
273 std::unique_ptr
<LLVMTargetMachine
> TM
= createTargetMachine();
276 auto ModuleMMIPair
= createDummyModule(Context
, *TM
, "");
277 MachineFunction
*MF
=
278 getMFFromMMI(ModuleMMIPair
.first
.get(), ModuleMMIPair
.second
.get());
279 SmallVector
<unsigned, 4> Copies
;
280 collectCopies(Copies
, MF
);
281 MachineBasicBlock
*EntryMBB
= &*MF
->begin();
282 MachineIRBuilder
B(*MF
);
283 MachineRegisterInfo
&MRI
= MF
->getRegInfo();
284 B
.setInsertPt(*EntryMBB
, EntryMBB
->end());
286 // Truncate s64 to s32.
287 LLT s32
= LLT::scalar(32);
288 auto Copy0s32
= B
.buildFPTrunc(s32
, Copies
[0]);
291 auto MIBFabs
= B
.buildInstr(TargetOpcode::G_FABS
, {s32
}, {Copy0s32
});
292 bool match
= mi_match(MIBFabs
->getOperand(0).getReg(), MRI
, m_GFabs(m_Reg()));
296 auto MIBFNeg
= B
.buildInstr(TargetOpcode::G_FNEG
, {s32
}, {Copy0s32
});
297 match
= mi_match(MIBFNeg
->getOperand(0).getReg(), MRI
, m_GFNeg(m_Reg(Src
)));
299 EXPECT_EQ(Src
, Copy0s32
->getOperand(0).getReg());
301 match
= mi_match(MIBFabs
->getOperand(0).getReg(), MRI
, m_GFabs(m_Reg(Src
)));
303 EXPECT_EQ(Src
, Copy0s32
->getOperand(0).getReg());
305 // Build and match FConstant.
306 auto MIBFCst
= B
.buildFConstant(s32
, .5);
307 const ConstantFP
*TmpFP
{};
308 match
= mi_match(MIBFCst
->getOperand(0).getReg(), MRI
, m_GFCst(TmpFP
));
311 APFloat
APF((float).5);
312 auto *CFP
= ConstantFP::get(Context
, APF
);
313 EXPECT_EQ(CFP
, TmpFP
);
315 // Build double float.
316 LLT s64
= LLT::scalar(64);
317 auto MIBFCst64
= B
.buildFConstant(s64
, .5);
318 const ConstantFP
*TmpFP64
{};
319 match
= mi_match(MIBFCst64
->getOperand(0).getReg(), MRI
, m_GFCst(TmpFP64
));
321 EXPECT_TRUE(TmpFP64
);
323 auto CFP64
= ConstantFP::get(Context
, APF64
);
324 EXPECT_EQ(CFP64
, TmpFP64
);
325 EXPECT_NE(TmpFP64
, TmpFP
);
328 LLT s16
= LLT::scalar(16);
329 auto MIBFCst16
= B
.buildFConstant(s16
, .5);
330 const ConstantFP
*TmpFP16
{};
331 match
= mi_match(MIBFCst16
->getOperand(0).getReg(), MRI
, m_GFCst(TmpFP16
));
333 EXPECT_TRUE(TmpFP16
);
336 APF16
.convert(APFloat::IEEEhalf(), APFloat::rmNearestTiesToEven
, &Ignored
);
337 auto CFP16
= ConstantFP::get(Context
, APF16
);
338 EXPECT_EQ(TmpFP16
, CFP16
);
339 EXPECT_NE(TmpFP16
, TmpFP
);
342 TEST(PatternMatchInstr
, MatchExtendsTrunc
) {
344 std::unique_ptr
<LLVMTargetMachine
> TM
= createTargetMachine();
347 auto ModuleMMIPair
= createDummyModule(Context
, *TM
, "");
348 MachineFunction
*MF
=
349 getMFFromMMI(ModuleMMIPair
.first
.get(), ModuleMMIPair
.second
.get());
350 SmallVector
<unsigned, 4> Copies
;
351 collectCopies(Copies
, MF
);
352 MachineBasicBlock
*EntryMBB
= &*MF
->begin();
353 MachineIRBuilder
B(*MF
);
354 MachineRegisterInfo
&MRI
= MF
->getRegInfo();
355 B
.setInsertPt(*EntryMBB
, EntryMBB
->end());
356 LLT s64
= LLT::scalar(64);
357 LLT s32
= LLT::scalar(32);
359 auto MIBTrunc
= B
.buildTrunc(s32
, Copies
[0]);
360 auto MIBAExt
= B
.buildAnyExt(s64
, MIBTrunc
);
361 auto MIBZExt
= B
.buildZExt(s64
, MIBTrunc
);
362 auto MIBSExt
= B
.buildSExt(s64
, MIBTrunc
);
365 mi_match(MIBTrunc
->getOperand(0).getReg(), MRI
, m_GTrunc(m_Reg(Src0
)));
367 EXPECT_EQ(Src0
, Copies
[0]);
369 mi_match(MIBAExt
->getOperand(0).getReg(), MRI
, m_GAnyExt(m_Reg(Src0
)));
371 EXPECT_EQ(Src0
, MIBTrunc
->getOperand(0).getReg());
373 match
= mi_match(MIBSExt
->getOperand(0).getReg(), MRI
, m_GSExt(m_Reg(Src0
)));
375 EXPECT_EQ(Src0
, MIBTrunc
->getOperand(0).getReg());
377 match
= mi_match(MIBZExt
->getOperand(0).getReg(), MRI
, m_GZExt(m_Reg(Src0
)));
379 EXPECT_EQ(Src0
, MIBTrunc
->getOperand(0).getReg());
381 // Match ext(trunc src)
382 match
= mi_match(MIBAExt
->getOperand(0).getReg(), MRI
,
383 m_GAnyExt(m_GTrunc(m_Reg(Src0
))));
385 EXPECT_EQ(Src0
, Copies
[0]);
387 match
= mi_match(MIBSExt
->getOperand(0).getReg(), MRI
,
388 m_GSExt(m_GTrunc(m_Reg(Src0
))));
390 EXPECT_EQ(Src0
, Copies
[0]);
392 match
= mi_match(MIBZExt
->getOperand(0).getReg(), MRI
,
393 m_GZExt(m_GTrunc(m_Reg(Src0
))));
395 EXPECT_EQ(Src0
, Copies
[0]);
398 TEST(PatternMatchInstr
, MatchSpecificType
) {
400 std::unique_ptr
<LLVMTargetMachine
> TM
= createTargetMachine();
403 auto ModuleMMIPair
= createDummyModule(Context
, *TM
, "");
404 MachineFunction
*MF
=
405 getMFFromMMI(ModuleMMIPair
.first
.get(), ModuleMMIPair
.second
.get());
406 SmallVector
<unsigned, 4> Copies
;
407 collectCopies(Copies
, MF
);
408 MachineBasicBlock
*EntryMBB
= &*MF
->begin();
409 MachineIRBuilder
B(*MF
);
410 MachineRegisterInfo
&MRI
= MF
->getRegInfo();
411 B
.setInsertPt(*EntryMBB
, EntryMBB
->end());
413 // Try to match a 64bit add.
414 LLT s64
= LLT::scalar(64);
415 LLT s32
= LLT::scalar(32);
416 auto MIBAdd
= B
.buildAdd(s64
, Copies
[0], Copies
[1]);
417 EXPECT_FALSE(mi_match(MIBAdd
->getOperand(0).getReg(), MRI
,
418 m_GAdd(m_SpecificType(s32
), m_Reg())));
419 EXPECT_TRUE(mi_match(MIBAdd
->getOperand(0).getReg(), MRI
,
420 m_GAdd(m_SpecificType(s64
), m_Reg())));
422 // Try to match the destination type of a bitcast.
423 LLT v2s32
= LLT::vector(2, 32);
424 auto MIBCast
= B
.buildCast(v2s32
, Copies
[0]);
426 mi_match(MIBCast
->getOperand(0).getReg(), MRI
, m_GBitcast(m_Reg())));
428 mi_match(MIBCast
->getOperand(0).getReg(), MRI
, m_SpecificType(v2s32
)));
430 mi_match(MIBCast
->getOperand(1).getReg(), MRI
, m_SpecificType(s64
)));
432 // Build a PTRToInt and INTTOPTR and match and test them.
433 LLT PtrTy
= LLT::pointer(0, 64);
434 auto MIBIntToPtr
= B
.buildCast(PtrTy
, Copies
[0]);
435 auto MIBPtrToInt
= B
.buildCast(s64
, MIBIntToPtr
);
438 // match the ptrtoint(inttoptr reg)
439 bool match
= mi_match(MIBPtrToInt
->getOperand(0).getReg(), MRI
,
440 m_GPtrToInt(m_GIntToPtr(m_Reg(Src0
))));
442 EXPECT_EQ(Src0
, Copies
[0]);
445 TEST(PatternMatchInstr
, MatchCombinators
) {
447 std::unique_ptr
<LLVMTargetMachine
> TM
= createTargetMachine();
450 auto ModuleMMIPair
= createDummyModule(Context
, *TM
, "");
451 MachineFunction
*MF
=
452 getMFFromMMI(ModuleMMIPair
.first
.get(), ModuleMMIPair
.second
.get());
453 SmallVector
<unsigned, 4> Copies
;
454 collectCopies(Copies
, MF
);
455 MachineBasicBlock
*EntryMBB
= &*MF
->begin();
456 MachineIRBuilder
B(*MF
);
457 MachineRegisterInfo
&MRI
= MF
->getRegInfo();
458 B
.setInsertPt(*EntryMBB
, EntryMBB
->end());
459 LLT s64
= LLT::scalar(64);
460 LLT s32
= LLT::scalar(32);
461 auto MIBAdd
= B
.buildAdd(s64
, Copies
[0], Copies
[1]);
464 mi_match(MIBAdd
->getOperand(0).getReg(), MRI
,
465 m_all_of(m_SpecificType(s64
), m_GAdd(m_Reg(Src0
), m_Reg(Src1
))));
467 EXPECT_EQ(Src0
, Copies
[0]);
468 EXPECT_EQ(Src1
, Copies
[1]);
469 // Check for s32 (which should fail).
471 mi_match(MIBAdd
->getOperand(0).getReg(), MRI
,
472 m_all_of(m_SpecificType(s32
), m_GAdd(m_Reg(Src0
), m_Reg(Src1
))));
475 mi_match(MIBAdd
->getOperand(0).getReg(), MRI
,
476 m_any_of(m_SpecificType(s32
), m_GAdd(m_Reg(Src0
), m_Reg(Src1
))));
478 EXPECT_EQ(Src0
, Copies
[0]);
479 EXPECT_EQ(Src1
, Copies
[1]);
481 // Match a case where none of the predicates hold true.
483 MIBAdd
->getOperand(0).getReg(), MRI
,
484 m_any_of(m_SpecificType(LLT::scalar(16)), m_GSub(m_Reg(), m_Reg())));
489 int main(int argc
, char **argv
) {
490 ::testing::InitGoogleTest(&argc
, argv
);
492 return RUN_ALL_TESTS();