[x86/SLH] Fix an issue where we wouldn't harden any loads if we found
[llvm-complete.git] / unittests / Analysis / SparsePropagation.cpp
blob298b1403eb5a61fd751d0c1ff1d282a1b08dbfad
1 //===- SparsePropagation.cpp - Unit tests for the generic solver ----------===//
2 //
3 // The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
10 #include "llvm/Analysis/SparsePropagation.h"
11 #include "llvm/ADT/PointerIntPair.h"
12 #include "llvm/IR/CallSite.h"
13 #include "llvm/IR/IRBuilder.h"
14 #include "gtest/gtest.h"
15 using namespace llvm;
17 namespace {
18 /// To enable interprocedural analysis, we assign LLVM values to the following
19 /// groups. The register group represents SSA registers, the return group
20 /// represents the return values of functions, and the memory group represents
21 /// in-memory values. An LLVM Value can technically be in more than one group.
22 /// It's necessary to distinguish these groups so we can, for example, track a
23 /// global variable separately from the value stored at its location.
24 enum class IPOGrouping { Register, Return, Memory };
26 /// Our LatticeKeys are PointerIntPairs composed of LLVM values and groupings.
27 /// The PointerIntPair header provides a DenseMapInfo specialization, so using
28 /// these as LatticeKeys is fine.
29 using TestLatticeKey = PointerIntPair<Value *, 2, IPOGrouping>;
30 } // namespace
32 namespace llvm {
33 /// A specialization of LatticeKeyInfo for TestLatticeKeys. The generic solver
34 /// must translate between LatticeKeys and LLVM Values when adding Values to
35 /// its work list and inspecting the state of control-flow related values.
36 template <> struct LatticeKeyInfo<TestLatticeKey> {
37 static inline Value *getValueFromLatticeKey(TestLatticeKey Key) {
38 return Key.getPointer();
40 static inline TestLatticeKey getLatticeKeyFromValue(Value *V) {
41 return TestLatticeKey(V, IPOGrouping::Register);
44 } // namespace llvm
46 namespace {
47 /// This class defines a simple test lattice value that could be used for
48 /// solving problems similar to constant propagation. The value is maintained
49 /// as a PointerIntPair.
50 class TestLatticeVal {
51 public:
52 /// The states of the lattices value. Only the ConstantVal state is
53 /// interesting; the rest are special states used by the generic solver. The
54 /// UntrackedVal state differs from the other three in that the generic
55 /// solver uses it to avoid doing unnecessary work. In particular, when a
56 /// value moves to the UntrackedVal state, it's users are not notified.
57 enum TestLatticeStateTy {
58 UndefinedVal,
59 ConstantVal,
60 OverdefinedVal,
61 UntrackedVal
64 TestLatticeVal() : LatticeVal(nullptr, UndefinedVal) {}
65 TestLatticeVal(Constant *C, TestLatticeStateTy State)
66 : LatticeVal(C, State) {}
68 /// Return true if this lattice value is in the Constant state. This is used
69 /// for checking the solver results.
70 bool isConstant() const { return LatticeVal.getInt() == ConstantVal; }
72 /// Return true if this lattice value is in the Overdefined state. This is
73 /// used for checking the solver results.
74 bool isOverdefined() const { return LatticeVal.getInt() == OverdefinedVal; }
76 bool operator==(const TestLatticeVal &RHS) const {
77 return LatticeVal == RHS.LatticeVal;
80 bool operator!=(const TestLatticeVal &RHS) const {
81 return LatticeVal != RHS.LatticeVal;
84 private:
85 /// A simple lattice value type for problems similar to constant propagation.
86 /// It holds the constant value and the lattice state.
87 PointerIntPair<const Constant *, 2, TestLatticeStateTy> LatticeVal;
90 /// This class defines a simple test lattice function that could be used for
91 /// solving problems similar to constant propagation. The test lattice differs
92 /// from a "real" lattice in a few ways. First, it initializes all return
93 /// values, values stored in global variables, and arguments in the undefined
94 /// state. This means that there are no limitations on what we can track
95 /// interprocedurally. For simplicity, all global values in the tests will be
96 /// given internal linkage, since this is not something this lattice function
97 /// tracks. Second, it only handles the few instructions necessary for the
98 /// tests.
99 class TestLatticeFunc
100 : public AbstractLatticeFunction<TestLatticeKey, TestLatticeVal> {
101 public:
102 /// Construct a new test lattice function with special values for the
103 /// Undefined, Overdefined, and Untracked states.
104 TestLatticeFunc()
105 : AbstractLatticeFunction(
106 TestLatticeVal(nullptr, TestLatticeVal::UndefinedVal),
107 TestLatticeVal(nullptr, TestLatticeVal::OverdefinedVal),
108 TestLatticeVal(nullptr, TestLatticeVal::UntrackedVal)) {}
110 /// Compute and return a TestLatticeVal for the given TestLatticeKey. For the
111 /// test analysis, a LatticeKey will begin in the undefined state, unless it
112 /// represents an LLVM Constant in the register grouping.
113 TestLatticeVal ComputeLatticeVal(TestLatticeKey Key) override {
114 if (Key.getInt() == IPOGrouping::Register)
115 if (auto *C = dyn_cast<Constant>(Key.getPointer()))
116 return TestLatticeVal(C, TestLatticeVal::ConstantVal);
117 return getUndefVal();
120 /// Merge the two given lattice values. This merge should be equivalent to
121 /// what is done for constant propagation. That is, the resulting lattice
122 /// value is constant only if the two given lattice values are constant and
123 /// hold the same value.
124 TestLatticeVal MergeValues(TestLatticeVal X, TestLatticeVal Y) override {
125 if (X == getUntrackedVal() || Y == getUntrackedVal())
126 return getUntrackedVal();
127 if (X == getOverdefinedVal() || Y == getOverdefinedVal())
128 return getOverdefinedVal();
129 if (X == getUndefVal() && Y == getUndefVal())
130 return getUndefVal();
131 if (X == getUndefVal())
132 return Y;
133 if (Y == getUndefVal())
134 return X;
135 if (X == Y)
136 return X;
137 return getOverdefinedVal();
140 /// Compute the lattice values that change as a result of executing the given
141 /// instruction. We only handle the few instructions needed for the tests.
142 void ComputeInstructionState(
143 Instruction &I, DenseMap<TestLatticeKey, TestLatticeVal> &ChangedValues,
144 SparseSolver<TestLatticeKey, TestLatticeVal> &SS) override {
145 switch (I.getOpcode()) {
146 case Instruction::Call:
147 return visitCallSite(cast<CallInst>(&I), ChangedValues, SS);
148 case Instruction::Ret:
149 return visitReturn(*cast<ReturnInst>(&I), ChangedValues, SS);
150 case Instruction::Store:
151 return visitStore(*cast<StoreInst>(&I), ChangedValues, SS);
152 default:
153 return visitInst(I, ChangedValues, SS);
157 private:
158 /// Handle call sites. The state of a called function's argument is the merge
159 /// of the current formal argument state with the call site's corresponding
160 /// actual argument state. The call site state is the merge of the call site
161 /// state with the returned value state of the called function.
162 void visitCallSite(CallSite CS,
163 DenseMap<TestLatticeKey, TestLatticeVal> &ChangedValues,
164 SparseSolver<TestLatticeKey, TestLatticeVal> &SS) {
165 Function *F = CS.getCalledFunction();
166 Instruction *I = CS.getInstruction();
167 auto RegI = TestLatticeKey(I, IPOGrouping::Register);
168 if (!F) {
169 ChangedValues[RegI] = getOverdefinedVal();
170 return;
172 SS.MarkBlockExecutable(&F->front());
173 for (Argument &A : F->args()) {
174 auto RegFormal = TestLatticeKey(&A, IPOGrouping::Register);
175 auto RegActual =
176 TestLatticeKey(CS.getArgument(A.getArgNo()), IPOGrouping::Register);
177 ChangedValues[RegFormal] =
178 MergeValues(SS.getValueState(RegFormal), SS.getValueState(RegActual));
180 auto RetF = TestLatticeKey(F, IPOGrouping::Return);
181 ChangedValues[RegI] =
182 MergeValues(SS.getValueState(RegI), SS.getValueState(RetF));
185 /// Handle return instructions. The function's return state is the merge of
186 /// the returned value state and the function's current return state.
187 void visitReturn(ReturnInst &I,
188 DenseMap<TestLatticeKey, TestLatticeVal> &ChangedValues,
189 SparseSolver<TestLatticeKey, TestLatticeVal> &SS) {
190 Function *F = I.getParent()->getParent();
191 if (F->getReturnType()->isVoidTy())
192 return;
193 auto RegR = TestLatticeKey(I.getReturnValue(), IPOGrouping::Register);
194 auto RetF = TestLatticeKey(F, IPOGrouping::Return);
195 ChangedValues[RetF] =
196 MergeValues(SS.getValueState(RegR), SS.getValueState(RetF));
199 /// Handle store instructions. If the pointer operand of the store is a
200 /// global variable, we attempt to track the value. The global variable state
201 /// is the merge of the stored value state with the current global variable
202 /// state.
203 void visitStore(StoreInst &I,
204 DenseMap<TestLatticeKey, TestLatticeVal> &ChangedValues,
205 SparseSolver<TestLatticeKey, TestLatticeVal> &SS) {
206 auto *GV = dyn_cast<GlobalVariable>(I.getPointerOperand());
207 if (!GV)
208 return;
209 auto RegVal = TestLatticeKey(I.getValueOperand(), IPOGrouping::Register);
210 auto MemPtr = TestLatticeKey(GV, IPOGrouping::Memory);
211 ChangedValues[MemPtr] =
212 MergeValues(SS.getValueState(RegVal), SS.getValueState(MemPtr));
215 /// Handle all other instructions. All other instructions are marked
216 /// overdefined.
217 void visitInst(Instruction &I,
218 DenseMap<TestLatticeKey, TestLatticeVal> &ChangedValues,
219 SparseSolver<TestLatticeKey, TestLatticeVal> &SS) {
220 auto RegI = TestLatticeKey(&I, IPOGrouping::Register);
221 ChangedValues[RegI] = getOverdefinedVal();
225 /// This class defines the common data used for all of the tests. The tests
226 /// should add code to the module and then run the solver.
227 class SparsePropagationTest : public testing::Test {
228 protected:
229 LLVMContext Context;
230 Module M;
231 IRBuilder<> Builder;
232 TestLatticeFunc Lattice;
233 SparseSolver<TestLatticeKey, TestLatticeVal> Solver;
235 public:
236 SparsePropagationTest()
237 : M("", Context), Builder(Context), Solver(&Lattice) {}
239 } // namespace
241 /// Test that we mark discovered functions executable.
243 /// define internal void @f() {
244 /// call void @g()
245 /// ret void
246 /// }
248 /// define internal void @g() {
249 /// call void @f()
250 /// ret void
251 /// }
253 /// For this test, we initially mark "f" executable, and the solver discovers
254 /// "g" because of the call in "f". The mutually recursive call in "g" also
255 /// tests that we don't add a block to the basic block work list if it is
256 /// already executable. Doing so would put the solver into an infinite loop.
257 TEST_F(SparsePropagationTest, MarkBlockExecutable) {
258 Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
259 GlobalValue::InternalLinkage, "f", &M);
260 Function *G = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
261 GlobalValue::InternalLinkage, "g", &M);
262 BasicBlock *FEntry = BasicBlock::Create(Context, "", F);
263 BasicBlock *GEntry = BasicBlock::Create(Context, "", G);
264 Builder.SetInsertPoint(FEntry);
265 Builder.CreateCall(G);
266 Builder.CreateRetVoid();
267 Builder.SetInsertPoint(GEntry);
268 Builder.CreateCall(F);
269 Builder.CreateRetVoid();
271 Solver.MarkBlockExecutable(FEntry);
272 Solver.Solve();
274 EXPECT_TRUE(Solver.isBlockExecutable(GEntry));
277 /// Test that we propagate information through global variables.
279 /// @gv = internal global i64
281 /// define internal void @f() {
282 /// store i64 1, i64* @gv
283 /// ret void
284 /// }
286 /// define internal void @g() {
287 /// store i64 1, i64* @gv
288 /// ret void
289 /// }
291 /// For this test, we initially mark both "f" and "g" executable, and the
292 /// solver computes the lattice state of the global variable as constant.
293 TEST_F(SparsePropagationTest, GlobalVariableConstant) {
294 Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
295 GlobalValue::InternalLinkage, "f", &M);
296 Function *G = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
297 GlobalValue::InternalLinkage, "g", &M);
298 GlobalVariable *GV =
299 new GlobalVariable(M, Builder.getInt64Ty(), false,
300 GlobalValue::InternalLinkage, nullptr, "gv");
301 BasicBlock *FEntry = BasicBlock::Create(Context, "", F);
302 BasicBlock *GEntry = BasicBlock::Create(Context, "", G);
303 Builder.SetInsertPoint(FEntry);
304 Builder.CreateStore(Builder.getInt64(1), GV);
305 Builder.CreateRetVoid();
306 Builder.SetInsertPoint(GEntry);
307 Builder.CreateStore(Builder.getInt64(1), GV);
308 Builder.CreateRetVoid();
310 Solver.MarkBlockExecutable(FEntry);
311 Solver.MarkBlockExecutable(GEntry);
312 Solver.Solve();
314 auto MemGV = TestLatticeKey(GV, IPOGrouping::Memory);
315 EXPECT_TRUE(Solver.getExistingValueState(MemGV).isConstant());
318 /// Test that we propagate information through global variables.
320 /// @gv = internal global i64
322 /// define internal void @f() {
323 /// store i64 0, i64* @gv
324 /// ret void
325 /// }
327 /// define internal void @g() {
328 /// store i64 1, i64* @gv
329 /// ret void
330 /// }
332 /// For this test, we initially mark both "f" and "g" executable, and the
333 /// solver computes the lattice state of the global variable as overdefined.
334 TEST_F(SparsePropagationTest, GlobalVariableOverDefined) {
335 Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
336 GlobalValue::InternalLinkage, "f", &M);
337 Function *G = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
338 GlobalValue::InternalLinkage, "g", &M);
339 GlobalVariable *GV =
340 new GlobalVariable(M, Builder.getInt64Ty(), false,
341 GlobalValue::InternalLinkage, nullptr, "gv");
342 BasicBlock *FEntry = BasicBlock::Create(Context, "", F);
343 BasicBlock *GEntry = BasicBlock::Create(Context, "", G);
344 Builder.SetInsertPoint(FEntry);
345 Builder.CreateStore(Builder.getInt64(0), GV);
346 Builder.CreateRetVoid();
347 Builder.SetInsertPoint(GEntry);
348 Builder.CreateStore(Builder.getInt64(1), GV);
349 Builder.CreateRetVoid();
351 Solver.MarkBlockExecutable(FEntry);
352 Solver.MarkBlockExecutable(GEntry);
353 Solver.Solve();
355 auto MemGV = TestLatticeKey(GV, IPOGrouping::Memory);
356 EXPECT_TRUE(Solver.getExistingValueState(MemGV).isOverdefined());
359 /// Test that we propagate information through function returns.
361 /// define internal i64 @f(i1* %cond) {
362 /// if:
363 /// %0 = load i1, i1* %cond
364 /// br i1 %0, label %then, label %else
366 /// then:
367 /// ret i64 1
369 /// else:
370 /// ret i64 1
371 /// }
373 /// For this test, we initially mark "f" executable, and the solver computes
374 /// the return value of the function as constant.
375 TEST_F(SparsePropagationTest, FunctionDefined) {
376 Function *F =
377 Function::Create(FunctionType::get(Builder.getInt64Ty(),
378 {Type::getInt1PtrTy(Context)}, false),
379 GlobalValue::InternalLinkage, "f", &M);
380 BasicBlock *If = BasicBlock::Create(Context, "if", F);
381 BasicBlock *Then = BasicBlock::Create(Context, "then", F);
382 BasicBlock *Else = BasicBlock::Create(Context, "else", F);
383 F->arg_begin()->setName("cond");
384 Builder.SetInsertPoint(If);
385 LoadInst *Cond = Builder.CreateLoad(F->arg_begin());
386 Builder.CreateCondBr(Cond, Then, Else);
387 Builder.SetInsertPoint(Then);
388 Builder.CreateRet(Builder.getInt64(1));
389 Builder.SetInsertPoint(Else);
390 Builder.CreateRet(Builder.getInt64(1));
392 Solver.MarkBlockExecutable(If);
393 Solver.Solve();
395 auto RetF = TestLatticeKey(F, IPOGrouping::Return);
396 EXPECT_TRUE(Solver.getExistingValueState(RetF).isConstant());
399 /// Test that we propagate information through function returns.
401 /// define internal i64 @f(i1* %cond) {
402 /// if:
403 /// %0 = load i1, i1* %cond
404 /// br i1 %0, label %then, label %else
406 /// then:
407 /// ret i64 0
409 /// else:
410 /// ret i64 1
411 /// }
413 /// For this test, we initially mark "f" executable, and the solver computes
414 /// the return value of the function as overdefined.
415 TEST_F(SparsePropagationTest, FunctionOverDefined) {
416 Function *F =
417 Function::Create(FunctionType::get(Builder.getInt64Ty(),
418 {Type::getInt1PtrTy(Context)}, false),
419 GlobalValue::InternalLinkage, "f", &M);
420 BasicBlock *If = BasicBlock::Create(Context, "if", F);
421 BasicBlock *Then = BasicBlock::Create(Context, "then", F);
422 BasicBlock *Else = BasicBlock::Create(Context, "else", F);
423 F->arg_begin()->setName("cond");
424 Builder.SetInsertPoint(If);
425 LoadInst *Cond = Builder.CreateLoad(F->arg_begin());
426 Builder.CreateCondBr(Cond, Then, Else);
427 Builder.SetInsertPoint(Then);
428 Builder.CreateRet(Builder.getInt64(0));
429 Builder.SetInsertPoint(Else);
430 Builder.CreateRet(Builder.getInt64(1));
432 Solver.MarkBlockExecutable(If);
433 Solver.Solve();
435 auto RetF = TestLatticeKey(F, IPOGrouping::Return);
436 EXPECT_TRUE(Solver.getExistingValueState(RetF).isOverdefined());
439 /// Test that we propagate information through arguments.
441 /// define internal void @f() {
442 /// call void @g(i64 0, i64 1)
443 /// call void @g(i64 1, i64 1)
444 /// ret void
445 /// }
447 /// define internal void @g(i64 %a, i64 %b) {
448 /// ret void
449 /// }
451 /// For this test, we initially mark "f" executable, and the solver discovers
452 /// "g" because of the calls in "f". The solver computes the state of argument
453 /// "a" as overdefined and the state of "b" as constant.
455 /// In addition, this test demonstrates that ComputeInstructionState can alter
456 /// the state of multiple lattice values, in addition to the one associated
457 /// with the instruction definition. Each call instruction in this test updates
458 /// the state of arguments "a" and "b".
459 TEST_F(SparsePropagationTest, ComputeInstructionState) {
460 Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
461 GlobalValue::InternalLinkage, "f", &M);
462 Function *G = Function::Create(
463 FunctionType::get(Builder.getVoidTy(),
464 {Builder.getInt64Ty(), Builder.getInt64Ty()}, false),
465 GlobalValue::InternalLinkage, "g", &M);
466 Argument *A = G->arg_begin();
467 Argument *B = std::next(G->arg_begin());
468 A->setName("a");
469 B->setName("b");
470 BasicBlock *FEntry = BasicBlock::Create(Context, "", F);
471 BasicBlock *GEntry = BasicBlock::Create(Context, "", G);
472 Builder.SetInsertPoint(FEntry);
473 Builder.CreateCall(G, {Builder.getInt64(0), Builder.getInt64(1)});
474 Builder.CreateCall(G, {Builder.getInt64(1), Builder.getInt64(1)});
475 Builder.CreateRetVoid();
476 Builder.SetInsertPoint(GEntry);
477 Builder.CreateRetVoid();
479 Solver.MarkBlockExecutable(FEntry);
480 Solver.Solve();
482 auto RegA = TestLatticeKey(A, IPOGrouping::Register);
483 auto RegB = TestLatticeKey(B, IPOGrouping::Register);
484 EXPECT_TRUE(Solver.getExistingValueState(RegA).isOverdefined());
485 EXPECT_TRUE(Solver.getExistingValueState(RegB).isConstant());
488 /// Test that we can handle exceptional terminator instructions.
490 /// declare internal void @p()
492 /// declare internal void @g()
494 /// define internal void @f() personality i8* bitcast (void ()* @p to i8*) {
495 /// entry:
496 /// invoke void @g()
497 /// to label %exit unwind label %catch.pad
499 /// catch.pad:
500 /// %0 = catchswitch within none [label %catch.body] unwind to caller
502 /// catch.body:
503 /// %1 = catchpad within %0 []
504 /// catchret from %1 to label %exit
506 /// exit:
507 /// ret void
508 /// }
510 /// For this test, we initially mark the entry block executable. The solver
511 /// then discovers the rest of the blocks in the function are executable.
512 TEST_F(SparsePropagationTest, ExceptionalTerminatorInsts) {
513 Function *P = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
514 GlobalValue::InternalLinkage, "p", &M);
515 Function *G = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
516 GlobalValue::InternalLinkage, "g", &M);
517 Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
518 GlobalValue::InternalLinkage, "f", &M);
519 Constant *C =
520 ConstantExpr::getCast(Instruction::BitCast, P, Builder.getInt8PtrTy());
521 F->setPersonalityFn(C);
522 BasicBlock *Entry = BasicBlock::Create(Context, "entry", F);
523 BasicBlock *Pad = BasicBlock::Create(Context, "catch.pad", F);
524 BasicBlock *Body = BasicBlock::Create(Context, "catch.body", F);
525 BasicBlock *Exit = BasicBlock::Create(Context, "exit", F);
526 Builder.SetInsertPoint(Entry);
527 Builder.CreateInvoke(G, Exit, Pad);
528 Builder.SetInsertPoint(Pad);
529 CatchSwitchInst *CatchSwitch =
530 Builder.CreateCatchSwitch(ConstantTokenNone::get(Context), nullptr, 1);
531 CatchSwitch->addHandler(Body);
532 Builder.SetInsertPoint(Body);
533 CatchPadInst *CatchPad = Builder.CreateCatchPad(CatchSwitch, {});
534 Builder.CreateCatchRet(CatchPad, Exit);
535 Builder.SetInsertPoint(Exit);
536 Builder.CreateRetVoid();
538 Solver.MarkBlockExecutable(Entry);
539 Solver.Solve();
541 EXPECT_TRUE(Solver.isBlockExecutable(Pad));
542 EXPECT_TRUE(Solver.isBlockExecutable(Body));
543 EXPECT_TRUE(Solver.isBlockExecutable(Exit));