1 //===- CodeExtractor.cpp - Unit tests for CodeExtractor -------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "llvm/Transforms/Utils/CodeExtractor.h"
10 #include "llvm/AsmParser/Parser.h"
11 #include "llvm/Analysis/AssumptionCache.h"
12 #include "llvm/IR/BasicBlock.h"
13 #include "llvm/IR/Constants.h"
14 #include "llvm/IR/Dominators.h"
15 #include "llvm/IR/Instructions.h"
16 #include "llvm/IR/LLVMContext.h"
17 #include "llvm/IR/Module.h"
18 #include "llvm/IR/Verifier.h"
19 #include "llvm/IRReader/IRReader.h"
20 #include "llvm/Support/SourceMgr.h"
21 #include "gtest/gtest.h"
26 BasicBlock
*getBlockByName(Function
*F
, StringRef name
) {
28 if (BB
.getName() == name
)
33 TEST(CodeExtractor
, ExitStub
) {
36 std::unique_ptr
<Module
> M(parseAssemblyString(R
"invalid(
37 define i32 @foo(i32 %x, i32 %y, i32 %z) {
39 %0 = icmp ugt i32 %x, %y
40 br i1 %0, label %body1, label %body2
44 br label %notExtracted
48 br label %notExtracted
51 %3 = phi i32 [ %1, %body1 ], [ %2, %body2 ]
58 Function
*Func
= M
->getFunction("foo");
59 SmallVector
<BasicBlock
*, 3> Candidates
{ getBlockByName(Func
, "header"),
60 getBlockByName(Func
, "body1"),
61 getBlockByName(Func
, "body2") };
63 CodeExtractor
CE(Candidates
);
64 EXPECT_TRUE(CE
.isEligible());
66 CodeExtractorAnalysisCache
CEAC(*Func
);
67 Function
*Outlined
= CE
.extractCodeRegion(CEAC
);
68 EXPECT_TRUE(Outlined
);
69 BasicBlock
*Exit
= getBlockByName(Func
, "notExtracted");
70 BasicBlock
*ExitSplit
= getBlockByName(Outlined
, "notExtracted.split");
71 // Ensure that PHI in exit block has only one incoming value (from code
73 EXPECT_TRUE(Exit
&& cast
<PHINode
>(Exit
->front()).getNumIncomingValues() == 1);
74 // Ensure that there is a PHI in outlined function with 2 incoming values.
75 EXPECT_TRUE(ExitSplit
&&
76 cast
<PHINode
>(ExitSplit
->front()).getNumIncomingValues() == 2);
77 EXPECT_FALSE(verifyFunction(*Outlined
));
78 EXPECT_FALSE(verifyFunction(*Func
));
81 TEST(CodeExtractor
, InputOutputMonitoring
) {
84 std::unique_ptr
<Module
> M(parseAssemblyString(R
"invalid(
85 define i32 @foo(i32 %x, i32 %y, i32 %z) {
87 %0 = icmp ugt i32 %x, %y
88 br i1 %0, label %body1, label %body2
92 br label %notExtracted
96 br label %notExtracted
99 %3 = phi i32 [ %1, %body1 ], [ %2, %body2 ]
106 Function
*Func
= M
->getFunction("foo");
107 SmallVector
<BasicBlock
*, 3> Candidates
{getBlockByName(Func
, "header"),
108 getBlockByName(Func
, "body1"),
109 getBlockByName(Func
, "body2")};
111 CodeExtractor
CE(Candidates
);
112 EXPECT_TRUE(CE
.isEligible());
114 CodeExtractorAnalysisCache
CEAC(*Func
);
115 SetVector
<Value
*> Inputs
, Outputs
;
116 Function
*Outlined
= CE
.extractCodeRegion(CEAC
, Inputs
, Outputs
);
117 EXPECT_TRUE(Outlined
);
119 EXPECT_EQ(Inputs
.size(), 3u);
120 EXPECT_EQ(Inputs
[0], Func
->getArg(2));
121 EXPECT_EQ(Inputs
[1], Func
->getArg(0));
122 EXPECT_EQ(Inputs
[2], Func
->getArg(1));
123 EXPECT_EQ(Outputs
.size(), 1u);
124 StoreInst
*SI
= cast
<StoreInst
>(Outlined
->getArg(3)->user_back());
125 Value
*OutputVal
= SI
->getValueOperand();
126 EXPECT_EQ(Outputs
[0], OutputVal
);
127 BasicBlock
*Exit
= getBlockByName(Func
, "notExtracted");
128 BasicBlock
*ExitSplit
= getBlockByName(Outlined
, "notExtracted.split");
129 // Ensure that PHI in exit block has only one incoming value (from code
131 EXPECT_TRUE(Exit
&& cast
<PHINode
>(Exit
->front()).getNumIncomingValues() == 1);
132 // Ensure that there is a PHI in outlined function with 2 incoming values.
133 EXPECT_TRUE(ExitSplit
&&
134 cast
<PHINode
>(ExitSplit
->front()).getNumIncomingValues() == 2);
135 EXPECT_FALSE(verifyFunction(*Outlined
));
136 EXPECT_FALSE(verifyFunction(*Func
));
139 TEST(CodeExtractor
, ExitBlockOrderingPhis
) {
142 std::unique_ptr
<Module
> M(parseAssemblyString(R
"invalid(
143 define void @foo(i32 %a, i32 %b) {
145 %0 = alloca i32, align 4
148 %c = load i32, i32* %0, align 4
151 %e = load i32, i32* %0, align 4
152 br i1 true, label %first, label %test
154 %d = load i32, i32* %0, align 4
155 br i1 true, label %first, label %next
157 %1 = phi i32 [ %c, %test ], [ %e, %test1 ]
166 Function
*Func
= M
->getFunction("foo");
167 SmallVector
<BasicBlock
*, 3> Candidates
{ getBlockByName(Func
, "test0"),
168 getBlockByName(Func
, "test1"),
169 getBlockByName(Func
, "test") };
171 CodeExtractor
CE(Candidates
);
172 EXPECT_TRUE(CE
.isEligible());
174 CodeExtractorAnalysisCache
CEAC(*Func
);
175 Function
*Outlined
= CE
.extractCodeRegion(CEAC
);
176 EXPECT_TRUE(Outlined
);
178 BasicBlock
*FirstExitStub
= getBlockByName(Outlined
, "first.exitStub");
179 BasicBlock
*NextExitStub
= getBlockByName(Outlined
, "next.exitStub");
181 Instruction
*FirstTerm
= FirstExitStub
->getTerminator();
182 ReturnInst
*FirstReturn
= dyn_cast
<ReturnInst
>(FirstTerm
);
183 EXPECT_TRUE(FirstReturn
);
184 ConstantInt
*CIFirst
= dyn_cast
<ConstantInt
>(FirstReturn
->getReturnValue());
185 EXPECT_TRUE(CIFirst
->getLimitedValue() == 1u);
187 Instruction
*NextTerm
= NextExitStub
->getTerminator();
188 ReturnInst
*NextReturn
= dyn_cast
<ReturnInst
>(NextTerm
);
189 EXPECT_TRUE(NextReturn
);
190 ConstantInt
*CINext
= dyn_cast
<ConstantInt
>(NextReturn
->getReturnValue());
191 EXPECT_TRUE(CINext
->getLimitedValue() == 0u);
193 EXPECT_FALSE(verifyFunction(*Outlined
));
194 EXPECT_FALSE(verifyFunction(*Func
));
197 TEST(CodeExtractor
, ExitBlockOrdering
) {
200 std::unique_ptr
<Module
> M(parseAssemblyString(R
"invalid(
201 define void @foo(i32 %a, i32 %b) {
203 %0 = alloca i32, align 4
206 %c = load i32, i32* %0, align 4
209 %e = load i32, i32* %0, align 4
210 br i1 true, label %first, label %test
212 %d = load i32, i32* %0, align 4
213 br i1 true, label %first, label %next
223 Function
*Func
= M
->getFunction("foo");
224 SmallVector
<BasicBlock
*, 3> Candidates
{ getBlockByName(Func
, "test0"),
225 getBlockByName(Func
, "test1"),
226 getBlockByName(Func
, "test") };
228 CodeExtractor
CE(Candidates
);
229 EXPECT_TRUE(CE
.isEligible());
231 CodeExtractorAnalysisCache
CEAC(*Func
);
232 Function
*Outlined
= CE
.extractCodeRegion(CEAC
);
233 EXPECT_TRUE(Outlined
);
235 BasicBlock
*FirstExitStub
= getBlockByName(Outlined
, "first.exitStub");
236 BasicBlock
*NextExitStub
= getBlockByName(Outlined
, "next.exitStub");
238 Instruction
*FirstTerm
= FirstExitStub
->getTerminator();
239 ReturnInst
*FirstReturn
= dyn_cast
<ReturnInst
>(FirstTerm
);
240 EXPECT_TRUE(FirstReturn
);
241 ConstantInt
*CIFirst
= dyn_cast
<ConstantInt
>(FirstReturn
->getReturnValue());
242 EXPECT_TRUE(CIFirst
->getLimitedValue() == 1u);
244 Instruction
*NextTerm
= NextExitStub
->getTerminator();
245 ReturnInst
*NextReturn
= dyn_cast
<ReturnInst
>(NextTerm
);
246 EXPECT_TRUE(NextReturn
);
247 ConstantInt
*CINext
= dyn_cast
<ConstantInt
>(NextReturn
->getReturnValue());
248 EXPECT_TRUE(CINext
->getLimitedValue() == 0u);
250 EXPECT_FALSE(verifyFunction(*Outlined
));
251 EXPECT_FALSE(verifyFunction(*Func
));
254 TEST(CodeExtractor
, ExitPHIOnePredFromRegion
) {
257 std::unique_ptr
<Module
> M(parseAssemblyString(R
"invalid(
260 br i1 undef, label %extracted1, label %pred
263 br i1 undef, label %exit1, label %exit2
266 br i1 undef, label %extracted2, label %exit1
272 %0 = phi i32 [ 1, %extracted1 ], [ 2, %pred ]
276 %1 = phi i32 [ 3, %extracted2 ], [ 4, %pred ]
279 )invalid", Err
, Ctx
));
281 Function
*Func
= M
->getFunction("foo");
282 SmallVector
<BasicBlock
*, 2> ExtractedBlocks
{
283 getBlockByName(Func
, "extracted1"),
284 getBlockByName(Func
, "extracted2")
287 CodeExtractor
CE(ExtractedBlocks
);
288 EXPECT_TRUE(CE
.isEligible());
290 CodeExtractorAnalysisCache
CEAC(*Func
);
291 Function
*Outlined
= CE
.extractCodeRegion(CEAC
);
292 EXPECT_TRUE(Outlined
);
293 BasicBlock
*Exit1
= getBlockByName(Func
, "exit1");
294 BasicBlock
*Exit2
= getBlockByName(Func
, "exit2");
295 // Ensure that PHIs in exits are not splitted (since that they have only one
296 // incoming value from extracted region).
298 cast
<PHINode
>(Exit1
->front()).getNumIncomingValues() == 2);
300 cast
<PHINode
>(Exit2
->front()).getNumIncomingValues() == 2);
301 EXPECT_FALSE(verifyFunction(*Outlined
));
302 EXPECT_FALSE(verifyFunction(*Func
));
305 TEST(CodeExtractor
, StoreOutputInvokeResultAfterEHPad
) {
308 std::unique_ptr
<Module
> M(parseAssemblyString(R
"invalid(
311 define i32 @foo() personality i8* null {
313 %call = invoke i8 @hoge()
314 to label %invoke.cont unwind label %lpad
316 invoke.cont: ; preds = %entry
319 lpad: ; preds = %entry
320 %0 = landingpad { i8*, i32 }
322 br i1 undef, label %catch, label %finally.catchall
324 catch: ; preds = %lpad
325 %call2 = invoke i8 @hoge()
326 to label %invoke.cont2 unwind label %lpad2
328 invoke.cont2: ; preds = %catch
329 %call3 = invoke i8 @hoge()
330 to label %invoke.cont3 unwind label %lpad2
332 invoke.cont3: ; preds = %invoke.cont2
335 lpad2: ; preds = %invoke.cont2, %catch
336 %ex.1 = phi i8* [ undef, %invoke.cont2 ], [ null, %catch ]
337 %1 = landingpad { i8*, i32 }
339 br label %finally.catchall
341 finally.catchall: ; preds = %lpad33, %lpad
342 %ex.2 = phi i8* [ %ex.1, %lpad2 ], [ null, %lpad ]
345 )invalid", Err
, Ctx
));
348 Err
.print("unit", errs());
352 Function
*Func
= M
->getFunction("foo");
353 EXPECT_FALSE(verifyFunction(*Func
, &errs()));
355 SmallVector
<BasicBlock
*, 2> ExtractedBlocks
{
356 getBlockByName(Func
, "catch"),
357 getBlockByName(Func
, "invoke.cont2"),
358 getBlockByName(Func
, "invoke.cont3"),
359 getBlockByName(Func
, "lpad2")
362 CodeExtractor
CE(ExtractedBlocks
);
363 EXPECT_TRUE(CE
.isEligible());
365 CodeExtractorAnalysisCache
CEAC(*Func
);
366 Function
*Outlined
= CE
.extractCodeRegion(CEAC
);
367 EXPECT_TRUE(Outlined
);
368 EXPECT_FALSE(verifyFunction(*Outlined
, &errs()));
369 EXPECT_FALSE(verifyFunction(*Func
, &errs()));
372 TEST(CodeExtractor
, StoreOutputInvokeResultInExitStub
) {
375 std::unique_ptr
<Module
> M(parseAssemblyString(R
"invalid(
378 define i32 @foo() personality i8* null {
380 %0 = invoke i32 @bar() to label %exit unwind label %lpad
386 %1 = landingpad { i8*, i32 }
388 resume { i8*, i32 } %1
393 Function
*Func
= M
->getFunction("foo");
394 SmallVector
<BasicBlock
*, 1> Blocks
{ getBlockByName(Func
, "entry"),
395 getBlockByName(Func
, "lpad") };
397 CodeExtractor
CE(Blocks
);
398 EXPECT_TRUE(CE
.isEligible());
400 CodeExtractorAnalysisCache
CEAC(*Func
);
401 Function
*Outlined
= CE
.extractCodeRegion(CEAC
);
402 EXPECT_TRUE(Outlined
);
403 EXPECT_FALSE(verifyFunction(*Outlined
));
404 EXPECT_FALSE(verifyFunction(*Func
));
407 TEST(CodeExtractor
, ExtractAndInvalidateAssumptionCache
) {
410 std::unique_ptr
<Module
> M(parseAssemblyString(R
"ir(
411 target datalayout = "e
-m
:e
-i8
:8:32-i16
:16:32-i64
:64-i128
:128-n32
:64-S128
"
412 target triple = "aarch64
"
417 declare void @llvm.assume(i1) #0
419 define void @test() {
424 %0 = load %b*, %b** inttoptr (i64 8 to %b**), align 8
425 %1 = getelementptr inbounds %b, %b* %0, i64 undef, i32 0
426 %2 = load i64, i64* %1, align 8
427 %3 = icmp ugt i64 %2, 1
428 br i1 %3, label %if.then, label %if.else
434 call void @g(i8* undef)
435 store i64 undef, i64* null, align 536870912
436 %4 = icmp eq i64 %2, 0
437 call void @llvm.assume(i1 %4)
441 attributes #0 = { nounwind willreturn }
445 assert(M
&& "Could not parse module?");
446 Function
*Func
= M
->getFunction("test");
447 SmallVector
<BasicBlock
*, 1> Blocks
{ getBlockByName(Func
, "if.else") };
448 AssumptionCache
AC(*Func
);
449 CodeExtractor
CE(Blocks
, nullptr, false, nullptr, nullptr, &AC
);
450 EXPECT_TRUE(CE
.isEligible());
452 CodeExtractorAnalysisCache
CEAC(*Func
);
453 Function
*Outlined
= CE
.extractCodeRegion(CEAC
);
454 EXPECT_TRUE(Outlined
);
455 EXPECT_FALSE(verifyFunction(*Outlined
));
456 EXPECT_FALSE(verifyFunction(*Func
));
457 EXPECT_FALSE(CE
.verifyAssumptionCache(*Func
, *Outlined
, &AC
));
460 TEST(CodeExtractor
, RemoveBitcastUsesFromOuterLifetimeMarkers
) {
463 std::unique_ptr
<Module
> M(parseAssemblyString(R
"ir(
464 target datalayout = "e
-m
:e
-p270
:32:32-p271
:32:32-p272
:64:64-i64
:64-f80
:128-n8
:16:32:64-S128
"
465 target triple = "x86_64
-unknown
-linux
-gnu
"
467 declare void @use(i32*)
468 declare void @llvm.lifetime.start.p0i8(i64, i8*)
469 declare void @llvm.lifetime.end.p0i8(i64, i8*)
477 %1 = bitcast i32* %0 to i8*
478 call void @llvm.lifetime.start.p0i8(i64 4, i8* %1)
479 call void @use(i32* %0)
483 call void @use(i32* %0)
484 call void @llvm.lifetime.end.p0i8(i64 4, i8* %1)
490 Function
*Func
= M
->getFunction("foo");
491 SmallVector
<BasicBlock
*, 1> Blocks
{getBlockByName(Func
, "extract")};
493 CodeExtractor
CE(Blocks
);
494 EXPECT_TRUE(CE
.isEligible());
496 CodeExtractorAnalysisCache
CEAC(*Func
);
497 SetVector
<Value
*> Inputs
, Outputs
, SinkingCands
, HoistingCands
;
498 BasicBlock
*CommonExit
= nullptr;
499 CE
.findAllocas(CEAC
, SinkingCands
, HoistingCands
, CommonExit
);
500 CE
.findInputsOutputs(Inputs
, Outputs
, SinkingCands
);
501 EXPECT_EQ(Outputs
.size(), 0U);
503 Function
*Outlined
= CE
.extractCodeRegion(CEAC
);
504 EXPECT_TRUE(Outlined
);
505 EXPECT_FALSE(verifyFunction(*Outlined
));
506 EXPECT_FALSE(verifyFunction(*Func
));
509 TEST(CodeExtractor
, PartialAggregateArgs
) {
512 std::unique_ptr
<Module
> M(parseAssemblyString(R
"ir(
513 target datalayout = "e
-m
:e
-p270
:32:32-p271
:32:32-p272
:64:64-i64
:64-f80
:128-n8
:16:32:64-S128
"
514 target triple = "x86_64
-unknown
-linux
-gnu
"
516 declare void @use(i32)
518 define void @foo(i32 %a, i32 %b, i32 %c) {
523 call void @use(i32 %a)
524 call void @use(i32 %b)
525 call void @use(i32 %c)
534 Function
*Func
= M
->getFunction("foo");
535 SmallVector
<BasicBlock
*, 1> Blocks
{getBlockByName(Func
, "extract")};
537 // Create the CodeExtractor with arguments aggregation enabled.
538 CodeExtractor
CE(Blocks
, /* DominatorTree */ nullptr,
539 /* AggregateArgs */ true);
540 EXPECT_TRUE(CE
.isEligible());
542 CodeExtractorAnalysisCache
CEAC(*Func
);
543 SetVector
<Value
*> Inputs
, Outputs
, SinkingCands
, HoistingCands
;
544 BasicBlock
*CommonExit
= nullptr;
545 CE
.findAllocas(CEAC
, SinkingCands
, HoistingCands
, CommonExit
);
546 CE
.findInputsOutputs(Inputs
, Outputs
, SinkingCands
);
547 // Exclude the first input from the argument aggregate.
548 CE
.excludeArgFromAggregate(Inputs
[0]);
550 Function
*Outlined
= CE
.extractCodeRegion(CEAC
, Inputs
, Outputs
);
551 EXPECT_TRUE(Outlined
);
552 // Expect 2 arguments in the outlined function: the excluded input and the
553 // struct aggregate for the remaining inputs.
554 EXPECT_EQ(Outlined
->arg_size(), 2U);
555 EXPECT_FALSE(verifyFunction(*Outlined
));
556 EXPECT_FALSE(verifyFunction(*Func
));
558 } // end anonymous namespace