1 //===- SPIRVPartialOrderingVisitorTests.cpp ----------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "SPIRVUtils.h"
10 #include "llvm/Analysis/DominanceFrontier.h"
11 #include "llvm/Analysis/PostDominators.h"
12 #include "llvm/AsmParser/Parser.h"
13 #include "llvm/IR/Instructions.h"
14 #include "llvm/IR/LLVMContext.h"
15 #include "llvm/IR/LegacyPassManager.h"
16 #include "llvm/IR/Module.h"
17 #include "llvm/IR/PassInstrumentation.h"
18 #include "llvm/IR/Type.h"
19 #include "llvm/IR/TypedPointerType.h"
20 #include "llvm/Support/SourceMgr.h"
22 #include "gmock/gmock.h"
23 #include "gtest/gtest.h"
27 using namespace llvm::SPIRV
;
29 class SPIRVPartialOrderingVisitorTest
: public testing::Test
{
31 void TearDown() override
{ M
.reset(); }
33 void run(StringRef Assembly
) {
34 assert(M
== nullptr &&
35 "Calling runAnalysis multiple times is unsafe. See getAnalysis().");
38 M
= parseAssemblyString(Assembly
, Error
, Context
);
39 assert(M
&& "Bad assembly. Bad test?");
41 llvm::Function
*F
= M
->getFunction("main");
42 Visitor
= std::make_unique
<PartialOrderingVisitor
>(*F
);
46 checkBasicBlockRank(std::vector
<std::pair
<const char *, size_t>> &&Expected
) {
47 llvm::Function
*F
= M
->getFunction("main");
48 auto It
= Expected
.begin();
49 Visitor
->partialOrderVisit(*F
->begin(), [&](BasicBlock
*BB
) {
50 const auto &[Name
, Rank
] = *It
;
51 EXPECT_TRUE(It
!= Expected
.end())
52 << "Unexpected block \"" << BB
->getName() << " visited.";
53 EXPECT_TRUE(BB
->getName() == Name
)
54 << "Error: expected block \"" << Name
<< "\" got \"" << BB
->getName()
56 EXPECT_EQ(Rank
, Visitor
->GetNodeRank(BB
))
57 << "Bad rank for BB \"" << BB
->getName() << "\"";
61 ASSERT_TRUE(It
== Expected
.end())
62 << "Expected block \"" << It
->first
63 << "\" but reached the end of the function instead.";
68 std::unique_ptr
<Module
> M
;
69 std::unique_ptr
<PartialOrderingVisitor
> Visitor
;
72 TEST_F(SPIRVPartialOrderingVisitorTest
, EmptyFunction
) {
73 StringRef Assembly
= R
"(
74 define void @main() convergent "hlsl
.numthreads
"="4,8,16" "hlsl
.shader
"="compute
" {
80 checkBasicBlockRank({{"", 0}});
83 TEST_F(SPIRVPartialOrderingVisitorTest
, BasicBlockSwap
) {
84 StringRef Assembly
= R
"(
85 define void @main() convergent "hlsl
.numthreads
"="4,8,16" "hlsl
.shader
"="compute
" {
96 checkBasicBlockRank({{"entry", 0}, {"middle", 1}, {"exit", 2}});
103 TEST_F(SPIRVPartialOrderingVisitorTest
, SkipCondition
) {
104 StringRef Assembly
= R
"(
105 define void @main() convergent "hlsl
.numthreads
"="4,8,16" "hlsl
.shader
"="compute
" {
107 %1 = icmp ne i32 0, 0
108 br i1 %1, label %c, label %a
117 checkBasicBlockRank({{"entry", 0}, {"a", 1}, {"c", 2}});
121 // entry -> header <-----------------+
122 // | `-> body -> continue -+
124 TEST_F(SPIRVPartialOrderingVisitorTest
, LoopOrdering
) {
125 StringRef Assembly
= R
"(
126 define void @main() convergent "hlsl
.numthreads
"="4,8,16" "hlsl
.shader
"="compute
" {
128 %1 = icmp ne i32 0, 0
137 br i1 %1, label %body, label %end
143 {{"entry", 0}, {"header", 1}, {"body", 2}, {"continue", 3}, {"end", 4}});
146 // Diamond condition:
151 // A and B order can be flipped with no effect, but it must be remain
152 // deterministic/stable.
153 TEST_F(SPIRVPartialOrderingVisitorTest
, DiamondCondition
) {
154 StringRef Assembly
= R
"(
155 define void @main() convergent "hlsl
.numthreads
"="4,8,16" "hlsl
.shader
"="compute
" {
157 %1 = icmp ne i32 0, 0
158 br i1 %1, label %a, label %b
169 checkBasicBlockRank({{"entry", 0}, {"a", 1}, {"b", 1}, {"c", 2}});
172 // Crossing conditions:
175 // entry -+ +--_|_-+ +-> E
177 // +------+----> D -+
179 // A & B have the same rank.
180 // C & D have the same rank, but are after A & B.
181 // E if the last block.
182 TEST_F(SPIRVPartialOrderingVisitorTest
, CrossingCondition
) {
183 StringRef Assembly
= R
"(
184 define void @main() convergent "hlsl
.numthreads
"="4,8,16" "hlsl
.shader
"="compute
" {
186 %1 = icmp ne i32 0, 0
187 br i1 %1, label %a, label %b
193 br i1 %1, label %d, label %c
197 br i1 %1, label %c, label %d
203 {{"entry", 0}, {"a", 1}, {"b", 1}, {"c", 2}, {"d", 2}, {"e", 3}});
206 TEST_F(SPIRVPartialOrderingVisitorTest
, LoopDiamond
) {
207 StringRef Assembly
= R
"(
208 define void @main() convergent "hlsl
.numthreads
"="4,8,16" "hlsl
.shader
"="compute
" {
210 %1 = icmp ne i32 0, 0
213 br i1 %1, label %body, label %end
215 br i1 %1, label %inside_a, label %break
219 br i1 %1, label %inside_c, label %inside_d
234 checkBasicBlockRank({{"entry", 0},
246 TEST_F(SPIRVPartialOrderingVisitorTest
, LoopNested
) {
247 StringRef Assembly
= R
"(
248 define void @main() convergent "hlsl
.numthreads
"="4,8,16" "hlsl
.shader
"="compute
" {
250 %1 = icmp ne i32 0, 0
253 br i1 %1, label %h, label %b
257 br i1 %1, label %d, label %e
272 checkBasicBlockRank({{"entry", 0},
283 TEST_F(SPIRVPartialOrderingVisitorTest
, IfNested
) {
284 StringRef Assembly
= R
"(
285 define void @main() convergent "hlsl
.numthreads
"="4,8,16" "hlsl
.shader
"="compute
" {
287 br i1 true, label %a, label %d
289 br i1 true, label %b, label %c
295 br i1 true, label %e, label %f
299 br i1 true, label %g, label %h
311 checkBasicBlockRank({{"entry", 0},
324 TEST_F(SPIRVPartialOrderingVisitorTest
, CheckDeathIrreducible
) {
325 StringRef Assembly
= R
"(
326 define void @main() convergent "hlsl
.numthreads
"="4,8,16" "hlsl
.shader
"="compute
" {
328 %1 = icmp ne i32 0, 0
331 br i1 %1, label %a, label %c
335 br i1 %1, label %b, label %c
341 "No valid candidate in the queue. Is the graph reducible?");