1 //===--- unittests/Tooling/RecursiveASTVisitorTests/CallbacksCommon.h -----===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "TestVisitor.h"
11 using namespace clang
;
15 enum class ShouldTraversePostOrder
: bool {
20 /// Base class for tests for RecursiveASTVisitor tests that validate the
21 /// sequence of calls to user-defined callbacks like Traverse*(), WalkUp*(),
23 template <typename Derived
>
24 class RecordingVisitorBase
: public TestVisitor
<Derived
> {
25 ShouldTraversePostOrder ShouldTraversePostOrderValue
;
28 RecordingVisitorBase(ShouldTraversePostOrder ShouldTraversePostOrderValue
)
29 : ShouldTraversePostOrderValue(ShouldTraversePostOrderValue
) {}
31 bool shouldTraversePostOrder() const {
32 return static_cast<bool>(ShouldTraversePostOrderValue
);
35 // Callbacks received during traversal.
36 std::string CallbackLog
;
37 unsigned CallbackLogIndent
= 0;
39 std::string
stmtToString(Stmt
*S
) {
40 StringRef ClassName
= S
->getStmtClassName();
41 if (IntegerLiteral
*IL
= dyn_cast
<IntegerLiteral
>(S
)) {
42 return (ClassName
+ "(" + toString(IL
->getValue(), 10, false) + ")").str();
44 if (UnaryOperator
*UO
= dyn_cast
<UnaryOperator
>(S
)) {
45 return (ClassName
+ "(" + UnaryOperator::getOpcodeStr(UO
->getOpcode()) +
49 if (BinaryOperator
*BO
= dyn_cast
<BinaryOperator
>(S
)) {
50 return (ClassName
+ "(" + BinaryOperator::getOpcodeStr(BO
->getOpcode()) +
54 if (CallExpr
*CE
= dyn_cast
<CallExpr
>(S
)) {
55 if (FunctionDecl
*Callee
= CE
->getDirectCallee()) {
56 if (Callee
->getIdentifier()) {
57 return (ClassName
+ "(" + Callee
->getName() + ")").str();
61 if (DeclRefExpr
*DRE
= dyn_cast
<DeclRefExpr
>(S
)) {
62 if (NamedDecl
*ND
= DRE
->getFoundDecl()) {
63 if (ND
->getIdentifier()) {
64 return (ClassName
+ "(" + ND
->getName() + ")").str();
68 return ClassName
.str();
71 /// Record the fact that the user-defined callback member function
72 /// \p CallbackName was called with the argument \p S. Then, record the
73 /// effects of calling the default implementation \p CallDefaultFn.
74 template <typename CallDefault
>
75 void recordCallback(StringRef CallbackName
, Stmt
*S
,
76 CallDefault CallDefaultFn
) {
77 for (unsigned i
= 0; i
!= CallbackLogIndent
; ++i
) {
80 CallbackLog
+= (CallbackName
+ " " + stmtToString(S
) + "\n").str();
87 template <typename VisitorTy
>
88 ::testing::AssertionResult
visitorCallbackLogEqual(VisitorTy Visitor
,
90 StringRef ExpectedLog
) {
91 Visitor
.runOver(Code
);
92 // EXPECT_EQ shows the diff between the two strings if they are different.
93 EXPECT_EQ(ExpectedLog
.trim().str(),
94 StringRef(Visitor
.CallbackLog
).trim().str());
95 if (ExpectedLog
.trim() != StringRef(Visitor
.CallbackLog
).trim()) {
96 return ::testing::AssertionFailure();
98 return ::testing::AssertionSuccess();