1 //===- ExtraRematTest.cpp - Coroutines unit tests -------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "llvm/AsmParser/Parser.h"
10 #include "llvm/IR/Module.h"
11 #include "llvm/Passes/PassBuilder.h"
12 #include "llvm/Support/SourceMgr.h"
13 #include "llvm/Testing/Support/Error.h"
14 #include "llvm/Transforms/Coroutines/CoroSplit.h"
15 #include "gtest/gtest.h"
21 struct ExtraRematTest
: public testing::Test
{
23 ModulePassManager MPM
;
25 LoopAnalysisManager LAM
;
26 FunctionAnalysisManager FAM
;
27 CGSCCAnalysisManager CGAM
;
28 ModuleAnalysisManager MAM
;
30 std::unique_ptr
<Module
> M
;
33 PB
.registerModuleAnalyses(MAM
);
34 PB
.registerCGSCCAnalyses(CGAM
);
35 PB
.registerFunctionAnalyses(FAM
);
36 PB
.registerLoopAnalyses(LAM
);
37 PB
.crossRegisterProxies(LAM
, FAM
, CGAM
, MAM
);
40 BasicBlock
*getBasicBlockByName(Function
*F
, StringRef Name
) const {
41 for (BasicBlock
&BB
: *F
) {
42 if (BB
.getName() == Name
)
48 CallInst
*getCallByName(BasicBlock
*BB
, StringRef Name
) const {
49 for (Instruction
&I
: *BB
) {
50 if (CallInst
*CI
= dyn_cast
<CallInst
>(&I
))
51 if (CI
->getCalledFunction()->getName() == Name
)
57 void ParseAssembly(const StringRef IR
) {
59 M
= parseAssemblyString(IR
, Error
, Context
);
61 raw_string_ostream
os(errMsg
);
64 // A failure here means that the test itself is buggy.
66 report_fatal_error(os
.str().c_str());
71 define ptr @f(i32 %n) presplitcoroutine {
73 %id = call token @llvm.coro.id(i32 0, ptr null, ptr null, ptr null)
74 %size = call i32 @llvm.coro.size.i32()
75 %alloc = call ptr @malloc(i32 %size)
76 %hdl = call ptr @llvm.coro.begin(token %id, ptr %alloc)
79 %val2 = call i32 @should.remat(i32 %inc1)
80 %sp1 = call i8 @llvm.coro.suspend(token none, i1 false)
81 switch i8 %sp1, label %suspend [i8 0, label %resume1
84 %inc2 = add i32 %val2, 1
85 %sp2 = call i8 @llvm.coro.suspend(token none, i1 false)
86 switch i8 %sp1, label %suspend [i8 0, label %resume2
90 call void @print(i32 %val2)
91 call void @print(i32 %inc2)
95 %mem = call ptr @llvm.coro.free(token %id, ptr %hdl)
96 call void @free(ptr %mem)
99 call i1 @llvm.coro.end(ptr %hdl, i1 0)
103 declare ptr @llvm.coro.free(token, ptr)
104 declare i32 @llvm.coro.size.i32()
105 declare i8 @llvm.coro.suspend(token, i1)
106 declare void @llvm.coro.resume(ptr)
107 declare void @llvm.coro.destroy(ptr)
109 declare token @llvm.coro.id(i32, ptr, ptr, ptr)
110 declare i1 @llvm.coro.alloc(token)
111 declare ptr @llvm.coro.begin(token, ptr)
112 declare i1 @llvm.coro.end(ptr, i1)
114 declare i32 @should.remat(i32)
116 declare noalias ptr @malloc(i32)
117 declare void @print(i32)
118 declare void @free(ptr)
121 // Materializable callback with extra rematerialization
122 bool ExtraMaterializable(Instruction
&I
) {
123 if (isa
<CastInst
>(&I
) || isa
<GetElementPtrInst
>(&I
) ||
124 isa
<BinaryOperator
>(&I
) || isa
<CmpInst
>(&I
) || isa
<SelectInst
>(&I
))
127 if (auto *CI
= dyn_cast
<CallInst
>(&I
)) {
128 auto *CalledFunc
= CI
->getCalledFunction();
129 if (CalledFunc
&& CalledFunc
->getName().starts_with("should.remat"))
136 TEST_F(ExtraRematTest
, TestCoroRematDefault
) {
141 CGSCCPassManager CGPM
;
142 CGPM
.addPass(CoroSplitPass());
143 MPM
.addPass(createModuleToPostOrderCGSCCPassAdaptor(std::move(CGPM
)));
146 // Verify that extra rematerializable instruction has been rematerialized
147 Function
*F
= M
->getFunction("f.resume");
148 ASSERT_TRUE(F
) << "could not find split function f.resume";
150 BasicBlock
*Resume1
= getBasicBlockByName(F
, "resume1");
152 << "could not find expected BB resume1 in split function";
154 // With default materialization the intrinsic should not have been
156 CallInst
*CI
= getCallByName(Resume1
, "should.remat");
160 TEST_F(ExtraRematTest
, TestCoroRematWithCallback
) {
165 CGSCCPassManager CGPM
;
167 CoroSplitPass(std::function
<bool(Instruction
&)>(ExtraMaterializable
)));
168 MPM
.addPass(createModuleToPostOrderCGSCCPassAdaptor(std::move(CGPM
)));
171 // Verify that extra rematerializable instruction has been rematerialized
172 Function
*F
= M
->getFunction("f.resume");
173 ASSERT_TRUE(F
) << "could not find split function f.resume";
175 BasicBlock
*Resume1
= getBasicBlockByName(F
, "resume1");
177 << "could not find expected BB resume1 in split function";
179 // With callback the extra rematerialization of the function should have
181 CallInst
*CI
= getCallByName(Resume1
, "should.remat");