Run DCE after a LoopFlatten test to reduce spurious output [nfc]
[llvm-project.git] / llvm / unittests / Transforms / IPO / FunctionSpecializationTest.cpp
blobd0e8977f1245d5ef9a2a36ea1bf3243f6a3e55d6
1 //===- FunctionSpecializationTest.cpp - Cost model unit tests -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "llvm/Analysis/AssumptionCache.h"
10 #include "llvm/Analysis/BlockFrequencyInfo.h"
11 #include "llvm/Analysis/BranchProbabilityInfo.h"
12 #include "llvm/Analysis/LoopInfo.h"
13 #include "llvm/Analysis/PostDominators.h"
14 #include "llvm/Analysis/TargetLibraryInfo.h"
15 #include "llvm/Analysis/TargetTransformInfo.h"
16 #include "llvm/AsmParser/Parser.h"
17 #include "llvm/IR/Constants.h"
18 #include "llvm/Support/SourceMgr.h"
19 #include "llvm/Transforms/IPO/FunctionSpecialization.h"
20 #include "llvm/Transforms/Utils/SCCPSolver.h"
21 #include "gtest/gtest.h"
22 #include <memory>
24 namespace llvm {
26 static void removeSSACopy(Function &F) {
27 for (BasicBlock &BB : F) {
28 for (Instruction &Inst : llvm::make_early_inc_range(BB)) {
29 if (auto *II = dyn_cast<IntrinsicInst>(&Inst)) {
30 if (II->getIntrinsicID() != Intrinsic::ssa_copy)
31 continue;
32 Inst.replaceAllUsesWith(II->getOperand(0));
33 Inst.eraseFromParent();
39 class FunctionSpecializationTest : public testing::Test {
40 protected:
41 LLVMContext Ctx;
42 FunctionAnalysisManager FAM;
43 std::unique_ptr<Module> M;
44 std::unique_ptr<SCCPSolver> Solver;
46 FunctionSpecializationTest() {
47 FAM.registerPass([&] { return TargetLibraryAnalysis(); });
48 FAM.registerPass([&] { return TargetIRAnalysis(); });
49 FAM.registerPass([&] { return BlockFrequencyAnalysis(); });
50 FAM.registerPass([&] { return BranchProbabilityAnalysis(); });
51 FAM.registerPass([&] { return LoopAnalysis(); });
52 FAM.registerPass([&] { return AssumptionAnalysis(); });
53 FAM.registerPass([&] { return DominatorTreeAnalysis(); });
54 FAM.registerPass([&] { return PostDominatorTreeAnalysis(); });
55 FAM.registerPass([&] { return PassInstrumentationAnalysis(); });
58 Module &parseModule(const char *ModuleString) {
59 SMDiagnostic Err;
60 M = parseAssemblyString(ModuleString, Err, Ctx);
61 EXPECT_TRUE(M);
62 return *M;
65 FunctionSpecializer getSpecializerFor(Function *F) {
66 auto GetTLI = [this](Function &F) -> const TargetLibraryInfo & {
67 return FAM.getResult<TargetLibraryAnalysis>(F);
69 auto GetTTI = [this](Function &F) -> TargetTransformInfo & {
70 return FAM.getResult<TargetIRAnalysis>(F);
72 auto GetAC = [this](Function &F) -> AssumptionCache & {
73 return FAM.getResult<AssumptionAnalysis>(F);
75 auto GetDT = [this](Function &F) -> DominatorTree & {
76 return FAM.getResult<DominatorTreeAnalysis>(F);
78 auto GetBFI = [this](Function &F) -> BlockFrequencyInfo & {
79 return FAM.getResult<BlockFrequencyAnalysis>(F);
82 Solver = std::make_unique<SCCPSolver>(M->getDataLayout(), GetTLI, Ctx);
84 DominatorTree &DT = GetDT(*F);
85 AssumptionCache &AC = GetAC(*F);
86 Solver->addPredicateInfo(*F, DT, AC);
88 Solver->markBlockExecutable(&F->front());
89 for (Argument &Arg : F->args())
90 Solver->markOverdefined(&Arg);
91 Solver->solveWhileResolvedUndefsIn(*M);
93 removeSSACopy(*F);
95 return FunctionSpecializer(*Solver, *M, &FAM, GetBFI, GetTLI, GetTTI,
96 GetAC);
99 Bonus getInstCost(Instruction &I, bool SizeOnly = false) {
100 auto &TTI = FAM.getResult<TargetIRAnalysis>(*I.getFunction());
101 auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(*I.getFunction());
103 Cost CodeSize =
104 TTI.getInstructionCost(&I, TargetTransformInfo::TCK_CodeSize);
106 Cost Latency =
107 SizeOnly
109 : BFI.getBlockFreq(I.getParent()).getFrequency() /
110 BFI.getEntryFreq().getFrequency() *
111 TTI.getInstructionCost(&I, TargetTransformInfo::TCK_Latency);
113 return {CodeSize, Latency};
117 } // namespace llvm
119 using namespace llvm;
121 TEST_F(FunctionSpecializationTest, SwitchInst) {
122 const char *ModuleString = R"(
123 define void @foo(i32 %a, i32 %b, i32 %i) {
124 entry:
125 br label %loop
126 loop:
127 switch i32 %i, label %default
128 [ i32 1, label %case1
129 i32 2, label %case2 ]
130 case1:
131 %0 = mul i32 %a, 2
132 %1 = sub i32 6, 5
133 br label %bb1
134 case2:
135 %2 = and i32 %b, 3
136 %3 = sdiv i32 8, 2
137 br label %bb2
138 bb1:
139 %4 = add i32 %0, %b
140 br label %loop
141 bb2:
142 %5 = or i32 %2, %a
143 br label %loop
144 default:
145 ret void
149 Module &M = parseModule(ModuleString);
150 Function *F = M.getFunction("foo");
151 FunctionSpecializer Specializer = getSpecializerFor(F);
152 InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
154 Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1);
156 auto FuncIter = F->begin();
157 BasicBlock &Loop = *++FuncIter;
158 BasicBlock &Case1 = *++FuncIter;
159 BasicBlock &Case2 = *++FuncIter;
160 BasicBlock &BB1 = *++FuncIter;
161 BasicBlock &BB2 = *++FuncIter;
163 Instruction &Switch = Loop.front();
164 Instruction &Mul = Case1.front();
165 Instruction &And = Case2.front();
166 Instruction &Sdiv = *++Case2.begin();
167 Instruction &BrBB2 = Case2.back();
168 Instruction &Add = BB1.front();
169 Instruction &Or = BB2.front();
170 Instruction &BrLoop = BB2.back();
172 // mul
173 Bonus Ref = getInstCost(Mul);
174 Bonus Test = Visitor.getSpecializationBonus(F->getArg(0), One);
175 EXPECT_EQ(Test, Ref);
176 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
178 // and + or + add
179 Ref = getInstCost(And) + getInstCost(Or) + getInstCost(Add);
180 Test = Visitor.getSpecializationBonus(F->getArg(1), One);
181 EXPECT_EQ(Test, Ref);
182 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
184 // switch + sdiv + br + br
185 Ref = getInstCost(Switch) +
186 getInstCost(Sdiv, /*SizeOnly =*/ true) +
187 getInstCost(BrBB2, /*SizeOnly =*/ true) +
188 getInstCost(BrLoop, /*SizeOnly =*/ true);
189 Test = Visitor.getSpecializationBonus(F->getArg(2), One);
190 EXPECT_EQ(Test, Ref);
191 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
194 TEST_F(FunctionSpecializationTest, BranchInst) {
195 const char *ModuleString = R"(
196 define void @foo(i32 %a, i32 %b, i1 %cond) {
197 entry:
198 br label %loop
199 loop:
200 br i1 %cond, label %bb0, label %bb3
201 bb0:
202 %0 = mul i32 %a, 2
203 %1 = sub i32 6, 5
204 br i1 %cond, label %bb1, label %bb2
205 bb1:
206 %2 = add i32 %0, %b
207 %3 = sdiv i32 8, 2
208 br label %bb2
209 bb2:
210 br label %loop
211 bb3:
212 ret void
216 Module &M = parseModule(ModuleString);
217 Function *F = M.getFunction("foo");
218 FunctionSpecializer Specializer = getSpecializerFor(F);
219 InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
221 Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1);
222 Constant *False = ConstantInt::getFalse(M.getContext());
224 auto FuncIter = F->begin();
225 BasicBlock &Loop = *++FuncIter;
226 BasicBlock &BB0 = *++FuncIter;
227 BasicBlock &BB1 = *++FuncIter;
228 BasicBlock &BB2 = *++FuncIter;
230 Instruction &Branch = Loop.front();
231 Instruction &Mul = BB0.front();
232 Instruction &Sub = *++BB0.begin();
233 Instruction &BrBB1BB2 = BB0.back();
234 Instruction &Add = BB1.front();
235 Instruction &Sdiv = *++BB1.begin();
236 Instruction &BrBB2 = BB1.back();
237 Instruction &BrLoop = BB2.front();
239 // mul
240 Bonus Ref = getInstCost(Mul);
241 Bonus Test = Visitor.getSpecializationBonus(F->getArg(0), One);
242 EXPECT_EQ(Test, Ref);
243 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
245 // add
246 Ref = getInstCost(Add);
247 Test = Visitor.getSpecializationBonus(F->getArg(1), One);
248 EXPECT_EQ(Test, Ref);
249 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
251 // branch + sub + br + sdiv + br
252 Ref = getInstCost(Branch) +
253 getInstCost(Sub, /*SizeOnly =*/ true) +
254 getInstCost(BrBB1BB2) +
255 getInstCost(Sdiv, /*SizeOnly =*/ true) +
256 getInstCost(BrBB2, /*SizeOnly =*/ true) +
257 getInstCost(BrLoop, /*SizeOnly =*/ true);
258 Test = Visitor.getSpecializationBonus(F->getArg(2), False);
259 EXPECT_EQ(Test, Ref);
260 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
263 TEST_F(FunctionSpecializationTest, Misc) {
264 const char *ModuleString = R"(
265 %struct_t = type { [8 x i16], [8 x i16], i32, i32, i32, ptr, [8 x i8] }
266 @g = constant %struct_t zeroinitializer, align 16
268 declare i32 @llvm.smax.i32(i32, i32)
269 declare i32 @bar(i32)
271 define i32 @foo(i8 %a, i1 %cond, ptr %b, i32 %c) {
272 %cmp = icmp eq i8 %a, 10
273 %ext = zext i1 %cmp to i64
274 %sel = select i1 %cond, i64 %ext, i64 1
275 %gep = getelementptr inbounds %struct_t, ptr %b, i64 %sel, i32 4
276 %ld = load i32, ptr %gep
277 %fr = freeze i32 %ld
278 %smax = call i32 @llvm.smax.i32(i32 %fr, i32 1)
279 %call = call i32 @bar(i32 %smax)
280 %fr2 = freeze i32 %c
281 %add = add i32 %call, %fr2
282 ret i32 %add
286 Module &M = parseModule(ModuleString);
287 Function *F = M.getFunction("foo");
288 FunctionSpecializer Specializer = getSpecializerFor(F);
289 InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
291 GlobalVariable *GV = M.getGlobalVariable("g");
292 Constant *One = ConstantInt::get(IntegerType::getInt8Ty(M.getContext()), 1);
293 Constant *True = ConstantInt::getTrue(M.getContext());
294 Constant *Undef = UndefValue::get(IntegerType::getInt32Ty(M.getContext()));
296 auto BlockIter = F->front().begin();
297 Instruction &Icmp = *BlockIter++;
298 Instruction &Zext = *BlockIter++;
299 Instruction &Select = *BlockIter++;
300 Instruction &Gep = *BlockIter++;
301 Instruction &Load = *BlockIter++;
302 Instruction &Freeze = *BlockIter++;
303 Instruction &Smax = *BlockIter++;
305 // icmp + zext
306 Bonus Ref = getInstCost(Icmp) + getInstCost(Zext);
307 Bonus Test = Visitor.getSpecializationBonus(F->getArg(0), One);
308 EXPECT_EQ(Test, Ref);
309 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
311 // select
312 Ref = getInstCost(Select);
313 Test = Visitor.getSpecializationBonus(F->getArg(1), True);
314 EXPECT_EQ(Test, Ref);
315 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
317 // gep + load + freeze + smax
318 Ref = getInstCost(Gep) + getInstCost(Load) + getInstCost(Freeze) +
319 getInstCost(Smax);
320 Test = Visitor.getSpecializationBonus(F->getArg(2), GV);
321 EXPECT_EQ(Test, Ref);
322 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
324 Test = Visitor.getSpecializationBonus(F->getArg(3), Undef);
325 EXPECT_TRUE(Test.CodeSize == 0 && Test.Latency == 0);
328 TEST_F(FunctionSpecializationTest, PhiNode) {
329 const char *ModuleString = R"(
330 define void @foo(i32 %a, i32 %b, i32 %i) {
331 entry:
332 br label %loop
333 loop:
334 %0 = phi i32 [ %a, %entry ], [ %3, %bb ]
335 switch i32 %i, label %default
336 [ i32 1, label %case1
337 i32 2, label %case2 ]
338 case1:
339 %1 = add i32 %0, 1
340 br label %bb
341 case2:
342 %2 = phi i32 [ %a, %entry ], [ %0, %loop ]
343 br label %bb
345 %3 = phi i32 [ %b, %case1 ], [ %2, %case2 ], [ %3, %bb ]
346 %4 = icmp eq i32 %3, 1
347 br i1 %4, label %bb, label %loop
348 default:
349 ret void
353 Module &M = parseModule(ModuleString);
354 Function *F = M.getFunction("foo");
355 FunctionSpecializer Specializer = getSpecializerFor(F);
356 InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
358 Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1);
360 auto FuncIter = F->begin();
361 BasicBlock &Loop = *++FuncIter;
362 BasicBlock &Case1 = *++FuncIter;
363 BasicBlock &Case2 = *++FuncIter;
364 BasicBlock &BB = *++FuncIter;
366 Instruction &PhiLoop = Loop.front();
367 Instruction &Switch = Loop.back();
368 Instruction &Add = Case1.front();
369 Instruction &PhiCase2 = Case2.front();
370 Instruction &BrBB = Case2.back();
371 Instruction &PhiBB = BB.front();
372 Instruction &Icmp = *++BB.begin();
373 Instruction &Branch = BB.back();
375 Bonus Test = Visitor.getSpecializationBonus(F->getArg(0), One);
376 EXPECT_TRUE(Test.CodeSize == 0 && Test.Latency == 0);
378 Test = Visitor.getSpecializationBonus(F->getArg(1), One);
379 EXPECT_TRUE(Test.CodeSize == 0 && Test.Latency == 0);
381 // switch + phi + br
382 Bonus Ref = getInstCost(Switch) +
383 getInstCost(PhiCase2, /*SizeOnly =*/ true) +
384 getInstCost(BrBB, /*SizeOnly =*/ true);
385 Test = Visitor.getSpecializationBonus(F->getArg(2), One);
386 EXPECT_EQ(Test, Ref);
387 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
389 // phi + phi + add + icmp + branch
390 Ref = getInstCost(PhiBB) + getInstCost(PhiLoop) + getInstCost(Add) +
391 getInstCost(Icmp) + getInstCost(Branch);
392 Test = Visitor.getBonusFromPendingPHIs();
393 EXPECT_EQ(Test, Ref);
394 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);