Run DCE after a LoopFlatten test to reduce spurious output [nfc]
[llvm-project.git] / clang / unittests / Tooling / TestVisitor.h
blob751ca74d1a881ca811f7dc29feb0a16a526758cb
1 //===--- TestVisitor.h ------------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// \brief Defines utility templates for RecursiveASTVisitor related tests.
11 ///
12 //===----------------------------------------------------------------------===//
14 #ifndef LLVM_CLANG_UNITTESTS_TOOLING_TESTVISITOR_H
15 #define LLVM_CLANG_UNITTESTS_TOOLING_TESTVISITOR_H
17 #include "clang/AST/ASTConsumer.h"
18 #include "clang/AST/ASTContext.h"
19 #include "clang/AST/RecursiveASTVisitor.h"
20 #include "clang/Frontend/CompilerInstance.h"
21 #include "clang/Frontend/FrontendAction.h"
22 #include "clang/Tooling/Tooling.h"
23 #include "gtest/gtest.h"
24 #include <vector>
26 namespace clang {
28 /// \brief Base class for simple RecursiveASTVisitor based tests.
29 ///
30 /// This is a drop-in replacement for RecursiveASTVisitor itself, with the
31 /// additional capability of running it over a snippet of code.
32 ///
33 /// Visits template instantiations and implicit code by default.
34 template <typename T>
35 class TestVisitor : public RecursiveASTVisitor<T> {
36 public:
37 TestVisitor() { }
39 virtual ~TestVisitor() { }
41 enum Language {
42 Lang_C,
43 Lang_CXX98,
44 Lang_CXX11,
45 Lang_CXX14,
46 Lang_CXX17,
47 Lang_CXX2a,
48 Lang_OBJC,
49 Lang_OBJCXX11,
50 Lang_CXX = Lang_CXX98
53 /// \brief Runs the current AST visitor over the given code.
54 bool runOver(StringRef Code, Language L = Lang_CXX) {
55 std::vector<std::string> Args;
56 switch (L) {
57 case Lang_C:
58 Args.push_back("-x");
59 Args.push_back("c");
60 break;
61 case Lang_CXX98: Args.push_back("-std=c++98"); break;
62 case Lang_CXX11: Args.push_back("-std=c++11"); break;
63 case Lang_CXX14: Args.push_back("-std=c++14"); break;
64 case Lang_CXX17: Args.push_back("-std=c++17"); break;
65 case Lang_CXX2a: Args.push_back("-std=c++2a"); break;
66 case Lang_OBJC:
67 Args.push_back("-ObjC");
68 Args.push_back("-fobjc-runtime=macosx-10.12.0");
69 break;
70 case Lang_OBJCXX11:
71 Args.push_back("-ObjC++");
72 Args.push_back("-std=c++11");
73 Args.push_back("-fblocks");
74 break;
76 return tooling::runToolOnCodeWithArgs(CreateTestAction(), Code, Args);
79 bool shouldVisitTemplateInstantiations() const {
80 return true;
83 bool shouldVisitImplicitCode() const {
84 return true;
87 protected:
88 virtual std::unique_ptr<ASTFrontendAction> CreateTestAction() {
89 return std::make_unique<TestAction>(this);
92 class FindConsumer : public ASTConsumer {
93 public:
94 FindConsumer(TestVisitor *Visitor) : Visitor(Visitor) {}
96 void HandleTranslationUnit(clang::ASTContext &Context) override {
97 Visitor->Context = &Context;
98 Visitor->TraverseDecl(Context.getTranslationUnitDecl());
101 private:
102 TestVisitor *Visitor;
105 class TestAction : public ASTFrontendAction {
106 public:
107 TestAction(TestVisitor *Visitor) : Visitor(Visitor) {}
109 std::unique_ptr<clang::ASTConsumer>
110 CreateASTConsumer(CompilerInstance &, llvm::StringRef dummy) override {
111 /// TestConsumer will be deleted by the framework calling us.
112 return std::make_unique<FindConsumer>(Visitor);
115 protected:
116 TestVisitor *Visitor;
119 ASTContext *Context;
122 /// \brief A RecursiveASTVisitor to check that certain matches are (or are
123 /// not) observed during visitation.
125 /// This is a RecursiveASTVisitor for testing the RecursiveASTVisitor itself,
126 /// and allows simple creation of test visitors running matches on only a small
127 /// subset of the Visit* methods.
128 template <typename T, template <typename> class Visitor = TestVisitor>
129 class ExpectedLocationVisitor : public Visitor<T> {
130 public:
131 /// \brief Expect 'Match' *not* to occur at the given 'Line' and 'Column'.
133 /// Any number of matches can be disallowed.
134 void DisallowMatch(Twine Match, unsigned Line, unsigned Column) {
135 DisallowedMatches.push_back(MatchCandidate(Match, Line, Column));
138 /// \brief Expect 'Match' to occur at the given 'Line' and 'Column'.
140 /// Any number of expected matches can be set by calling this repeatedly.
141 /// Each is expected to be matched 'Times' number of times. (This is useful in
142 /// cases in which different AST nodes can match at the same source code
143 /// location.)
144 void ExpectMatch(Twine Match, unsigned Line, unsigned Column,
145 unsigned Times = 1) {
146 ExpectedMatches.push_back(ExpectedMatch(Match, Line, Column, Times));
149 /// \brief Checks that all expected matches have been found.
150 ~ExpectedLocationVisitor() override {
151 for (typename std::vector<ExpectedMatch>::const_iterator
152 It = ExpectedMatches.begin(), End = ExpectedMatches.end();
153 It != End; ++It) {
154 It->ExpectFound();
158 protected:
159 /// \brief Checks an actual match against expected and disallowed matches.
161 /// Implementations are required to call this with appropriate values
162 /// for 'Name' during visitation.
163 void Match(StringRef Name, SourceLocation Location) {
164 const FullSourceLoc FullLocation = this->Context->getFullLoc(Location);
166 for (typename std::vector<MatchCandidate>::const_iterator
167 It = DisallowedMatches.begin(), End = DisallowedMatches.end();
168 It != End; ++It) {
169 EXPECT_FALSE(It->Matches(Name, FullLocation))
170 << "Matched disallowed " << *It;
173 for (typename std::vector<ExpectedMatch>::iterator
174 It = ExpectedMatches.begin(), End = ExpectedMatches.end();
175 It != End; ++It) {
176 It->UpdateFor(Name, FullLocation, this->Context->getSourceManager());
180 private:
181 struct MatchCandidate {
182 std::string ExpectedName;
183 unsigned LineNumber;
184 unsigned ColumnNumber;
186 MatchCandidate(Twine Name, unsigned LineNumber, unsigned ColumnNumber)
187 : ExpectedName(Name.str()), LineNumber(LineNumber),
188 ColumnNumber(ColumnNumber) {
191 bool Matches(StringRef Name, FullSourceLoc const &Location) const {
192 return MatchesName(Name) && MatchesLocation(Location);
195 bool PartiallyMatches(StringRef Name, FullSourceLoc const &Location) const {
196 return MatchesName(Name) || MatchesLocation(Location);
199 bool MatchesName(StringRef Name) const {
200 return Name == ExpectedName;
203 bool MatchesLocation(FullSourceLoc const &Location) const {
204 return Location.isValid() &&
205 Location.getSpellingLineNumber() == LineNumber &&
206 Location.getSpellingColumnNumber() == ColumnNumber;
209 friend std::ostream &operator<<(std::ostream &Stream,
210 MatchCandidate const &Match) {
211 return Stream << Match.ExpectedName
212 << " at " << Match.LineNumber << ":" << Match.ColumnNumber;
216 struct ExpectedMatch {
217 ExpectedMatch(Twine Name, unsigned LineNumber, unsigned ColumnNumber,
218 unsigned Times)
219 : Candidate(Name, LineNumber, ColumnNumber), TimesExpected(Times),
220 TimesSeen(0) {}
222 void UpdateFor(StringRef Name, FullSourceLoc Location, SourceManager &SM) {
223 if (Candidate.Matches(Name, Location)) {
224 EXPECT_LT(TimesSeen, TimesExpected);
225 ++TimesSeen;
226 } else if (TimesSeen < TimesExpected &&
227 Candidate.PartiallyMatches(Name, Location)) {
228 llvm::raw_string_ostream Stream(PartialMatches);
229 Stream << ", partial match: \"" << Name << "\" at ";
230 Location.print(Stream, SM);
234 void ExpectFound() const {
235 EXPECT_EQ(TimesExpected, TimesSeen)
236 << "Expected \"" << Candidate.ExpectedName
237 << "\" at " << Candidate.LineNumber
238 << ":" << Candidate.ColumnNumber << PartialMatches;
241 MatchCandidate Candidate;
242 std::string PartialMatches;
243 unsigned TimesExpected;
244 unsigned TimesSeen;
247 std::vector<MatchCandidate> DisallowedMatches;
248 std::vector<ExpectedMatch> ExpectedMatches;
252 #endif