Fix test failures introduced by PR #113697 (#116941)
[llvm-project.git] / llvm / unittests / Analysis / PluginInlineAdvisorAnalysisTest.cpp
blobca4ea8b627e839bfde810a3e9b09049d99b98e21
1 #include "llvm/Analysis/CallGraph.h"
2 #include "llvm/AsmParser/Parser.h"
3 #include "llvm/Config/config.h"
4 #include "llvm/IR/Module.h"
5 #include "llvm/Passes/PassBuilder.h"
6 #include "llvm/Passes/PassPlugin.h"
7 #include "llvm/Support/CommandLine.h"
8 #include "llvm/Support/raw_ostream.h"
9 #include "llvm/Testing/Support/Error.h"
10 #include "gtest/gtest.h"
12 namespace llvm {
14 namespace {
16 void anchor() {}
18 static std::string libPath(const std::string Name = "InlineAdvisorPlugin") {
19 const auto &Argvs = testing::internal::GetArgvs();
20 const char *Argv0 =
21 Argvs.size() > 0 ? Argvs[0].c_str() : "PluginInlineAdvisorAnalysisTest";
22 void *Ptr = (void *)(intptr_t)anchor;
23 std::string Path = sys::fs::getMainExecutable(Argv0, Ptr);
24 llvm::SmallString<256> Buf{sys::path::parent_path(Path)};
25 sys::path::append(Buf, (Name + LLVM_PLUGIN_EXT).c_str());
26 return std::string(Buf.str());
29 // Example of a custom InlineAdvisor that only inlines calls to functions called
30 // "foo".
31 class FooOnlyInlineAdvisor : public InlineAdvisor {
32 public:
33 FooOnlyInlineAdvisor(Module &M, FunctionAnalysisManager &FAM,
34 InlineParams Params, InlineContext IC)
35 : InlineAdvisor(M, FAM, IC) {}
37 std::unique_ptr<InlineAdvice> getAdviceImpl(CallBase &CB) override {
38 if (CB.getCalledFunction()->getName() == "foo")
39 return std::make_unique<InlineAdvice>(this, CB, getCallerORE(CB), true);
40 return std::make_unique<InlineAdvice>(this, CB, getCallerORE(CB), false);
44 static InlineAdvisor *fooOnlyFactory(Module &M, FunctionAnalysisManager &FAM,
45 InlineParams Params, InlineContext IC) {
46 return new FooOnlyInlineAdvisor(M, FAM, Params, IC);
49 struct CompilerInstance {
50 LLVMContext Ctx;
51 ModulePassManager MPM;
52 InlineParams IP;
54 PassBuilder PB;
55 LoopAnalysisManager LAM;
56 FunctionAnalysisManager FAM;
57 CGSCCAnalysisManager CGAM;
58 ModuleAnalysisManager MAM;
60 SMDiagnostic Error;
62 // connect the plugin to our compiler instance
63 void setupPlugin() {
64 auto PluginPath = libPath();
65 ASSERT_NE("", PluginPath);
66 Expected<PassPlugin> Plugin = PassPlugin::Load(PluginPath);
67 ASSERT_TRUE(!!Plugin) << "Plugin path: " << PluginPath;
68 Plugin->registerPassBuilderCallbacks(PB);
71 // connect the FooOnlyInlineAdvisor to our compiler instance
72 void setupFooOnly() {
73 MAM.registerPass(
74 [&] { return PluginInlineAdvisorAnalysis(fooOnlyFactory); });
77 CompilerInstance() {
78 IP = getInlineParams(3, 0);
79 PB.registerModuleAnalyses(MAM);
80 PB.registerCGSCCAnalyses(CGAM);
81 PB.registerFunctionAnalyses(FAM);
82 PB.registerLoopAnalyses(LAM);
83 PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
84 MPM.addPass(ModuleInlinerPass(IP, InliningAdvisorMode::Default,
85 ThinOrFullLTOPhase::None));
88 std::string output;
89 std::unique_ptr<Module> outputM;
91 auto run(StringRef IR) {
92 outputM = parseAssemblyString(IR, Error, Ctx);
93 MPM.run(*outputM, MAM);
94 ASSERT_TRUE(outputM);
95 output.clear();
96 raw_string_ostream o_stream{output};
97 outputM->print(o_stream, nullptr);
98 ASSERT_TRUE(true);
102 StringRef TestIRS[] = {
103 // Simple 3 function inline case
105 define void @f1() {
106 call void @foo()
107 ret void
109 define void @foo() {
110 call void @f3()
111 ret void
113 define void @f3() {
114 ret void
117 // Test that has 5 functions of which 2 are recursive
119 define void @f1() {
120 call void @foo()
121 ret void
123 define void @f2() {
124 call void @foo()
125 ret void
127 define void @foo() {
128 call void @f4()
129 call void @f5()
130 ret void
132 define void @f4() {
133 ret void
135 define void @f5() {
136 call void @foo()
137 ret void
140 // test with 2 mutually recursive functions and 1 function with a loop
142 define void @f1() {
143 call void @f2()
144 ret void
146 define void @f2() {
147 call void @f3()
148 ret void
150 define void @f3() {
151 call void @f1()
152 ret void
154 define void @f4() {
155 br label %loop
156 loop:
157 call void @f5()
158 br label %loop
160 define void @f5() {
161 ret void
164 // test that has a function that computes fibonacci in a loop, one in a
165 // recurisve manner, and one that calls both and compares them
167 define i32 @fib_loop(i32 %n){
168 %curr = alloca i32
169 %last = alloca i32
170 %i = alloca i32
171 store i32 1, i32* %curr
172 store i32 1, i32* %last
173 store i32 2, i32* %i
174 br label %loop_cond
175 loop_cond:
176 %i_val = load i32, i32* %i
177 %cmp = icmp slt i32 %i_val, %n
178 br i1 %cmp, label %loop_body, label %loop_end
179 loop_body:
180 %curr_val = load i32, i32* %curr
181 %last_val = load i32, i32* %last
182 %add = add i32 %curr_val, %last_val
183 store i32 %add, i32* %last
184 store i32 %curr_val, i32* %curr
185 %i_val2 = load i32, i32* %i
186 %add2 = add i32 %i_val2, 1
187 store i32 %add2, i32* %i
188 br label %loop_cond
189 loop_end:
190 %curr_val3 = load i32, i32* %curr
191 ret i32 %curr_val3
194 define i32 @fib_rec(i32 %n){
195 %cmp = icmp eq i32 %n, 0
196 %cmp2 = icmp eq i32 %n, 1
197 %or = or i1 %cmp, %cmp2
198 br i1 %or, label %if_true, label %if_false
199 if_true:
200 ret i32 1
201 if_false:
202 %sub = sub i32 %n, 1
203 %call = call i32 @fib_rec(i32 %sub)
204 %sub2 = sub i32 %n, 2
205 %call2 = call i32 @fib_rec(i32 %sub2)
206 %add = add i32 %call, %call2
207 ret i32 %add
210 define i32 @fib_check(){
211 %correct = alloca i32
212 %i = alloca i32
213 store i32 1, i32* %correct
214 store i32 0, i32* %i
215 br label %loop_cond
216 loop_cond:
217 %i_val = load i32, i32* %i
218 %cmp = icmp slt i32 %i_val, 10
219 br i1 %cmp, label %loop_body, label %loop_end
220 loop_body:
221 %i_val2 = load i32, i32* %i
222 %call = call i32 @fib_loop(i32 %i_val2)
223 %i_val3 = load i32, i32* %i
224 %call2 = call i32 @fib_rec(i32 %i_val3)
225 %cmp2 = icmp ne i32 %call, %call2
226 br i1 %cmp2, label %if_true, label %if_false
227 if_true:
228 store i32 0, i32* %correct
229 br label %if_end
230 if_false:
231 br label %if_end
232 if_end:
233 %i_val4 = load i32, i32* %i
234 %add = add i32 %i_val4, 1
235 store i32 %add, i32* %i
236 br label %loop_cond
237 loop_end:
238 %correct_val = load i32, i32* %correct
239 ret i32 %correct_val
241 )"};
243 } // namespace
245 // check that loading a plugin works
246 // the plugin being loaded acts identically to the default inliner
247 TEST(PluginInlineAdvisorTest, PluginLoad) {
248 #if !defined(LLVM_ENABLE_PLUGINS)
249 // Skip the test if plugins are disabled.
250 GTEST_SKIP();
251 #endif
252 CompilerInstance DefaultCI{};
254 CompilerInstance PluginCI{};
255 PluginCI.setupPlugin();
257 for (StringRef IR : TestIRS) {
258 DefaultCI.run(IR);
259 std::string default_output = DefaultCI.output;
260 PluginCI.run(IR);
261 std::string dynamic_output = PluginCI.output;
262 ASSERT_EQ(default_output, dynamic_output);
266 // check that the behaviour of a custom inliner is correct
267 // the custom inliner inlines all functions that are not named "foo"
268 // this testdoes not require plugins to be enabled
269 TEST(PluginInlineAdvisorTest, CustomAdvisor) {
270 CompilerInstance CI{};
271 CI.setupFooOnly();
273 for (StringRef IR : TestIRS) {
274 CI.run(IR);
275 CallGraph CGraph = CallGraph(*CI.outputM);
276 for (auto &node : CGraph) {
277 for (auto &edge : *node.second) {
278 if (!edge.first)
279 continue;
280 ASSERT_NE(edge.second->getFunction()->getName(), "foo");
286 } // namespace llvm