1 //===--- RefactoringCallbacks.cpp - Structural query framework ------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
10 //===----------------------------------------------------------------------===//
11 #include "clang/Tooling/RefactoringCallbacks.h"
12 #include "clang/ASTMatchers/ASTMatchFinder.h"
13 #include "clang/Basic/SourceLocation.h"
14 #include "clang/Lex/Lexer.h"
16 using llvm::StringError
;
17 using llvm::make_error
;
22 RefactoringCallback::RefactoringCallback() {}
23 tooling::Replacements
&RefactoringCallback::getReplacements() {
27 ASTMatchRefactorer::ASTMatchRefactorer(
28 std::map
<std::string
, Replacements
> &FileToReplaces
)
29 : FileToReplaces(FileToReplaces
) {}
31 void ASTMatchRefactorer::addDynamicMatcher(
32 const ast_matchers::internal::DynTypedMatcher
&Matcher
,
33 RefactoringCallback
*Callback
) {
34 MatchFinder
.addDynamicMatcher(Matcher
, Callback
);
35 Callbacks
.push_back(Callback
);
38 class RefactoringASTConsumer
: public ASTConsumer
{
40 explicit RefactoringASTConsumer(ASTMatchRefactorer
&Refactoring
)
41 : Refactoring(Refactoring
) {}
43 void HandleTranslationUnit(ASTContext
&Context
) override
{
44 // The ASTMatchRefactorer is re-used between translation units.
45 // Clear the matchers so that each Replacement is only emitted once.
46 for (const auto &Callback
: Refactoring
.Callbacks
) {
47 Callback
->getReplacements().clear();
49 Refactoring
.MatchFinder
.matchAST(Context
);
50 for (const auto &Callback
: Refactoring
.Callbacks
) {
51 for (const auto &Replacement
: Callback
->getReplacements()) {
53 Refactoring
.FileToReplaces
[std::string(Replacement
.getFilePath())]
56 llvm::errs() << "Skipping replacement " << Replacement
.toString()
57 << " due to this error:\n"
58 << toString(std::move(Err
)) << "\n";
65 ASTMatchRefactorer
&Refactoring
;
68 std::unique_ptr
<ASTConsumer
> ASTMatchRefactorer::newASTConsumer() {
69 return std::make_unique
<RefactoringASTConsumer
>(*this);
72 static Replacement
replaceStmtWithText(SourceManager
&Sources
, const Stmt
&From
,
74 return tooling::Replacement(
75 Sources
, CharSourceRange::getTokenRange(From
.getSourceRange()), Text
);
77 static Replacement
replaceStmtWithStmt(SourceManager
&Sources
, const Stmt
&From
,
79 return replaceStmtWithText(
81 Lexer::getSourceText(CharSourceRange::getTokenRange(To
.getSourceRange()),
82 Sources
, LangOptions()));
85 ReplaceStmtWithText::ReplaceStmtWithText(StringRef FromId
, StringRef ToText
)
86 : FromId(std::string(FromId
)), ToText(std::string(ToText
)) {}
88 void ReplaceStmtWithText::run(
89 const ast_matchers::MatchFinder::MatchResult
&Result
) {
90 if (const Stmt
*FromMatch
= Result
.Nodes
.getNodeAs
<Stmt
>(FromId
)) {
91 auto Err
= Replace
.add(tooling::Replacement(
92 *Result
.SourceManager
,
93 CharSourceRange::getTokenRange(FromMatch
->getSourceRange()), ToText
));
94 // FIXME: better error handling. For now, just print error message in the
97 llvm::errs() << llvm::toString(std::move(Err
)) << "\n";
103 ReplaceStmtWithStmt::ReplaceStmtWithStmt(StringRef FromId
, StringRef ToId
)
104 : FromId(std::string(FromId
)), ToId(std::string(ToId
)) {}
106 void ReplaceStmtWithStmt::run(
107 const ast_matchers::MatchFinder::MatchResult
&Result
) {
108 const Stmt
*FromMatch
= Result
.Nodes
.getNodeAs
<Stmt
>(FromId
);
109 const Stmt
*ToMatch
= Result
.Nodes
.getNodeAs
<Stmt
>(ToId
);
110 if (FromMatch
&& ToMatch
) {
111 auto Err
= Replace
.add(
112 replaceStmtWithStmt(*Result
.SourceManager
, *FromMatch
, *ToMatch
));
113 // FIXME: better error handling. For now, just print error message in the
116 llvm::errs() << llvm::toString(std::move(Err
)) << "\n";
122 ReplaceIfStmtWithItsBody::ReplaceIfStmtWithItsBody(StringRef Id
,
124 : Id(std::string(Id
)), PickTrueBranch(PickTrueBranch
) {}
126 void ReplaceIfStmtWithItsBody::run(
127 const ast_matchers::MatchFinder::MatchResult
&Result
) {
128 if (const IfStmt
*Node
= Result
.Nodes
.getNodeAs
<IfStmt
>(Id
)) {
129 const Stmt
*Body
= PickTrueBranch
? Node
->getThen() : Node
->getElse();
132 Replace
.add(replaceStmtWithStmt(*Result
.SourceManager
, *Node
, *Body
));
133 // FIXME: better error handling. For now, just print error message in the
136 llvm::errs() << llvm::toString(std::move(Err
)) << "\n";
139 } else if (!PickTrueBranch
) {
140 // If we want to use the 'else'-branch, but it doesn't exist, delete
143 Replace
.add(replaceStmtWithText(*Result
.SourceManager
, *Node
, ""));
144 // FIXME: better error handling. For now, just print error message in the
147 llvm::errs() << llvm::toString(std::move(Err
)) << "\n";
154 ReplaceNodeWithTemplate::ReplaceNodeWithTemplate(
155 llvm::StringRef FromId
, std::vector
<TemplateElement
> Template
)
156 : FromId(std::string(FromId
)), Template(std::move(Template
)) {}
158 llvm::Expected
<std::unique_ptr
<ReplaceNodeWithTemplate
>>
159 ReplaceNodeWithTemplate::create(StringRef FromId
, StringRef ToTemplate
) {
160 std::vector
<TemplateElement
> ParsedTemplate
;
161 for (size_t Index
= 0; Index
< ToTemplate
.size();) {
162 if (ToTemplate
[Index
] == '$') {
163 if (ToTemplate
.substr(Index
, 2) == "$$") {
165 ParsedTemplate
.push_back(
166 TemplateElement
{TemplateElement::Literal
, "$"});
167 } else if (ToTemplate
.substr(Index
, 2) == "${") {
168 size_t EndOfIdentifier
= ToTemplate
.find("}", Index
);
169 if (EndOfIdentifier
== std::string::npos
) {
170 return make_error
<StringError
>(
171 "Unterminated ${...} in replacement template near " +
172 ToTemplate
.substr(Index
),
173 llvm::inconvertibleErrorCode());
175 std::string SourceNodeName
= std::string(
176 ToTemplate
.substr(Index
+ 2, EndOfIdentifier
- Index
- 2));
177 ParsedTemplate
.push_back(
178 TemplateElement
{TemplateElement::Identifier
, SourceNodeName
});
179 Index
= EndOfIdentifier
+ 1;
181 return make_error
<StringError
>(
182 "Invalid $ in replacement template near " +
183 ToTemplate
.substr(Index
),
184 llvm::inconvertibleErrorCode());
187 size_t NextIndex
= ToTemplate
.find('$', Index
+ 1);
188 ParsedTemplate
.push_back(TemplateElement
{
189 TemplateElement::Literal
,
190 std::string(ToTemplate
.substr(Index
, NextIndex
- Index
))});
194 return std::unique_ptr
<ReplaceNodeWithTemplate
>(
195 new ReplaceNodeWithTemplate(FromId
, std::move(ParsedTemplate
)));
198 void ReplaceNodeWithTemplate::run(
199 const ast_matchers::MatchFinder::MatchResult
&Result
) {
200 const auto &NodeMap
= Result
.Nodes
.getMap();
203 for (const auto &Element
: Template
) {
204 switch (Element
.Type
) {
205 case TemplateElement::Literal
:
206 ToText
+= Element
.Value
;
208 case TemplateElement::Identifier
: {
209 auto NodeIter
= NodeMap
.find(Element
.Value
);
210 if (NodeIter
== NodeMap
.end()) {
211 llvm::errs() << "Node " << Element
.Value
212 << " used in replacement template not bound in Matcher \n";
213 llvm::report_fatal_error("Unbound node in replacement template.");
215 CharSourceRange Source
=
216 CharSourceRange::getTokenRange(NodeIter
->second
.getSourceRange());
217 ToText
+= Lexer::getSourceText(Source
, *Result
.SourceManager
,
218 Result
.Context
->getLangOpts());
223 if (NodeMap
.count(FromId
) == 0) {
224 llvm::errs() << "Node to be replaced " << FromId
225 << " not bound in query.\n";
226 llvm::report_fatal_error("FromId node not bound in MatchResult");
229 tooling::Replacement(*Result
.SourceManager
, &NodeMap
.at(FromId
), ToText
,
230 Result
.Context
->getLangOpts());
231 llvm::Error Err
= Replace
.add(Replacement
);
233 llvm::errs() << "Query and replace failed in " << Replacement
.getFilePath()
234 << "! " << llvm::toString(std::move(Err
)) << "\n";
235 llvm::report_fatal_error("Replacement failed");
239 } // end namespace tooling
240 } // end namespace clang