Fix test failures introduced by PR #113697 (#116941)
[llvm-project.git] / llvm / unittests / Analysis / SparsePropagation.cpp
blobca73a480cbb2db3108dfaf9b56e70e81ce6db77f
1 //===- SparsePropagation.cpp - Unit tests for the generic solver ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "llvm/Analysis/SparsePropagation.h"
10 #include "llvm/ADT/PointerIntPair.h"
11 #include "llvm/IR/IRBuilder.h"
12 #include "llvm/IR/Module.h"
13 #include "gtest/gtest.h"
14 using namespace llvm;
16 namespace {
17 /// To enable interprocedural analysis, we assign LLVM values to the following
18 /// groups. The register group represents SSA registers, the return group
19 /// represents the return values of functions, and the memory group represents
20 /// in-memory values. An LLVM Value can technically be in more than one group.
21 /// It's necessary to distinguish these groups so we can, for example, track a
22 /// global variable separately from the value stored at its location.
23 enum class IPOGrouping { Register, Return, Memory };
25 /// Our LatticeKeys are PointerIntPairs composed of LLVM values and groupings.
26 /// The PointerIntPair header provides a DenseMapInfo specialization, so using
27 /// these as LatticeKeys is fine.
28 using TestLatticeKey = PointerIntPair<Value *, 2, IPOGrouping>;
29 } // namespace
31 namespace llvm {
32 /// A specialization of LatticeKeyInfo for TestLatticeKeys. The generic solver
33 /// must translate between LatticeKeys and LLVM Values when adding Values to
34 /// its work list and inspecting the state of control-flow related values.
35 template <> struct LatticeKeyInfo<TestLatticeKey> {
36 static inline Value *getValueFromLatticeKey(TestLatticeKey Key) {
37 return Key.getPointer();
39 static inline TestLatticeKey getLatticeKeyFromValue(Value *V) {
40 return TestLatticeKey(V, IPOGrouping::Register);
43 } // namespace llvm
45 namespace {
46 /// This class defines a simple test lattice value that could be used for
47 /// solving problems similar to constant propagation. The value is maintained
48 /// as a PointerIntPair.
49 class TestLatticeVal {
50 public:
51 /// The states of the lattices value. Only the ConstantVal state is
52 /// interesting; the rest are special states used by the generic solver. The
53 /// UntrackedVal state differs from the other three in that the generic
54 /// solver uses it to avoid doing unnecessary work. In particular, when a
55 /// value moves to the UntrackedVal state, it's users are not notified.
56 enum TestLatticeStateTy {
57 UndefinedVal,
58 ConstantVal,
59 OverdefinedVal,
60 UntrackedVal
63 TestLatticeVal() : LatticeVal(nullptr, UndefinedVal) {}
64 TestLatticeVal(Constant *C, TestLatticeStateTy State)
65 : LatticeVal(C, State) {}
67 /// Return true if this lattice value is in the Constant state. This is used
68 /// for checking the solver results.
69 bool isConstant() const { return LatticeVal.getInt() == ConstantVal; }
71 /// Return true if this lattice value is in the Overdefined state. This is
72 /// used for checking the solver results.
73 bool isOverdefined() const { return LatticeVal.getInt() == OverdefinedVal; }
75 bool operator==(const TestLatticeVal &RHS) const {
76 return LatticeVal == RHS.LatticeVal;
79 bool operator!=(const TestLatticeVal &RHS) const {
80 return LatticeVal != RHS.LatticeVal;
83 private:
84 /// A simple lattice value type for problems similar to constant propagation.
85 /// It holds the constant value and the lattice state.
86 PointerIntPair<const Constant *, 2, TestLatticeStateTy> LatticeVal;
89 /// This class defines a simple test lattice function that could be used for
90 /// solving problems similar to constant propagation. The test lattice differs
91 /// from a "real" lattice in a few ways. First, it initializes all return
92 /// values, values stored in global variables, and arguments in the undefined
93 /// state. This means that there are no limitations on what we can track
94 /// interprocedurally. For simplicity, all global values in the tests will be
95 /// given internal linkage, since this is not something this lattice function
96 /// tracks. Second, it only handles the few instructions necessary for the
97 /// tests.
98 class TestLatticeFunc
99 : public AbstractLatticeFunction<TestLatticeKey, TestLatticeVal> {
100 public:
101 /// Construct a new test lattice function with special values for the
102 /// Undefined, Overdefined, and Untracked states.
103 TestLatticeFunc()
104 : AbstractLatticeFunction(
105 TestLatticeVal(nullptr, TestLatticeVal::UndefinedVal),
106 TestLatticeVal(nullptr, TestLatticeVal::OverdefinedVal),
107 TestLatticeVal(nullptr, TestLatticeVal::UntrackedVal)) {}
109 /// Compute and return a TestLatticeVal for the given TestLatticeKey. For the
110 /// test analysis, a LatticeKey will begin in the undefined state, unless it
111 /// represents an LLVM Constant in the register grouping.
112 TestLatticeVal ComputeLatticeVal(TestLatticeKey Key) override {
113 if (Key.getInt() == IPOGrouping::Register)
114 if (auto *C = dyn_cast<Constant>(Key.getPointer()))
115 return TestLatticeVal(C, TestLatticeVal::ConstantVal);
116 return getUndefVal();
119 /// Merge the two given lattice values. This merge should be equivalent to
120 /// what is done for constant propagation. That is, the resulting lattice
121 /// value is constant only if the two given lattice values are constant and
122 /// hold the same value.
123 TestLatticeVal MergeValues(TestLatticeVal X, TestLatticeVal Y) override {
124 if (X == getUntrackedVal() || Y == getUntrackedVal())
125 return getUntrackedVal();
126 if (X == getOverdefinedVal() || Y == getOverdefinedVal())
127 return getOverdefinedVal();
128 if (X == getUndefVal() && Y == getUndefVal())
129 return getUndefVal();
130 if (X == getUndefVal())
131 return Y;
132 if (Y == getUndefVal())
133 return X;
134 if (X == Y)
135 return X;
136 return getOverdefinedVal();
139 /// Compute the lattice values that change as a result of executing the given
140 /// instruction. We only handle the few instructions needed for the tests.
141 void ComputeInstructionState(
142 Instruction &I,
143 SmallDenseMap<TestLatticeKey, TestLatticeVal, 16> &ChangedValues,
144 SparseSolver<TestLatticeKey, TestLatticeVal> &SS) override {
145 switch (I.getOpcode()) {
146 case Instruction::Call:
147 return visitCallBase(cast<CallBase>(I), ChangedValues, SS);
148 case Instruction::Ret:
149 return visitReturn(*cast<ReturnInst>(&I), ChangedValues, SS);
150 case Instruction::Store:
151 return visitStore(*cast<StoreInst>(&I), ChangedValues, SS);
152 default:
153 return visitInst(I, ChangedValues, SS);
157 private:
158 /// Handle call sites. The state of a called function's argument is the merge
159 /// of the current formal argument state with the call site's corresponding
160 /// actual argument state. The call site state is the merge of the call site
161 /// state with the returned value state of the called function.
162 void visitCallBase(CallBase &I,
163 SmallDenseMap<TestLatticeKey, TestLatticeVal, 16> &ChangedValues,
164 SparseSolver<TestLatticeKey, TestLatticeVal> &SS) {
165 Function *F = I.getCalledFunction();
166 auto RegI = TestLatticeKey(&I, IPOGrouping::Register);
167 if (!F) {
168 ChangedValues[RegI] = getOverdefinedVal();
169 return;
171 SS.MarkBlockExecutable(&F->front());
172 for (Argument &A : F->args()) {
173 auto RegFormal = TestLatticeKey(&A, IPOGrouping::Register);
174 auto RegActual =
175 TestLatticeKey(I.getArgOperand(A.getArgNo()), IPOGrouping::Register);
176 ChangedValues[RegFormal] =
177 MergeValues(SS.getValueState(RegFormal), SS.getValueState(RegActual));
179 auto RetF = TestLatticeKey(F, IPOGrouping::Return);
180 ChangedValues[RegI] =
181 MergeValues(SS.getValueState(RegI), SS.getValueState(RetF));
184 /// Handle return instructions. The function's return state is the merge of
185 /// the returned value state and the function's current return state.
186 void visitReturn(ReturnInst &I,
187 SmallDenseMap<TestLatticeKey, TestLatticeVal, 16> &ChangedValues,
188 SparseSolver<TestLatticeKey, TestLatticeVal> &SS) {
189 Function *F = I.getParent()->getParent();
190 if (F->getReturnType()->isVoidTy())
191 return;
192 auto RegR = TestLatticeKey(I.getReturnValue(), IPOGrouping::Register);
193 auto RetF = TestLatticeKey(F, IPOGrouping::Return);
194 ChangedValues[RetF] =
195 MergeValues(SS.getValueState(RegR), SS.getValueState(RetF));
198 /// Handle store instructions. If the pointer operand of the store is a
199 /// global variable, we attempt to track the value. The global variable state
200 /// is the merge of the stored value state with the current global variable
201 /// state.
202 void visitStore(StoreInst &I,
203 SmallDenseMap<TestLatticeKey, TestLatticeVal, 16> &ChangedValues,
204 SparseSolver<TestLatticeKey, TestLatticeVal> &SS) {
205 auto *GV = dyn_cast<GlobalVariable>(I.getPointerOperand());
206 if (!GV)
207 return;
208 auto RegVal = TestLatticeKey(I.getValueOperand(), IPOGrouping::Register);
209 auto MemPtr = TestLatticeKey(GV, IPOGrouping::Memory);
210 ChangedValues[MemPtr] =
211 MergeValues(SS.getValueState(RegVal), SS.getValueState(MemPtr));
214 /// Handle all other instructions. All other instructions are marked
215 /// overdefined.
216 void visitInst(Instruction &I,
217 SmallDenseMap<TestLatticeKey, TestLatticeVal, 16> &ChangedValues,
218 SparseSolver<TestLatticeKey, TestLatticeVal> &SS) {
219 auto RegI = TestLatticeKey(&I, IPOGrouping::Register);
220 ChangedValues[RegI] = getOverdefinedVal();
224 /// This class defines the common data used for all of the tests. The tests
225 /// should add code to the module and then run the solver.
226 class SparsePropagationTest : public testing::Test {
227 protected:
228 LLVMContext Context;
229 Module M;
230 IRBuilder<> Builder;
231 TestLatticeFunc Lattice;
232 SparseSolver<TestLatticeKey, TestLatticeVal> Solver;
234 public:
235 SparsePropagationTest()
236 : M("", Context), Builder(Context), Solver(&Lattice) {}
238 } // namespace
240 /// Test that we mark discovered functions executable.
242 /// define internal void @f() {
243 /// call void @g()
244 /// ret void
245 /// }
247 /// define internal void @g() {
248 /// call void @f()
249 /// ret void
250 /// }
252 /// For this test, we initially mark "f" executable, and the solver discovers
253 /// "g" because of the call in "f". The mutually recursive call in "g" also
254 /// tests that we don't add a block to the basic block work list if it is
255 /// already executable. Doing so would put the solver into an infinite loop.
256 TEST_F(SparsePropagationTest, MarkBlockExecutable) {
257 Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
258 GlobalValue::InternalLinkage, "f", &M);
259 Function *G = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
260 GlobalValue::InternalLinkage, "g", &M);
261 BasicBlock *FEntry = BasicBlock::Create(Context, "", F);
262 BasicBlock *GEntry = BasicBlock::Create(Context, "", G);
263 Builder.SetInsertPoint(FEntry);
264 Builder.CreateCall(G);
265 Builder.CreateRetVoid();
266 Builder.SetInsertPoint(GEntry);
267 Builder.CreateCall(F);
268 Builder.CreateRetVoid();
270 Solver.MarkBlockExecutable(FEntry);
271 Solver.Solve();
273 EXPECT_TRUE(Solver.isBlockExecutable(GEntry));
276 /// Test that we propagate information through global variables.
278 /// @gv = internal global i64
280 /// define internal void @f() {
281 /// store i64 1, i64* @gv
282 /// ret void
283 /// }
285 /// define internal void @g() {
286 /// store i64 1, i64* @gv
287 /// ret void
288 /// }
290 /// For this test, we initially mark both "f" and "g" executable, and the
291 /// solver computes the lattice state of the global variable as constant.
292 TEST_F(SparsePropagationTest, GlobalVariableConstant) {
293 Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
294 GlobalValue::InternalLinkage, "f", &M);
295 Function *G = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
296 GlobalValue::InternalLinkage, "g", &M);
297 GlobalVariable *GV =
298 new GlobalVariable(M, Builder.getInt64Ty(), false,
299 GlobalValue::InternalLinkage, nullptr, "gv");
300 BasicBlock *FEntry = BasicBlock::Create(Context, "", F);
301 BasicBlock *GEntry = BasicBlock::Create(Context, "", G);
302 Builder.SetInsertPoint(FEntry);
303 Builder.CreateStore(Builder.getInt64(1), GV);
304 Builder.CreateRetVoid();
305 Builder.SetInsertPoint(GEntry);
306 Builder.CreateStore(Builder.getInt64(1), GV);
307 Builder.CreateRetVoid();
309 Solver.MarkBlockExecutable(FEntry);
310 Solver.MarkBlockExecutable(GEntry);
311 Solver.Solve();
313 auto MemGV = TestLatticeKey(GV, IPOGrouping::Memory);
314 EXPECT_TRUE(Solver.getExistingValueState(MemGV).isConstant());
317 /// Test that we propagate information through global variables.
319 /// @gv = internal global i64
321 /// define internal void @f() {
322 /// store i64 0, i64* @gv
323 /// ret void
324 /// }
326 /// define internal void @g() {
327 /// store i64 1, i64* @gv
328 /// ret void
329 /// }
331 /// For this test, we initially mark both "f" and "g" executable, and the
332 /// solver computes the lattice state of the global variable as overdefined.
333 TEST_F(SparsePropagationTest, GlobalVariableOverDefined) {
334 Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
335 GlobalValue::InternalLinkage, "f", &M);
336 Function *G = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
337 GlobalValue::InternalLinkage, "g", &M);
338 GlobalVariable *GV =
339 new GlobalVariable(M, Builder.getInt64Ty(), false,
340 GlobalValue::InternalLinkage, nullptr, "gv");
341 BasicBlock *FEntry = BasicBlock::Create(Context, "", F);
342 BasicBlock *GEntry = BasicBlock::Create(Context, "", G);
343 Builder.SetInsertPoint(FEntry);
344 Builder.CreateStore(Builder.getInt64(0), GV);
345 Builder.CreateRetVoid();
346 Builder.SetInsertPoint(GEntry);
347 Builder.CreateStore(Builder.getInt64(1), GV);
348 Builder.CreateRetVoid();
350 Solver.MarkBlockExecutable(FEntry);
351 Solver.MarkBlockExecutable(GEntry);
352 Solver.Solve();
354 auto MemGV = TestLatticeKey(GV, IPOGrouping::Memory);
355 EXPECT_TRUE(Solver.getExistingValueState(MemGV).isOverdefined());
358 /// Test that we propagate information through function returns.
360 /// define internal i64 @f(i1* %cond) {
361 /// if:
362 /// %0 = load i1, i1* %cond
363 /// br i1 %0, label %then, label %else
365 /// then:
366 /// ret i64 1
368 /// else:
369 /// ret i64 1
370 /// }
372 /// For this test, we initially mark "f" executable, and the solver computes
373 /// the return value of the function as constant.
374 TEST_F(SparsePropagationTest, FunctionDefined) {
375 Function *F =
376 Function::Create(FunctionType::get(Builder.getInt64Ty(),
377 {PointerType::get(Context, 0)}, false),
378 GlobalValue::InternalLinkage, "f", &M);
379 BasicBlock *If = BasicBlock::Create(Context, "if", F);
380 BasicBlock *Then = BasicBlock::Create(Context, "then", F);
381 BasicBlock *Else = BasicBlock::Create(Context, "else", F);
382 F->arg_begin()->setName("cond");
383 Builder.SetInsertPoint(If);
384 LoadInst *Cond = Builder.CreateLoad(Type::getInt1Ty(Context), F->arg_begin());
385 Builder.CreateCondBr(Cond, Then, Else);
386 Builder.SetInsertPoint(Then);
387 Builder.CreateRet(Builder.getInt64(1));
388 Builder.SetInsertPoint(Else);
389 Builder.CreateRet(Builder.getInt64(1));
391 Solver.MarkBlockExecutable(If);
392 Solver.Solve();
394 auto RetF = TestLatticeKey(F, IPOGrouping::Return);
395 EXPECT_TRUE(Solver.getExistingValueState(RetF).isConstant());
398 /// Test that we propagate information through function returns.
400 /// define internal i64 @f(i1* %cond) {
401 /// if:
402 /// %0 = load i1, i1* %cond
403 /// br i1 %0, label %then, label %else
405 /// then:
406 /// ret i64 0
408 /// else:
409 /// ret i64 1
410 /// }
412 /// For this test, we initially mark "f" executable, and the solver computes
413 /// the return value of the function as overdefined.
414 TEST_F(SparsePropagationTest, FunctionOverDefined) {
415 Function *F =
416 Function::Create(FunctionType::get(Builder.getInt64Ty(),
417 {PointerType::get(Context, 0)}, false),
418 GlobalValue::InternalLinkage, "f", &M);
419 BasicBlock *If = BasicBlock::Create(Context, "if", F);
420 BasicBlock *Then = BasicBlock::Create(Context, "then", F);
421 BasicBlock *Else = BasicBlock::Create(Context, "else", F);
422 F->arg_begin()->setName("cond");
423 Builder.SetInsertPoint(If);
424 LoadInst *Cond = Builder.CreateLoad(Type::getInt1Ty(Context), F->arg_begin());
425 Builder.CreateCondBr(Cond, Then, Else);
426 Builder.SetInsertPoint(Then);
427 Builder.CreateRet(Builder.getInt64(0));
428 Builder.SetInsertPoint(Else);
429 Builder.CreateRet(Builder.getInt64(1));
431 Solver.MarkBlockExecutable(If);
432 Solver.Solve();
434 auto RetF = TestLatticeKey(F, IPOGrouping::Return);
435 EXPECT_TRUE(Solver.getExistingValueState(RetF).isOverdefined());
438 /// Test that we propagate information through arguments.
440 /// define internal void @f() {
441 /// call void @g(i64 0, i64 1)
442 /// call void @g(i64 1, i64 1)
443 /// ret void
444 /// }
446 /// define internal void @g(i64 %a, i64 %b) {
447 /// ret void
448 /// }
450 /// For this test, we initially mark "f" executable, and the solver discovers
451 /// "g" because of the calls in "f". The solver computes the state of argument
452 /// "a" as overdefined and the state of "b" as constant.
454 /// In addition, this test demonstrates that ComputeInstructionState can alter
455 /// the state of multiple lattice values, in addition to the one associated
456 /// with the instruction definition. Each call instruction in this test updates
457 /// the state of arguments "a" and "b".
458 TEST_F(SparsePropagationTest, ComputeInstructionState) {
459 Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
460 GlobalValue::InternalLinkage, "f", &M);
461 Function *G = Function::Create(
462 FunctionType::get(Builder.getVoidTy(),
463 {Builder.getInt64Ty(), Builder.getInt64Ty()}, false),
464 GlobalValue::InternalLinkage, "g", &M);
465 Argument *A = G->arg_begin();
466 Argument *B = std::next(G->arg_begin());
467 A->setName("a");
468 B->setName("b");
469 BasicBlock *FEntry = BasicBlock::Create(Context, "", F);
470 BasicBlock *GEntry = BasicBlock::Create(Context, "", G);
471 Builder.SetInsertPoint(FEntry);
472 Builder.CreateCall(G, {Builder.getInt64(0), Builder.getInt64(1)});
473 Builder.CreateCall(G, {Builder.getInt64(1), Builder.getInt64(1)});
474 Builder.CreateRetVoid();
475 Builder.SetInsertPoint(GEntry);
476 Builder.CreateRetVoid();
478 Solver.MarkBlockExecutable(FEntry);
479 Solver.Solve();
481 auto RegA = TestLatticeKey(A, IPOGrouping::Register);
482 auto RegB = TestLatticeKey(B, IPOGrouping::Register);
483 EXPECT_TRUE(Solver.getExistingValueState(RegA).isOverdefined());
484 EXPECT_TRUE(Solver.getExistingValueState(RegB).isConstant());
487 /// Test that we can handle exceptional terminator instructions.
489 /// declare internal void @p()
491 /// declare internal void @g()
493 /// define internal void @f() personality ptr @p {
494 /// entry:
495 /// invoke void @g()
496 /// to label %exit unwind label %catch.pad
498 /// catch.pad:
499 /// %0 = catchswitch within none [label %catch.body] unwind to caller
501 /// catch.body:
502 /// %1 = catchpad within %0 []
503 /// catchret from %1 to label %exit
505 /// exit:
506 /// ret void
507 /// }
509 /// For this test, we initially mark the entry block executable. The solver
510 /// then discovers the rest of the blocks in the function are executable.
511 TEST_F(SparsePropagationTest, ExceptionalTerminatorInsts) {
512 Function *P = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
513 GlobalValue::InternalLinkage, "p", &M);
514 Function *G = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
515 GlobalValue::InternalLinkage, "g", &M);
516 Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
517 GlobalValue::InternalLinkage, "f", &M);
518 F->setPersonalityFn(P);
519 BasicBlock *Entry = BasicBlock::Create(Context, "entry", F);
520 BasicBlock *Pad = BasicBlock::Create(Context, "catch.pad", F);
521 BasicBlock *Body = BasicBlock::Create(Context, "catch.body", F);
522 BasicBlock *Exit = BasicBlock::Create(Context, "exit", F);
523 Builder.SetInsertPoint(Entry);
524 Builder.CreateInvoke(G, Exit, Pad);
525 Builder.SetInsertPoint(Pad);
526 CatchSwitchInst *CatchSwitch =
527 Builder.CreateCatchSwitch(ConstantTokenNone::get(Context), nullptr, 1);
528 CatchSwitch->addHandler(Body);
529 Builder.SetInsertPoint(Body);
530 CatchPadInst *CatchPad = Builder.CreateCatchPad(CatchSwitch, {});
531 Builder.CreateCatchRet(CatchPad, Exit);
532 Builder.SetInsertPoint(Exit);
533 Builder.CreateRetVoid();
535 Solver.MarkBlockExecutable(Entry);
536 Solver.Solve();
538 EXPECT_TRUE(Solver.isBlockExecutable(Pad));
539 EXPECT_TRUE(Solver.isBlockExecutable(Body));
540 EXPECT_TRUE(Solver.isBlockExecutable(Exit));