1 //===- unittest/AST/RecursiveASTVisitorTest.cpp ---------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "clang/AST/RecursiveASTVisitor.h"
10 #include "clang/AST/ASTConsumer.h"
11 #include "clang/AST/ASTContext.h"
12 #include "clang/AST/Attr.h"
13 #include "clang/AST/Decl.h"
14 #include "clang/AST/TypeLoc.h"
15 #include "clang/Frontend/FrontendAction.h"
16 #include "clang/Tooling/Tooling.h"
17 #include "llvm/ADT/FunctionExtras.h"
18 #include "llvm/ADT/STLExtras.h"
19 #include "gmock/gmock.h"
20 #include "gtest/gtest.h"
23 using namespace clang
;
24 using ::testing::ElementsAre
;
27 class ProcessASTAction
: public clang::ASTFrontendAction
{
29 ProcessASTAction(llvm::unique_function
<void(clang::ASTContext
&)> Process
)
30 : Process(std::move(Process
)) {
31 assert(this->Process
);
34 std::unique_ptr
<ASTConsumer
> CreateASTConsumer(CompilerInstance
&CI
,
36 class Consumer
: public ASTConsumer
{
38 Consumer(llvm::function_ref
<void(ASTContext
&CTx
)> Process
)
41 void HandleTranslationUnit(ASTContext
&Ctx
) override
{ Process(Ctx
); }
44 llvm::function_ref
<void(ASTContext
&CTx
)> Process
;
47 return std::make_unique
<Consumer
>(Process
);
51 llvm::unique_function
<void(clang::ASTContext
&)> Process
;
54 enum class VisitEvent
{
55 StartTraverseFunction
,
61 StartTraverseTypedefType
,
62 EndTraverseTypedefType
,
63 StartTraverseObjCInterface
,
64 EndTraverseObjCInterface
,
65 StartTraverseObjCProtocol
,
66 EndTraverseObjCProtocol
,
67 StartTraverseObjCProtocolLoc
,
68 EndTraverseObjCProtocolLoc
,
71 class CollectInterestingEvents
72 : public RecursiveASTVisitor
<CollectInterestingEvents
> {
74 bool TraverseFunctionDecl(FunctionDecl
*D
) {
75 Events
.push_back(VisitEvent::StartTraverseFunction
);
76 bool Ret
= RecursiveASTVisitor::TraverseFunctionDecl(D
);
77 Events
.push_back(VisitEvent::EndTraverseFunction
);
82 bool TraverseAttr(Attr
*A
) {
83 Events
.push_back(VisitEvent::StartTraverseAttr
);
84 bool Ret
= RecursiveASTVisitor::TraverseAttr(A
);
85 Events
.push_back(VisitEvent::EndTraverseAttr
);
90 bool TraverseEnumDecl(EnumDecl
*D
) {
91 Events
.push_back(VisitEvent::StartTraverseEnum
);
92 bool Ret
= RecursiveASTVisitor::TraverseEnumDecl(D
);
93 Events
.push_back(VisitEvent::EndTraverseEnum
);
98 bool TraverseTypedefTypeLoc(TypedefTypeLoc TL
) {
99 Events
.push_back(VisitEvent::StartTraverseTypedefType
);
100 bool Ret
= RecursiveASTVisitor::TraverseTypedefTypeLoc(TL
);
101 Events
.push_back(VisitEvent::EndTraverseTypedefType
);
106 bool TraverseObjCInterfaceDecl(ObjCInterfaceDecl
*ID
) {
107 Events
.push_back(VisitEvent::StartTraverseObjCInterface
);
108 bool Ret
= RecursiveASTVisitor::TraverseObjCInterfaceDecl(ID
);
109 Events
.push_back(VisitEvent::EndTraverseObjCInterface
);
114 bool TraverseObjCProtocolDecl(ObjCProtocolDecl
*PD
) {
115 Events
.push_back(VisitEvent::StartTraverseObjCProtocol
);
116 bool Ret
= RecursiveASTVisitor::TraverseObjCProtocolDecl(PD
);
117 Events
.push_back(VisitEvent::EndTraverseObjCProtocol
);
122 bool TraverseObjCProtocolLoc(ObjCProtocolLoc ProtocolLoc
) {
123 Events
.push_back(VisitEvent::StartTraverseObjCProtocolLoc
);
124 bool Ret
= RecursiveASTVisitor::TraverseObjCProtocolLoc(ProtocolLoc
);
125 Events
.push_back(VisitEvent::EndTraverseObjCProtocolLoc
);
130 std::vector
<VisitEvent
> takeEvents() && { return std::move(Events
); }
133 std::vector
<VisitEvent
> Events
;
136 std::vector
<VisitEvent
> collectEvents(llvm::StringRef Code
,
137 const Twine
&FileName
= "input.cc") {
138 CollectInterestingEvents Visitor
;
139 clang::tooling::runToolOnCode(
140 std::make_unique
<ProcessASTAction
>(
141 [&](clang::ASTContext
&Ctx
) { Visitor
.TraverseAST(Ctx
); }),
143 return std::move(Visitor
).takeEvents();
147 TEST(RecursiveASTVisitorTest
, AttributesInsideDecls
) {
148 /// Check attributes are traversed inside TraverseFunctionDecl.
149 llvm::StringRef Code
= R
"cpp(
150 __attribute__((annotate("something
"))) int foo() { return 10; }
153 EXPECT_THAT(collectEvents(Code
),
154 ElementsAre(VisitEvent::StartTraverseFunction
,
155 VisitEvent::StartTraverseAttr
,
156 VisitEvent::EndTraverseAttr
,
157 VisitEvent::EndTraverseFunction
));
160 TEST(RecursiveASTVisitorTest
, EnumDeclWithBase
) {
161 // Check enum and its integer base is visited.
162 llvm::StringRef Code
= R
"cpp(
167 EXPECT_THAT(collectEvents(Code
),
168 ElementsAre(VisitEvent::StartTraverseEnum
,
169 VisitEvent::StartTraverseTypedefType
,
170 VisitEvent::EndTraverseTypedefType
,
171 VisitEvent::EndTraverseEnum
));
174 TEST(RecursiveASTVisitorTest
, InterfaceDeclWithProtocols
) {
175 // Check interface and its protocols are visited.
176 llvm::StringRef Code
= R
"cpp(
182 @interface SomeObject <Foo, Bar>
186 EXPECT_THAT(collectEvents(Code
, "input.m"),
187 ElementsAre(VisitEvent::StartTraverseObjCProtocol
,
188 VisitEvent::EndTraverseObjCProtocol
,
189 VisitEvent::StartTraverseObjCProtocol
,
190 VisitEvent::EndTraverseObjCProtocol
,
191 VisitEvent::StartTraverseObjCInterface
,
192 VisitEvent::StartTraverseObjCProtocolLoc
,
193 VisitEvent::EndTraverseObjCProtocolLoc
,
194 VisitEvent::StartTraverseObjCProtocolLoc
,
195 VisitEvent::EndTraverseObjCProtocolLoc
,
196 VisitEvent::EndTraverseObjCInterface
));