Bump version to 19.1.0-rc3
[llvm-project.git] / llvm / unittests / Transforms / IPO / FunctionSpecializationTest.cpp
blob52bad210b583ed7e482a89db675dd8d521367c6e
1 //===- FunctionSpecializationTest.cpp - Cost model unit tests -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "llvm/Analysis/AssumptionCache.h"
10 #include "llvm/Analysis/BlockFrequencyInfo.h"
11 #include "llvm/Analysis/BranchProbabilityInfo.h"
12 #include "llvm/Analysis/LoopInfo.h"
13 #include "llvm/Analysis/PostDominators.h"
14 #include "llvm/Analysis/TargetLibraryInfo.h"
15 #include "llvm/Analysis/TargetTransformInfo.h"
16 #include "llvm/AsmParser/Parser.h"
17 #include "llvm/IR/Constants.h"
18 #include "llvm/IR/PassInstrumentation.h"
19 #include "llvm/Support/SourceMgr.h"
20 #include "llvm/Transforms/IPO/FunctionSpecialization.h"
21 #include "llvm/Transforms/Utils/SCCPSolver.h"
22 #include "gtest/gtest.h"
23 #include <memory>
25 namespace llvm {
27 static void removeSSACopy(Function &F) {
28 for (BasicBlock &BB : F) {
29 for (Instruction &Inst : llvm::make_early_inc_range(BB)) {
30 if (auto *II = dyn_cast<IntrinsicInst>(&Inst)) {
31 if (II->getIntrinsicID() != Intrinsic::ssa_copy)
32 continue;
33 Inst.replaceAllUsesWith(II->getOperand(0));
34 Inst.eraseFromParent();
40 class FunctionSpecializationTest : public testing::Test {
41 protected:
42 LLVMContext Ctx;
43 FunctionAnalysisManager FAM;
44 std::unique_ptr<Module> M;
45 std::unique_ptr<SCCPSolver> Solver;
47 FunctionSpecializationTest() {
48 FAM.registerPass([&] { return TargetLibraryAnalysis(); });
49 FAM.registerPass([&] { return TargetIRAnalysis(); });
50 FAM.registerPass([&] { return BlockFrequencyAnalysis(); });
51 FAM.registerPass([&] { return BranchProbabilityAnalysis(); });
52 FAM.registerPass([&] { return LoopAnalysis(); });
53 FAM.registerPass([&] { return AssumptionAnalysis(); });
54 FAM.registerPass([&] { return DominatorTreeAnalysis(); });
55 FAM.registerPass([&] { return PostDominatorTreeAnalysis(); });
56 FAM.registerPass([&] { return PassInstrumentationAnalysis(); });
59 Module &parseModule(const char *ModuleString) {
60 SMDiagnostic Err;
61 M = parseAssemblyString(ModuleString, Err, Ctx);
62 EXPECT_TRUE(M);
63 return *M;
66 FunctionSpecializer getSpecializerFor(Function *F) {
67 auto GetTLI = [this](Function &F) -> const TargetLibraryInfo & {
68 return FAM.getResult<TargetLibraryAnalysis>(F);
70 auto GetTTI = [this](Function &F) -> TargetTransformInfo & {
71 return FAM.getResult<TargetIRAnalysis>(F);
73 auto GetAC = [this](Function &F) -> AssumptionCache & {
74 return FAM.getResult<AssumptionAnalysis>(F);
76 auto GetDT = [this](Function &F) -> DominatorTree & {
77 return FAM.getResult<DominatorTreeAnalysis>(F);
79 auto GetBFI = [this](Function &F) -> BlockFrequencyInfo & {
80 return FAM.getResult<BlockFrequencyAnalysis>(F);
83 Solver = std::make_unique<SCCPSolver>(M->getDataLayout(), GetTLI, Ctx);
85 DominatorTree &DT = GetDT(*F);
86 AssumptionCache &AC = GetAC(*F);
87 Solver->addPredicateInfo(*F, DT, AC);
89 Solver->markBlockExecutable(&F->front());
90 for (Argument &Arg : F->args())
91 Solver->markOverdefined(&Arg);
92 Solver->solveWhileResolvedUndefsIn(*M);
94 removeSSACopy(*F);
96 return FunctionSpecializer(*Solver, *M, &FAM, GetBFI, GetTLI, GetTTI,
97 GetAC);
100 Bonus getInstCost(Instruction &I, bool SizeOnly = false) {
101 auto &TTI = FAM.getResult<TargetIRAnalysis>(*I.getFunction());
102 auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(*I.getFunction());
104 Cost CodeSize =
105 TTI.getInstructionCost(&I, TargetTransformInfo::TCK_CodeSize);
107 Cost Latency =
108 SizeOnly
110 : BFI.getBlockFreq(I.getParent()).getFrequency() /
111 BFI.getEntryFreq().getFrequency() *
112 TTI.getInstructionCost(&I, TargetTransformInfo::TCK_Latency);
114 return {CodeSize, Latency};
118 } // namespace llvm
120 using namespace llvm;
122 TEST_F(FunctionSpecializationTest, SwitchInst) {
123 const char *ModuleString = R"(
124 define void @foo(i32 %a, i32 %b, i32 %i) {
125 entry:
126 br label %loop
127 loop:
128 switch i32 %i, label %default
129 [ i32 1, label %case1
130 i32 2, label %case2 ]
131 case1:
132 %0 = mul i32 %a, 2
133 %1 = sub i32 6, 5
134 br label %bb1
135 case2:
136 %2 = and i32 %b, 3
137 %3 = sdiv i32 8, 2
138 br label %bb2
139 bb1:
140 %4 = add i32 %0, %b
141 br label %loop
142 bb2:
143 %5 = or i32 %2, %a
144 br label %loop
145 default:
146 ret void
150 Module &M = parseModule(ModuleString);
151 Function *F = M.getFunction("foo");
152 FunctionSpecializer Specializer = getSpecializerFor(F);
153 InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
155 Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1);
157 auto FuncIter = F->begin();
158 BasicBlock &Loop = *++FuncIter;
159 BasicBlock &Case1 = *++FuncIter;
160 BasicBlock &Case2 = *++FuncIter;
161 BasicBlock &BB1 = *++FuncIter;
162 BasicBlock &BB2 = *++FuncIter;
164 Instruction &Switch = Loop.front();
165 Instruction &Mul = Case1.front();
166 Instruction &And = Case2.front();
167 Instruction &Sdiv = *++Case2.begin();
168 Instruction &BrBB2 = Case2.back();
169 Instruction &Add = BB1.front();
170 Instruction &Or = BB2.front();
171 Instruction &BrLoop = BB2.back();
173 // mul
174 Bonus Ref = getInstCost(Mul);
175 Bonus Test = Visitor.getSpecializationBonus(F->getArg(0), One);
176 EXPECT_EQ(Test, Ref);
177 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
179 // and + or + add
180 Ref = getInstCost(And) + getInstCost(Or) + getInstCost(Add);
181 Test = Visitor.getSpecializationBonus(F->getArg(1), One);
182 EXPECT_EQ(Test, Ref);
183 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
185 // switch + sdiv + br + br
186 Ref = getInstCost(Switch) +
187 getInstCost(Sdiv, /*SizeOnly =*/ true) +
188 getInstCost(BrBB2, /*SizeOnly =*/ true) +
189 getInstCost(BrLoop, /*SizeOnly =*/ true);
190 Test = Visitor.getSpecializationBonus(F->getArg(2), One);
191 EXPECT_EQ(Test, Ref);
192 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
195 TEST_F(FunctionSpecializationTest, BranchInst) {
196 const char *ModuleString = R"(
197 define void @foo(i32 %a, i32 %b, i1 %cond) {
198 entry:
199 br label %loop
200 loop:
201 br i1 %cond, label %bb0, label %bb3
202 bb0:
203 %0 = mul i32 %a, 2
204 %1 = sub i32 6, 5
205 br i1 %cond, label %bb1, label %bb2
206 bb1:
207 %2 = add i32 %0, %b
208 %3 = sdiv i32 8, 2
209 br label %bb2
210 bb2:
211 br label %loop
212 bb3:
213 ret void
217 Module &M = parseModule(ModuleString);
218 Function *F = M.getFunction("foo");
219 FunctionSpecializer Specializer = getSpecializerFor(F);
220 InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
222 Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1);
223 Constant *False = ConstantInt::getFalse(M.getContext());
225 auto FuncIter = F->begin();
226 BasicBlock &Loop = *++FuncIter;
227 BasicBlock &BB0 = *++FuncIter;
228 BasicBlock &BB1 = *++FuncIter;
229 BasicBlock &BB2 = *++FuncIter;
231 Instruction &Branch = Loop.front();
232 Instruction &Mul = BB0.front();
233 Instruction &Sub = *++BB0.begin();
234 Instruction &BrBB1BB2 = BB0.back();
235 Instruction &Add = BB1.front();
236 Instruction &Sdiv = *++BB1.begin();
237 Instruction &BrBB2 = BB1.back();
238 Instruction &BrLoop = BB2.front();
240 // mul
241 Bonus Ref = getInstCost(Mul);
242 Bonus Test = Visitor.getSpecializationBonus(F->getArg(0), One);
243 EXPECT_EQ(Test, Ref);
244 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
246 // add
247 Ref = getInstCost(Add);
248 Test = Visitor.getSpecializationBonus(F->getArg(1), One);
249 EXPECT_EQ(Test, Ref);
250 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
252 // branch + sub + br + sdiv + br
253 Ref = getInstCost(Branch) +
254 getInstCost(Sub, /*SizeOnly =*/ true) +
255 getInstCost(BrBB1BB2) +
256 getInstCost(Sdiv, /*SizeOnly =*/ true) +
257 getInstCost(BrBB2, /*SizeOnly =*/ true) +
258 getInstCost(BrLoop, /*SizeOnly =*/ true);
259 Test = Visitor.getSpecializationBonus(F->getArg(2), False);
260 EXPECT_EQ(Test, Ref);
261 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
264 TEST_F(FunctionSpecializationTest, Misc) {
265 const char *ModuleString = R"(
266 %struct_t = type { [8 x i16], [8 x i16], i32, i32, i32, ptr, [8 x i8] }
267 @g = constant %struct_t zeroinitializer, align 16
269 declare i32 @llvm.smax.i32(i32, i32)
270 declare i32 @bar(i32)
272 define i32 @foo(i8 %a, i1 %cond, ptr %b, i32 %c) {
273 %cmp = icmp eq i8 %a, 10
274 %ext = zext i1 %cmp to i64
275 %sel = select i1 %cond, i64 %ext, i64 1
276 %gep = getelementptr inbounds %struct_t, ptr %b, i64 %sel, i32 4
277 %ld = load i32, ptr %gep
278 %fr = freeze i32 %ld
279 %smax = call i32 @llvm.smax.i32(i32 %fr, i32 1)
280 %call = call i32 @bar(i32 %smax)
281 %fr2 = freeze i32 %c
282 %add = add i32 %call, %fr2
283 ret i32 %add
287 Module &M = parseModule(ModuleString);
288 Function *F = M.getFunction("foo");
289 FunctionSpecializer Specializer = getSpecializerFor(F);
290 InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
292 GlobalVariable *GV = M.getGlobalVariable("g");
293 Constant *One = ConstantInt::get(IntegerType::getInt8Ty(M.getContext()), 1);
294 Constant *True = ConstantInt::getTrue(M.getContext());
295 Constant *Undef = UndefValue::get(IntegerType::getInt32Ty(M.getContext()));
297 auto BlockIter = F->front().begin();
298 Instruction &Icmp = *BlockIter++;
299 Instruction &Zext = *BlockIter++;
300 Instruction &Select = *BlockIter++;
301 Instruction &Gep = *BlockIter++;
302 Instruction &Load = *BlockIter++;
303 Instruction &Freeze = *BlockIter++;
304 Instruction &Smax = *BlockIter++;
306 // icmp + zext
307 Bonus Ref = getInstCost(Icmp) + getInstCost(Zext);
308 Bonus Test = Visitor.getSpecializationBonus(F->getArg(0), One);
309 EXPECT_EQ(Test, Ref);
310 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
312 // select
313 Ref = getInstCost(Select);
314 Test = Visitor.getSpecializationBonus(F->getArg(1), True);
315 EXPECT_EQ(Test, Ref);
316 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
318 // gep + load + freeze + smax
319 Ref = getInstCost(Gep) + getInstCost(Load) + getInstCost(Freeze) +
320 getInstCost(Smax);
321 Test = Visitor.getSpecializationBonus(F->getArg(2), GV);
322 EXPECT_EQ(Test, Ref);
323 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
325 Test = Visitor.getSpecializationBonus(F->getArg(3), Undef);
326 EXPECT_TRUE(Test.CodeSize == 0 && Test.Latency == 0);
329 TEST_F(FunctionSpecializationTest, PhiNode) {
330 const char *ModuleString = R"(
331 define void @foo(i32 %a, i32 %b, i32 %i) {
332 entry:
333 br label %loop
334 loop:
335 %0 = phi i32 [ %a, %entry ], [ %3, %bb ]
336 switch i32 %i, label %default
337 [ i32 1, label %case1
338 i32 2, label %case2 ]
339 case1:
340 %1 = add i32 %0, 1
341 br label %bb
342 case2:
343 %2 = phi i32 [ %a, %entry ], [ %0, %loop ]
344 br label %bb
346 %3 = phi i32 [ %b, %case1 ], [ %2, %case2 ], [ %3, %bb ]
347 %4 = icmp eq i32 %3, 1
348 br i1 %4, label %bb, label %loop
349 default:
350 ret void
354 Module &M = parseModule(ModuleString);
355 Function *F = M.getFunction("foo");
356 FunctionSpecializer Specializer = getSpecializerFor(F);
357 InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
359 Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1);
361 auto FuncIter = F->begin();
362 BasicBlock &Loop = *++FuncIter;
363 BasicBlock &Case1 = *++FuncIter;
364 BasicBlock &Case2 = *++FuncIter;
365 BasicBlock &BB = *++FuncIter;
367 Instruction &PhiLoop = Loop.front();
368 Instruction &Switch = Loop.back();
369 Instruction &Add = Case1.front();
370 Instruction &PhiCase2 = Case2.front();
371 Instruction &BrBB = Case2.back();
372 Instruction &PhiBB = BB.front();
373 Instruction &Icmp = *++BB.begin();
374 Instruction &Branch = BB.back();
376 Bonus Test = Visitor.getSpecializationBonus(F->getArg(0), One);
377 EXPECT_TRUE(Test.CodeSize == 0 && Test.Latency == 0);
379 Test = Visitor.getSpecializationBonus(F->getArg(1), One);
380 EXPECT_TRUE(Test.CodeSize == 0 && Test.Latency == 0);
382 // switch + phi + br
383 Bonus Ref = getInstCost(Switch) +
384 getInstCost(PhiCase2, /*SizeOnly =*/ true) +
385 getInstCost(BrBB, /*SizeOnly =*/ true);
386 Test = Visitor.getSpecializationBonus(F->getArg(2), One);
387 EXPECT_EQ(Test, Ref);
388 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
390 // phi + phi + add + icmp + branch
391 Ref = getInstCost(PhiBB) + getInstCost(PhiLoop) + getInstCost(Add) +
392 getInstCost(Icmp) + getInstCost(Branch);
393 Test = Visitor.getBonusFromPendingPHIs();
394 EXPECT_EQ(Test, Ref);
395 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);