[llvm-objdump] - Remove one overload of reportError. NFCI.
[llvm-complete.git] / unittests / Transforms / Utils / CodeExtractorTest.cpp
blob8b86951fa5e199ea8cf542074a22b1e07432e74f
1 //===- CodeExtractor.cpp - Unit tests for CodeExtractor -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "llvm/Transforms/Utils/CodeExtractor.h"
10 #include "llvm/AsmParser/Parser.h"
11 #include "llvm/IR/BasicBlock.h"
12 #include "llvm/IR/Dominators.h"
13 #include "llvm/IR/Instructions.h"
14 #include "llvm/IR/LLVMContext.h"
15 #include "llvm/IR/Module.h"
16 #include "llvm/IR/Verifier.h"
17 #include "llvm/IRReader/IRReader.h"
18 #include "llvm/Support/SourceMgr.h"
19 #include "gtest/gtest.h"
21 using namespace llvm;
23 namespace {
24 BasicBlock *getBlockByName(Function *F, StringRef name) {
25 for (auto &BB : *F)
26 if (BB.getName() == name)
27 return &BB;
28 return nullptr;
31 TEST(CodeExtractor, ExitStub) {
32 LLVMContext Ctx;
33 SMDiagnostic Err;
34 std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
35 define i32 @foo(i32 %x, i32 %y, i32 %z) {
36 header:
37 %0 = icmp ugt i32 %x, %y
38 br i1 %0, label %body1, label %body2
40 body1:
41 %1 = add i32 %z, 2
42 br label %notExtracted
44 body2:
45 %2 = mul i32 %z, 7
46 br label %notExtracted
48 notExtracted:
49 %3 = phi i32 [ %1, %body1 ], [ %2, %body2 ]
50 %4 = add i32 %3, %x
51 ret i32 %4
53 )invalid",
54 Err, Ctx));
56 Function *Func = M->getFunction("foo");
57 SmallVector<BasicBlock *, 3> Candidates{ getBlockByName(Func, "header"),
58 getBlockByName(Func, "body1"),
59 getBlockByName(Func, "body2") };
61 CodeExtractor CE(Candidates);
62 EXPECT_TRUE(CE.isEligible());
64 Function *Outlined = CE.extractCodeRegion();
65 EXPECT_TRUE(Outlined);
66 BasicBlock *Exit = getBlockByName(Func, "notExtracted");
67 BasicBlock *ExitSplit = getBlockByName(Outlined, "notExtracted.split");
68 // Ensure that PHI in exit block has only one incoming value (from code
69 // replacer block).
70 EXPECT_TRUE(Exit && cast<PHINode>(Exit->front()).getNumIncomingValues() == 1);
71 // Ensure that there is a PHI in outlined function with 2 incoming values.
72 EXPECT_TRUE(ExitSplit &&
73 cast<PHINode>(ExitSplit->front()).getNumIncomingValues() == 2);
74 EXPECT_FALSE(verifyFunction(*Outlined));
75 EXPECT_FALSE(verifyFunction(*Func));
78 TEST(CodeExtractor, ExitPHIOnePredFromRegion) {
79 LLVMContext Ctx;
80 SMDiagnostic Err;
81 std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
82 define i32 @foo() {
83 header:
84 br i1 undef, label %extracted1, label %pred
86 pred:
87 br i1 undef, label %exit1, label %exit2
89 extracted1:
90 br i1 undef, label %extracted2, label %exit1
92 extracted2:
93 br label %exit2
95 exit1:
96 %0 = phi i32 [ 1, %extracted1 ], [ 2, %pred ]
97 ret i32 %0
99 exit2:
100 %1 = phi i32 [ 3, %extracted2 ], [ 4, %pred ]
101 ret i32 %1
103 )invalid", Err, Ctx));
105 Function *Func = M->getFunction("foo");
106 SmallVector<BasicBlock *, 2> ExtractedBlocks{
107 getBlockByName(Func, "extracted1"),
108 getBlockByName(Func, "extracted2")
111 CodeExtractor CE(ExtractedBlocks);
112 EXPECT_TRUE(CE.isEligible());
114 Function *Outlined = CE.extractCodeRegion();
115 EXPECT_TRUE(Outlined);
116 BasicBlock *Exit1 = getBlockByName(Func, "exit1");
117 BasicBlock *Exit2 = getBlockByName(Func, "exit2");
118 // Ensure that PHIs in exits are not splitted (since that they have only one
119 // incoming value from extracted region).
120 EXPECT_TRUE(Exit1 &&
121 cast<PHINode>(Exit1->front()).getNumIncomingValues() == 2);
122 EXPECT_TRUE(Exit2 &&
123 cast<PHINode>(Exit2->front()).getNumIncomingValues() == 2);
124 EXPECT_FALSE(verifyFunction(*Outlined));
125 EXPECT_FALSE(verifyFunction(*Func));
128 TEST(CodeExtractor, StoreOutputInvokeResultAfterEHPad) {
129 LLVMContext Ctx;
130 SMDiagnostic Err;
131 std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
132 declare i8 @hoge()
134 define i32 @foo() personality i8* null {
135 entry:
136 %call = invoke i8 @hoge()
137 to label %invoke.cont unwind label %lpad
139 invoke.cont: ; preds = %entry
140 unreachable
142 lpad: ; preds = %entry
143 %0 = landingpad { i8*, i32 }
144 catch i8* null
145 br i1 undef, label %catch, label %finally.catchall
147 catch: ; preds = %lpad
148 %call2 = invoke i8 @hoge()
149 to label %invoke.cont2 unwind label %lpad2
151 invoke.cont2: ; preds = %catch
152 %call3 = invoke i8 @hoge()
153 to label %invoke.cont3 unwind label %lpad2
155 invoke.cont3: ; preds = %invoke.cont2
156 unreachable
158 lpad2: ; preds = %invoke.cont2, %catch
159 %ex.1 = phi i8* [ undef, %invoke.cont2 ], [ null, %catch ]
160 %1 = landingpad { i8*, i32 }
161 catch i8* null
162 br label %finally.catchall
164 finally.catchall: ; preds = %lpad33, %lpad
165 %ex.2 = phi i8* [ %ex.1, %lpad2 ], [ null, %lpad ]
166 unreachable
168 )invalid", Err, Ctx));
170 if (!M) {
171 Err.print("unit", errs());
172 exit(1);
175 Function *Func = M->getFunction("foo");
176 EXPECT_FALSE(verifyFunction(*Func, &errs()));
178 SmallVector<BasicBlock *, 2> ExtractedBlocks{
179 getBlockByName(Func, "catch"),
180 getBlockByName(Func, "invoke.cont2"),
181 getBlockByName(Func, "invoke.cont3"),
182 getBlockByName(Func, "lpad2")
185 CodeExtractor CE(ExtractedBlocks);
186 EXPECT_TRUE(CE.isEligible());
188 Function *Outlined = CE.extractCodeRegion();
189 EXPECT_TRUE(Outlined);
190 EXPECT_FALSE(verifyFunction(*Outlined, &errs()));
191 EXPECT_FALSE(verifyFunction(*Func, &errs()));
194 TEST(CodeExtractor, StoreOutputInvokeResultInExitStub) {
195 LLVMContext Ctx;
196 SMDiagnostic Err;
197 std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
198 declare i32 @bar()
200 define i32 @foo() personality i8* null {
201 entry:
202 %0 = invoke i32 @bar() to label %exit unwind label %lpad
204 exit:
205 ret i32 %0
207 lpad:
208 %1 = landingpad { i8*, i32 }
209 cleanup
210 resume { i8*, i32 } %1
212 )invalid",
213 Err, Ctx));
215 Function *Func = M->getFunction("foo");
216 SmallVector<BasicBlock *, 1> Blocks{ getBlockByName(Func, "entry"),
217 getBlockByName(Func, "lpad") };
219 CodeExtractor CE(Blocks);
220 EXPECT_TRUE(CE.isEligible());
222 Function *Outlined = CE.extractCodeRegion();
223 EXPECT_TRUE(Outlined);
224 EXPECT_FALSE(verifyFunction(*Outlined));
225 EXPECT_FALSE(verifyFunction(*Func));
228 } // end anonymous namespace