1 //===- SparsePropagation.cpp - Unit tests for the generic solver ----------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "llvm/Analysis/SparsePropagation.h"
10 #include "llvm/ADT/PointerIntPair.h"
11 #include "llvm/IR/CallSite.h"
12 #include "llvm/IR/IRBuilder.h"
13 #include "gtest/gtest.h"
17 /// To enable interprocedural analysis, we assign LLVM values to the following
18 /// groups. The register group represents SSA registers, the return group
19 /// represents the return values of functions, and the memory group represents
20 /// in-memory values. An LLVM Value can technically be in more than one group.
21 /// It's necessary to distinguish these groups so we can, for example, track a
22 /// global variable separately from the value stored at its location.
23 enum class IPOGrouping
{ Register
, Return
, Memory
};
25 /// Our LatticeKeys are PointerIntPairs composed of LLVM values and groupings.
26 /// The PointerIntPair header provides a DenseMapInfo specialization, so using
27 /// these as LatticeKeys is fine.
28 using TestLatticeKey
= PointerIntPair
<Value
*, 2, IPOGrouping
>;
32 /// A specialization of LatticeKeyInfo for TestLatticeKeys. The generic solver
33 /// must translate between LatticeKeys and LLVM Values when adding Values to
34 /// its work list and inspecting the state of control-flow related values.
35 template <> struct LatticeKeyInfo
<TestLatticeKey
> {
36 static inline Value
*getValueFromLatticeKey(TestLatticeKey Key
) {
37 return Key
.getPointer();
39 static inline TestLatticeKey
getLatticeKeyFromValue(Value
*V
) {
40 return TestLatticeKey(V
, IPOGrouping::Register
);
46 /// This class defines a simple test lattice value that could be used for
47 /// solving problems similar to constant propagation. The value is maintained
48 /// as a PointerIntPair.
49 class TestLatticeVal
{
51 /// The states of the lattices value. Only the ConstantVal state is
52 /// interesting; the rest are special states used by the generic solver. The
53 /// UntrackedVal state differs from the other three in that the generic
54 /// solver uses it to avoid doing unnecessary work. In particular, when a
55 /// value moves to the UntrackedVal state, it's users are not notified.
56 enum TestLatticeStateTy
{
63 TestLatticeVal() : LatticeVal(nullptr, UndefinedVal
) {}
64 TestLatticeVal(Constant
*C
, TestLatticeStateTy State
)
65 : LatticeVal(C
, State
) {}
67 /// Return true if this lattice value is in the Constant state. This is used
68 /// for checking the solver results.
69 bool isConstant() const { return LatticeVal
.getInt() == ConstantVal
; }
71 /// Return true if this lattice value is in the Overdefined state. This is
72 /// used for checking the solver results.
73 bool isOverdefined() const { return LatticeVal
.getInt() == OverdefinedVal
; }
75 bool operator==(const TestLatticeVal
&RHS
) const {
76 return LatticeVal
== RHS
.LatticeVal
;
79 bool operator!=(const TestLatticeVal
&RHS
) const {
80 return LatticeVal
!= RHS
.LatticeVal
;
84 /// A simple lattice value type for problems similar to constant propagation.
85 /// It holds the constant value and the lattice state.
86 PointerIntPair
<const Constant
*, 2, TestLatticeStateTy
> LatticeVal
;
89 /// This class defines a simple test lattice function that could be used for
90 /// solving problems similar to constant propagation. The test lattice differs
91 /// from a "real" lattice in a few ways. First, it initializes all return
92 /// values, values stored in global variables, and arguments in the undefined
93 /// state. This means that there are no limitations on what we can track
94 /// interprocedurally. For simplicity, all global values in the tests will be
95 /// given internal linkage, since this is not something this lattice function
96 /// tracks. Second, it only handles the few instructions necessary for the
99 : public AbstractLatticeFunction
<TestLatticeKey
, TestLatticeVal
> {
101 /// Construct a new test lattice function with special values for the
102 /// Undefined, Overdefined, and Untracked states.
104 : AbstractLatticeFunction(
105 TestLatticeVal(nullptr, TestLatticeVal::UndefinedVal
),
106 TestLatticeVal(nullptr, TestLatticeVal::OverdefinedVal
),
107 TestLatticeVal(nullptr, TestLatticeVal::UntrackedVal
)) {}
109 /// Compute and return a TestLatticeVal for the given TestLatticeKey. For the
110 /// test analysis, a LatticeKey will begin in the undefined state, unless it
111 /// represents an LLVM Constant in the register grouping.
112 TestLatticeVal
ComputeLatticeVal(TestLatticeKey Key
) override
{
113 if (Key
.getInt() == IPOGrouping::Register
)
114 if (auto *C
= dyn_cast
<Constant
>(Key
.getPointer()))
115 return TestLatticeVal(C
, TestLatticeVal::ConstantVal
);
116 return getUndefVal();
119 /// Merge the two given lattice values. This merge should be equivalent to
120 /// what is done for constant propagation. That is, the resulting lattice
121 /// value is constant only if the two given lattice values are constant and
122 /// hold the same value.
123 TestLatticeVal
MergeValues(TestLatticeVal X
, TestLatticeVal Y
) override
{
124 if (X
== getUntrackedVal() || Y
== getUntrackedVal())
125 return getUntrackedVal();
126 if (X
== getOverdefinedVal() || Y
== getOverdefinedVal())
127 return getOverdefinedVal();
128 if (X
== getUndefVal() && Y
== getUndefVal())
129 return getUndefVal();
130 if (X
== getUndefVal())
132 if (Y
== getUndefVal())
136 return getOverdefinedVal();
139 /// Compute the lattice values that change as a result of executing the given
140 /// instruction. We only handle the few instructions needed for the tests.
141 void ComputeInstructionState(
142 Instruction
&I
, DenseMap
<TestLatticeKey
, TestLatticeVal
> &ChangedValues
,
143 SparseSolver
<TestLatticeKey
, TestLatticeVal
> &SS
) override
{
144 switch (I
.getOpcode()) {
145 case Instruction::Call
:
146 return visitCallSite(cast
<CallInst
>(&I
), ChangedValues
, SS
);
147 case Instruction::Ret
:
148 return visitReturn(*cast
<ReturnInst
>(&I
), ChangedValues
, SS
);
149 case Instruction::Store
:
150 return visitStore(*cast
<StoreInst
>(&I
), ChangedValues
, SS
);
152 return visitInst(I
, ChangedValues
, SS
);
157 /// Handle call sites. The state of a called function's argument is the merge
158 /// of the current formal argument state with the call site's corresponding
159 /// actual argument state. The call site state is the merge of the call site
160 /// state with the returned value state of the called function.
161 void visitCallSite(CallSite CS
,
162 DenseMap
<TestLatticeKey
, TestLatticeVal
> &ChangedValues
,
163 SparseSolver
<TestLatticeKey
, TestLatticeVal
> &SS
) {
164 Function
*F
= CS
.getCalledFunction();
165 Instruction
*I
= CS
.getInstruction();
166 auto RegI
= TestLatticeKey(I
, IPOGrouping::Register
);
168 ChangedValues
[RegI
] = getOverdefinedVal();
171 SS
.MarkBlockExecutable(&F
->front());
172 for (Argument
&A
: F
->args()) {
173 auto RegFormal
= TestLatticeKey(&A
, IPOGrouping::Register
);
175 TestLatticeKey(CS
.getArgument(A
.getArgNo()), IPOGrouping::Register
);
176 ChangedValues
[RegFormal
] =
177 MergeValues(SS
.getValueState(RegFormal
), SS
.getValueState(RegActual
));
179 auto RetF
= TestLatticeKey(F
, IPOGrouping::Return
);
180 ChangedValues
[RegI
] =
181 MergeValues(SS
.getValueState(RegI
), SS
.getValueState(RetF
));
184 /// Handle return instructions. The function's return state is the merge of
185 /// the returned value state and the function's current return state.
186 void visitReturn(ReturnInst
&I
,
187 DenseMap
<TestLatticeKey
, TestLatticeVal
> &ChangedValues
,
188 SparseSolver
<TestLatticeKey
, TestLatticeVal
> &SS
) {
189 Function
*F
= I
.getParent()->getParent();
190 if (F
->getReturnType()->isVoidTy())
192 auto RegR
= TestLatticeKey(I
.getReturnValue(), IPOGrouping::Register
);
193 auto RetF
= TestLatticeKey(F
, IPOGrouping::Return
);
194 ChangedValues
[RetF
] =
195 MergeValues(SS
.getValueState(RegR
), SS
.getValueState(RetF
));
198 /// Handle store instructions. If the pointer operand of the store is a
199 /// global variable, we attempt to track the value. The global variable state
200 /// is the merge of the stored value state with the current global variable
202 void visitStore(StoreInst
&I
,
203 DenseMap
<TestLatticeKey
, TestLatticeVal
> &ChangedValues
,
204 SparseSolver
<TestLatticeKey
, TestLatticeVal
> &SS
) {
205 auto *GV
= dyn_cast
<GlobalVariable
>(I
.getPointerOperand());
208 auto RegVal
= TestLatticeKey(I
.getValueOperand(), IPOGrouping::Register
);
209 auto MemPtr
= TestLatticeKey(GV
, IPOGrouping::Memory
);
210 ChangedValues
[MemPtr
] =
211 MergeValues(SS
.getValueState(RegVal
), SS
.getValueState(MemPtr
));
214 /// Handle all other instructions. All other instructions are marked
216 void visitInst(Instruction
&I
,
217 DenseMap
<TestLatticeKey
, TestLatticeVal
> &ChangedValues
,
218 SparseSolver
<TestLatticeKey
, TestLatticeVal
> &SS
) {
219 auto RegI
= TestLatticeKey(&I
, IPOGrouping::Register
);
220 ChangedValues
[RegI
] = getOverdefinedVal();
224 /// This class defines the common data used for all of the tests. The tests
225 /// should add code to the module and then run the solver.
226 class SparsePropagationTest
: public testing::Test
{
231 TestLatticeFunc Lattice
;
232 SparseSolver
<TestLatticeKey
, TestLatticeVal
> Solver
;
235 SparsePropagationTest()
236 : M("", Context
), Builder(Context
), Solver(&Lattice
) {}
240 /// Test that we mark discovered functions executable.
242 /// define internal void @f() {
247 /// define internal void @g() {
252 /// For this test, we initially mark "f" executable, and the solver discovers
253 /// "g" because of the call in "f". The mutually recursive call in "g" also
254 /// tests that we don't add a block to the basic block work list if it is
255 /// already executable. Doing so would put the solver into an infinite loop.
256 TEST_F(SparsePropagationTest
, MarkBlockExecutable
) {
257 Function
*F
= Function::Create(FunctionType::get(Builder
.getVoidTy(), false),
258 GlobalValue::InternalLinkage
, "f", &M
);
259 Function
*G
= Function::Create(FunctionType::get(Builder
.getVoidTy(), false),
260 GlobalValue::InternalLinkage
, "g", &M
);
261 BasicBlock
*FEntry
= BasicBlock::Create(Context
, "", F
);
262 BasicBlock
*GEntry
= BasicBlock::Create(Context
, "", G
);
263 Builder
.SetInsertPoint(FEntry
);
264 Builder
.CreateCall(G
);
265 Builder
.CreateRetVoid();
266 Builder
.SetInsertPoint(GEntry
);
267 Builder
.CreateCall(F
);
268 Builder
.CreateRetVoid();
270 Solver
.MarkBlockExecutable(FEntry
);
273 EXPECT_TRUE(Solver
.isBlockExecutable(GEntry
));
276 /// Test that we propagate information through global variables.
278 /// @gv = internal global i64
280 /// define internal void @f() {
281 /// store i64 1, i64* @gv
285 /// define internal void @g() {
286 /// store i64 1, i64* @gv
290 /// For this test, we initially mark both "f" and "g" executable, and the
291 /// solver computes the lattice state of the global variable as constant.
292 TEST_F(SparsePropagationTest
, GlobalVariableConstant
) {
293 Function
*F
= Function::Create(FunctionType::get(Builder
.getVoidTy(), false),
294 GlobalValue::InternalLinkage
, "f", &M
);
295 Function
*G
= Function::Create(FunctionType::get(Builder
.getVoidTy(), false),
296 GlobalValue::InternalLinkage
, "g", &M
);
298 new GlobalVariable(M
, Builder
.getInt64Ty(), false,
299 GlobalValue::InternalLinkage
, nullptr, "gv");
300 BasicBlock
*FEntry
= BasicBlock::Create(Context
, "", F
);
301 BasicBlock
*GEntry
= BasicBlock::Create(Context
, "", G
);
302 Builder
.SetInsertPoint(FEntry
);
303 Builder
.CreateStore(Builder
.getInt64(1), GV
);
304 Builder
.CreateRetVoid();
305 Builder
.SetInsertPoint(GEntry
);
306 Builder
.CreateStore(Builder
.getInt64(1), GV
);
307 Builder
.CreateRetVoid();
309 Solver
.MarkBlockExecutable(FEntry
);
310 Solver
.MarkBlockExecutable(GEntry
);
313 auto MemGV
= TestLatticeKey(GV
, IPOGrouping::Memory
);
314 EXPECT_TRUE(Solver
.getExistingValueState(MemGV
).isConstant());
317 /// Test that we propagate information through global variables.
319 /// @gv = internal global i64
321 /// define internal void @f() {
322 /// store i64 0, i64* @gv
326 /// define internal void @g() {
327 /// store i64 1, i64* @gv
331 /// For this test, we initially mark both "f" and "g" executable, and the
332 /// solver computes the lattice state of the global variable as overdefined.
333 TEST_F(SparsePropagationTest
, GlobalVariableOverDefined
) {
334 Function
*F
= Function::Create(FunctionType::get(Builder
.getVoidTy(), false),
335 GlobalValue::InternalLinkage
, "f", &M
);
336 Function
*G
= Function::Create(FunctionType::get(Builder
.getVoidTy(), false),
337 GlobalValue::InternalLinkage
, "g", &M
);
339 new GlobalVariable(M
, Builder
.getInt64Ty(), false,
340 GlobalValue::InternalLinkage
, nullptr, "gv");
341 BasicBlock
*FEntry
= BasicBlock::Create(Context
, "", F
);
342 BasicBlock
*GEntry
= BasicBlock::Create(Context
, "", G
);
343 Builder
.SetInsertPoint(FEntry
);
344 Builder
.CreateStore(Builder
.getInt64(0), GV
);
345 Builder
.CreateRetVoid();
346 Builder
.SetInsertPoint(GEntry
);
347 Builder
.CreateStore(Builder
.getInt64(1), GV
);
348 Builder
.CreateRetVoid();
350 Solver
.MarkBlockExecutable(FEntry
);
351 Solver
.MarkBlockExecutable(GEntry
);
354 auto MemGV
= TestLatticeKey(GV
, IPOGrouping::Memory
);
355 EXPECT_TRUE(Solver
.getExistingValueState(MemGV
).isOverdefined());
358 /// Test that we propagate information through function returns.
360 /// define internal i64 @f(i1* %cond) {
362 /// %0 = load i1, i1* %cond
363 /// br i1 %0, label %then, label %else
372 /// For this test, we initially mark "f" executable, and the solver computes
373 /// the return value of the function as constant.
374 TEST_F(SparsePropagationTest
, FunctionDefined
) {
376 Function::Create(FunctionType::get(Builder
.getInt64Ty(),
377 {Type::getInt1PtrTy(Context
)}, false),
378 GlobalValue::InternalLinkage
, "f", &M
);
379 BasicBlock
*If
= BasicBlock::Create(Context
, "if", F
);
380 BasicBlock
*Then
= BasicBlock::Create(Context
, "then", F
);
381 BasicBlock
*Else
= BasicBlock::Create(Context
, "else", F
);
382 F
->arg_begin()->setName("cond");
383 Builder
.SetInsertPoint(If
);
384 LoadInst
*Cond
= Builder
.CreateLoad(Type::getInt1Ty(Context
), F
->arg_begin());
385 Builder
.CreateCondBr(Cond
, Then
, Else
);
386 Builder
.SetInsertPoint(Then
);
387 Builder
.CreateRet(Builder
.getInt64(1));
388 Builder
.SetInsertPoint(Else
);
389 Builder
.CreateRet(Builder
.getInt64(1));
391 Solver
.MarkBlockExecutable(If
);
394 auto RetF
= TestLatticeKey(F
, IPOGrouping::Return
);
395 EXPECT_TRUE(Solver
.getExistingValueState(RetF
).isConstant());
398 /// Test that we propagate information through function returns.
400 /// define internal i64 @f(i1* %cond) {
402 /// %0 = load i1, i1* %cond
403 /// br i1 %0, label %then, label %else
412 /// For this test, we initially mark "f" executable, and the solver computes
413 /// the return value of the function as overdefined.
414 TEST_F(SparsePropagationTest
, FunctionOverDefined
) {
416 Function::Create(FunctionType::get(Builder
.getInt64Ty(),
417 {Type::getInt1PtrTy(Context
)}, false),
418 GlobalValue::InternalLinkage
, "f", &M
);
419 BasicBlock
*If
= BasicBlock::Create(Context
, "if", F
);
420 BasicBlock
*Then
= BasicBlock::Create(Context
, "then", F
);
421 BasicBlock
*Else
= BasicBlock::Create(Context
, "else", F
);
422 F
->arg_begin()->setName("cond");
423 Builder
.SetInsertPoint(If
);
424 LoadInst
*Cond
= Builder
.CreateLoad(Type::getInt1Ty(Context
), F
->arg_begin());
425 Builder
.CreateCondBr(Cond
, Then
, Else
);
426 Builder
.SetInsertPoint(Then
);
427 Builder
.CreateRet(Builder
.getInt64(0));
428 Builder
.SetInsertPoint(Else
);
429 Builder
.CreateRet(Builder
.getInt64(1));
431 Solver
.MarkBlockExecutable(If
);
434 auto RetF
= TestLatticeKey(F
, IPOGrouping::Return
);
435 EXPECT_TRUE(Solver
.getExistingValueState(RetF
).isOverdefined());
438 /// Test that we propagate information through arguments.
440 /// define internal void @f() {
441 /// call void @g(i64 0, i64 1)
442 /// call void @g(i64 1, i64 1)
446 /// define internal void @g(i64 %a, i64 %b) {
450 /// For this test, we initially mark "f" executable, and the solver discovers
451 /// "g" because of the calls in "f". The solver computes the state of argument
452 /// "a" as overdefined and the state of "b" as constant.
454 /// In addition, this test demonstrates that ComputeInstructionState can alter
455 /// the state of multiple lattice values, in addition to the one associated
456 /// with the instruction definition. Each call instruction in this test updates
457 /// the state of arguments "a" and "b".
458 TEST_F(SparsePropagationTest
, ComputeInstructionState
) {
459 Function
*F
= Function::Create(FunctionType::get(Builder
.getVoidTy(), false),
460 GlobalValue::InternalLinkage
, "f", &M
);
461 Function
*G
= Function::Create(
462 FunctionType::get(Builder
.getVoidTy(),
463 {Builder
.getInt64Ty(), Builder
.getInt64Ty()}, false),
464 GlobalValue::InternalLinkage
, "g", &M
);
465 Argument
*A
= G
->arg_begin();
466 Argument
*B
= std::next(G
->arg_begin());
469 BasicBlock
*FEntry
= BasicBlock::Create(Context
, "", F
);
470 BasicBlock
*GEntry
= BasicBlock::Create(Context
, "", G
);
471 Builder
.SetInsertPoint(FEntry
);
472 Builder
.CreateCall(G
, {Builder
.getInt64(0), Builder
.getInt64(1)});
473 Builder
.CreateCall(G
, {Builder
.getInt64(1), Builder
.getInt64(1)});
474 Builder
.CreateRetVoid();
475 Builder
.SetInsertPoint(GEntry
);
476 Builder
.CreateRetVoid();
478 Solver
.MarkBlockExecutable(FEntry
);
481 auto RegA
= TestLatticeKey(A
, IPOGrouping::Register
);
482 auto RegB
= TestLatticeKey(B
, IPOGrouping::Register
);
483 EXPECT_TRUE(Solver
.getExistingValueState(RegA
).isOverdefined());
484 EXPECT_TRUE(Solver
.getExistingValueState(RegB
).isConstant());
487 /// Test that we can handle exceptional terminator instructions.
489 /// declare internal void @p()
491 /// declare internal void @g()
493 /// define internal void @f() personality i8* bitcast (void ()* @p to i8*) {
496 /// to label %exit unwind label %catch.pad
499 /// %0 = catchswitch within none [label %catch.body] unwind to caller
502 /// %1 = catchpad within %0 []
503 /// catchret from %1 to label %exit
509 /// For this test, we initially mark the entry block executable. The solver
510 /// then discovers the rest of the blocks in the function are executable.
511 TEST_F(SparsePropagationTest
, ExceptionalTerminatorInsts
) {
512 Function
*P
= Function::Create(FunctionType::get(Builder
.getVoidTy(), false),
513 GlobalValue::InternalLinkage
, "p", &M
);
514 Function
*G
= Function::Create(FunctionType::get(Builder
.getVoidTy(), false),
515 GlobalValue::InternalLinkage
, "g", &M
);
516 Function
*F
= Function::Create(FunctionType::get(Builder
.getVoidTy(), false),
517 GlobalValue::InternalLinkage
, "f", &M
);
519 ConstantExpr::getCast(Instruction::BitCast
, P
, Builder
.getInt8PtrTy());
520 F
->setPersonalityFn(C
);
521 BasicBlock
*Entry
= BasicBlock::Create(Context
, "entry", F
);
522 BasicBlock
*Pad
= BasicBlock::Create(Context
, "catch.pad", F
);
523 BasicBlock
*Body
= BasicBlock::Create(Context
, "catch.body", F
);
524 BasicBlock
*Exit
= BasicBlock::Create(Context
, "exit", F
);
525 Builder
.SetInsertPoint(Entry
);
526 Builder
.CreateInvoke(G
, Exit
, Pad
);
527 Builder
.SetInsertPoint(Pad
);
528 CatchSwitchInst
*CatchSwitch
=
529 Builder
.CreateCatchSwitch(ConstantTokenNone::get(Context
), nullptr, 1);
530 CatchSwitch
->addHandler(Body
);
531 Builder
.SetInsertPoint(Body
);
532 CatchPadInst
*CatchPad
= Builder
.CreateCatchPad(CatchSwitch
, {});
533 Builder
.CreateCatchRet(CatchPad
, Exit
);
534 Builder
.SetInsertPoint(Exit
);
535 Builder
.CreateRetVoid();
537 Solver
.MarkBlockExecutable(Entry
);
540 EXPECT_TRUE(Solver
.isBlockExecutable(Pad
));
541 EXPECT_TRUE(Solver
.isBlockExecutable(Body
));
542 EXPECT_TRUE(Solver
.isBlockExecutable(Exit
));