[flang][cuda] Do not register global constants (#118582)
[llvm-project.git] / llvm / unittests / Transforms / Utils / CodeExtractorTest.cpp
blobcfe07a2f6c461e6932ea8755fac6688a29a59ffa
1 //===- CodeExtractor.cpp - Unit tests for CodeExtractor -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "llvm/Transforms/Utils/CodeExtractor.h"
10 #include "llvm/Analysis/AssumptionCache.h"
11 #include "llvm/AsmParser/Parser.h"
12 #include "llvm/IR/BasicBlock.h"
13 #include "llvm/IR/Constants.h"
14 #include "llvm/IR/Dominators.h"
15 #include "llvm/IR/InstIterator.h"
16 #include "llvm/IR/Instructions.h"
17 #include "llvm/IR/LLVMContext.h"
18 #include "llvm/IR/Module.h"
19 #include "llvm/IR/Verifier.h"
20 #include "llvm/IRReader/IRReader.h"
21 #include "llvm/Support/SourceMgr.h"
22 #include "gtest/gtest.h"
24 using namespace llvm;
26 namespace {
27 BasicBlock *getBlockByName(Function *F, StringRef name) {
28 for (auto &BB : *F)
29 if (BB.getName() == name)
30 return &BB;
31 return nullptr;
34 Instruction *getInstByName(Function *F, StringRef Name) {
35 for (Instruction &I : instructions(F))
36 if (I.getName() == Name)
37 return &I;
38 return nullptr;
41 TEST(CodeExtractor, ExitStub) {
42 LLVMContext Ctx;
43 SMDiagnostic Err;
44 std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
45 define i32 @foo(i32 %x, i32 %y, i32 %z) {
46 header:
47 %0 = icmp ugt i32 %x, %y
48 br i1 %0, label %body1, label %body2
50 body1:
51 %1 = add i32 %z, 2
52 br label %notExtracted
54 body2:
55 %2 = mul i32 %z, 7
56 br label %notExtracted
58 notExtracted:
59 %3 = phi i32 [ %1, %body1 ], [ %2, %body2 ]
60 %4 = add i32 %3, %x
61 ret i32 %4
63 )invalid",
64 Err, Ctx));
66 Function *Func = M->getFunction("foo");
67 SmallVector<BasicBlock *, 3> Candidates{ getBlockByName(Func, "header"),
68 getBlockByName(Func, "body1"),
69 getBlockByName(Func, "body2") };
71 CodeExtractor CE(Candidates);
72 EXPECT_TRUE(CE.isEligible());
74 CodeExtractorAnalysisCache CEAC(*Func);
75 Function *Outlined = CE.extractCodeRegion(CEAC);
76 EXPECT_TRUE(Outlined);
77 BasicBlock *Exit = getBlockByName(Func, "notExtracted");
78 BasicBlock *ExitSplit = getBlockByName(Outlined, "notExtracted.split");
79 // Ensure that PHI in exit block has only one incoming value (from code
80 // replacer block).
81 EXPECT_TRUE(Exit && cast<PHINode>(Exit->front()).getNumIncomingValues() == 1);
82 // Ensure that there is a PHI in outlined function with 2 incoming values.
83 EXPECT_TRUE(ExitSplit &&
84 cast<PHINode>(ExitSplit->front()).getNumIncomingValues() == 2);
85 EXPECT_FALSE(verifyFunction(*Outlined));
86 EXPECT_FALSE(verifyFunction(*Func));
89 TEST(CodeExtractor, InputOutputMonitoring) {
90 LLVMContext Ctx;
91 SMDiagnostic Err;
92 std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
93 define i32 @foo(i32 %x, i32 %y, i32 %z) {
94 header:
95 %0 = icmp ugt i32 %x, %y
96 br i1 %0, label %body1, label %body2
98 body1:
99 %1 = add i32 %z, 2
100 br label %notExtracted
102 body2:
103 %2 = mul i32 %z, 7
104 br label %notExtracted
106 notExtracted:
107 %3 = phi i32 [ %1, %body1 ], [ %2, %body2 ]
108 %4 = add i32 %3, %x
109 ret i32 %4
111 )invalid",
112 Err, Ctx));
114 Function *Func = M->getFunction("foo");
115 SmallVector<BasicBlock *, 3> Candidates{getBlockByName(Func, "header"),
116 getBlockByName(Func, "body1"),
117 getBlockByName(Func, "body2")};
119 CodeExtractor CE(Candidates);
120 EXPECT_TRUE(CE.isEligible());
122 CodeExtractorAnalysisCache CEAC(*Func);
123 SetVector<Value *> Inputs, Outputs;
124 Function *Outlined = CE.extractCodeRegion(CEAC, Inputs, Outputs);
125 EXPECT_TRUE(Outlined);
127 EXPECT_EQ(Inputs.size(), 3u);
128 EXPECT_EQ(Inputs[0], Func->getArg(2));
129 EXPECT_EQ(Inputs[1], Func->getArg(0));
130 EXPECT_EQ(Inputs[2], Func->getArg(1));
131 EXPECT_EQ(Outputs.size(), 1u);
132 StoreInst *SI = cast<StoreInst>(Outlined->getArg(3)->user_back());
133 Value *OutputVal = SI->getValueOperand();
134 EXPECT_EQ(Outputs[0], OutputVal);
135 BasicBlock *Exit = getBlockByName(Func, "notExtracted");
136 BasicBlock *ExitSplit = getBlockByName(Outlined, "notExtracted.split");
137 // Ensure that PHI in exit block has only one incoming value (from code
138 // replacer block).
139 EXPECT_TRUE(Exit && cast<PHINode>(Exit->front()).getNumIncomingValues() == 1);
140 // Ensure that there is a PHI in outlined function with 2 incoming values.
141 EXPECT_TRUE(ExitSplit &&
142 cast<PHINode>(ExitSplit->front()).getNumIncomingValues() == 2);
143 EXPECT_FALSE(verifyFunction(*Outlined));
144 EXPECT_FALSE(verifyFunction(*Func));
147 TEST(CodeExtractor, ExitBlockOrderingPhis) {
148 LLVMContext Ctx;
149 SMDiagnostic Err;
150 std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
151 define void @foo(i32 %a, i32 %b) {
152 entry:
153 %0 = alloca i32, align 4
154 br label %test0
155 test0:
156 %c = load i32, i32* %0, align 4
157 br label %test1
158 test1:
159 %e = load i32, i32* %0, align 4
160 br i1 true, label %first, label %test
161 test:
162 %d = load i32, i32* %0, align 4
163 br i1 true, label %first, label %next
164 first:
165 %1 = phi i32 [ %c, %test ], [ %e, %test1 ]
166 ret void
167 next:
168 %2 = add i32 %d, 1
169 %3 = add i32 %e, 1
170 ret void
172 )invalid",
173 Err, Ctx));
174 Function *Func = M->getFunction("foo");
175 SmallVector<BasicBlock *, 3> Candidates{ getBlockByName(Func, "test0"),
176 getBlockByName(Func, "test1"),
177 getBlockByName(Func, "test") };
179 CodeExtractor CE(Candidates);
180 EXPECT_TRUE(CE.isEligible());
182 CodeExtractorAnalysisCache CEAC(*Func);
183 Function *Outlined = CE.extractCodeRegion(CEAC);
184 EXPECT_TRUE(Outlined);
186 BasicBlock *FirstExitStub = getBlockByName(Outlined, "first.exitStub");
187 BasicBlock *NextExitStub = getBlockByName(Outlined, "next.exitStub");
189 Instruction *FirstTerm = FirstExitStub->getTerminator();
190 ReturnInst *FirstReturn = dyn_cast<ReturnInst>(FirstTerm);
191 EXPECT_TRUE(FirstReturn);
192 ConstantInt *CIFirst = dyn_cast<ConstantInt>(FirstReturn->getReturnValue());
193 EXPECT_TRUE(CIFirst->getLimitedValue() == 1u);
195 Instruction *NextTerm = NextExitStub->getTerminator();
196 ReturnInst *NextReturn = dyn_cast<ReturnInst>(NextTerm);
197 EXPECT_TRUE(NextReturn);
198 ConstantInt *CINext = dyn_cast<ConstantInt>(NextReturn->getReturnValue());
199 EXPECT_TRUE(CINext->getLimitedValue() == 0u);
201 EXPECT_FALSE(verifyFunction(*Outlined));
202 EXPECT_FALSE(verifyFunction(*Func));
205 TEST(CodeExtractor, ExitBlockOrdering) {
206 LLVMContext Ctx;
207 SMDiagnostic Err;
208 std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
209 define void @foo(i32 %a, i32 %b) {
210 entry:
211 %0 = alloca i32, align 4
212 br label %test0
213 test0:
214 %c = load i32, i32* %0, align 4
215 br label %test1
216 test1:
217 %e = load i32, i32* %0, align 4
218 br i1 true, label %first, label %test
219 test:
220 %d = load i32, i32* %0, align 4
221 br i1 true, label %first, label %next
222 first:
223 ret void
224 next:
225 %1 = add i32 %d, 1
226 %2 = add i32 %e, 1
227 ret void
229 )invalid",
230 Err, Ctx));
231 Function *Func = M->getFunction("foo");
232 SmallVector<BasicBlock *, 3> Candidates{ getBlockByName(Func, "test0"),
233 getBlockByName(Func, "test1"),
234 getBlockByName(Func, "test") };
236 CodeExtractor CE(Candidates);
237 EXPECT_TRUE(CE.isEligible());
239 CodeExtractorAnalysisCache CEAC(*Func);
240 Function *Outlined = CE.extractCodeRegion(CEAC);
241 EXPECT_TRUE(Outlined);
243 BasicBlock *FirstExitStub = getBlockByName(Outlined, "first.exitStub");
244 BasicBlock *NextExitStub = getBlockByName(Outlined, "next.exitStub");
246 Instruction *FirstTerm = FirstExitStub->getTerminator();
247 ReturnInst *FirstReturn = dyn_cast<ReturnInst>(FirstTerm);
248 EXPECT_TRUE(FirstReturn);
249 ConstantInt *CIFirst = dyn_cast<ConstantInt>(FirstReturn->getReturnValue());
250 EXPECT_TRUE(CIFirst->getLimitedValue() == 1u);
252 Instruction *NextTerm = NextExitStub->getTerminator();
253 ReturnInst *NextReturn = dyn_cast<ReturnInst>(NextTerm);
254 EXPECT_TRUE(NextReturn);
255 ConstantInt *CINext = dyn_cast<ConstantInt>(NextReturn->getReturnValue());
256 EXPECT_TRUE(CINext->getLimitedValue() == 0u);
258 EXPECT_FALSE(verifyFunction(*Outlined));
259 EXPECT_FALSE(verifyFunction(*Func));
262 TEST(CodeExtractor, ExitPHIOnePredFromRegion) {
263 LLVMContext Ctx;
264 SMDiagnostic Err;
265 std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
266 define i32 @foo() {
267 header:
268 br i1 undef, label %extracted1, label %pred
270 pred:
271 br i1 undef, label %exit1, label %exit2
273 extracted1:
274 br i1 undef, label %extracted2, label %exit1
276 extracted2:
277 br label %exit2
279 exit1:
280 %0 = phi i32 [ 1, %extracted1 ], [ 2, %pred ]
281 ret i32 %0
283 exit2:
284 %1 = phi i32 [ 3, %extracted2 ], [ 4, %pred ]
285 ret i32 %1
287 )invalid", Err, Ctx));
289 Function *Func = M->getFunction("foo");
290 SmallVector<BasicBlock *, 2> ExtractedBlocks{
291 getBlockByName(Func, "extracted1"),
292 getBlockByName(Func, "extracted2")
295 CodeExtractor CE(ExtractedBlocks);
296 EXPECT_TRUE(CE.isEligible());
298 CodeExtractorAnalysisCache CEAC(*Func);
299 Function *Outlined = CE.extractCodeRegion(CEAC);
300 EXPECT_TRUE(Outlined);
301 BasicBlock *Exit1 = getBlockByName(Func, "exit1");
302 BasicBlock *Exit2 = getBlockByName(Func, "exit2");
303 // Ensure that PHIs in exits are not splitted (since that they have only one
304 // incoming value from extracted region).
305 EXPECT_TRUE(Exit1 &&
306 cast<PHINode>(Exit1->front()).getNumIncomingValues() == 2);
307 EXPECT_TRUE(Exit2 &&
308 cast<PHINode>(Exit2->front()).getNumIncomingValues() == 2);
309 EXPECT_FALSE(verifyFunction(*Outlined));
310 EXPECT_FALSE(verifyFunction(*Func));
313 TEST(CodeExtractor, StoreOutputInvokeResultAfterEHPad) {
314 LLVMContext Ctx;
315 SMDiagnostic Err;
316 std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
317 declare i8 @hoge()
319 define i32 @foo() personality i8* null {
320 entry:
321 %call = invoke i8 @hoge()
322 to label %invoke.cont unwind label %lpad
324 invoke.cont: ; preds = %entry
325 unreachable
327 lpad: ; preds = %entry
328 %0 = landingpad { i8*, i32 }
329 catch i8* null
330 br i1 undef, label %catch, label %finally.catchall
332 catch: ; preds = %lpad
333 %call2 = invoke i8 @hoge()
334 to label %invoke.cont2 unwind label %lpad2
336 invoke.cont2: ; preds = %catch
337 %call3 = invoke i8 @hoge()
338 to label %invoke.cont3 unwind label %lpad2
340 invoke.cont3: ; preds = %invoke.cont2
341 unreachable
343 lpad2: ; preds = %invoke.cont2, %catch
344 %ex.1 = phi i8* [ undef, %invoke.cont2 ], [ null, %catch ]
345 %1 = landingpad { i8*, i32 }
346 catch i8* null
347 br label %finally.catchall
349 finally.catchall: ; preds = %lpad33, %lpad
350 %ex.2 = phi i8* [ %ex.1, %lpad2 ], [ null, %lpad ]
351 unreachable
353 )invalid", Err, Ctx));
355 if (!M) {
356 Err.print("unit", errs());
357 exit(1);
360 Function *Func = M->getFunction("foo");
361 EXPECT_FALSE(verifyFunction(*Func, &errs()));
363 SmallVector<BasicBlock *, 2> ExtractedBlocks{
364 getBlockByName(Func, "catch"),
365 getBlockByName(Func, "invoke.cont2"),
366 getBlockByName(Func, "invoke.cont3"),
367 getBlockByName(Func, "lpad2")
370 CodeExtractor CE(ExtractedBlocks);
371 EXPECT_TRUE(CE.isEligible());
373 CodeExtractorAnalysisCache CEAC(*Func);
374 Function *Outlined = CE.extractCodeRegion(CEAC);
375 EXPECT_TRUE(Outlined);
376 EXPECT_FALSE(verifyFunction(*Outlined, &errs()));
377 EXPECT_FALSE(verifyFunction(*Func, &errs()));
380 TEST(CodeExtractor, StoreOutputInvokeResultInExitStub) {
381 LLVMContext Ctx;
382 SMDiagnostic Err;
383 std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
384 declare i32 @bar()
386 define i32 @foo() personality i8* null {
387 entry:
388 %0 = invoke i32 @bar() to label %exit unwind label %lpad
390 exit:
391 ret i32 %0
393 lpad:
394 %1 = landingpad { i8*, i32 }
395 cleanup
396 resume { i8*, i32 } %1
398 )invalid",
399 Err, Ctx));
401 Function *Func = M->getFunction("foo");
402 SmallVector<BasicBlock *, 1> Blocks{ getBlockByName(Func, "entry"),
403 getBlockByName(Func, "lpad") };
405 CodeExtractor CE(Blocks);
406 EXPECT_TRUE(CE.isEligible());
408 CodeExtractorAnalysisCache CEAC(*Func);
409 Function *Outlined = CE.extractCodeRegion(CEAC);
410 EXPECT_TRUE(Outlined);
411 EXPECT_FALSE(verifyFunction(*Outlined));
412 EXPECT_FALSE(verifyFunction(*Func));
415 TEST(CodeExtractor, ExtractAndInvalidateAssumptionCache) {
416 LLVMContext Ctx;
417 SMDiagnostic Err;
418 std::unique_ptr<Module> M(parseAssemblyString(R"ir(
419 target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
420 target triple = "aarch64"
422 %b = type { i64 }
423 declare void @g(i8*)
425 declare void @llvm.assume(i1) #0
427 define void @test() {
428 entry:
429 br label %label
431 label:
432 %0 = load %b*, %b** inttoptr (i64 8 to %b**), align 8
433 %1 = getelementptr inbounds %b, %b* %0, i64 undef, i32 0
434 %2 = load i64, i64* %1, align 8
435 %3 = icmp ugt i64 %2, 1
436 br i1 %3, label %if.then, label %if.else
438 if.then:
439 unreachable
441 if.else:
442 call void @g(i8* undef)
443 store i64 undef, i64* null, align 536870912
444 %4 = icmp eq i64 %2, 0
445 call void @llvm.assume(i1 %4)
446 unreachable
449 attributes #0 = { nounwind willreturn }
450 )ir",
451 Err, Ctx));
453 assert(M && "Could not parse module?");
454 Function *Func = M->getFunction("test");
455 SmallVector<BasicBlock *, 1> Blocks{ getBlockByName(Func, "if.else") };
456 AssumptionCache AC(*Func);
457 CodeExtractor CE(Blocks, nullptr, false, nullptr, nullptr, &AC);
458 EXPECT_TRUE(CE.isEligible());
460 CodeExtractorAnalysisCache CEAC(*Func);
461 Function *Outlined = CE.extractCodeRegion(CEAC);
462 EXPECT_TRUE(Outlined);
463 EXPECT_FALSE(verifyFunction(*Outlined));
464 EXPECT_FALSE(verifyFunction(*Func));
465 EXPECT_FALSE(CE.verifyAssumptionCache(*Func, *Outlined, &AC));
468 TEST(CodeExtractor, RemoveBitcastUsesFromOuterLifetimeMarkers) {
469 LLVMContext Ctx;
470 SMDiagnostic Err;
471 std::unique_ptr<Module> M(parseAssemblyString(R"ir(
472 target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
473 target triple = "x86_64-unknown-linux-gnu"
475 declare void @use(i32*)
476 declare void @llvm.lifetime.start.p0i8(i64, i8*)
477 declare void @llvm.lifetime.end.p0i8(i64, i8*)
479 define void @foo() {
480 entry:
481 %0 = alloca i32
482 br label %extract
484 extract:
485 %1 = bitcast i32* %0 to i8*
486 call void @llvm.lifetime.start.p0i8(i64 4, i8* %1)
487 call void @use(i32* %0)
488 br label %exit
490 exit:
491 call void @use(i32* %0)
492 call void @llvm.lifetime.end.p0i8(i64 4, i8* %1)
493 ret void
495 )ir",
496 Err, Ctx));
498 Function *Func = M->getFunction("foo");
499 SmallVector<BasicBlock *, 1> Blocks{getBlockByName(Func, "extract")};
501 CodeExtractor CE(Blocks);
502 EXPECT_TRUE(CE.isEligible());
504 CodeExtractorAnalysisCache CEAC(*Func);
505 SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;
506 BasicBlock *CommonExit = nullptr;
507 CE.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit);
508 CE.findInputsOutputs(Inputs, Outputs, SinkingCands);
509 EXPECT_EQ(Outputs.size(), 0U);
511 Function *Outlined = CE.extractCodeRegion(CEAC);
512 EXPECT_TRUE(Outlined);
513 EXPECT_FALSE(verifyFunction(*Outlined));
514 EXPECT_FALSE(verifyFunction(*Func));
517 TEST(CodeExtractor, PartialAggregateArgs) {
518 LLVMContext Ctx;
519 SMDiagnostic Err;
520 std::unique_ptr<Module> M(parseAssemblyString(R"ir(
521 target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
522 target triple = "x86_64-unknown-linux-gnu"
524 ; use different types such that an index mismatch will result in a type mismatch during verification.
525 declare void @use16(i16)
526 declare void @use32(i32)
527 declare void @use64(i64)
529 define void @foo(i16 %a, i32 %b, i64 %c) {
530 entry:
531 br label %extract
533 extract:
534 call void @use16(i16 %a)
535 call void @use32(i32 %b)
536 call void @use64(i64 %c)
537 %d = add i16 21, 21
538 %e = add i32 21, 21
539 %f = add i64 21, 21
540 br label %exit
542 exit:
543 call void @use16(i16 %d)
544 call void @use32(i32 %e)
545 call void @use64(i64 %f)
546 ret void
548 )ir",
549 Err, Ctx));
551 Function *Func = M->getFunction("foo");
552 SmallVector<BasicBlock *, 1> Blocks{getBlockByName(Func, "extract")};
554 // Create the CodeExtractor with arguments aggregation enabled.
555 CodeExtractor CE(Blocks, /* DominatorTree */ nullptr,
556 /* AggregateArgs */ true);
557 EXPECT_TRUE(CE.isEligible());
559 CodeExtractorAnalysisCache CEAC(*Func);
560 SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;
561 BasicBlock *CommonExit = nullptr;
562 CE.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit);
563 CE.findInputsOutputs(Inputs, Outputs, SinkingCands);
564 // Exclude the middle input and output from the argument aggregate.
565 CE.excludeArgFromAggregate(Inputs[1]);
566 CE.excludeArgFromAggregate(Outputs[1]);
568 Function *Outlined = CE.extractCodeRegion(CEAC, Inputs, Outputs);
569 EXPECT_TRUE(Outlined);
570 // Expect 3 arguments in the outlined function: the excluded input, the
571 // excluded output, and the struct aggregate for the remaining inputs.
572 EXPECT_EQ(Outlined->arg_size(), 3U);
573 EXPECT_FALSE(verifyFunction(*Outlined));
574 EXPECT_FALSE(verifyFunction(*Func));
577 TEST(CodeExtractor, AllocaBlock) {
578 LLVMContext Ctx;
579 SMDiagnostic Err;
580 std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
581 define i32 @foo(i32 %x, i32 %y, i32 %z) {
582 entry:
583 br label %allocas
585 allocas:
586 br label %body
588 body:
589 %w = add i32 %x, %y
590 br label %notExtracted
592 notExtracted:
593 %r = add i32 %w, %x
594 ret i32 %r
596 )invalid",
597 Err, Ctx));
599 Function *Func = M->getFunction("foo");
600 SmallVector<BasicBlock *, 3> Candidates{getBlockByName(Func, "body")};
602 BasicBlock *AllocaBlock = getBlockByName(Func, "allocas");
603 CodeExtractor CE(Candidates, nullptr, true, nullptr, nullptr, nullptr, false,
604 false, AllocaBlock);
605 CE.excludeArgFromAggregate(Func->getArg(0));
606 CE.excludeArgFromAggregate(getInstByName(Func, "w"));
607 EXPECT_TRUE(CE.isEligible());
609 CodeExtractorAnalysisCache CEAC(*Func);
610 SetVector<Value *> Inputs, Outputs;
611 Function *Outlined = CE.extractCodeRegion(CEAC, Inputs, Outputs);
612 EXPECT_TRUE(Outlined);
613 EXPECT_FALSE(verifyFunction(*Outlined));
614 EXPECT_FALSE(verifyFunction(*Func));
616 // The only added allocas may be in the dedicated alloca block. There should
617 // be one alloca for the struct, and another one for the reload value.
618 int NumAllocas = 0;
619 for (Instruction &I : instructions(Func)) {
620 if (!isa<AllocaInst>(I))
621 continue;
622 EXPECT_EQ(I.getParent(), AllocaBlock);
623 NumAllocas += 1;
625 EXPECT_EQ(NumAllocas, 2);
628 /// Regression test to ensure we don't crash trying to set the name of the ptr
629 /// argument
630 TEST(CodeExtractor, PartialAggregateArgs2) {
631 LLVMContext Ctx;
632 SMDiagnostic Err;
633 std::unique_ptr<Module> M(parseAssemblyString(R"ir(
634 declare void @usei(i32)
635 declare void @usep(ptr)
637 define void @foo(i32 %a, i32 %b, ptr %p) {
638 entry:
639 br label %extract
641 extract:
642 call void @usei(i32 %a)
643 call void @usei(i32 %b)
644 call void @usep(ptr %p)
645 br label %exit
647 exit:
648 ret void
650 )ir",
651 Err, Ctx));
653 Function *Func = M->getFunction("foo");
654 SmallVector<BasicBlock *, 1> Blocks{getBlockByName(Func, "extract")};
656 // Create the CodeExtractor with arguments aggregation enabled.
657 CodeExtractor CE(Blocks, /* DominatorTree */ nullptr,
658 /* AggregateArgs */ true);
659 EXPECT_TRUE(CE.isEligible());
661 CodeExtractorAnalysisCache CEAC(*Func);
662 SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;
663 BasicBlock *CommonExit = nullptr;
664 CE.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit);
665 CE.findInputsOutputs(Inputs, Outputs, SinkingCands);
666 // Exclude the last input from the argument aggregate.
667 CE.excludeArgFromAggregate(Inputs[2]);
669 Function *Outlined = CE.extractCodeRegion(CEAC, Inputs, Outputs);
670 EXPECT_TRUE(Outlined);
671 EXPECT_FALSE(verifyFunction(*Outlined));
672 EXPECT_FALSE(verifyFunction(*Func));
675 TEST(CodeExtractor, OpenMPAggregateArgs) {
676 LLVMContext Ctx;
677 SMDiagnostic Err;
678 std::unique_ptr<Module> M(parseAssemblyString(R"ir(
679 target datalayout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-p7:160:256:256:32-p8:128:128-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7:8:9"
680 target triple = "amdgcn-amd-amdhsa"
682 define void @foo(ptr %0) {
683 %2= alloca ptr, align 8, addrspace(5)
684 %3 = addrspacecast ptr addrspace(5) %2 to ptr
685 store ptr %0, ptr %3, align 8
686 %4 = load ptr, ptr %3, align 8
687 br label %entry
689 entry:
690 br label %extract
692 extract:
693 store i64 10, ptr %4, align 4
694 br label %exit
696 exit:
697 ret void
699 )ir",
700 Err, Ctx));
701 Function *Func = M->getFunction("foo");
702 SmallVector<BasicBlock *, 1> Blocks{getBlockByName(Func, "extract")};
704 // Create the CodeExtractor with arguments aggregation enabled.
705 // Outlined function argument should be declared in 0 address space
706 // even if the default alloca address space is 5.
707 CodeExtractor CE(Blocks, /* DominatorTree */ nullptr,
708 /* AggregateArgs */ true, /* BlockFrequencyInfo */ nullptr,
709 /* BranchProbabilityInfo */ nullptr,
710 /* AssumptionCache */ nullptr,
711 /* AllowVarArgs */ true,
712 /* AllowAlloca */ true,
713 /* AllocaBlock*/ &Func->getEntryBlock(),
714 /* Suffix */ ".outlined",
715 /* ArgsInZeroAddressSpace */ true);
717 EXPECT_TRUE(CE.isEligible());
719 CodeExtractorAnalysisCache CEAC(*Func);
720 SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;
721 BasicBlock *CommonExit = nullptr;
722 CE.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit);
723 CE.findInputsOutputs(Inputs, Outputs, SinkingCands);
725 Function *Outlined = CE.extractCodeRegion(CEAC, Inputs, Outputs);
726 EXPECT_TRUE(Outlined);
727 EXPECT_EQ(Outlined->arg_size(), 1U);
728 // Check address space of outlined argument is ptr in address space 0
729 EXPECT_EQ(Outlined->getArg(0)->getType(),
730 PointerType::get(M->getContext(), 0));
731 EXPECT_FALSE(verifyFunction(*Outlined));
732 EXPECT_FALSE(verifyFunction(*Func));
734 } // end anonymous namespace