Bump version to 19.1.0-rc3
[llvm-project.git] / llvm / unittests / Analysis / PluginInlineAdvisorAnalysisTest.cpp
blob3330751120e6c8531502c7743b57c7571ba02fa9
1 #include "llvm/Analysis/CallGraph.h"
2 #include "llvm/AsmParser/Parser.h"
3 #include "llvm/Config/config.h"
4 #include "llvm/IR/Module.h"
5 #include "llvm/Passes/PassBuilder.h"
6 #include "llvm/Passes/PassPlugin.h"
7 #include "llvm/Support/CommandLine.h"
8 #include "llvm/Support/raw_ostream.h"
9 #include "llvm/Testing/Support/Error.h"
10 #include "gtest/gtest.h"
12 namespace llvm {
14 namespace {
16 void anchor() {}
18 static std::string libPath(const std::string Name = "InlineAdvisorPlugin") {
19 const auto &Argvs = testing::internal::GetArgvs();
20 const char *Argv0 =
21 Argvs.size() > 0 ? Argvs[0].c_str() : "PluginInlineAdvisorAnalysisTest";
22 void *Ptr = (void *)(intptr_t)anchor;
23 std::string Path = sys::fs::getMainExecutable(Argv0, Ptr);
24 llvm::SmallString<256> Buf{sys::path::parent_path(Path)};
25 sys::path::append(Buf, (Name + LLVM_PLUGIN_EXT).c_str());
26 return std::string(Buf.str());
29 // Example of a custom InlineAdvisor that only inlines calls to functions called
30 // "foo".
31 class FooOnlyInlineAdvisor : public InlineAdvisor {
32 public:
33 FooOnlyInlineAdvisor(Module &M, FunctionAnalysisManager &FAM,
34 InlineParams Params, InlineContext IC)
35 : InlineAdvisor(M, FAM, IC) {}
37 std::unique_ptr<InlineAdvice> getAdviceImpl(CallBase &CB) override {
38 if (CB.getCalledFunction()->getName() == "foo")
39 return std::make_unique<InlineAdvice>(this, CB, getCallerORE(CB), true);
40 return std::make_unique<InlineAdvice>(this, CB, getCallerORE(CB), false);
44 static InlineAdvisor *fooOnlyFactory(Module &M, FunctionAnalysisManager &FAM,
45 InlineParams Params, InlineContext IC) {
46 return new FooOnlyInlineAdvisor(M, FAM, Params, IC);
49 struct CompilerInstance {
50 LLVMContext Ctx;
51 ModulePassManager MPM;
52 InlineParams IP;
54 PassBuilder PB;
55 LoopAnalysisManager LAM;
56 FunctionAnalysisManager FAM;
57 CGSCCAnalysisManager CGAM;
58 ModuleAnalysisManager MAM;
60 SMDiagnostic Error;
62 // connect the plugin to our compiler instance
63 void setupPlugin() {
64 auto PluginPath = libPath();
65 ASSERT_NE("", PluginPath);
66 Expected<PassPlugin> Plugin = PassPlugin::Load(PluginPath);
67 ASSERT_TRUE(!!Plugin) << "Plugin path: " << PluginPath;
68 Plugin->registerPassBuilderCallbacks(PB);
69 ASSERT_THAT_ERROR(PB.parsePassPipeline(MPM, "dynamic-inline-advisor"),
70 Succeeded());
73 // connect the FooOnlyInlineAdvisor to our compiler instance
74 void setupFooOnly() {
75 MAM.registerPass(
76 [&] { return PluginInlineAdvisorAnalysis(fooOnlyFactory); });
79 CompilerInstance() {
80 IP = getInlineParams(3, 0);
81 PB.registerModuleAnalyses(MAM);
82 PB.registerCGSCCAnalyses(CGAM);
83 PB.registerFunctionAnalyses(FAM);
84 PB.registerLoopAnalyses(LAM);
85 PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
86 MPM.addPass(ModuleInlinerPass(IP, InliningAdvisorMode::Default,
87 ThinOrFullLTOPhase::None));
90 ~CompilerInstance() {
91 // Reset the static variable that tracks if the plugin has been registered.
92 // This is needed to allow the test to run multiple times.
93 PluginInlineAdvisorAnalysis::HasBeenRegistered = false;
96 std::string output;
97 std::unique_ptr<Module> outputM;
99 // run with the default inliner
100 auto run_default(StringRef IR) {
101 PluginInlineAdvisorAnalysis::HasBeenRegistered = false;
102 outputM = parseAssemblyString(IR, Error, Ctx);
103 MPM.run(*outputM, MAM);
104 ASSERT_TRUE(outputM);
105 output.clear();
106 raw_string_ostream o_stream{output};
107 outputM->print(o_stream, nullptr);
108 ASSERT_TRUE(true);
111 // run with the dnamic inliner
112 auto run_dynamic(StringRef IR) {
113 // note typically the constructor for the DynamicInlineAdvisorAnalysis
114 // will automatically set this to true, we controll it here only to
115 // altenate between the default and dynamic inliner in our test
116 PluginInlineAdvisorAnalysis::HasBeenRegistered = true;
117 outputM = parseAssemblyString(IR, Error, Ctx);
118 MPM.run(*outputM, MAM);
119 ASSERT_TRUE(outputM);
120 output.clear();
121 raw_string_ostream o_stream{output};
122 outputM->print(o_stream, nullptr);
123 ASSERT_TRUE(true);
127 StringRef TestIRS[] = {
128 // Simple 3 function inline case
130 define void @f1() {
131 call void @foo()
132 ret void
134 define void @foo() {
135 call void @f3()
136 ret void
138 define void @f3() {
139 ret void
142 // Test that has 5 functions of which 2 are recursive
144 define void @f1() {
145 call void @foo()
146 ret void
148 define void @f2() {
149 call void @foo()
150 ret void
152 define void @foo() {
153 call void @f4()
154 call void @f5()
155 ret void
157 define void @f4() {
158 ret void
160 define void @f5() {
161 call void @foo()
162 ret void
165 // test with 2 mutually recursive functions and 1 function with a loop
167 define void @f1() {
168 call void @f2()
169 ret void
171 define void @f2() {
172 call void @f3()
173 ret void
175 define void @f3() {
176 call void @f1()
177 ret void
179 define void @f4() {
180 br label %loop
181 loop:
182 call void @f5()
183 br label %loop
185 define void @f5() {
186 ret void
189 // test that has a function that computes fibonacci in a loop, one in a
190 // recurisve manner, and one that calls both and compares them
192 define i32 @fib_loop(i32 %n){
193 %curr = alloca i32
194 %last = alloca i32
195 %i = alloca i32
196 store i32 1, i32* %curr
197 store i32 1, i32* %last
198 store i32 2, i32* %i
199 br label %loop_cond
200 loop_cond:
201 %i_val = load i32, i32* %i
202 %cmp = icmp slt i32 %i_val, %n
203 br i1 %cmp, label %loop_body, label %loop_end
204 loop_body:
205 %curr_val = load i32, i32* %curr
206 %last_val = load i32, i32* %last
207 %add = add i32 %curr_val, %last_val
208 store i32 %add, i32* %last
209 store i32 %curr_val, i32* %curr
210 %i_val2 = load i32, i32* %i
211 %add2 = add i32 %i_val2, 1
212 store i32 %add2, i32* %i
213 br label %loop_cond
214 loop_end:
215 %curr_val3 = load i32, i32* %curr
216 ret i32 %curr_val3
219 define i32 @fib_rec(i32 %n){
220 %cmp = icmp eq i32 %n, 0
221 %cmp2 = icmp eq i32 %n, 1
222 %or = or i1 %cmp, %cmp2
223 br i1 %or, label %if_true, label %if_false
224 if_true:
225 ret i32 1
226 if_false:
227 %sub = sub i32 %n, 1
228 %call = call i32 @fib_rec(i32 %sub)
229 %sub2 = sub i32 %n, 2
230 %call2 = call i32 @fib_rec(i32 %sub2)
231 %add = add i32 %call, %call2
232 ret i32 %add
235 define i32 @fib_check(){
236 %correct = alloca i32
237 %i = alloca i32
238 store i32 1, i32* %correct
239 store i32 0, i32* %i
240 br label %loop_cond
241 loop_cond:
242 %i_val = load i32, i32* %i
243 %cmp = icmp slt i32 %i_val, 10
244 br i1 %cmp, label %loop_body, label %loop_end
245 loop_body:
246 %i_val2 = load i32, i32* %i
247 %call = call i32 @fib_loop(i32 %i_val2)
248 %i_val3 = load i32, i32* %i
249 %call2 = call i32 @fib_rec(i32 %i_val3)
250 %cmp2 = icmp ne i32 %call, %call2
251 br i1 %cmp2, label %if_true, label %if_false
252 if_true:
253 store i32 0, i32* %correct
254 br label %if_end
255 if_false:
256 br label %if_end
257 if_end:
258 %i_val4 = load i32, i32* %i
259 %add = add i32 %i_val4, 1
260 store i32 %add, i32* %i
261 br label %loop_cond
262 loop_end:
263 %correct_val = load i32, i32* %correct
264 ret i32 %correct_val
266 )"};
268 } // namespace
270 // check that loading a plugin works
271 // the plugin being loaded acts identically to the default inliner
272 TEST(PluginInlineAdvisorTest, PluginLoad) {
273 #if !defined(LLVM_ENABLE_PLUGINS)
274 // Skip the test if plugins are disabled.
275 GTEST_SKIP();
276 #endif
277 CompilerInstance CI{};
278 CI.setupPlugin();
280 for (StringRef IR : TestIRS) {
281 CI.run_default(IR);
282 std::string default_output = CI.output;
283 CI.run_dynamic(IR);
284 std::string dynamic_output = CI.output;
285 ASSERT_EQ(default_output, dynamic_output);
289 // check that the behaviour of a custom inliner is correct
290 // the custom inliner inlines all functions that are not named "foo"
291 // this testdoes not require plugins to be enabled
292 TEST(PluginInlineAdvisorTest, CustomAdvisor) {
293 CompilerInstance CI{};
294 CI.setupFooOnly();
296 for (StringRef IR : TestIRS) {
297 CI.run_dynamic(IR);
298 CallGraph CGraph = CallGraph(*CI.outputM);
299 for (auto &node : CGraph) {
300 for (auto &edge : *node.second) {
301 if (!edge.first)
302 continue;
303 ASSERT_NE(edge.second->getFunction()->getName(), "foo");
309 } // namespace llvm