Run DCE after a LoopFlatten test to reduce spurious output [nfc]
[llvm-project.git] / clang / unittests / AST / RecursiveASTVisitorTest.cpp
blob9d7ff5947fe530128e474427aabb9632fe368372
1 //===- unittest/AST/RecursiveASTVisitorTest.cpp ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "clang/AST/RecursiveASTVisitor.h"
10 #include "clang/AST/ASTConsumer.h"
11 #include "clang/AST/ASTContext.h"
12 #include "clang/AST/Attr.h"
13 #include "clang/AST/Decl.h"
14 #include "clang/AST/TypeLoc.h"
15 #include "clang/Frontend/FrontendAction.h"
16 #include "clang/Tooling/Tooling.h"
17 #include "llvm/ADT/FunctionExtras.h"
18 #include "llvm/ADT/STLExtras.h"
19 #include "gmock/gmock.h"
20 #include "gtest/gtest.h"
21 #include <cassert>
23 using namespace clang;
24 using ::testing::ElementsAre;
26 namespace {
27 class ProcessASTAction : public clang::ASTFrontendAction {
28 public:
29 ProcessASTAction(llvm::unique_function<void(clang::ASTContext &)> Process)
30 : Process(std::move(Process)) {
31 assert(this->Process);
34 std::unique_ptr<ASTConsumer> CreateASTConsumer(CompilerInstance &CI,
35 StringRef InFile) {
36 class Consumer : public ASTConsumer {
37 public:
38 Consumer(llvm::function_ref<void(ASTContext &CTx)> Process)
39 : Process(Process) {}
41 void HandleTranslationUnit(ASTContext &Ctx) override { Process(Ctx); }
43 private:
44 llvm::function_ref<void(ASTContext &CTx)> Process;
47 return std::make_unique<Consumer>(Process);
50 private:
51 llvm::unique_function<void(clang::ASTContext &)> Process;
54 enum class VisitEvent {
55 StartTraverseFunction,
56 EndTraverseFunction,
57 StartTraverseAttr,
58 EndTraverseAttr,
59 StartTraverseEnum,
60 EndTraverseEnum,
61 StartTraverseTypedefType,
62 EndTraverseTypedefType,
63 StartTraverseObjCInterface,
64 EndTraverseObjCInterface,
65 StartTraverseObjCProtocol,
66 EndTraverseObjCProtocol,
67 StartTraverseObjCProtocolLoc,
68 EndTraverseObjCProtocolLoc,
71 class CollectInterestingEvents
72 : public RecursiveASTVisitor<CollectInterestingEvents> {
73 public:
74 bool TraverseFunctionDecl(FunctionDecl *D) {
75 Events.push_back(VisitEvent::StartTraverseFunction);
76 bool Ret = RecursiveASTVisitor::TraverseFunctionDecl(D);
77 Events.push_back(VisitEvent::EndTraverseFunction);
79 return Ret;
82 bool TraverseAttr(Attr *A) {
83 Events.push_back(VisitEvent::StartTraverseAttr);
84 bool Ret = RecursiveASTVisitor::TraverseAttr(A);
85 Events.push_back(VisitEvent::EndTraverseAttr);
87 return Ret;
90 bool TraverseEnumDecl(EnumDecl *D) {
91 Events.push_back(VisitEvent::StartTraverseEnum);
92 bool Ret = RecursiveASTVisitor::TraverseEnumDecl(D);
93 Events.push_back(VisitEvent::EndTraverseEnum);
95 return Ret;
98 bool TraverseTypedefTypeLoc(TypedefTypeLoc TL) {
99 Events.push_back(VisitEvent::StartTraverseTypedefType);
100 bool Ret = RecursiveASTVisitor::TraverseTypedefTypeLoc(TL);
101 Events.push_back(VisitEvent::EndTraverseTypedefType);
103 return Ret;
106 bool TraverseObjCInterfaceDecl(ObjCInterfaceDecl *ID) {
107 Events.push_back(VisitEvent::StartTraverseObjCInterface);
108 bool Ret = RecursiveASTVisitor::TraverseObjCInterfaceDecl(ID);
109 Events.push_back(VisitEvent::EndTraverseObjCInterface);
111 return Ret;
114 bool TraverseObjCProtocolDecl(ObjCProtocolDecl *PD) {
115 Events.push_back(VisitEvent::StartTraverseObjCProtocol);
116 bool Ret = RecursiveASTVisitor::TraverseObjCProtocolDecl(PD);
117 Events.push_back(VisitEvent::EndTraverseObjCProtocol);
119 return Ret;
122 bool TraverseObjCProtocolLoc(ObjCProtocolLoc ProtocolLoc) {
123 Events.push_back(VisitEvent::StartTraverseObjCProtocolLoc);
124 bool Ret = RecursiveASTVisitor::TraverseObjCProtocolLoc(ProtocolLoc);
125 Events.push_back(VisitEvent::EndTraverseObjCProtocolLoc);
127 return Ret;
130 std::vector<VisitEvent> takeEvents() && { return std::move(Events); }
132 private:
133 std::vector<VisitEvent> Events;
136 std::vector<VisitEvent> collectEvents(llvm::StringRef Code,
137 const Twine &FileName = "input.cc") {
138 CollectInterestingEvents Visitor;
139 clang::tooling::runToolOnCode(
140 std::make_unique<ProcessASTAction>(
141 [&](clang::ASTContext &Ctx) { Visitor.TraverseAST(Ctx); }),
142 Code, FileName);
143 return std::move(Visitor).takeEvents();
145 } // namespace
147 TEST(RecursiveASTVisitorTest, AttributesInsideDecls) {
148 /// Check attributes are traversed inside TraverseFunctionDecl.
149 llvm::StringRef Code = R"cpp(
150 __attribute__((annotate("something"))) int foo() { return 10; }
151 )cpp";
153 EXPECT_THAT(collectEvents(Code),
154 ElementsAre(VisitEvent::StartTraverseFunction,
155 VisitEvent::StartTraverseAttr,
156 VisitEvent::EndTraverseAttr,
157 VisitEvent::EndTraverseFunction));
160 TEST(RecursiveASTVisitorTest, EnumDeclWithBase) {
161 // Check enum and its integer base is visited.
162 llvm::StringRef Code = R"cpp(
163 typedef int Foo;
164 enum Bar : Foo;
165 )cpp";
167 EXPECT_THAT(collectEvents(Code),
168 ElementsAre(VisitEvent::StartTraverseEnum,
169 VisitEvent::StartTraverseTypedefType,
170 VisitEvent::EndTraverseTypedefType,
171 VisitEvent::EndTraverseEnum));
174 TEST(RecursiveASTVisitorTest, InterfaceDeclWithProtocols) {
175 // Check interface and its protocols are visited.
176 llvm::StringRef Code = R"cpp(
177 @protocol Foo
178 @end
179 @protocol Bar
180 @end
182 @interface SomeObject <Foo, Bar>
183 @end
184 )cpp";
186 EXPECT_THAT(collectEvents(Code, "input.m"),
187 ElementsAre(VisitEvent::StartTraverseObjCProtocol,
188 VisitEvent::EndTraverseObjCProtocol,
189 VisitEvent::StartTraverseObjCProtocol,
190 VisitEvent::EndTraverseObjCProtocol,
191 VisitEvent::StartTraverseObjCInterface,
192 VisitEvent::StartTraverseObjCProtocolLoc,
193 VisitEvent::EndTraverseObjCProtocolLoc,
194 VisitEvent::StartTraverseObjCProtocolLoc,
195 VisitEvent::EndTraverseObjCProtocolLoc,
196 VisitEvent::EndTraverseObjCInterface));