[clangd] Fix warnings
[llvm-project.git] / llvm / unittests / Transforms / IPO / FunctionSpecializationTest.cpp
blob9f76e9ff11c3aa24713a08fa79ed05636ab6d7b0
1 //===- FunctionSpecializationTest.cpp - Cost model unit tests -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "llvm/Analysis/AssumptionCache.h"
10 #include "llvm/Analysis/BlockFrequencyInfo.h"
11 #include "llvm/Analysis/BranchProbabilityInfo.h"
12 #include "llvm/Analysis/LoopInfo.h"
13 #include "llvm/Analysis/PostDominators.h"
14 #include "llvm/Analysis/TargetLibraryInfo.h"
15 #include "llvm/Analysis/TargetTransformInfo.h"
16 #include "llvm/AsmParser/Parser.h"
17 #include "llvm/IR/Constants.h"
18 #include "llvm/IR/PassInstrumentation.h"
19 #include "llvm/Support/SourceMgr.h"
20 #include "llvm/Transforms/IPO/FunctionSpecialization.h"
21 #include "llvm/Transforms/Utils/SCCPSolver.h"
22 #include "gtest/gtest.h"
23 #include <memory>
25 namespace llvm {
27 static void removeSSACopy(Function &F) {
28 for (BasicBlock &BB : F) {
29 for (Instruction &Inst : llvm::make_early_inc_range(BB)) {
30 if (auto *II = dyn_cast<IntrinsicInst>(&Inst)) {
31 if (II->getIntrinsicID() != Intrinsic::ssa_copy)
32 continue;
33 Inst.replaceAllUsesWith(II->getOperand(0));
34 Inst.eraseFromParent();
40 class FunctionSpecializationTest : public testing::Test {
41 protected:
42 LLVMContext Ctx;
43 FunctionAnalysisManager FAM;
44 std::unique_ptr<Module> M;
45 std::unique_ptr<SCCPSolver> Solver;
46 SmallVector<Instruction *, 8> KnownConstants;
48 FunctionSpecializationTest() {
49 FAM.registerPass([&] { return TargetLibraryAnalysis(); });
50 FAM.registerPass([&] { return TargetIRAnalysis(); });
51 FAM.registerPass([&] { return BlockFrequencyAnalysis(); });
52 FAM.registerPass([&] { return BranchProbabilityAnalysis(); });
53 FAM.registerPass([&] { return LoopAnalysis(); });
54 FAM.registerPass([&] { return AssumptionAnalysis(); });
55 FAM.registerPass([&] { return DominatorTreeAnalysis(); });
56 FAM.registerPass([&] { return PostDominatorTreeAnalysis(); });
57 FAM.registerPass([&] { return PassInstrumentationAnalysis(); });
60 Module &parseModule(const char *ModuleString) {
61 SMDiagnostic Err;
62 M = parseAssemblyString(ModuleString, Err, Ctx);
63 EXPECT_TRUE(M);
64 return *M;
67 FunctionSpecializer getSpecializerFor(Function *F) {
68 auto GetTLI = [this](Function &F) -> const TargetLibraryInfo & {
69 return FAM.getResult<TargetLibraryAnalysis>(F);
71 auto GetTTI = [this](Function &F) -> TargetTransformInfo & {
72 return FAM.getResult<TargetIRAnalysis>(F);
74 auto GetAC = [this](Function &F) -> AssumptionCache & {
75 return FAM.getResult<AssumptionAnalysis>(F);
77 auto GetDT = [this](Function &F) -> DominatorTree & {
78 return FAM.getResult<DominatorTreeAnalysis>(F);
80 auto GetBFI = [this](Function &F) -> BlockFrequencyInfo & {
81 return FAM.getResult<BlockFrequencyAnalysis>(F);
84 Solver = std::make_unique<SCCPSolver>(M->getDataLayout(), GetTLI, Ctx);
86 DominatorTree &DT = GetDT(*F);
87 AssumptionCache &AC = GetAC(*F);
88 Solver->addPredicateInfo(*F, DT, AC);
90 Solver->markBlockExecutable(&F->front());
91 for (Argument &Arg : F->args())
92 Solver->markOverdefined(&Arg);
93 Solver->solveWhileResolvedUndefsIn(*M);
95 removeSSACopy(*F);
97 return FunctionSpecializer(*Solver, *M, &FAM, GetBFI, GetTLI, GetTTI,
98 GetAC);
101 Cost getCodeSizeSavings(Instruction &I, bool HasLatencySavings = true) {
102 auto &TTI = FAM.getResult<TargetIRAnalysis>(*I.getFunction());
104 Cost CodeSize =
105 TTI.getInstructionCost(&I, TargetTransformInfo::TCK_CodeSize);
107 if (HasLatencySavings)
108 KnownConstants.push_back(&I);
110 return CodeSize;
113 Cost getLatencySavings(Function *F) {
114 auto &TTI = FAM.getResult<TargetIRAnalysis>(*F);
115 auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(*F);
117 Cost Latency = 0;
118 for (const Instruction *I : KnownConstants)
119 Latency += BFI.getBlockFreq(I->getParent()).getFrequency() /
120 BFI.getEntryFreq().getFrequency() *
121 TTI.getInstructionCost(I, TargetTransformInfo::TCK_Latency);
123 return Latency;
127 } // namespace llvm
129 using namespace llvm;
131 TEST_F(FunctionSpecializationTest, SwitchInst) {
132 const char *ModuleString = R"(
133 define void @foo(i32 %a, i32 %b, i32 %i) {
134 entry:
135 br label %loop
136 loop:
137 switch i32 %i, label %default
138 [ i32 1, label %case1
139 i32 2, label %case2 ]
140 case1:
141 %0 = mul i32 %a, 2
142 %1 = sub i32 6, 5
143 br label %bb1
144 case2:
145 %2 = and i32 %b, 3
146 %3 = sdiv i32 8, 2
147 br label %bb2
148 bb1:
149 %4 = add i32 %0, %b
150 br label %loop
151 bb2:
152 %5 = or i32 %2, %a
153 br label %loop
154 default:
155 ret void
159 Module &M = parseModule(ModuleString);
160 Function *F = M.getFunction("foo");
161 FunctionSpecializer Specializer = getSpecializerFor(F);
162 InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
164 Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1);
166 auto FuncIter = F->begin();
167 BasicBlock &Loop = *++FuncIter;
168 BasicBlock &Case1 = *++FuncIter;
169 BasicBlock &Case2 = *++FuncIter;
170 BasicBlock &BB1 = *++FuncIter;
171 BasicBlock &BB2 = *++FuncIter;
173 Instruction &Switch = Loop.front();
174 Instruction &Mul = Case1.front();
175 Instruction &And = Case2.front();
176 Instruction &Sdiv = *++Case2.begin();
177 Instruction &BrBB2 = Case2.back();
178 Instruction &Add = BB1.front();
179 Instruction &Or = BB2.front();
180 Instruction &BrLoop = BB2.back();
182 // mul
183 Cost Ref = getCodeSizeSavings(Mul);
184 Cost Test = Visitor.getCodeSizeSavingsForArg(F->getArg(0), One);
185 EXPECT_EQ(Test, Ref);
186 EXPECT_TRUE(Test > 0);
188 // and + or + add
189 Ref = getCodeSizeSavings(And) + getCodeSizeSavings(Or) +
190 getCodeSizeSavings(Add);
191 Test = Visitor.getCodeSizeSavingsForArg(F->getArg(1), One);
192 EXPECT_EQ(Test, Ref);
193 EXPECT_TRUE(Test > 0);
195 // switch + sdiv + br + br
196 Ref = getCodeSizeSavings(Switch) +
197 getCodeSizeSavings(Sdiv, /*HasLatencySavings=*/false) +
198 getCodeSizeSavings(BrBB2, /*HasLatencySavings=*/false) +
199 getCodeSizeSavings(BrLoop, /*HasLatencySavings=*/false);
200 Test = Visitor.getCodeSizeSavingsForArg(F->getArg(2), One);
201 EXPECT_EQ(Test, Ref);
202 EXPECT_TRUE(Test > 0);
204 // Latency.
205 Ref = getLatencySavings(F);
206 Test = Visitor.getLatencySavingsForKnownConstants();
207 EXPECT_EQ(Test, Ref);
208 EXPECT_TRUE(Test > 0);
211 TEST_F(FunctionSpecializationTest, BranchInst) {
212 const char *ModuleString = R"(
213 define void @foo(i32 %a, i32 %b, i1 %cond) {
214 entry:
215 br label %loop
216 loop:
217 br i1 %cond, label %bb0, label %bb3
218 bb0:
219 %0 = mul i32 %a, 2
220 %1 = sub i32 6, 5
221 br i1 %cond, label %bb1, label %bb2
222 bb1:
223 %2 = add i32 %0, %b
224 %3 = sdiv i32 8, 2
225 br label %bb2
226 bb2:
227 br label %loop
228 bb3:
229 ret void
233 Module &M = parseModule(ModuleString);
234 Function *F = M.getFunction("foo");
235 FunctionSpecializer Specializer = getSpecializerFor(F);
236 InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
238 Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1);
239 Constant *False = ConstantInt::getFalse(M.getContext());
241 auto FuncIter = F->begin();
242 BasicBlock &Loop = *++FuncIter;
243 BasicBlock &BB0 = *++FuncIter;
244 BasicBlock &BB1 = *++FuncIter;
245 BasicBlock &BB2 = *++FuncIter;
247 Instruction &Branch = Loop.front();
248 Instruction &Mul = BB0.front();
249 Instruction &Sub = *++BB0.begin();
250 Instruction &BrBB1BB2 = BB0.back();
251 Instruction &Add = BB1.front();
252 Instruction &Sdiv = *++BB1.begin();
253 Instruction &BrBB2 = BB1.back();
254 Instruction &BrLoop = BB2.front();
256 // mul
257 Cost Ref = getCodeSizeSavings(Mul);
258 Cost Test = Visitor.getCodeSizeSavingsForArg(F->getArg(0), One);
259 EXPECT_EQ(Test, Ref);
260 EXPECT_TRUE(Test > 0);
262 // add
263 Ref = getCodeSizeSavings(Add);
264 Test = Visitor.getCodeSizeSavingsForArg(F->getArg(1), One);
265 EXPECT_EQ(Test, Ref);
266 EXPECT_TRUE(Test > 0);
268 // branch + sub + br + sdiv + br
269 Ref = getCodeSizeSavings(Branch) +
270 getCodeSizeSavings(Sub, /*HasLatencySavings=*/false) +
271 getCodeSizeSavings(BrBB1BB2) +
272 getCodeSizeSavings(Sdiv, /*HasLatencySavings=*/false) +
273 getCodeSizeSavings(BrBB2, /*HasLatencySavings=*/false) +
274 getCodeSizeSavings(BrLoop, /*HasLatencySavings=*/false);
275 Test = Visitor.getCodeSizeSavingsForArg(F->getArg(2), False);
276 EXPECT_EQ(Test, Ref);
277 EXPECT_TRUE(Test > 0);
279 // Latency.
280 Ref = getLatencySavings(F);
281 Test = Visitor.getLatencySavingsForKnownConstants();
282 EXPECT_EQ(Test, Ref);
283 EXPECT_TRUE(Test > 0);
286 TEST_F(FunctionSpecializationTest, SelectInst) {
287 const char *ModuleString = R"(
288 define i32 @foo(i1 %cond, i32 %a, i32 %b) {
289 %sel = select i1 %cond, i32 %a, i32 %b
290 ret i32 %sel
294 Module &M = parseModule(ModuleString);
295 Function *F = M.getFunction("foo");
296 FunctionSpecializer Specializer = getSpecializerFor(F);
297 InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
299 Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1);
300 Constant *Zero = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 0);
301 Constant *False = ConstantInt::getFalse(M.getContext());
302 Instruction &Select = *F->front().begin();
304 Cost RefCodeSize = getCodeSizeSavings(Select);
305 Cost RefLatency = getLatencySavings(F);
307 Cost TestCodeSize = Visitor.getCodeSizeSavingsForArg(F->getArg(0), False);
308 EXPECT_TRUE(TestCodeSize == 0);
309 TestCodeSize = Visitor.getCodeSizeSavingsForArg(F->getArg(1), One);
310 EXPECT_TRUE(TestCodeSize == 0);
311 Cost TestLatency = Visitor.getLatencySavingsForKnownConstants();
312 EXPECT_TRUE(TestLatency == 0);
314 TestCodeSize = Visitor.getCodeSizeSavingsForArg(F->getArg(2), Zero);
315 EXPECT_EQ(TestCodeSize, RefCodeSize);
316 EXPECT_TRUE(TestCodeSize > 0);
317 TestLatency = Visitor.getLatencySavingsForKnownConstants();
318 EXPECT_EQ(TestLatency, RefLatency);
319 EXPECT_TRUE(TestLatency > 0);
322 TEST_F(FunctionSpecializationTest, Misc) {
323 const char *ModuleString = R"(
324 %struct_t = type { [8 x i16], [8 x i16], i32, i32, i32, ptr, [8 x i8] }
325 @g = constant %struct_t zeroinitializer, align 16
327 declare i32 @llvm.smax.i32(i32, i32)
328 declare i32 @bar(i32)
330 define i32 @foo(i8 %a, i1 %cond, ptr %b, i32 %c) {
331 %cmp = icmp eq i8 %a, 10
332 %ext = zext i1 %cmp to i64
333 %sel = select i1 %cond, i64 %ext, i64 1
334 %gep = getelementptr inbounds %struct_t, ptr %b, i64 %sel, i32 4
335 %ld = load i32, ptr %gep
336 %fr = freeze i32 %ld
337 %smax = call i32 @llvm.smax.i32(i32 %fr, i32 1)
338 %call = call i32 @bar(i32 %smax)
339 %fr2 = freeze i32 %c
340 %add = add i32 %call, %fr2
341 ret i32 %add
345 Module &M = parseModule(ModuleString);
346 Function *F = M.getFunction("foo");
347 FunctionSpecializer Specializer = getSpecializerFor(F);
348 InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
350 GlobalVariable *GV = M.getGlobalVariable("g");
351 Constant *One = ConstantInt::get(IntegerType::getInt8Ty(M.getContext()), 1);
352 Constant *True = ConstantInt::getTrue(M.getContext());
353 Constant *Undef = UndefValue::get(IntegerType::getInt32Ty(M.getContext()));
355 auto BlockIter = F->front().begin();
356 Instruction &Icmp = *BlockIter++;
357 Instruction &Zext = *BlockIter++;
358 Instruction &Select = *BlockIter++;
359 Instruction &Gep = *BlockIter++;
360 Instruction &Load = *BlockIter++;
361 Instruction &Freeze = *BlockIter++;
362 Instruction &Smax = *BlockIter++;
364 // icmp + zext
365 Cost Ref = getCodeSizeSavings(Icmp) + getCodeSizeSavings(Zext);
366 Cost Test = Visitor.getCodeSizeSavingsForArg(F->getArg(0), One);
367 EXPECT_EQ(Test, Ref);
368 EXPECT_TRUE(Test > 0);
370 // select
371 Ref = getCodeSizeSavings(Select);
372 Test = Visitor.getCodeSizeSavingsForArg(F->getArg(1), True);
373 EXPECT_EQ(Test, Ref);
374 EXPECT_TRUE(Test > 0);
376 // gep + load + freeze + smax
377 Ref = getCodeSizeSavings(Gep) + getCodeSizeSavings(Load) +
378 getCodeSizeSavings(Freeze) + getCodeSizeSavings(Smax);
379 Test = Visitor.getCodeSizeSavingsForArg(F->getArg(2), GV);
380 EXPECT_EQ(Test, Ref);
381 EXPECT_TRUE(Test > 0);
383 Test = Visitor.getCodeSizeSavingsForArg(F->getArg(3), Undef);
384 EXPECT_TRUE(Test == 0);
386 // Latency.
387 Ref = getLatencySavings(F);
388 Test = Visitor.getLatencySavingsForKnownConstants();
389 EXPECT_EQ(Test, Ref);
390 EXPECT_TRUE(Test > 0);
393 TEST_F(FunctionSpecializationTest, PhiNode) {
394 const char *ModuleString = R"(
395 define void @foo(i32 %a, i32 %b, i32 %i) {
396 entry:
397 br label %loop
398 loop:
399 %0 = phi i32 [ %a, %entry ], [ %3, %bb ]
400 switch i32 %i, label %default
401 [ i32 1, label %case1
402 i32 2, label %case2 ]
403 case1:
404 %1 = add i32 %0, 1
405 br label %bb
406 case2:
407 %2 = phi i32 [ %a, %entry ], [ %0, %loop ]
408 br label %bb
410 %3 = phi i32 [ %b, %case1 ], [ %2, %case2 ], [ %3, %bb ]
411 %4 = icmp eq i32 %3, 1
412 br i1 %4, label %bb, label %loop
413 default:
414 ret void
418 Module &M = parseModule(ModuleString);
419 Function *F = M.getFunction("foo");
420 FunctionSpecializer Specializer = getSpecializerFor(F);
421 InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
423 Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1);
425 auto FuncIter = F->begin();
426 BasicBlock &Loop = *++FuncIter;
427 BasicBlock &Case1 = *++FuncIter;
428 BasicBlock &Case2 = *++FuncIter;
429 BasicBlock &BB = *++FuncIter;
431 Instruction &PhiLoop = Loop.front();
432 Instruction &Switch = Loop.back();
433 Instruction &Add = Case1.front();
434 Instruction &PhiCase2 = Case2.front();
435 Instruction &BrBB = Case2.back();
436 Instruction &PhiBB = BB.front();
437 Instruction &Icmp = *++BB.begin();
438 Instruction &Branch = BB.back();
440 Cost Test = Visitor.getCodeSizeSavingsForArg(F->getArg(0), One);
441 EXPECT_TRUE(Test == 0);
443 Test = Visitor.getCodeSizeSavingsForArg(F->getArg(1), One);
444 EXPECT_TRUE(Test == 0);
446 Test = Visitor.getLatencySavingsForKnownConstants();
447 EXPECT_TRUE(Test == 0);
449 // switch + phi + br
450 Cost Ref = getCodeSizeSavings(Switch) +
451 getCodeSizeSavings(PhiCase2, /*HasLatencySavings=*/false) +
452 getCodeSizeSavings(BrBB, /*HasLatencySavings=*/false);
453 Test = Visitor.getCodeSizeSavingsForArg(F->getArg(2), One);
454 EXPECT_EQ(Test, Ref);
455 EXPECT_TRUE(Test > 0 && Test > 0);
457 // phi + phi + add + icmp + branch
458 Ref = getCodeSizeSavings(PhiBB) + getCodeSizeSavings(PhiLoop) +
459 getCodeSizeSavings(Add) + getCodeSizeSavings(Icmp) +
460 getCodeSizeSavings(Branch);
461 Test = Visitor.getCodeSizeSavingsFromPendingPHIs();
462 EXPECT_EQ(Test, Ref);
463 EXPECT_TRUE(Test > 0);
465 // Latency.
466 Ref = getLatencySavings(F);
467 Test = Visitor.getLatencySavingsForKnownConstants();
468 EXPECT_EQ(Test, Ref);
469 EXPECT_TRUE(Test > 0);
472 TEST_F(FunctionSpecializationTest, BinOp) {
473 // Verify that we can handle binary operators even when only one operand is
474 // constant.
475 const char *ModuleString = R"(
476 define i32 @foo(i1 %a, i1 %b) {
477 %and1 = and i1 %a, %b
478 %and2 = and i1 %b, %and1
479 %sel = select i1 %and2, i32 1, i32 0
480 ret i32 %sel
484 Module &M = parseModule(ModuleString);
485 Function *F = M.getFunction("foo");
486 FunctionSpecializer Specializer = getSpecializerFor(F);
487 InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
489 Constant *False = ConstantInt::getFalse(M.getContext());
490 BasicBlock &BB = F->front();
491 Instruction &And1 = BB.front();
492 Instruction &And2 = *++BB.begin();
493 Instruction &Select = *++BB.begin();
495 Cost RefCodeSize = getCodeSizeSavings(And1) + getCodeSizeSavings(And2) +
496 getCodeSizeSavings(Select);
497 Cost RefLatency = getLatencySavings(F);
499 Cost TestCodeSize = Visitor.getCodeSizeSavingsForArg(F->getArg(0), False);
500 Cost TestLatency = Visitor.getLatencySavingsForKnownConstants();
502 EXPECT_EQ(TestCodeSize, RefCodeSize);
503 EXPECT_TRUE(TestCodeSize > 0);
504 EXPECT_EQ(TestLatency, RefLatency);
505 EXPECT_TRUE(TestLatency > 0);