1 //===- FunctionSpecializationTest.cpp - Cost model unit tests -------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "llvm/Analysis/AssumptionCache.h"
10 #include "llvm/Analysis/BlockFrequencyInfo.h"
11 #include "llvm/Analysis/BranchProbabilityInfo.h"
12 #include "llvm/Analysis/LoopInfo.h"
13 #include "llvm/Analysis/PostDominators.h"
14 #include "llvm/Analysis/TargetLibraryInfo.h"
15 #include "llvm/Analysis/TargetTransformInfo.h"
16 #include "llvm/AsmParser/Parser.h"
17 #include "llvm/IR/Constants.h"
18 #include "llvm/IR/PassInstrumentation.h"
19 #include "llvm/Support/SourceMgr.h"
20 #include "llvm/Transforms/IPO/FunctionSpecialization.h"
21 #include "llvm/Transforms/Utils/SCCPSolver.h"
22 #include "gtest/gtest.h"
27 static void removeSSACopy(Function
&F
) {
28 for (BasicBlock
&BB
: F
) {
29 for (Instruction
&Inst
: llvm::make_early_inc_range(BB
)) {
30 if (auto *II
= dyn_cast
<IntrinsicInst
>(&Inst
)) {
31 if (II
->getIntrinsicID() != Intrinsic::ssa_copy
)
33 Inst
.replaceAllUsesWith(II
->getOperand(0));
34 Inst
.eraseFromParent();
40 class FunctionSpecializationTest
: public testing::Test
{
43 FunctionAnalysisManager FAM
;
44 std::unique_ptr
<Module
> M
;
45 std::unique_ptr
<SCCPSolver
> Solver
;
46 SmallVector
<Instruction
*, 8> KnownConstants
;
48 FunctionSpecializationTest() {
49 FAM
.registerPass([&] { return TargetLibraryAnalysis(); });
50 FAM
.registerPass([&] { return TargetIRAnalysis(); });
51 FAM
.registerPass([&] { return BlockFrequencyAnalysis(); });
52 FAM
.registerPass([&] { return BranchProbabilityAnalysis(); });
53 FAM
.registerPass([&] { return LoopAnalysis(); });
54 FAM
.registerPass([&] { return AssumptionAnalysis(); });
55 FAM
.registerPass([&] { return DominatorTreeAnalysis(); });
56 FAM
.registerPass([&] { return PostDominatorTreeAnalysis(); });
57 FAM
.registerPass([&] { return PassInstrumentationAnalysis(); });
60 Module
&parseModule(const char *ModuleString
) {
62 M
= parseAssemblyString(ModuleString
, Err
, Ctx
);
67 FunctionSpecializer
getSpecializerFor(Function
*F
) {
68 auto GetTLI
= [this](Function
&F
) -> const TargetLibraryInfo
& {
69 return FAM
.getResult
<TargetLibraryAnalysis
>(F
);
71 auto GetTTI
= [this](Function
&F
) -> TargetTransformInfo
& {
72 return FAM
.getResult
<TargetIRAnalysis
>(F
);
74 auto GetAC
= [this](Function
&F
) -> AssumptionCache
& {
75 return FAM
.getResult
<AssumptionAnalysis
>(F
);
77 auto GetDT
= [this](Function
&F
) -> DominatorTree
& {
78 return FAM
.getResult
<DominatorTreeAnalysis
>(F
);
80 auto GetBFI
= [this](Function
&F
) -> BlockFrequencyInfo
& {
81 return FAM
.getResult
<BlockFrequencyAnalysis
>(F
);
84 Solver
= std::make_unique
<SCCPSolver
>(M
->getDataLayout(), GetTLI
, Ctx
);
86 DominatorTree
&DT
= GetDT(*F
);
87 AssumptionCache
&AC
= GetAC(*F
);
88 Solver
->addPredicateInfo(*F
, DT
, AC
);
90 Solver
->markBlockExecutable(&F
->front());
91 for (Argument
&Arg
: F
->args())
92 Solver
->markOverdefined(&Arg
);
93 Solver
->solveWhileResolvedUndefsIn(*M
);
97 return FunctionSpecializer(*Solver
, *M
, &FAM
, GetBFI
, GetTLI
, GetTTI
,
101 Cost
getCodeSizeSavings(Instruction
&I
, bool HasLatencySavings
= true) {
102 auto &TTI
= FAM
.getResult
<TargetIRAnalysis
>(*I
.getFunction());
105 TTI
.getInstructionCost(&I
, TargetTransformInfo::TCK_CodeSize
);
107 if (HasLatencySavings
)
108 KnownConstants
.push_back(&I
);
113 Cost
getLatencySavings(Function
*F
) {
114 auto &TTI
= FAM
.getResult
<TargetIRAnalysis
>(*F
);
115 auto &BFI
= FAM
.getResult
<BlockFrequencyAnalysis
>(*F
);
118 for (const Instruction
*I
: KnownConstants
)
119 Latency
+= BFI
.getBlockFreq(I
->getParent()).getFrequency() /
120 BFI
.getEntryFreq().getFrequency() *
121 TTI
.getInstructionCost(I
, TargetTransformInfo::TCK_Latency
);
129 using namespace llvm
;
131 TEST_F(FunctionSpecializationTest
, SwitchInst
) {
132 const char *ModuleString
= R
"(
133 define void @foo(i32 %a, i32 %b, i32 %i) {
137 switch i32 %i, label %default
138 [ i32 1, label %case1
139 i32 2, label %case2 ]
159 Module
&M
= parseModule(ModuleString
);
160 Function
*F
= M
.getFunction("foo");
161 FunctionSpecializer Specializer
= getSpecializerFor(F
);
162 InstCostVisitor Visitor
= Specializer
.getInstCostVisitorFor(F
);
164 Constant
*One
= ConstantInt::get(IntegerType::getInt32Ty(M
.getContext()), 1);
166 auto FuncIter
= F
->begin();
167 BasicBlock
&Loop
= *++FuncIter
;
168 BasicBlock
&Case1
= *++FuncIter
;
169 BasicBlock
&Case2
= *++FuncIter
;
170 BasicBlock
&BB1
= *++FuncIter
;
171 BasicBlock
&BB2
= *++FuncIter
;
173 Instruction
&Switch
= Loop
.front();
174 Instruction
&Mul
= Case1
.front();
175 Instruction
&And
= Case2
.front();
176 Instruction
&Sdiv
= *++Case2
.begin();
177 Instruction
&BrBB2
= Case2
.back();
178 Instruction
&Add
= BB1
.front();
179 Instruction
&Or
= BB2
.front();
180 Instruction
&BrLoop
= BB2
.back();
183 Cost Ref
= getCodeSizeSavings(Mul
);
184 Cost Test
= Visitor
.getCodeSizeSavingsForArg(F
->getArg(0), One
);
185 EXPECT_EQ(Test
, Ref
);
186 EXPECT_TRUE(Test
> 0);
189 Ref
= getCodeSizeSavings(And
) + getCodeSizeSavings(Or
) +
190 getCodeSizeSavings(Add
);
191 Test
= Visitor
.getCodeSizeSavingsForArg(F
->getArg(1), One
);
192 EXPECT_EQ(Test
, Ref
);
193 EXPECT_TRUE(Test
> 0);
195 // switch + sdiv + br + br
196 Ref
= getCodeSizeSavings(Switch
) +
197 getCodeSizeSavings(Sdiv
, /*HasLatencySavings=*/false) +
198 getCodeSizeSavings(BrBB2
, /*HasLatencySavings=*/false) +
199 getCodeSizeSavings(BrLoop
, /*HasLatencySavings=*/false);
200 Test
= Visitor
.getCodeSizeSavingsForArg(F
->getArg(2), One
);
201 EXPECT_EQ(Test
, Ref
);
202 EXPECT_TRUE(Test
> 0);
205 Ref
= getLatencySavings(F
);
206 Test
= Visitor
.getLatencySavingsForKnownConstants();
207 EXPECT_EQ(Test
, Ref
);
208 EXPECT_TRUE(Test
> 0);
211 TEST_F(FunctionSpecializationTest
, BranchInst
) {
212 const char *ModuleString
= R
"(
213 define void @foo(i32 %a, i32 %b, i1 %cond) {
217 br i1 %cond, label %bb0, label %bb3
221 br i1 %cond, label %bb1, label %bb2
233 Module
&M
= parseModule(ModuleString
);
234 Function
*F
= M
.getFunction("foo");
235 FunctionSpecializer Specializer
= getSpecializerFor(F
);
236 InstCostVisitor Visitor
= Specializer
.getInstCostVisitorFor(F
);
238 Constant
*One
= ConstantInt::get(IntegerType::getInt32Ty(M
.getContext()), 1);
239 Constant
*False
= ConstantInt::getFalse(M
.getContext());
241 auto FuncIter
= F
->begin();
242 BasicBlock
&Loop
= *++FuncIter
;
243 BasicBlock
&BB0
= *++FuncIter
;
244 BasicBlock
&BB1
= *++FuncIter
;
245 BasicBlock
&BB2
= *++FuncIter
;
247 Instruction
&Branch
= Loop
.front();
248 Instruction
&Mul
= BB0
.front();
249 Instruction
&Sub
= *++BB0
.begin();
250 Instruction
&BrBB1BB2
= BB0
.back();
251 Instruction
&Add
= BB1
.front();
252 Instruction
&Sdiv
= *++BB1
.begin();
253 Instruction
&BrBB2
= BB1
.back();
254 Instruction
&BrLoop
= BB2
.front();
257 Cost Ref
= getCodeSizeSavings(Mul
);
258 Cost Test
= Visitor
.getCodeSizeSavingsForArg(F
->getArg(0), One
);
259 EXPECT_EQ(Test
, Ref
);
260 EXPECT_TRUE(Test
> 0);
263 Ref
= getCodeSizeSavings(Add
);
264 Test
= Visitor
.getCodeSizeSavingsForArg(F
->getArg(1), One
);
265 EXPECT_EQ(Test
, Ref
);
266 EXPECT_TRUE(Test
> 0);
268 // branch + sub + br + sdiv + br
269 Ref
= getCodeSizeSavings(Branch
) +
270 getCodeSizeSavings(Sub
, /*HasLatencySavings=*/false) +
271 getCodeSizeSavings(BrBB1BB2
) +
272 getCodeSizeSavings(Sdiv
, /*HasLatencySavings=*/false) +
273 getCodeSizeSavings(BrBB2
, /*HasLatencySavings=*/false) +
274 getCodeSizeSavings(BrLoop
, /*HasLatencySavings=*/false);
275 Test
= Visitor
.getCodeSizeSavingsForArg(F
->getArg(2), False
);
276 EXPECT_EQ(Test
, Ref
);
277 EXPECT_TRUE(Test
> 0);
280 Ref
= getLatencySavings(F
);
281 Test
= Visitor
.getLatencySavingsForKnownConstants();
282 EXPECT_EQ(Test
, Ref
);
283 EXPECT_TRUE(Test
> 0);
286 TEST_F(FunctionSpecializationTest
, SelectInst
) {
287 const char *ModuleString
= R
"(
288 define i32 @foo(i1 %cond, i32 %a, i32 %b) {
289 %sel = select i1 %cond, i32 %a, i32 %b
294 Module
&M
= parseModule(ModuleString
);
295 Function
*F
= M
.getFunction("foo");
296 FunctionSpecializer Specializer
= getSpecializerFor(F
);
297 InstCostVisitor Visitor
= Specializer
.getInstCostVisitorFor(F
);
299 Constant
*One
= ConstantInt::get(IntegerType::getInt32Ty(M
.getContext()), 1);
300 Constant
*Zero
= ConstantInt::get(IntegerType::getInt32Ty(M
.getContext()), 0);
301 Constant
*False
= ConstantInt::getFalse(M
.getContext());
302 Instruction
&Select
= *F
->front().begin();
304 Cost RefCodeSize
= getCodeSizeSavings(Select
);
305 Cost RefLatency
= getLatencySavings(F
);
307 Cost TestCodeSize
= Visitor
.getCodeSizeSavingsForArg(F
->getArg(0), False
);
308 EXPECT_TRUE(TestCodeSize
== 0);
309 TestCodeSize
= Visitor
.getCodeSizeSavingsForArg(F
->getArg(1), One
);
310 EXPECT_TRUE(TestCodeSize
== 0);
311 Cost TestLatency
= Visitor
.getLatencySavingsForKnownConstants();
312 EXPECT_TRUE(TestLatency
== 0);
314 TestCodeSize
= Visitor
.getCodeSizeSavingsForArg(F
->getArg(2), Zero
);
315 EXPECT_EQ(TestCodeSize
, RefCodeSize
);
316 EXPECT_TRUE(TestCodeSize
> 0);
317 TestLatency
= Visitor
.getLatencySavingsForKnownConstants();
318 EXPECT_EQ(TestLatency
, RefLatency
);
319 EXPECT_TRUE(TestLatency
> 0);
322 TEST_F(FunctionSpecializationTest
, Misc
) {
323 const char *ModuleString
= R
"(
324 %struct_t = type { [8 x i16], [8 x i16], i32, i32, i32, ptr, [8 x i8] }
325 @g = constant %struct_t zeroinitializer, align 16
327 declare i32 @llvm.smax.i32(i32, i32)
328 declare i32 @bar(i32)
330 define i32 @foo(i8 %a, i1 %cond, ptr %b, i32 %c) {
331 %cmp = icmp eq i8 %a, 10
332 %ext = zext i1 %cmp to i64
333 %sel = select i1 %cond, i64 %ext, i64 1
334 %gep = getelementptr inbounds %struct_t, ptr %b, i64 %sel, i32 4
335 %ld = load i32, ptr %gep
337 %smax = call i32 @llvm.smax.i32(i32 %fr, i32 1)
338 %call = call i32 @bar(i32 %smax)
340 %add = add i32 %call, %fr2
345 Module
&M
= parseModule(ModuleString
);
346 Function
*F
= M
.getFunction("foo");
347 FunctionSpecializer Specializer
= getSpecializerFor(F
);
348 InstCostVisitor Visitor
= Specializer
.getInstCostVisitorFor(F
);
350 GlobalVariable
*GV
= M
.getGlobalVariable("g");
351 Constant
*One
= ConstantInt::get(IntegerType::getInt8Ty(M
.getContext()), 1);
352 Constant
*True
= ConstantInt::getTrue(M
.getContext());
353 Constant
*Undef
= UndefValue::get(IntegerType::getInt32Ty(M
.getContext()));
355 auto BlockIter
= F
->front().begin();
356 Instruction
&Icmp
= *BlockIter
++;
357 Instruction
&Zext
= *BlockIter
++;
358 Instruction
&Select
= *BlockIter
++;
359 Instruction
&Gep
= *BlockIter
++;
360 Instruction
&Load
= *BlockIter
++;
361 Instruction
&Freeze
= *BlockIter
++;
362 Instruction
&Smax
= *BlockIter
++;
365 Cost Ref
= getCodeSizeSavings(Icmp
) + getCodeSizeSavings(Zext
);
366 Cost Test
= Visitor
.getCodeSizeSavingsForArg(F
->getArg(0), One
);
367 EXPECT_EQ(Test
, Ref
);
368 EXPECT_TRUE(Test
> 0);
371 Ref
= getCodeSizeSavings(Select
);
372 Test
= Visitor
.getCodeSizeSavingsForArg(F
->getArg(1), True
);
373 EXPECT_EQ(Test
, Ref
);
374 EXPECT_TRUE(Test
> 0);
376 // gep + load + freeze + smax
377 Ref
= getCodeSizeSavings(Gep
) + getCodeSizeSavings(Load
) +
378 getCodeSizeSavings(Freeze
) + getCodeSizeSavings(Smax
);
379 Test
= Visitor
.getCodeSizeSavingsForArg(F
->getArg(2), GV
);
380 EXPECT_EQ(Test
, Ref
);
381 EXPECT_TRUE(Test
> 0);
383 Test
= Visitor
.getCodeSizeSavingsForArg(F
->getArg(3), Undef
);
384 EXPECT_TRUE(Test
== 0);
387 Ref
= getLatencySavings(F
);
388 Test
= Visitor
.getLatencySavingsForKnownConstants();
389 EXPECT_EQ(Test
, Ref
);
390 EXPECT_TRUE(Test
> 0);
393 TEST_F(FunctionSpecializationTest
, PhiNode
) {
394 const char *ModuleString
= R
"(
395 define void @foo(i32 %a, i32 %b, i32 %i) {
399 %0 = phi i32 [ %a, %entry ], [ %3, %bb ]
400 switch i32 %i, label %default
401 [ i32 1, label %case1
402 i32 2, label %case2 ]
407 %2 = phi i32 [ %a, %entry ], [ %0, %loop ]
410 %3 = phi i32 [ %b, %case1 ], [ %2, %case2 ], [ %3, %bb ]
411 %4 = icmp eq i32 %3, 1
412 br i1 %4, label %bb, label %loop
418 Module
&M
= parseModule(ModuleString
);
419 Function
*F
= M
.getFunction("foo");
420 FunctionSpecializer Specializer
= getSpecializerFor(F
);
421 InstCostVisitor Visitor
= Specializer
.getInstCostVisitorFor(F
);
423 Constant
*One
= ConstantInt::get(IntegerType::getInt32Ty(M
.getContext()), 1);
425 auto FuncIter
= F
->begin();
426 BasicBlock
&Loop
= *++FuncIter
;
427 BasicBlock
&Case1
= *++FuncIter
;
428 BasicBlock
&Case2
= *++FuncIter
;
429 BasicBlock
&BB
= *++FuncIter
;
431 Instruction
&PhiLoop
= Loop
.front();
432 Instruction
&Switch
= Loop
.back();
433 Instruction
&Add
= Case1
.front();
434 Instruction
&PhiCase2
= Case2
.front();
435 Instruction
&BrBB
= Case2
.back();
436 Instruction
&PhiBB
= BB
.front();
437 Instruction
&Icmp
= *++BB
.begin();
438 Instruction
&Branch
= BB
.back();
440 Cost Test
= Visitor
.getCodeSizeSavingsForArg(F
->getArg(0), One
);
441 EXPECT_TRUE(Test
== 0);
443 Test
= Visitor
.getCodeSizeSavingsForArg(F
->getArg(1), One
);
444 EXPECT_TRUE(Test
== 0);
446 Test
= Visitor
.getLatencySavingsForKnownConstants();
447 EXPECT_TRUE(Test
== 0);
450 Cost Ref
= getCodeSizeSavings(Switch
) +
451 getCodeSizeSavings(PhiCase2
, /*HasLatencySavings=*/false) +
452 getCodeSizeSavings(BrBB
, /*HasLatencySavings=*/false);
453 Test
= Visitor
.getCodeSizeSavingsForArg(F
->getArg(2), One
);
454 EXPECT_EQ(Test
, Ref
);
455 EXPECT_TRUE(Test
> 0 && Test
> 0);
457 // phi + phi + add + icmp + branch
458 Ref
= getCodeSizeSavings(PhiBB
) + getCodeSizeSavings(PhiLoop
) +
459 getCodeSizeSavings(Add
) + getCodeSizeSavings(Icmp
) +
460 getCodeSizeSavings(Branch
);
461 Test
= Visitor
.getCodeSizeSavingsFromPendingPHIs();
462 EXPECT_EQ(Test
, Ref
);
463 EXPECT_TRUE(Test
> 0);
466 Ref
= getLatencySavings(F
);
467 Test
= Visitor
.getLatencySavingsForKnownConstants();
468 EXPECT_EQ(Test
, Ref
);
469 EXPECT_TRUE(Test
> 0);
472 TEST_F(FunctionSpecializationTest
, BinOp
) {
473 // Verify that we can handle binary operators even when only one operand is
475 const char *ModuleString
= R
"(
476 define i32 @foo(i1 %a, i1 %b) {
477 %and1 = and i1 %a, %b
478 %and2 = and i1 %b, %and1
479 %sel = select i1 %and2, i32 1, i32 0
484 Module
&M
= parseModule(ModuleString
);
485 Function
*F
= M
.getFunction("foo");
486 FunctionSpecializer Specializer
= getSpecializerFor(F
);
487 InstCostVisitor Visitor
= Specializer
.getInstCostVisitorFor(F
);
489 Constant
*False
= ConstantInt::getFalse(M
.getContext());
490 BasicBlock
&BB
= F
->front();
491 Instruction
&And1
= BB
.front();
492 Instruction
&And2
= *++BB
.begin();
493 Instruction
&Select
= *++BB
.begin();
495 Cost RefCodeSize
= getCodeSizeSavings(And1
) + getCodeSizeSavings(And2
) +
496 getCodeSizeSavings(Select
);
497 Cost RefLatency
= getLatencySavings(F
);
499 Cost TestCodeSize
= Visitor
.getCodeSizeSavingsForArg(F
->getArg(0), False
);
500 Cost TestLatency
= Visitor
.getLatencySavingsForKnownConstants();
502 EXPECT_EQ(TestCodeSize
, RefCodeSize
);
503 EXPECT_TRUE(TestCodeSize
> 0);
504 EXPECT_EQ(TestLatency
, RefLatency
);
505 EXPECT_TRUE(TestLatency
> 0);