[docs] Fix build-docs.sh
[llvm-project.git] / llvm / unittests / Transforms / Utils / CodeExtractorTest.cpp
blobc142729e2c6f424900f910b828e6755ac13d5208
1 //===- CodeExtractor.cpp - Unit tests for CodeExtractor -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "llvm/Transforms/Utils/CodeExtractor.h"
10 #include "llvm/AsmParser/Parser.h"
11 #include "llvm/Analysis/AssumptionCache.h"
12 #include "llvm/IR/BasicBlock.h"
13 #include "llvm/IR/Constants.h"
14 #include "llvm/IR/Dominators.h"
15 #include "llvm/IR/Instructions.h"
16 #include "llvm/IR/LLVMContext.h"
17 #include "llvm/IR/Module.h"
18 #include "llvm/IR/Verifier.h"
19 #include "llvm/IRReader/IRReader.h"
20 #include "llvm/Support/SourceMgr.h"
21 #include "gtest/gtest.h"
23 using namespace llvm;
25 namespace {
26 BasicBlock *getBlockByName(Function *F, StringRef name) {
27 for (auto &BB : *F)
28 if (BB.getName() == name)
29 return &BB;
30 return nullptr;
33 TEST(CodeExtractor, ExitStub) {
34 LLVMContext Ctx;
35 SMDiagnostic Err;
36 std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
37 define i32 @foo(i32 %x, i32 %y, i32 %z) {
38 header:
39 %0 = icmp ugt i32 %x, %y
40 br i1 %0, label %body1, label %body2
42 body1:
43 %1 = add i32 %z, 2
44 br label %notExtracted
46 body2:
47 %2 = mul i32 %z, 7
48 br label %notExtracted
50 notExtracted:
51 %3 = phi i32 [ %1, %body1 ], [ %2, %body2 ]
52 %4 = add i32 %3, %x
53 ret i32 %4
55 )invalid",
56 Err, Ctx));
58 Function *Func = M->getFunction("foo");
59 SmallVector<BasicBlock *, 3> Candidates{ getBlockByName(Func, "header"),
60 getBlockByName(Func, "body1"),
61 getBlockByName(Func, "body2") };
63 CodeExtractor CE(Candidates);
64 EXPECT_TRUE(CE.isEligible());
66 CodeExtractorAnalysisCache CEAC(*Func);
67 Function *Outlined = CE.extractCodeRegion(CEAC);
68 EXPECT_TRUE(Outlined);
69 BasicBlock *Exit = getBlockByName(Func, "notExtracted");
70 BasicBlock *ExitSplit = getBlockByName(Outlined, "notExtracted.split");
71 // Ensure that PHI in exit block has only one incoming value (from code
72 // replacer block).
73 EXPECT_TRUE(Exit && cast<PHINode>(Exit->front()).getNumIncomingValues() == 1);
74 // Ensure that there is a PHI in outlined function with 2 incoming values.
75 EXPECT_TRUE(ExitSplit &&
76 cast<PHINode>(ExitSplit->front()).getNumIncomingValues() == 2);
77 EXPECT_FALSE(verifyFunction(*Outlined));
78 EXPECT_FALSE(verifyFunction(*Func));
81 TEST(CodeExtractor, InputOutputMonitoring) {
82 LLVMContext Ctx;
83 SMDiagnostic Err;
84 std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
85 define i32 @foo(i32 %x, i32 %y, i32 %z) {
86 header:
87 %0 = icmp ugt i32 %x, %y
88 br i1 %0, label %body1, label %body2
90 body1:
91 %1 = add i32 %z, 2
92 br label %notExtracted
94 body2:
95 %2 = mul i32 %z, 7
96 br label %notExtracted
98 notExtracted:
99 %3 = phi i32 [ %1, %body1 ], [ %2, %body2 ]
100 %4 = add i32 %3, %x
101 ret i32 %4
103 )invalid",
104 Err, Ctx));
106 Function *Func = M->getFunction("foo");
107 SmallVector<BasicBlock *, 3> Candidates{getBlockByName(Func, "header"),
108 getBlockByName(Func, "body1"),
109 getBlockByName(Func, "body2")};
111 CodeExtractor CE(Candidates);
112 EXPECT_TRUE(CE.isEligible());
114 CodeExtractorAnalysisCache CEAC(*Func);
115 SetVector<Value *> Inputs, Outputs;
116 Function *Outlined = CE.extractCodeRegion(CEAC, Inputs, Outputs);
117 EXPECT_TRUE(Outlined);
119 EXPECT_EQ(Inputs.size(), 3u);
120 EXPECT_EQ(Inputs[0], Func->getArg(2));
121 EXPECT_EQ(Inputs[1], Func->getArg(0));
122 EXPECT_EQ(Inputs[2], Func->getArg(1));
123 EXPECT_EQ(Outputs.size(), 1u);
124 StoreInst *SI = cast<StoreInst>(Outlined->getArg(3)->user_back());
125 Value *OutputVal = SI->getValueOperand();
126 EXPECT_EQ(Outputs[0], OutputVal);
127 BasicBlock *Exit = getBlockByName(Func, "notExtracted");
128 BasicBlock *ExitSplit = getBlockByName(Outlined, "notExtracted.split");
129 // Ensure that PHI in exit block has only one incoming value (from code
130 // replacer block).
131 EXPECT_TRUE(Exit && cast<PHINode>(Exit->front()).getNumIncomingValues() == 1);
132 // Ensure that there is a PHI in outlined function with 2 incoming values.
133 EXPECT_TRUE(ExitSplit &&
134 cast<PHINode>(ExitSplit->front()).getNumIncomingValues() == 2);
135 EXPECT_FALSE(verifyFunction(*Outlined));
136 EXPECT_FALSE(verifyFunction(*Func));
139 TEST(CodeExtractor, ExitBlockOrderingPhis) {
140 LLVMContext Ctx;
141 SMDiagnostic Err;
142 std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
143 define void @foo(i32 %a, i32 %b) {
144 entry:
145 %0 = alloca i32, align 4
146 br label %test0
147 test0:
148 %c = load i32, i32* %0, align 4
149 br label %test1
150 test1:
151 %e = load i32, i32* %0, align 4
152 br i1 true, label %first, label %test
153 test:
154 %d = load i32, i32* %0, align 4
155 br i1 true, label %first, label %next
156 first:
157 %1 = phi i32 [ %c, %test ], [ %e, %test1 ]
158 ret void
159 next:
160 %2 = add i32 %d, 1
161 %3 = add i32 %e, 1
162 ret void
164 )invalid",
165 Err, Ctx));
166 Function *Func = M->getFunction("foo");
167 SmallVector<BasicBlock *, 3> Candidates{ getBlockByName(Func, "test0"),
168 getBlockByName(Func, "test1"),
169 getBlockByName(Func, "test") };
171 CodeExtractor CE(Candidates);
172 EXPECT_TRUE(CE.isEligible());
174 CodeExtractorAnalysisCache CEAC(*Func);
175 Function *Outlined = CE.extractCodeRegion(CEAC);
176 EXPECT_TRUE(Outlined);
178 BasicBlock *FirstExitStub = getBlockByName(Outlined, "first.exitStub");
179 BasicBlock *NextExitStub = getBlockByName(Outlined, "next.exitStub");
181 Instruction *FirstTerm = FirstExitStub->getTerminator();
182 ReturnInst *FirstReturn = dyn_cast<ReturnInst>(FirstTerm);
183 EXPECT_TRUE(FirstReturn);
184 ConstantInt *CIFirst = dyn_cast<ConstantInt>(FirstReturn->getReturnValue());
185 EXPECT_TRUE(CIFirst->getLimitedValue() == 1u);
187 Instruction *NextTerm = NextExitStub->getTerminator();
188 ReturnInst *NextReturn = dyn_cast<ReturnInst>(NextTerm);
189 EXPECT_TRUE(NextReturn);
190 ConstantInt *CINext = dyn_cast<ConstantInt>(NextReturn->getReturnValue());
191 EXPECT_TRUE(CINext->getLimitedValue() == 0u);
193 EXPECT_FALSE(verifyFunction(*Outlined));
194 EXPECT_FALSE(verifyFunction(*Func));
197 TEST(CodeExtractor, ExitBlockOrdering) {
198 LLVMContext Ctx;
199 SMDiagnostic Err;
200 std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
201 define void @foo(i32 %a, i32 %b) {
202 entry:
203 %0 = alloca i32, align 4
204 br label %test0
205 test0:
206 %c = load i32, i32* %0, align 4
207 br label %test1
208 test1:
209 %e = load i32, i32* %0, align 4
210 br i1 true, label %first, label %test
211 test:
212 %d = load i32, i32* %0, align 4
213 br i1 true, label %first, label %next
214 first:
215 ret void
216 next:
217 %1 = add i32 %d, 1
218 %2 = add i32 %e, 1
219 ret void
221 )invalid",
222 Err, Ctx));
223 Function *Func = M->getFunction("foo");
224 SmallVector<BasicBlock *, 3> Candidates{ getBlockByName(Func, "test0"),
225 getBlockByName(Func, "test1"),
226 getBlockByName(Func, "test") };
228 CodeExtractor CE(Candidates);
229 EXPECT_TRUE(CE.isEligible());
231 CodeExtractorAnalysisCache CEAC(*Func);
232 Function *Outlined = CE.extractCodeRegion(CEAC);
233 EXPECT_TRUE(Outlined);
235 BasicBlock *FirstExitStub = getBlockByName(Outlined, "first.exitStub");
236 BasicBlock *NextExitStub = getBlockByName(Outlined, "next.exitStub");
238 Instruction *FirstTerm = FirstExitStub->getTerminator();
239 ReturnInst *FirstReturn = dyn_cast<ReturnInst>(FirstTerm);
240 EXPECT_TRUE(FirstReturn);
241 ConstantInt *CIFirst = dyn_cast<ConstantInt>(FirstReturn->getReturnValue());
242 EXPECT_TRUE(CIFirst->getLimitedValue() == 1u);
244 Instruction *NextTerm = NextExitStub->getTerminator();
245 ReturnInst *NextReturn = dyn_cast<ReturnInst>(NextTerm);
246 EXPECT_TRUE(NextReturn);
247 ConstantInt *CINext = dyn_cast<ConstantInt>(NextReturn->getReturnValue());
248 EXPECT_TRUE(CINext->getLimitedValue() == 0u);
250 EXPECT_FALSE(verifyFunction(*Outlined));
251 EXPECT_FALSE(verifyFunction(*Func));
254 TEST(CodeExtractor, ExitPHIOnePredFromRegion) {
255 LLVMContext Ctx;
256 SMDiagnostic Err;
257 std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
258 define i32 @foo() {
259 header:
260 br i1 undef, label %extracted1, label %pred
262 pred:
263 br i1 undef, label %exit1, label %exit2
265 extracted1:
266 br i1 undef, label %extracted2, label %exit1
268 extracted2:
269 br label %exit2
271 exit1:
272 %0 = phi i32 [ 1, %extracted1 ], [ 2, %pred ]
273 ret i32 %0
275 exit2:
276 %1 = phi i32 [ 3, %extracted2 ], [ 4, %pred ]
277 ret i32 %1
279 )invalid", Err, Ctx));
281 Function *Func = M->getFunction("foo");
282 SmallVector<BasicBlock *, 2> ExtractedBlocks{
283 getBlockByName(Func, "extracted1"),
284 getBlockByName(Func, "extracted2")
287 CodeExtractor CE(ExtractedBlocks);
288 EXPECT_TRUE(CE.isEligible());
290 CodeExtractorAnalysisCache CEAC(*Func);
291 Function *Outlined = CE.extractCodeRegion(CEAC);
292 EXPECT_TRUE(Outlined);
293 BasicBlock *Exit1 = getBlockByName(Func, "exit1");
294 BasicBlock *Exit2 = getBlockByName(Func, "exit2");
295 // Ensure that PHIs in exits are not splitted (since that they have only one
296 // incoming value from extracted region).
297 EXPECT_TRUE(Exit1 &&
298 cast<PHINode>(Exit1->front()).getNumIncomingValues() == 2);
299 EXPECT_TRUE(Exit2 &&
300 cast<PHINode>(Exit2->front()).getNumIncomingValues() == 2);
301 EXPECT_FALSE(verifyFunction(*Outlined));
302 EXPECT_FALSE(verifyFunction(*Func));
305 TEST(CodeExtractor, StoreOutputInvokeResultAfterEHPad) {
306 LLVMContext Ctx;
307 SMDiagnostic Err;
308 std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
309 declare i8 @hoge()
311 define i32 @foo() personality i8* null {
312 entry:
313 %call = invoke i8 @hoge()
314 to label %invoke.cont unwind label %lpad
316 invoke.cont: ; preds = %entry
317 unreachable
319 lpad: ; preds = %entry
320 %0 = landingpad { i8*, i32 }
321 catch i8* null
322 br i1 undef, label %catch, label %finally.catchall
324 catch: ; preds = %lpad
325 %call2 = invoke i8 @hoge()
326 to label %invoke.cont2 unwind label %lpad2
328 invoke.cont2: ; preds = %catch
329 %call3 = invoke i8 @hoge()
330 to label %invoke.cont3 unwind label %lpad2
332 invoke.cont3: ; preds = %invoke.cont2
333 unreachable
335 lpad2: ; preds = %invoke.cont2, %catch
336 %ex.1 = phi i8* [ undef, %invoke.cont2 ], [ null, %catch ]
337 %1 = landingpad { i8*, i32 }
338 catch i8* null
339 br label %finally.catchall
341 finally.catchall: ; preds = %lpad33, %lpad
342 %ex.2 = phi i8* [ %ex.1, %lpad2 ], [ null, %lpad ]
343 unreachable
345 )invalid", Err, Ctx));
347 if (!M) {
348 Err.print("unit", errs());
349 exit(1);
352 Function *Func = M->getFunction("foo");
353 EXPECT_FALSE(verifyFunction(*Func, &errs()));
355 SmallVector<BasicBlock *, 2> ExtractedBlocks{
356 getBlockByName(Func, "catch"),
357 getBlockByName(Func, "invoke.cont2"),
358 getBlockByName(Func, "invoke.cont3"),
359 getBlockByName(Func, "lpad2")
362 CodeExtractor CE(ExtractedBlocks);
363 EXPECT_TRUE(CE.isEligible());
365 CodeExtractorAnalysisCache CEAC(*Func);
366 Function *Outlined = CE.extractCodeRegion(CEAC);
367 EXPECT_TRUE(Outlined);
368 EXPECT_FALSE(verifyFunction(*Outlined, &errs()));
369 EXPECT_FALSE(verifyFunction(*Func, &errs()));
372 TEST(CodeExtractor, StoreOutputInvokeResultInExitStub) {
373 LLVMContext Ctx;
374 SMDiagnostic Err;
375 std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
376 declare i32 @bar()
378 define i32 @foo() personality i8* null {
379 entry:
380 %0 = invoke i32 @bar() to label %exit unwind label %lpad
382 exit:
383 ret i32 %0
385 lpad:
386 %1 = landingpad { i8*, i32 }
387 cleanup
388 resume { i8*, i32 } %1
390 )invalid",
391 Err, Ctx));
393 Function *Func = M->getFunction("foo");
394 SmallVector<BasicBlock *, 1> Blocks{ getBlockByName(Func, "entry"),
395 getBlockByName(Func, "lpad") };
397 CodeExtractor CE(Blocks);
398 EXPECT_TRUE(CE.isEligible());
400 CodeExtractorAnalysisCache CEAC(*Func);
401 Function *Outlined = CE.extractCodeRegion(CEAC);
402 EXPECT_TRUE(Outlined);
403 EXPECT_FALSE(verifyFunction(*Outlined));
404 EXPECT_FALSE(verifyFunction(*Func));
407 TEST(CodeExtractor, ExtractAndInvalidateAssumptionCache) {
408 LLVMContext Ctx;
409 SMDiagnostic Err;
410 std::unique_ptr<Module> M(parseAssemblyString(R"ir(
411 target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
412 target triple = "aarch64"
414 %b = type { i64 }
415 declare void @g(i8*)
417 declare void @llvm.assume(i1) #0
419 define void @test() {
420 entry:
421 br label %label
423 label:
424 %0 = load %b*, %b** inttoptr (i64 8 to %b**), align 8
425 %1 = getelementptr inbounds %b, %b* %0, i64 undef, i32 0
426 %2 = load i64, i64* %1, align 8
427 %3 = icmp ugt i64 %2, 1
428 br i1 %3, label %if.then, label %if.else
430 if.then:
431 unreachable
433 if.else:
434 call void @g(i8* undef)
435 store i64 undef, i64* null, align 536870912
436 %4 = icmp eq i64 %2, 0
437 call void @llvm.assume(i1 %4)
438 unreachable
441 attributes #0 = { nounwind willreturn }
442 )ir",
443 Err, Ctx));
445 assert(M && "Could not parse module?");
446 Function *Func = M->getFunction("test");
447 SmallVector<BasicBlock *, 1> Blocks{ getBlockByName(Func, "if.else") };
448 AssumptionCache AC(*Func);
449 CodeExtractor CE(Blocks, nullptr, false, nullptr, nullptr, &AC);
450 EXPECT_TRUE(CE.isEligible());
452 CodeExtractorAnalysisCache CEAC(*Func);
453 Function *Outlined = CE.extractCodeRegion(CEAC);
454 EXPECT_TRUE(Outlined);
455 EXPECT_FALSE(verifyFunction(*Outlined));
456 EXPECT_FALSE(verifyFunction(*Func));
457 EXPECT_FALSE(CE.verifyAssumptionCache(*Func, *Outlined, &AC));
460 TEST(CodeExtractor, RemoveBitcastUsesFromOuterLifetimeMarkers) {
461 LLVMContext Ctx;
462 SMDiagnostic Err;
463 std::unique_ptr<Module> M(parseAssemblyString(R"ir(
464 target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
465 target triple = "x86_64-unknown-linux-gnu"
467 declare void @use(i32*)
468 declare void @llvm.lifetime.start.p0i8(i64, i8*)
469 declare void @llvm.lifetime.end.p0i8(i64, i8*)
471 define void @foo() {
472 entry:
473 %0 = alloca i32
474 br label %extract
476 extract:
477 %1 = bitcast i32* %0 to i8*
478 call void @llvm.lifetime.start.p0i8(i64 4, i8* %1)
479 call void @use(i32* %0)
480 br label %exit
482 exit:
483 call void @use(i32* %0)
484 call void @llvm.lifetime.end.p0i8(i64 4, i8* %1)
485 ret void
487 )ir",
488 Err, Ctx));
490 Function *Func = M->getFunction("foo");
491 SmallVector<BasicBlock *, 1> Blocks{getBlockByName(Func, "extract")};
493 CodeExtractor CE(Blocks);
494 EXPECT_TRUE(CE.isEligible());
496 CodeExtractorAnalysisCache CEAC(*Func);
497 SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;
498 BasicBlock *CommonExit = nullptr;
499 CE.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit);
500 CE.findInputsOutputs(Inputs, Outputs, SinkingCands);
501 EXPECT_EQ(Outputs.size(), 0U);
503 Function *Outlined = CE.extractCodeRegion(CEAC);
504 EXPECT_TRUE(Outlined);
505 EXPECT_FALSE(verifyFunction(*Outlined));
506 EXPECT_FALSE(verifyFunction(*Func));
509 TEST(CodeExtractor, PartialAggregateArgs) {
510 LLVMContext Ctx;
511 SMDiagnostic Err;
512 std::unique_ptr<Module> M(parseAssemblyString(R"ir(
513 target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
514 target triple = "x86_64-unknown-linux-gnu"
516 declare void @use(i32)
518 define void @foo(i32 %a, i32 %b, i32 %c) {
519 entry:
520 br label %extract
522 extract:
523 call void @use(i32 %a)
524 call void @use(i32 %b)
525 call void @use(i32 %c)
526 br label %exit
528 exit:
529 ret void
531 )ir",
532 Err, Ctx));
534 Function *Func = M->getFunction("foo");
535 SmallVector<BasicBlock *, 1> Blocks{getBlockByName(Func, "extract")};
537 // Create the CodeExtractor with arguments aggregation enabled.
538 CodeExtractor CE(Blocks, /* DominatorTree */ nullptr,
539 /* AggregateArgs */ true);
540 EXPECT_TRUE(CE.isEligible());
542 CodeExtractorAnalysisCache CEAC(*Func);
543 SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;
544 BasicBlock *CommonExit = nullptr;
545 CE.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit);
546 CE.findInputsOutputs(Inputs, Outputs, SinkingCands);
547 // Exclude the first input from the argument aggregate.
548 CE.excludeArgFromAggregate(Inputs[0]);
550 Function *Outlined = CE.extractCodeRegion(CEAC, Inputs, Outputs);
551 EXPECT_TRUE(Outlined);
552 // Expect 2 arguments in the outlined function: the excluded input and the
553 // struct aggregate for the remaining inputs.
554 EXPECT_EQ(Outlined->arg_size(), 2U);
555 EXPECT_FALSE(verifyFunction(*Outlined));
556 EXPECT_FALSE(verifyFunction(*Func));
558 } // end anonymous namespace