Bump version to 19.1.0-rc3
[llvm-project.git] / llvm / unittests / Transforms / Coroutines / ExtraRematTest.cpp
blobda78c151e7f68d0a714d78b175a6bdea1c4984bc
1 //===- ExtraRematTest.cpp - Coroutines unit tests -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "llvm/AsmParser/Parser.h"
10 #include "llvm/IR/Module.h"
11 #include "llvm/Passes/PassBuilder.h"
12 #include "llvm/Support/SourceMgr.h"
13 #include "llvm/Testing/Support/Error.h"
14 #include "llvm/Transforms/Coroutines/CoroSplit.h"
15 #include "gtest/gtest.h"
17 using namespace llvm;
19 namespace {
21 struct ExtraRematTest : public testing::Test {
22 LLVMContext Ctx;
23 ModulePassManager MPM;
24 PassBuilder PB;
25 LoopAnalysisManager LAM;
26 FunctionAnalysisManager FAM;
27 CGSCCAnalysisManager CGAM;
28 ModuleAnalysisManager MAM;
29 LLVMContext Context;
30 std::unique_ptr<Module> M;
32 ExtraRematTest() {
33 PB.registerModuleAnalyses(MAM);
34 PB.registerCGSCCAnalyses(CGAM);
35 PB.registerFunctionAnalyses(FAM);
36 PB.registerLoopAnalyses(LAM);
37 PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
40 BasicBlock *getBasicBlockByName(Function *F, StringRef Name) const {
41 for (BasicBlock &BB : *F) {
42 if (BB.getName() == Name)
43 return &BB;
45 return nullptr;
48 CallInst *getCallByName(BasicBlock *BB, StringRef Name) const {
49 for (Instruction &I : *BB) {
50 if (CallInst *CI = dyn_cast<CallInst>(&I))
51 if (CI->getCalledFunction()->getName() == Name)
52 return CI;
54 return nullptr;
57 void ParseAssembly(const StringRef IR) {
58 SMDiagnostic Error;
59 M = parseAssemblyString(IR, Error, Context);
60 std::string errMsg;
61 raw_string_ostream os(errMsg);
62 Error.print("", os);
64 // A failure here means that the test itself is buggy.
65 if (!M)
66 report_fatal_error(os.str().c_str());
70 StringRef Text = R"(
71 define ptr @f(i32 %n) presplitcoroutine {
72 entry:
73 %id = call token @llvm.coro.id(i32 0, ptr null, ptr null, ptr null)
74 %size = call i32 @llvm.coro.size.i32()
75 %alloc = call ptr @malloc(i32 %size)
76 %hdl = call ptr @llvm.coro.begin(token %id, ptr %alloc)
78 %inc1 = add i32 %n, 1
79 %val2 = call i32 @should.remat(i32 %inc1)
80 %sp1 = call i8 @llvm.coro.suspend(token none, i1 false)
81 switch i8 %sp1, label %suspend [i8 0, label %resume1
82 i8 1, label %cleanup]
83 resume1:
84 %inc2 = add i32 %val2, 1
85 %sp2 = call i8 @llvm.coro.suspend(token none, i1 false)
86 switch i8 %sp1, label %suspend [i8 0, label %resume2
87 i8 1, label %cleanup]
89 resume2:
90 call void @print(i32 %val2)
91 call void @print(i32 %inc2)
92 br label %cleanup
94 cleanup:
95 %mem = call ptr @llvm.coro.free(token %id, ptr %hdl)
96 call void @free(ptr %mem)
97 br label %suspend
98 suspend:
99 call i1 @llvm.coro.end(ptr %hdl, i1 0)
100 ret ptr %hdl
103 declare ptr @llvm.coro.free(token, ptr)
104 declare i32 @llvm.coro.size.i32()
105 declare i8 @llvm.coro.suspend(token, i1)
106 declare void @llvm.coro.resume(ptr)
107 declare void @llvm.coro.destroy(ptr)
109 declare token @llvm.coro.id(i32, ptr, ptr, ptr)
110 declare i1 @llvm.coro.alloc(token)
111 declare ptr @llvm.coro.begin(token, ptr)
112 declare i1 @llvm.coro.end(ptr, i1)
114 declare i32 @should.remat(i32)
116 declare noalias ptr @malloc(i32)
117 declare void @print(i32)
118 declare void @free(ptr)
121 // Materializable callback with extra rematerialization
122 bool ExtraMaterializable(Instruction &I) {
123 if (isa<CastInst>(&I) || isa<GetElementPtrInst>(&I) ||
124 isa<BinaryOperator>(&I) || isa<CmpInst>(&I) || isa<SelectInst>(&I))
125 return true;
127 if (auto *CI = dyn_cast<CallInst>(&I)) {
128 auto *CalledFunc = CI->getCalledFunction();
129 if (CalledFunc && CalledFunc->getName().starts_with("should.remat"))
130 return true;
133 return false;
136 TEST_F(ExtraRematTest, TestCoroRematDefault) {
137 ParseAssembly(Text);
139 ASSERT_TRUE(M);
141 CGSCCPassManager CGPM;
142 CGPM.addPass(CoroSplitPass());
143 MPM.addPass(createModuleToPostOrderCGSCCPassAdaptor(std::move(CGPM)));
144 MPM.run(*M, MAM);
146 // Verify that extra rematerializable instruction has been rematerialized
147 Function *F = M->getFunction("f.resume");
148 ASSERT_TRUE(F) << "could not find split function f.resume";
150 BasicBlock *Resume1 = getBasicBlockByName(F, "resume1");
151 ASSERT_TRUE(Resume1)
152 << "could not find expected BB resume1 in split function";
154 // With default materialization the intrinsic should not have been
155 // rematerialized
156 CallInst *CI = getCallByName(Resume1, "should.remat");
157 ASSERT_FALSE(CI);
160 TEST_F(ExtraRematTest, TestCoroRematWithCallback) {
161 ParseAssembly(Text);
163 ASSERT_TRUE(M);
165 CGSCCPassManager CGPM;
166 CGPM.addPass(
167 CoroSplitPass(std::function<bool(Instruction &)>(ExtraMaterializable)));
168 MPM.addPass(createModuleToPostOrderCGSCCPassAdaptor(std::move(CGPM)));
169 MPM.run(*M, MAM);
171 // Verify that extra rematerializable instruction has been rematerialized
172 Function *F = M->getFunction("f.resume");
173 ASSERT_TRUE(F) << "could not find split function f.resume";
175 BasicBlock *Resume1 = getBasicBlockByName(F, "resume1");
176 ASSERT_TRUE(Resume1)
177 << "could not find expected BB resume1 in split function";
179 // With callback the extra rematerialization of the function should have
180 // happened
181 CallInst *CI = getCallByName(Resume1, "should.remat");
182 ASSERT_TRUE(CI);
184 } // namespace