1 //===- PatternMatchTest.cpp -----------------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "llvm/CodeGen/GlobalISel/ConstantFoldingMIRBuilder.h"
10 #include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
11 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
12 #include "llvm/CodeGen/GlobalISel/Utils.h"
13 #include "llvm/CodeGen/MIRParser/MIRParser.h"
14 #include "llvm/CodeGen/MachineFunction.h"
15 #include "llvm/CodeGen/MachineModuleInfo.h"
16 #include "llvm/CodeGen/TargetFrameLowering.h"
17 #include "llvm/CodeGen/TargetInstrInfo.h"
18 #include "llvm/CodeGen/TargetLowering.h"
19 #include "llvm/CodeGen/TargetSubtargetInfo.h"
20 #include "llvm/Support/SourceMgr.h"
21 #include "llvm/Support/TargetRegistry.h"
22 #include "llvm/Support/TargetSelect.h"
23 #include "llvm/Target/TargetMachine.h"
24 #include "llvm/Target/TargetOptions.h"
25 #include "gtest/gtest.h"
28 using namespace MIPatternMatch
;
33 InitializeAllTargets();
34 InitializeAllTargetMCs();
35 InitializeAllAsmPrinters();
36 InitializeAllAsmParsers();
38 PassRegistry
*Registry
= PassRegistry::getPassRegistry();
39 initializeCore(*Registry
);
40 initializeCodeGen(*Registry
);
43 /// Create a TargetMachine. As we lack a dedicated always available target for
44 /// unittests, we go for "AArch64".
45 std::unique_ptr
<LLVMTargetMachine
> createTargetMachine() {
46 Triple
TargetTriple("aarch64--");
48 const Target
*T
= TargetRegistry::lookupTarget("", TargetTriple
, Error
);
52 TargetOptions Options
;
53 return std::unique_ptr
<LLVMTargetMachine
>(static_cast<LLVMTargetMachine
*>(
54 T
->createTargetMachine("AArch64", "", "", Options
, None
, None
,
55 CodeGenOpt::Aggressive
)));
58 std::unique_ptr
<Module
> parseMIR(LLVMContext
&Context
,
59 std::unique_ptr
<MIRParser
> &MIR
,
60 const TargetMachine
&TM
, StringRef MIRCode
,
61 const char *FuncName
, MachineModuleInfo
&MMI
) {
62 SMDiagnostic Diagnostic
;
63 std::unique_ptr
<MemoryBuffer
> MBuffer
= MemoryBuffer::getMemBuffer(MIRCode
);
64 MIR
= createMIRParser(std::move(MBuffer
), Context
);
68 std::unique_ptr
<Module
> M
= MIR
->parseIRModule();
72 M
->setDataLayout(TM
.createDataLayout());
74 if (MIR
->parseMachineFunctions(*M
, MMI
))
80 std::pair
<std::unique_ptr
<Module
>, std::unique_ptr
<MachineModuleInfo
>>
81 createDummyModule(LLVMContext
&Context
, const LLVMTargetMachine
&TM
,
84 StringRef MIRString
= (Twine(R
"MIR(
98 )MIR") + Twine(MIRFunc
) + Twine("...\n"))
99 .toNullTerminatedStringRef(S
);
100 std::unique_ptr
<MIRParser
> MIR
;
101 auto MMI
= std::make_unique
<MachineModuleInfo
>(&TM
);
102 std::unique_ptr
<Module
> M
=
103 parseMIR(Context
, MIR
, TM
, MIRString
, "func", *MMI
);
104 return make_pair(std::move(M
), std::move(MMI
));
107 static MachineFunction
*getMFFromMMI(const Module
*M
,
108 const MachineModuleInfo
*MMI
) {
109 Function
*F
= M
->getFunction("func");
110 auto *MF
= MMI
->getMachineFunction(*F
);
114 static void collectCopies(SmallVectorImpl
<Register
> &Copies
,
115 MachineFunction
*MF
) {
116 for (auto &MBB
: *MF
)
117 for (MachineInstr
&MI
: MBB
) {
118 if (MI
.getOpcode() == TargetOpcode::COPY
)
119 Copies
.push_back(MI
.getOperand(0).getReg());
123 TEST(PatternMatchInstr
, MatchIntConstant
) {
125 std::unique_ptr
<LLVMTargetMachine
> TM
= createTargetMachine();
128 auto ModuleMMIPair
= createDummyModule(Context
, *TM
, "");
129 MachineFunction
*MF
=
130 getMFFromMMI(ModuleMMIPair
.first
.get(), ModuleMMIPair
.second
.get());
131 SmallVector
<Register
, 4> Copies
;
132 collectCopies(Copies
, MF
);
133 MachineBasicBlock
*EntryMBB
= &*MF
->begin();
134 MachineIRBuilder
B(*MF
);
135 MachineRegisterInfo
&MRI
= MF
->getRegInfo();
136 B
.setInsertPt(*EntryMBB
, EntryMBB
->end());
137 auto MIBCst
= B
.buildConstant(LLT::scalar(64), 42);
139 bool match
= mi_match(MIBCst
->getOperand(0).getReg(), MRI
, m_ICst(Cst
));
144 TEST(PatternMatchInstr
, MatchBinaryOp
) {
146 std::unique_ptr
<LLVMTargetMachine
> TM
= createTargetMachine();
149 auto ModuleMMIPair
= createDummyModule(Context
, *TM
, "");
150 MachineFunction
*MF
=
151 getMFFromMMI(ModuleMMIPair
.first
.get(), ModuleMMIPair
.second
.get());
152 SmallVector
<Register
, 4> Copies
;
153 collectCopies(Copies
, MF
);
154 MachineBasicBlock
*EntryMBB
= &*MF
->begin();
155 MachineIRBuilder
B(*MF
);
156 MachineRegisterInfo
&MRI
= MF
->getRegInfo();
157 B
.setInsertPt(*EntryMBB
, EntryMBB
->end());
158 LLT s64
= LLT::scalar(64);
159 auto MIBAdd
= B
.buildAdd(s64
, Copies
[0], Copies
[1]);
160 // Test case for no bind.
162 mi_match(MIBAdd
->getOperand(0).getReg(), MRI
, m_GAdd(m_Reg(), m_Reg()));
164 Register Src0
, Src1
, Src2
;
165 match
= mi_match(MIBAdd
->getOperand(0).getReg(), MRI
,
166 m_GAdd(m_Reg(Src0
), m_Reg(Src1
)));
168 EXPECT_EQ(Src0
, Copies
[0]);
169 EXPECT_EQ(Src1
, Copies
[1]);
171 // Build MUL(ADD %0, %1), %2
172 auto MIBMul
= B
.buildMul(s64
, MIBAdd
, Copies
[2]);
175 match
= mi_match(MIBMul
->getOperand(0).getReg(), MRI
,
176 m_GMul(m_Reg(Src0
), m_Reg(Src1
)));
178 EXPECT_EQ(Src0
, MIBAdd
->getOperand(0).getReg());
179 EXPECT_EQ(Src1
, Copies
[2]);
181 // Try to match MUL(ADD)
182 match
= mi_match(MIBMul
->getOperand(0).getReg(), MRI
,
183 m_GMul(m_GAdd(m_Reg(Src0
), m_Reg(Src1
)), m_Reg(Src2
)));
185 EXPECT_EQ(Src0
, Copies
[0]);
186 EXPECT_EQ(Src1
, Copies
[1]);
187 EXPECT_EQ(Src2
, Copies
[2]);
189 // Test Commutativity.
190 auto MIBMul2
= B
.buildMul(s64
, Copies
[0], B
.buildConstant(s64
, 42));
191 // Try to match MUL(Cst, Reg) on src of MUL(Reg, Cst) to validate
194 match
= mi_match(MIBMul2
->getOperand(0).getReg(), MRI
,
195 m_GMul(m_ICst(Cst
), m_Reg(Src0
)));
198 EXPECT_EQ(Src0
, Copies
[0]);
200 // Make sure commutative doesn't work with something like SUB.
201 auto MIBSub
= B
.buildSub(s64
, Copies
[0], B
.buildConstant(s64
, 42));
202 match
= mi_match(MIBSub
->getOperand(0).getReg(), MRI
,
203 m_GSub(m_ICst(Cst
), m_Reg(Src0
)));
206 auto MIBFMul
= B
.buildInstr(TargetOpcode::G_FMUL
, {s64
},
207 {Copies
[0], B
.buildConstant(s64
, 42)});
208 // Match and test commutativity for FMUL.
209 match
= mi_match(MIBFMul
->getOperand(0).getReg(), MRI
,
210 m_GFMul(m_ICst(Cst
), m_Reg(Src0
)));
213 EXPECT_EQ(Src0
, Copies
[0]);
216 auto MIBFSub
= B
.buildInstr(TargetOpcode::G_FSUB
, {s64
},
217 {Copies
[0], B
.buildConstant(s64
, 42)});
218 match
= mi_match(MIBFSub
->getOperand(0).getReg(), MRI
,
219 m_GFSub(m_Reg(Src0
), m_Reg()));
221 EXPECT_EQ(Src0
, Copies
[0]);
224 auto MIBAnd
= B
.buildAnd(s64
, Copies
[0], Copies
[1]);
226 match
= mi_match(MIBAnd
->getOperand(0).getReg(), MRI
,
227 m_GAnd(m_Reg(Src0
), m_Reg(Src1
)));
229 EXPECT_EQ(Src0
, Copies
[0]);
230 EXPECT_EQ(Src1
, Copies
[1]);
233 auto MIBOr
= B
.buildOr(s64
, Copies
[0], Copies
[1]);
235 match
= mi_match(MIBOr
->getOperand(0).getReg(), MRI
,
236 m_GOr(m_Reg(Src0
), m_Reg(Src1
)));
238 EXPECT_EQ(Src0
, Copies
[0]);
239 EXPECT_EQ(Src1
, Copies
[1]);
241 // Try to use the FoldableInstructionsBuilder to build binary ops.
242 ConstantFoldingMIRBuilder
CFB(B
.getState());
243 LLT s32
= LLT::scalar(32);
245 CFB
.buildAdd(s32
, CFB
.buildConstant(s32
, 0), CFB
.buildConstant(s32
, 1));
246 // This should be a constant now.
247 match
= mi_match(MIBCAdd
->getOperand(0).getReg(), MRI
, m_ICst(Cst
));
251 CFB
.buildInstr(TargetOpcode::G_ADD
, {s32
},
252 {CFB
.buildConstant(s32
, 0), CFB
.buildConstant(s32
, 1)});
253 // This should be a constant now.
254 match
= mi_match(MIBCAdd1
->getOperand(0).getReg(), MRI
, m_ICst(Cst
));
258 // Try one of the other constructors of MachineIRBuilder to make sure it's
260 ConstantFoldingMIRBuilder
CFB1(*MF
);
261 CFB1
.setInsertPt(*EntryMBB
, EntryMBB
->end());
263 CFB1
.buildInstr(TargetOpcode::G_SUB
, {s32
},
264 {CFB1
.buildConstant(s32
, 1), CFB1
.buildConstant(s32
, 1)});
265 // This should be a constant now.
266 match
= mi_match(MIBCSub
->getOperand(0).getReg(), MRI
, m_ICst(Cst
));
271 CFB1
.buildInstr(TargetOpcode::G_SEXT_INREG
, {s32
},
272 {CFB1
.buildConstant(s32
, 0x01), uint64_t(8)});
273 // This should be a constant now.
274 match
= mi_match(MIBCSext1
->getOperand(0).getReg(), MRI
, m_ICst(Cst
));
279 CFB1
.buildInstr(TargetOpcode::G_SEXT_INREG
, {s32
},
280 {CFB1
.buildConstant(s32
, 0x80), uint64_t(8)});
281 // This should be a constant now.
282 match
= mi_match(MIBCSext2
->getOperand(0).getReg(), MRI
, m_ICst(Cst
));
284 EXPECT_EQ(-0x80, Cst
);
287 TEST(PatternMatchInstr
, MatchFPUnaryOp
) {
289 std::unique_ptr
<LLVMTargetMachine
> TM
= createTargetMachine();
292 auto ModuleMMIPair
= createDummyModule(Context
, *TM
, "");
293 MachineFunction
*MF
=
294 getMFFromMMI(ModuleMMIPair
.first
.get(), ModuleMMIPair
.second
.get());
295 SmallVector
<Register
, 4> Copies
;
296 collectCopies(Copies
, MF
);
297 MachineBasicBlock
*EntryMBB
= &*MF
->begin();
298 MachineIRBuilder
B(*MF
);
299 MachineRegisterInfo
&MRI
= MF
->getRegInfo();
300 B
.setInsertPt(*EntryMBB
, EntryMBB
->end());
302 // Truncate s64 to s32.
303 LLT s32
= LLT::scalar(32);
304 auto Copy0s32
= B
.buildFPTrunc(s32
, Copies
[0]);
307 auto MIBFabs
= B
.buildInstr(TargetOpcode::G_FABS
, {s32
}, {Copy0s32
});
308 bool match
= mi_match(MIBFabs
->getOperand(0).getReg(), MRI
, m_GFabs(m_Reg()));
312 auto MIBFNeg
= B
.buildInstr(TargetOpcode::G_FNEG
, {s32
}, {Copy0s32
});
313 match
= mi_match(MIBFNeg
->getOperand(0).getReg(), MRI
, m_GFNeg(m_Reg(Src
)));
315 EXPECT_EQ(Src
, Copy0s32
->getOperand(0).getReg());
317 match
= mi_match(MIBFabs
->getOperand(0).getReg(), MRI
, m_GFabs(m_Reg(Src
)));
319 EXPECT_EQ(Src
, Copy0s32
->getOperand(0).getReg());
321 // Build and match FConstant.
322 auto MIBFCst
= B
.buildFConstant(s32
, .5);
323 const ConstantFP
*TmpFP
{};
324 match
= mi_match(MIBFCst
->getOperand(0).getReg(), MRI
, m_GFCst(TmpFP
));
327 APFloat
APF((float).5);
328 auto *CFP
= ConstantFP::get(Context
, APF
);
329 EXPECT_EQ(CFP
, TmpFP
);
331 // Build double float.
332 LLT s64
= LLT::scalar(64);
333 auto MIBFCst64
= B
.buildFConstant(s64
, .5);
334 const ConstantFP
*TmpFP64
{};
335 match
= mi_match(MIBFCst64
->getOperand(0).getReg(), MRI
, m_GFCst(TmpFP64
));
337 EXPECT_TRUE(TmpFP64
);
339 auto CFP64
= ConstantFP::get(Context
, APF64
);
340 EXPECT_EQ(CFP64
, TmpFP64
);
341 EXPECT_NE(TmpFP64
, TmpFP
);
344 LLT s16
= LLT::scalar(16);
345 auto MIBFCst16
= B
.buildFConstant(s16
, .5);
346 const ConstantFP
*TmpFP16
{};
347 match
= mi_match(MIBFCst16
->getOperand(0).getReg(), MRI
, m_GFCst(TmpFP16
));
349 EXPECT_TRUE(TmpFP16
);
352 APF16
.convert(APFloat::IEEEhalf(), APFloat::rmNearestTiesToEven
, &Ignored
);
353 auto CFP16
= ConstantFP::get(Context
, APF16
);
354 EXPECT_EQ(TmpFP16
, CFP16
);
355 EXPECT_NE(TmpFP16
, TmpFP
);
358 TEST(PatternMatchInstr
, MatchExtendsTrunc
) {
360 std::unique_ptr
<LLVMTargetMachine
> TM
= createTargetMachine();
363 auto ModuleMMIPair
= createDummyModule(Context
, *TM
, "");
364 MachineFunction
*MF
=
365 getMFFromMMI(ModuleMMIPair
.first
.get(), ModuleMMIPair
.second
.get());
366 SmallVector
<Register
, 4> Copies
;
367 collectCopies(Copies
, MF
);
368 MachineBasicBlock
*EntryMBB
= &*MF
->begin();
369 MachineIRBuilder
B(*MF
);
370 MachineRegisterInfo
&MRI
= MF
->getRegInfo();
371 B
.setInsertPt(*EntryMBB
, EntryMBB
->end());
372 LLT s64
= LLT::scalar(64);
373 LLT s32
= LLT::scalar(32);
375 auto MIBTrunc
= B
.buildTrunc(s32
, Copies
[0]);
376 auto MIBAExt
= B
.buildAnyExt(s64
, MIBTrunc
);
377 auto MIBZExt
= B
.buildZExt(s64
, MIBTrunc
);
378 auto MIBSExt
= B
.buildSExt(s64
, MIBTrunc
);
381 mi_match(MIBTrunc
->getOperand(0).getReg(), MRI
, m_GTrunc(m_Reg(Src0
)));
383 EXPECT_EQ(Src0
, Copies
[0]);
385 mi_match(MIBAExt
->getOperand(0).getReg(), MRI
, m_GAnyExt(m_Reg(Src0
)));
387 EXPECT_EQ(Src0
, MIBTrunc
->getOperand(0).getReg());
389 match
= mi_match(MIBSExt
->getOperand(0).getReg(), MRI
, m_GSExt(m_Reg(Src0
)));
391 EXPECT_EQ(Src0
, MIBTrunc
->getOperand(0).getReg());
393 match
= mi_match(MIBZExt
->getOperand(0).getReg(), MRI
, m_GZExt(m_Reg(Src0
)));
395 EXPECT_EQ(Src0
, MIBTrunc
->getOperand(0).getReg());
397 // Match ext(trunc src)
398 match
= mi_match(MIBAExt
->getOperand(0).getReg(), MRI
,
399 m_GAnyExt(m_GTrunc(m_Reg(Src0
))));
401 EXPECT_EQ(Src0
, Copies
[0]);
403 match
= mi_match(MIBSExt
->getOperand(0).getReg(), MRI
,
404 m_GSExt(m_GTrunc(m_Reg(Src0
))));
406 EXPECT_EQ(Src0
, Copies
[0]);
408 match
= mi_match(MIBZExt
->getOperand(0).getReg(), MRI
,
409 m_GZExt(m_GTrunc(m_Reg(Src0
))));
411 EXPECT_EQ(Src0
, Copies
[0]);
414 TEST(PatternMatchInstr
, MatchSpecificType
) {
416 std::unique_ptr
<LLVMTargetMachine
> TM
= createTargetMachine();
419 auto ModuleMMIPair
= createDummyModule(Context
, *TM
, "");
420 MachineFunction
*MF
=
421 getMFFromMMI(ModuleMMIPair
.first
.get(), ModuleMMIPair
.second
.get());
422 SmallVector
<Register
, 4> Copies
;
423 collectCopies(Copies
, MF
);
424 MachineBasicBlock
*EntryMBB
= &*MF
->begin();
425 MachineIRBuilder
B(*MF
);
426 MachineRegisterInfo
&MRI
= MF
->getRegInfo();
427 B
.setInsertPt(*EntryMBB
, EntryMBB
->end());
429 // Try to match a 64bit add.
430 LLT s64
= LLT::scalar(64);
431 LLT s32
= LLT::scalar(32);
432 auto MIBAdd
= B
.buildAdd(s64
, Copies
[0], Copies
[1]);
433 EXPECT_FALSE(mi_match(MIBAdd
->getOperand(0).getReg(), MRI
,
434 m_GAdd(m_SpecificType(s32
), m_Reg())));
435 EXPECT_TRUE(mi_match(MIBAdd
->getOperand(0).getReg(), MRI
,
436 m_GAdd(m_SpecificType(s64
), m_Reg())));
438 // Try to match the destination type of a bitcast.
439 LLT v2s32
= LLT::vector(2, 32);
440 auto MIBCast
= B
.buildCast(v2s32
, Copies
[0]);
442 mi_match(MIBCast
->getOperand(0).getReg(), MRI
, m_GBitcast(m_Reg())));
444 mi_match(MIBCast
->getOperand(0).getReg(), MRI
, m_SpecificType(v2s32
)));
446 mi_match(MIBCast
->getOperand(1).getReg(), MRI
, m_SpecificType(s64
)));
448 // Build a PTRToInt and INTTOPTR and match and test them.
449 LLT PtrTy
= LLT::pointer(0, 64);
450 auto MIBIntToPtr
= B
.buildCast(PtrTy
, Copies
[0]);
451 auto MIBPtrToInt
= B
.buildCast(s64
, MIBIntToPtr
);
454 // match the ptrtoint(inttoptr reg)
455 bool match
= mi_match(MIBPtrToInt
->getOperand(0).getReg(), MRI
,
456 m_GPtrToInt(m_GIntToPtr(m_Reg(Src0
))));
458 EXPECT_EQ(Src0
, Copies
[0]);
461 TEST(PatternMatchInstr
, MatchCombinators
) {
463 std::unique_ptr
<LLVMTargetMachine
> TM
= createTargetMachine();
466 auto ModuleMMIPair
= createDummyModule(Context
, *TM
, "");
467 MachineFunction
*MF
=
468 getMFFromMMI(ModuleMMIPair
.first
.get(), ModuleMMIPair
.second
.get());
469 SmallVector
<Register
, 4> Copies
;
470 collectCopies(Copies
, MF
);
471 MachineBasicBlock
*EntryMBB
= &*MF
->begin();
472 MachineIRBuilder
B(*MF
);
473 MachineRegisterInfo
&MRI
= MF
->getRegInfo();
474 B
.setInsertPt(*EntryMBB
, EntryMBB
->end());
475 LLT s64
= LLT::scalar(64);
476 LLT s32
= LLT::scalar(32);
477 auto MIBAdd
= B
.buildAdd(s64
, Copies
[0], Copies
[1]);
480 mi_match(MIBAdd
->getOperand(0).getReg(), MRI
,
481 m_all_of(m_SpecificType(s64
), m_GAdd(m_Reg(Src0
), m_Reg(Src1
))));
483 EXPECT_EQ(Src0
, Copies
[0]);
484 EXPECT_EQ(Src1
, Copies
[1]);
485 // Check for s32 (which should fail).
487 mi_match(MIBAdd
->getOperand(0).getReg(), MRI
,
488 m_all_of(m_SpecificType(s32
), m_GAdd(m_Reg(Src0
), m_Reg(Src1
))));
491 mi_match(MIBAdd
->getOperand(0).getReg(), MRI
,
492 m_any_of(m_SpecificType(s32
), m_GAdd(m_Reg(Src0
), m_Reg(Src1
))));
494 EXPECT_EQ(Src0
, Copies
[0]);
495 EXPECT_EQ(Src1
, Copies
[1]);
497 // Match a case where none of the predicates hold true.
499 MIBAdd
->getOperand(0).getReg(), MRI
,
500 m_any_of(m_SpecificType(LLT::scalar(16)), m_GSub(m_Reg(), m_Reg())));
504 TEST(PatternMatchInstr
, MatchMiscellaneous
) {
506 std::unique_ptr
<LLVMTargetMachine
> TM
= createTargetMachine();
509 auto ModuleMMIPair
= createDummyModule(Context
, *TM
, "");
510 MachineFunction
*MF
=
511 getMFFromMMI(ModuleMMIPair
.first
.get(), ModuleMMIPair
.second
.get());
512 SmallVector
<Register
, 4> Copies
;
513 collectCopies(Copies
, MF
);
514 MachineBasicBlock
*EntryMBB
= &*MF
->begin();
515 MachineIRBuilder
B(*MF
);
516 MachineRegisterInfo
&MRI
= MF
->getRegInfo();
517 B
.setInsertPt(*EntryMBB
, EntryMBB
->end());
518 LLT s64
= LLT::scalar(64);
519 auto MIBAdd
= B
.buildAdd(s64
, Copies
[0], Copies
[1]);
520 // Make multiple uses of this add.
521 B
.buildCast(LLT::pointer(0, 32), MIBAdd
);
522 B
.buildCast(LLT::pointer(1, 32), MIBAdd
);
523 bool match
= mi_match(MIBAdd
.getReg(0), MRI
, m_GAdd(m_Reg(), m_Reg()));
525 match
= mi_match(MIBAdd
.getReg(0), MRI
, m_OneUse(m_GAdd(m_Reg(), m_Reg())));
530 int main(int argc
, char **argv
) {
531 ::testing::InitGoogleTest(&argc
, argv
);
533 return RUN_ALL_TESTS();