1 //===- DivergenceAnalysisTest.cpp - DivergenceAnalysis unit tests ---------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "llvm/ADT/SmallVector.h"
10 #include "llvm/Analysis/AssumptionCache.h"
11 #include "llvm/Analysis/DivergenceAnalysis.h"
12 #include "llvm/Analysis/LoopInfo.h"
13 #include "llvm/Analysis/PostDominators.h"
14 #include "llvm/Analysis/SyncDependenceAnalysis.h"
15 #include "llvm/Analysis/TargetLibraryInfo.h"
16 #include "llvm/AsmParser/Parser.h"
17 #include "llvm/IR/Constants.h"
18 #include "llvm/IR/Dominators.h"
19 #include "llvm/IR/GlobalVariable.h"
20 #include "llvm/IR/IRBuilder.h"
21 #include "llvm/IR/InstIterator.h"
22 #include "llvm/IR/LLVMContext.h"
23 #include "llvm/IR/LegacyPassManager.h"
24 #include "llvm/IR/Module.h"
25 #include "llvm/IR/Verifier.h"
26 #include "llvm/Support/SourceMgr.h"
27 #include "gtest/gtest.h"
32 BasicBlock
*GetBlockByName(StringRef BlockName
, Function
&F
) {
34 if (BB
.getName() != BlockName
)
41 // We use this fixture to ensure that we clean up DivergenceAnalysis before
42 // deleting the PassManager.
43 class DivergenceAnalysisTest
: public testing::Test
{
47 TargetLibraryInfoImpl TLII
;
48 TargetLibraryInfo TLI
;
50 std::unique_ptr
<DominatorTree
> DT
;
51 std::unique_ptr
<PostDominatorTree
> PDT
;
52 std::unique_ptr
<LoopInfo
> LI
;
53 std::unique_ptr
<SyncDependenceAnalysis
> SDA
;
55 DivergenceAnalysisTest() : M("", Context
), TLII(), TLI(TLII
) {}
57 DivergenceAnalysis
buildDA(Function
&F
, bool IsLCSSA
) {
58 DT
.reset(new DominatorTree(F
));
59 PDT
.reset(new PostDominatorTree(F
));
60 LI
.reset(new LoopInfo(*DT
));
61 SDA
.reset(new SyncDependenceAnalysis(*DT
, *PDT
, *LI
));
62 return DivergenceAnalysis(F
, nullptr, *DT
, *LI
, *SDA
, IsLCSSA
);
66 Module
&M
, StringRef FuncName
, bool IsLCSSA
,
67 function_ref
<void(Function
&F
, LoopInfo
&LI
, DivergenceAnalysis
&DA
)>
69 auto *F
= M
.getFunction(FuncName
);
70 ASSERT_NE(F
, nullptr) << "Could not find " << FuncName
;
71 DivergenceAnalysis DA
= buildDA(*F
, IsLCSSA
);
76 // Simple initial state test
77 TEST_F(DivergenceAnalysisTest
, DAInitialState
) {
78 IntegerType
*IntTy
= IntegerType::getInt32Ty(Context
);
80 FunctionType::get(Type::getVoidTy(Context
), {IntTy
}, false);
81 Function
*F
= Function::Create(FTy
, Function::ExternalLinkage
, "f", M
);
82 BasicBlock
*BB
= BasicBlock::Create(Context
, "entry", F
);
83 ReturnInst::Create(Context
, nullptr, BB
);
85 DivergenceAnalysis DA
= buildDA(*F
, false);
87 // Whole function region
88 EXPECT_EQ(DA
.getRegionLoop(), nullptr);
90 // No divergence in initial state
91 EXPECT_FALSE(DA
.hasDetectedDivergence());
93 // No spurious divergence
95 EXPECT_FALSE(DA
.hasDetectedDivergence());
97 // Detected divergence after marking
98 Argument
&arg
= *F
->arg_begin();
99 DA
.markDivergent(arg
);
101 EXPECT_TRUE(DA
.hasDetectedDivergence());
102 EXPECT_TRUE(DA
.isDivergent(arg
));
105 EXPECT_TRUE(DA
.hasDetectedDivergence());
106 EXPECT_TRUE(DA
.isDivergent(arg
));
109 TEST_F(DivergenceAnalysisTest
, DANoLCSSA
) {
113 std::unique_ptr
<Module
> M
= parseAssemblyString(
114 "target datalayout = \"e-m:e-p:32:32-f64:32:64-f80:32-n8:16:32-S128\" "
116 "define i32 @f_1(i8* nocapture %arr, i32 %n, i32* %A, i32* %B) "
117 " local_unnamed_addr { "
119 " br label %loop.ph "
125 " %iv0 = phi i32 [ %iv0.inc, %loop ], [ 0, %loop.ph ] "
126 " %iv1 = phi i32 [ %iv1.inc, %loop ], [ -2147483648, %loop.ph ] "
127 " %iv0.inc = add i32 %iv0, 1 "
128 " %iv1.inc = add i32 %iv1, 3 "
129 " %cond.cont = icmp slt i32 %iv0, %n "
130 " br i1 %cond.cont, label %loop, label %for.end.loopexit "
137 Function
*F
= M
->getFunction("f_1");
138 DivergenceAnalysis DA
= buildDA(*F
, false);
139 EXPECT_FALSE(DA
.hasDetectedDivergence());
141 auto ItArg
= F
->arg_begin();
145 // Seed divergence in argument %n
146 DA
.markDivergent(NArg
);
149 EXPECT_TRUE(DA
.hasDetectedDivergence());
151 // Verify that "ret %iv.0" is divergent
152 auto ItBlock
= F
->begin();
153 std::advance(ItBlock
, 3);
154 auto &ExitBlock
= *GetBlockByName("for.end.loopexit", *F
);
155 auto &RetInst
= *cast
<ReturnInst
>(ExitBlock
.begin());
156 EXPECT_TRUE(DA
.isDivergent(RetInst
));
159 TEST_F(DivergenceAnalysisTest
, DALCSSA
) {
163 std::unique_ptr
<Module
> M
= parseAssemblyString(
164 "target datalayout = \"e-m:e-p:32:32-f64:32:64-f80:32-n8:16:32-S128\" "
166 "define i32 @f_lcssa(i8* nocapture %arr, i32 %n, i32* %A, i32* %B) "
167 " local_unnamed_addr { "
169 " br label %loop.ph "
175 " %iv0 = phi i32 [ %iv0.inc, %loop ], [ 0, %loop.ph ] "
176 " %iv1 = phi i32 [ %iv1.inc, %loop ], [ -2147483648, %loop.ph ] "
177 " %iv0.inc = add i32 %iv0, 1 "
178 " %iv1.inc = add i32 %iv1, 3 "
179 " %cond.cont = icmp slt i32 %iv0, %n "
180 " br i1 %cond.cont, label %loop, label %for.end.loopexit "
183 " %val.ret = phi i32 [ %iv0, %loop ] "
184 " br label %detached.return "
191 Function
*F
= M
->getFunction("f_lcssa");
192 DivergenceAnalysis DA
= buildDA(*F
, true);
193 EXPECT_FALSE(DA
.hasDetectedDivergence());
195 auto ItArg
= F
->arg_begin();
199 // Seed divergence in argument %n
200 DA
.markDivergent(NArg
);
203 EXPECT_TRUE(DA
.hasDetectedDivergence());
205 // Verify that "ret %iv.0" is divergent
206 auto ItBlock
= F
->begin();
207 std::advance(ItBlock
, 4);
208 auto &ExitBlock
= *GetBlockByName("detached.return", *F
);
209 auto &RetInst
= *cast
<ReturnInst
>(ExitBlock
.begin());
210 EXPECT_TRUE(DA
.isDivergent(RetInst
));
213 TEST_F(DivergenceAnalysisTest
, DAJoinDivergence
) {
217 std::unique_ptr
<Module
> M
= parseAssemblyString(
218 "target datalayout = \"e-m:e-p:32:32-f64:32:64-f80:32-n8:16:32-S128\" "
220 "define void @f_1(i1 %a, i1 %b, i1 %c) "
221 " local_unnamed_addr { "
223 " br i1 %a, label %B, label %C "
226 " br i1 %b, label %C, label %D "
229 " %c.join = phi i32 [ 0, %A ], [ 1, %B ] "
230 " br i1 %c, label %D, label %E "
233 " %d.join = phi i32 [ 0, %B ], [ 1, %C ] "
237 " %e.join = phi i32 [ 0, %C ], [ 1, %D ] "
241 "define void @f_2(i1 %a, i1 %b, i1 %c) "
242 " local_unnamed_addr { "
244 " br i1 %a, label %B, label %E "
247 " br i1 %b, label %C, label %D "
253 " %d.join = phi i32 [ 0, %B ], [ 1, %C ] "
257 " %e.join = phi i32 [ 0, %A ], [ 1, %D ] "
261 "define void @f_3(i1 %a, i1 %b, i1 %c)"
262 " local_unnamed_addr { "
264 " br i1 %a, label %B, label %C "
270 " %c.join = phi i32 [ 0, %A ], [ 1, %B ] "
271 " br i1 %c, label %D, label %E "
277 " %e.join = phi i32 [ 0, %C ], [ 1, %D ] "
282 // Maps divergent conditions to the basic blocks whose Phi nodes become
283 // divergent. Blocks need to be listed in IR order.
284 using SmallBlockVec
= SmallVector
<const BasicBlock
*, 4>;
285 using InducedDivJoinMap
= std::map
<const Value
*, SmallBlockVec
>;
287 // Actual function performing the checks.
288 auto CheckDivergenceFunc
= [this](Function
&F
,
289 InducedDivJoinMap
&ExpectedDivJoins
) {
290 for (auto &ItCase
: ExpectedDivJoins
) {
291 auto *DivVal
= ItCase
.first
;
292 auto DA
= buildDA(F
, false);
293 DA
.markDivergent(*DivVal
);
296 // List of basic blocks that shall host divergent Phi nodes.
297 auto ItDivJoins
= ItCase
.second
.begin();
300 auto *Phi
= dyn_cast
<PHINode
>(BB
.begin());
304 if (ItDivJoins
!= ItCase
.second
.end() && &BB
== *ItDivJoins
) {
305 EXPECT_TRUE(DA
.isDivergent(*Phi
));
306 // Advance to next block with expected divergent PHI node.
309 EXPECT_FALSE(DA
.isDivergent(*Phi
));
316 auto *F
= M
->getFunction("f_1");
317 auto ItBlocks
= F
->begin();
318 ItBlocks
++; // Skip A
319 ItBlocks
++; // Skip B
320 auto *C
= &*ItBlocks
++;
321 auto *D
= &*ItBlocks
++;
322 auto *E
= &*ItBlocks
;
324 auto ItArg
= F
->arg_begin();
325 auto *AArg
= &*ItArg
++;
326 auto *BArg
= &*ItArg
++;
327 auto *CArg
= &*ItArg
;
329 InducedDivJoinMap DivJoins
;
330 DivJoins
.emplace(AArg
, SmallBlockVec({C
, D
, E
}));
331 DivJoins
.emplace(BArg
, SmallBlockVec({D
, E
}));
332 DivJoins
.emplace(CArg
, SmallBlockVec({E
}));
334 CheckDivergenceFunc(*F
, DivJoins
);
338 auto *F
= M
->getFunction("f_2");
339 auto ItBlocks
= F
->begin();
340 ItBlocks
++; // Skip A
341 ItBlocks
++; // Skip B
342 ItBlocks
++; // Skip C
343 auto *D
= &*ItBlocks
++;
344 auto *E
= &*ItBlocks
;
346 auto ItArg
= F
->arg_begin();
347 auto *AArg
= &*ItArg
++;
348 auto *BArg
= &*ItArg
++;
349 auto *CArg
= &*ItArg
;
351 InducedDivJoinMap DivJoins
;
352 DivJoins
.emplace(AArg
, SmallBlockVec({E
}));
353 DivJoins
.emplace(BArg
, SmallBlockVec({D
}));
354 DivJoins
.emplace(CArg
, SmallBlockVec({}));
356 CheckDivergenceFunc(*F
, DivJoins
);
360 auto *F
= M
->getFunction("f_3");
361 auto ItBlocks
= F
->begin();
362 ItBlocks
++; // Skip A
363 ItBlocks
++; // Skip B
364 auto *C
= &*ItBlocks
++;
365 ItBlocks
++; // Skip D
366 auto *E
= &*ItBlocks
;
368 auto ItArg
= F
->arg_begin();
369 auto *AArg
= &*ItArg
++;
370 auto *BArg
= &*ItArg
++;
371 auto *CArg
= &*ItArg
;
373 InducedDivJoinMap DivJoins
;
374 DivJoins
.emplace(AArg
, SmallBlockVec({C
}));
375 DivJoins
.emplace(BArg
, SmallBlockVec({}));
376 DivJoins
.emplace(CArg
, SmallBlockVec({E
}));
378 CheckDivergenceFunc(*F
, DivJoins
);
382 TEST_F(DivergenceAnalysisTest
, DASwitchUnreachableDefault
) {
386 std::unique_ptr
<Module
> M
= parseAssemblyString(
387 "target datalayout = \"e-m:e-p:32:32-f64:32:64-f80:32-n8:16:32-S128\" "
389 "define void @switch_unreachable_default(i32 %cond) local_unnamed_addr { "
391 " switch i32 %cond, label %sw.default [ "
392 " i32 0, label %sw.bb0 "
393 " i32 1, label %sw.bb1 "
397 " br label %sw.epilog "
400 " br label %sw.epilog "
406 " %div.dbl = phi double [ 0.0, %sw.bb0], [ -1.0, %sw.bb1 ] "
411 auto *F
= M
->getFunction("switch_unreachable_default");
412 auto &CondArg
= *F
->arg_begin();
413 auto DA
= buildDA(*F
, false);
415 EXPECT_FALSE(DA
.hasDetectedDivergence());
417 DA
.markDivergent(CondArg
);
420 // Still %CondArg is divergent.
421 EXPECT_TRUE(DA
.hasDetectedDivergence());
423 // The join uni.dbl is not divergent (see D52221)
424 auto &ExitBlock
= *GetBlockByName("sw.epilog", *F
);
425 auto &DivDblPhi
= *cast
<PHINode
>(ExitBlock
.begin());
426 EXPECT_TRUE(DA
.isDivergent(DivDblPhi
));
429 } // end anonymous namespace
430 } // end namespace llvm