[docs] Add LICENSE.txt to the root of the mono-repo
[llvm-project.git] / clang-tools-extra / clangd / refactor / tweaks / ExtractFunction.cpp
blob66fe4fdbfa2d3a38e77bda1e7e6d4c2a8950469e
1 //===--- ExtractFunction.cpp -------------------------------------*- C++-*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Extracts statements to a new function and replaces the statements with a
10 // call to the new function.
11 // Before:
12 // void f(int a) {
13 // [[if(a < 5)
14 // a = 5;]]
15 // }
16 // After:
17 // void extracted(int &a) {
18 // if(a < 5)
19 // a = 5;
20 // }
21 // void f(int a) {
22 // extracted(a);
23 // }
25 // - Only extract statements
26 // - Extracts from non-templated free functions only.
27 // - Parameters are const only if the declaration was const
28 // - Always passed by l-value reference
29 // - Void return type
30 // - Cannot extract declarations that will be needed in the original function
31 // after extraction.
32 // - Checks for broken control flow (break/continue without loop/switch)
34 // 1. ExtractFunction is the tweak subclass
35 // - Prepare does basic analysis of the selection and is therefore fast.
36 // Successful prepare doesn't always mean we can apply the tweak.
37 // - Apply does a more detailed analysis and can be slower. In case of
38 // failure, we let the user know that we are unable to perform extraction.
39 // 2. ExtractionZone store information about the range being extracted and the
40 // enclosing function.
41 // 3. NewFunction stores properties of the extracted function and provides
42 // methods for rendering it.
43 // 4. CapturedZoneInfo uses a RecursiveASTVisitor to capture information about
44 // the extraction like declarations, existing return statements, etc.
45 // 5. getExtractedFunction is responsible for analyzing the CapturedZoneInfo and
46 // creating a NewFunction.
47 //===----------------------------------------------------------------------===//
49 #include "AST.h"
50 #include "FindTarget.h"
51 #include "ParsedAST.h"
52 #include "Selection.h"
53 #include "SourceCode.h"
54 #include "refactor/Tweak.h"
55 #include "support/Logger.h"
56 #include "clang/AST/ASTContext.h"
57 #include "clang/AST/Decl.h"
58 #include "clang/AST/DeclBase.h"
59 #include "clang/AST/NestedNameSpecifier.h"
60 #include "clang/AST/RecursiveASTVisitor.h"
61 #include "clang/AST/Stmt.h"
62 #include "clang/Basic/LangOptions.h"
63 #include "clang/Basic/SourceLocation.h"
64 #include "clang/Basic/SourceManager.h"
65 #include "clang/Tooling/Core/Replacement.h"
66 #include "clang/Tooling/Refactoring/Extract/SourceExtraction.h"
67 #include "llvm/ADT/None.h"
68 #include "llvm/ADT/Optional.h"
69 #include "llvm/ADT/STLExtras.h"
70 #include "llvm/ADT/SmallSet.h"
71 #include "llvm/ADT/SmallVector.h"
72 #include "llvm/ADT/StringRef.h"
73 #include "llvm/Support/Casting.h"
74 #include "llvm/Support/Error.h"
75 #include "llvm/Support/raw_os_ostream.h"
77 namespace clang {
78 namespace clangd {
79 namespace {
81 using Node = SelectionTree::Node;
83 // ExtractionZone is the part of code that is being extracted.
84 // EnclosingFunction is the function/method inside which the zone lies.
85 // We split the file into 4 parts relative to extraction zone.
86 enum class ZoneRelative {
87 Before, // Before Zone and inside EnclosingFunction.
88 Inside, // Inside Zone.
89 After, // After Zone and inside EnclosingFunction.
90 OutsideFunc // Outside EnclosingFunction.
93 enum FunctionDeclKind {
94 InlineDefinition,
95 ForwardDeclaration,
96 OutOfLineDefinition
99 // A RootStmt is a statement that's fully selected including all it's children
100 // and it's parent is unselected.
101 // Check if a node is a root statement.
102 bool isRootStmt(const Node *N) {
103 if (!N->ASTNode.get<Stmt>())
104 return false;
105 // Root statement cannot be partially selected.
106 if (N->Selected == SelectionTree::Partial)
107 return false;
108 // Only DeclStmt can be an unselected RootStmt since VarDecls claim the entire
109 // selection range in selectionTree.
110 if (N->Selected == SelectionTree::Unselected && !N->ASTNode.get<DeclStmt>())
111 return false;
112 return true;
115 // Returns the (unselected) parent of all RootStmts given the commonAncestor.
116 // Returns null if:
117 // 1. any node is partially selected
118 // 2. If all completely selected nodes don't have the same common parent
119 // 3. Any child of Parent isn't a RootStmt.
120 // Returns null if any child is not a RootStmt.
121 // We only support extraction of RootStmts since it allows us to extract without
122 // having to change the selection range. Also, this means that any scope that
123 // begins in selection range, ends in selection range and any scope that begins
124 // outside the selection range, ends outside as well.
125 const Node *getParentOfRootStmts(const Node *CommonAnc) {
126 if (!CommonAnc)
127 return nullptr;
128 const Node *Parent = nullptr;
129 switch (CommonAnc->Selected) {
130 case SelectionTree::Selection::Unselected:
131 // Typically a block, with the { and } unselected, could also be ForStmt etc
132 // Ensure all Children are RootStmts.
133 Parent = CommonAnc;
134 break;
135 case SelectionTree::Selection::Partial:
136 // Only a fully-selected single statement can be selected.
137 return nullptr;
138 case SelectionTree::Selection::Complete:
139 // If the Common Ancestor is completely selected, then it's a root statement
140 // and its parent will be unselected.
141 Parent = CommonAnc->Parent;
142 // If parent is a DeclStmt, even though it's unselected, we consider it a
143 // root statement and return its parent. This is done because the VarDecls
144 // claim the entire selection range of the Declaration and DeclStmt is
145 // always unselected.
146 if (Parent->ASTNode.get<DeclStmt>())
147 Parent = Parent->Parent;
148 break;
150 // Ensure all Children are RootStmts.
151 return llvm::all_of(Parent->Children, isRootStmt) ? Parent : nullptr;
154 // The ExtractionZone class forms a view of the code wrt Zone.
155 struct ExtractionZone {
156 // Parent of RootStatements being extracted.
157 const Node *Parent = nullptr;
158 // The half-open file range of the code being extracted.
159 SourceRange ZoneRange;
160 // The function inside which our zone resides.
161 const FunctionDecl *EnclosingFunction = nullptr;
162 // The half-open file range of the enclosing function.
163 SourceRange EnclosingFuncRange;
164 // Set of statements that form the ExtractionZone.
165 llvm::DenseSet<const Stmt *> RootStmts;
167 SourceLocation getInsertionPoint() const {
168 return EnclosingFuncRange.getBegin();
170 bool isRootStmt(const Stmt *S) const;
171 // The last root statement is important to decide where we need to insert a
172 // semicolon after the extraction.
173 const Node *getLastRootStmt() const { return Parent->Children.back(); }
175 // Checks if declarations inside extraction zone are accessed afterwards.
177 // This performs a partial AST traversal proportional to the size of the
178 // enclosing function, so it is possibly expensive.
179 bool requiresHoisting(const SourceManager &SM,
180 const HeuristicResolver *Resolver) const {
181 // First find all the declarations that happened inside extraction zone.
182 llvm::SmallSet<const Decl *, 1> DeclsInExtZone;
183 for (auto *RootStmt : RootStmts) {
184 findExplicitReferences(
185 RootStmt,
186 [&DeclsInExtZone](const ReferenceLoc &Loc) {
187 if (!Loc.IsDecl)
188 return;
189 DeclsInExtZone.insert(Loc.Targets.front());
191 Resolver);
193 // Early exit without performing expensive traversal below.
194 if (DeclsInExtZone.empty())
195 return false;
196 // Then make sure they are not used outside the zone.
197 for (const auto *S : EnclosingFunction->getBody()->children()) {
198 if (SM.isBeforeInTranslationUnit(S->getSourceRange().getEnd(),
199 ZoneRange.getEnd()))
200 continue;
201 bool HasPostUse = false;
202 findExplicitReferences(
204 [&](const ReferenceLoc &Loc) {
205 if (HasPostUse ||
206 SM.isBeforeInTranslationUnit(Loc.NameLoc, ZoneRange.getEnd()))
207 return;
208 HasPostUse = llvm::any_of(Loc.Targets,
209 [&DeclsInExtZone](const Decl *Target) {
210 return DeclsInExtZone.contains(Target);
213 Resolver);
214 if (HasPostUse)
215 return true;
217 return false;
221 // Whether the code in the extraction zone is guaranteed to return, assuming
222 // no broken control flow (unbound break/continue).
223 // This is a very naive check (does it end with a return stmt).
224 // Doing some rudimentary control flow analysis would cover more cases.
225 bool alwaysReturns(const ExtractionZone &EZ) {
226 const Stmt *Last = EZ.getLastRootStmt()->ASTNode.get<Stmt>();
227 // Unwrap enclosing (unconditional) compound statement.
228 while (const auto *CS = llvm::dyn_cast<CompoundStmt>(Last)) {
229 if (CS->body_empty())
230 return false;
231 Last = CS->body_back();
233 return llvm::isa<ReturnStmt>(Last);
236 bool ExtractionZone::isRootStmt(const Stmt *S) const {
237 return RootStmts.contains(S);
240 // Finds the function in which the zone lies.
241 const FunctionDecl *findEnclosingFunction(const Node *CommonAnc) {
242 // Walk up the SelectionTree until we find a function Decl
243 for (const Node *CurNode = CommonAnc; CurNode; CurNode = CurNode->Parent) {
244 // Don't extract from lambdas
245 if (CurNode->ASTNode.get<LambdaExpr>())
246 return nullptr;
247 if (const FunctionDecl *Func = CurNode->ASTNode.get<FunctionDecl>()) {
248 // FIXME: Support extraction from templated functions.
249 if (Func->isTemplated())
250 return nullptr;
251 return Func;
254 return nullptr;
257 // Zone Range is the union of SourceRanges of all child Nodes in Parent since
258 // all child Nodes are RootStmts
259 llvm::Optional<SourceRange> findZoneRange(const Node *Parent,
260 const SourceManager &SM,
261 const LangOptions &LangOpts) {
262 SourceRange SR;
263 if (auto BeginFileRange = toHalfOpenFileRange(
264 SM, LangOpts, Parent->Children.front()->ASTNode.getSourceRange()))
265 SR.setBegin(BeginFileRange->getBegin());
266 else
267 return llvm::None;
268 if (auto EndFileRange = toHalfOpenFileRange(
269 SM, LangOpts, Parent->Children.back()->ASTNode.getSourceRange()))
270 SR.setEnd(EndFileRange->getEnd());
271 else
272 return llvm::None;
273 return SR;
276 // Compute the range spanned by the enclosing function.
277 // FIXME: check if EnclosingFunction has any attributes as the AST doesn't
278 // always store the source range of the attributes and thus we end up extracting
279 // between the attributes and the EnclosingFunction.
280 llvm::Optional<SourceRange>
281 computeEnclosingFuncRange(const FunctionDecl *EnclosingFunction,
282 const SourceManager &SM,
283 const LangOptions &LangOpts) {
284 return toHalfOpenFileRange(SM, LangOpts, EnclosingFunction->getSourceRange());
287 // returns true if Child can be a single RootStmt being extracted from
288 // EnclosingFunc.
289 bool validSingleChild(const Node *Child, const FunctionDecl *EnclosingFunc) {
290 // Don't extract expressions.
291 // FIXME: We should extract expressions that are "statements" i.e. not
292 // subexpressions
293 if (Child->ASTNode.get<Expr>())
294 return false;
295 // Extracting the body of EnclosingFunc would remove it's definition.
296 assert(EnclosingFunc->hasBody() &&
297 "We should always be extracting from a function body.");
298 if (Child->ASTNode.get<Stmt>() == EnclosingFunc->getBody())
299 return false;
300 return true;
303 // FIXME: Check we're not extracting from the initializer/condition of a control
304 // flow structure.
305 llvm::Optional<ExtractionZone> findExtractionZone(const Node *CommonAnc,
306 const SourceManager &SM,
307 const LangOptions &LangOpts) {
308 ExtractionZone ExtZone;
309 ExtZone.Parent = getParentOfRootStmts(CommonAnc);
310 if (!ExtZone.Parent || ExtZone.Parent->Children.empty())
311 return llvm::None;
312 ExtZone.EnclosingFunction = findEnclosingFunction(ExtZone.Parent);
313 if (!ExtZone.EnclosingFunction)
314 return llvm::None;
315 // When there is a single RootStmt, we must check if it's valid for
316 // extraction.
317 if (ExtZone.Parent->Children.size() == 1 &&
318 !validSingleChild(ExtZone.getLastRootStmt(), ExtZone.EnclosingFunction))
319 return llvm::None;
320 if (auto FuncRange =
321 computeEnclosingFuncRange(ExtZone.EnclosingFunction, SM, LangOpts))
322 ExtZone.EnclosingFuncRange = *FuncRange;
323 if (auto ZoneRange = findZoneRange(ExtZone.Parent, SM, LangOpts))
324 ExtZone.ZoneRange = *ZoneRange;
325 if (ExtZone.EnclosingFuncRange.isInvalid() || ExtZone.ZoneRange.isInvalid())
326 return llvm::None;
328 for (const Node *Child : ExtZone.Parent->Children)
329 ExtZone.RootStmts.insert(Child->ASTNode.get<Stmt>());
331 return ExtZone;
334 // Stores information about the extracted function and provides methods for
335 // rendering it.
336 struct NewFunction {
337 struct Parameter {
338 std::string Name;
339 QualType TypeInfo;
340 bool PassByReference;
341 unsigned OrderPriority; // Lower value parameters are preferred first.
342 std::string render(const DeclContext *Context) const;
343 bool operator<(const Parameter &Other) const {
344 return OrderPriority < Other.OrderPriority;
347 std::string Name = "extracted";
348 QualType ReturnType;
349 std::vector<Parameter> Parameters;
350 SourceRange BodyRange;
351 SourceLocation DefinitionPoint;
352 llvm::Optional<SourceLocation> ForwardDeclarationPoint;
353 const CXXRecordDecl *EnclosingClass = nullptr;
354 const NestedNameSpecifier *DefinitionQualifier = nullptr;
355 const DeclContext *SemanticDC = nullptr;
356 const DeclContext *SyntacticDC = nullptr;
357 const DeclContext *ForwardDeclarationSyntacticDC = nullptr;
358 bool CallerReturnsValue = false;
359 bool Static = false;
360 ConstexprSpecKind Constexpr = ConstexprSpecKind::Unspecified;
361 bool Const = false;
363 // Decides whether the extracted function body and the function call need a
364 // semicolon after extraction.
365 tooling::ExtractionSemicolonPolicy SemicolonPolicy;
366 const LangOptions *LangOpts;
367 NewFunction(tooling::ExtractionSemicolonPolicy SemicolonPolicy,
368 const LangOptions *LangOpts)
369 : SemicolonPolicy(SemicolonPolicy), LangOpts(LangOpts) {}
370 // Render the call for this function.
371 std::string renderCall() const;
372 // Render the definition for this function.
373 std::string renderDeclaration(FunctionDeclKind K,
374 const DeclContext &SemanticDC,
375 const DeclContext &SyntacticDC,
376 const SourceManager &SM) const;
378 private:
379 std::string
380 renderParametersForDeclaration(const DeclContext &Enclosing) const;
381 std::string renderParametersForCall() const;
382 std::string renderSpecifiers(FunctionDeclKind K) const;
383 std::string renderQualifiers() const;
384 std::string renderDeclarationName(FunctionDeclKind K) const;
385 // Generate the function body.
386 std::string getFuncBody(const SourceManager &SM) const;
389 std::string NewFunction::renderParametersForDeclaration(
390 const DeclContext &Enclosing) const {
391 std::string Result;
392 bool NeedCommaBefore = false;
393 for (const Parameter &P : Parameters) {
394 if (NeedCommaBefore)
395 Result += ", ";
396 NeedCommaBefore = true;
397 Result += P.render(&Enclosing);
399 return Result;
402 std::string NewFunction::renderParametersForCall() const {
403 std::string Result;
404 bool NeedCommaBefore = false;
405 for (const Parameter &P : Parameters) {
406 if (NeedCommaBefore)
407 Result += ", ";
408 NeedCommaBefore = true;
409 Result += P.Name;
411 return Result;
414 std::string NewFunction::renderSpecifiers(FunctionDeclKind K) const {
415 std::string Attributes;
417 if (Static && K != FunctionDeclKind::OutOfLineDefinition) {
418 Attributes += "static ";
421 switch (Constexpr) {
422 case ConstexprSpecKind::Unspecified:
423 case ConstexprSpecKind::Constinit:
424 break;
425 case ConstexprSpecKind::Constexpr:
426 Attributes += "constexpr ";
427 break;
428 case ConstexprSpecKind::Consteval:
429 Attributes += "consteval ";
430 break;
433 return Attributes;
436 std::string NewFunction::renderQualifiers() const {
437 std::string Attributes;
439 if (Const) {
440 Attributes += " const";
443 return Attributes;
446 std::string NewFunction::renderDeclarationName(FunctionDeclKind K) const {
447 if (DefinitionQualifier == nullptr || K != OutOfLineDefinition) {
448 return Name;
451 std::string QualifierName;
452 llvm::raw_string_ostream Oss(QualifierName);
453 DefinitionQualifier->print(Oss, *LangOpts);
454 return llvm::formatv("{0}{1}", QualifierName, Name);
457 std::string NewFunction::renderCall() const {
458 return std::string(
459 llvm::formatv("{0}{1}({2}){3}", CallerReturnsValue ? "return " : "", Name,
460 renderParametersForCall(),
461 (SemicolonPolicy.isNeededInOriginalFunction() ? ";" : "")));
464 std::string NewFunction::renderDeclaration(FunctionDeclKind K,
465 const DeclContext &SemanticDC,
466 const DeclContext &SyntacticDC,
467 const SourceManager &SM) const {
468 std::string Declaration = std::string(llvm::formatv(
469 "{0}{1} {2}({3}){4}", renderSpecifiers(K),
470 printType(ReturnType, SyntacticDC), renderDeclarationName(K),
471 renderParametersForDeclaration(SemanticDC), renderQualifiers()));
473 switch (K) {
474 case ForwardDeclaration:
475 return std::string(llvm::formatv("{0};\n", Declaration));
476 case OutOfLineDefinition:
477 case InlineDefinition:
478 return std::string(
479 llvm::formatv("{0} {\n{1}\n}\n", Declaration, getFuncBody(SM)));
480 break;
482 llvm_unreachable("Unsupported FunctionDeclKind enum");
485 std::string NewFunction::getFuncBody(const SourceManager &SM) const {
486 // FIXME: Generate tooling::Replacements instead of std::string to
487 // - hoist decls
488 // - add return statement
489 // - Add semicolon
490 return toSourceCode(SM, BodyRange).str() +
491 (SemicolonPolicy.isNeededInExtractedFunction() ? ";" : "");
494 std::string NewFunction::Parameter::render(const DeclContext *Context) const {
495 return printType(TypeInfo, *Context) + (PassByReference ? " &" : " ") + Name;
498 // Stores captured information about Extraction Zone.
499 struct CapturedZoneInfo {
500 struct DeclInformation {
501 const Decl *TheDecl;
502 ZoneRelative DeclaredIn;
503 // index of the declaration or first reference.
504 unsigned DeclIndex;
505 bool IsReferencedInZone = false;
506 bool IsReferencedInPostZone = false;
507 // FIXME: Capture mutation information
508 DeclInformation(const Decl *TheDecl, ZoneRelative DeclaredIn,
509 unsigned DeclIndex)
510 : TheDecl(TheDecl), DeclaredIn(DeclaredIn), DeclIndex(DeclIndex){};
511 // Marks the occurence of a reference for this declaration
512 void markOccurence(ZoneRelative ReferenceLoc);
514 // Maps Decls to their DeclInfo
515 llvm::DenseMap<const Decl *, DeclInformation> DeclInfoMap;
516 bool HasReturnStmt = false; // Are there any return statements in the zone?
517 bool AlwaysReturns = false; // Does the zone always return?
518 // Control flow is broken if we are extracting a break/continue without a
519 // corresponding parent loop/switch
520 bool BrokenControlFlow = false;
521 // FIXME: capture TypeAliasDecl and UsingDirectiveDecl
522 // FIXME: Capture type information as well.
523 DeclInformation *createDeclInfo(const Decl *D, ZoneRelative RelativeLoc);
524 DeclInformation *getDeclInfoFor(const Decl *D);
527 CapturedZoneInfo::DeclInformation *
528 CapturedZoneInfo::createDeclInfo(const Decl *D, ZoneRelative RelativeLoc) {
529 // The new Decl's index is the size of the map so far.
530 auto InsertionResult = DeclInfoMap.insert(
531 {D, DeclInformation(D, RelativeLoc, DeclInfoMap.size())});
532 // Return the newly created DeclInfo
533 return &InsertionResult.first->second;
536 CapturedZoneInfo::DeclInformation *
537 CapturedZoneInfo::getDeclInfoFor(const Decl *D) {
538 // If the Decl doesn't exist, we
539 auto Iter = DeclInfoMap.find(D);
540 if (Iter == DeclInfoMap.end())
541 return nullptr;
542 return &Iter->second;
545 void CapturedZoneInfo::DeclInformation::markOccurence(
546 ZoneRelative ReferenceLoc) {
547 switch (ReferenceLoc) {
548 case ZoneRelative::Inside:
549 IsReferencedInZone = true;
550 break;
551 case ZoneRelative::After:
552 IsReferencedInPostZone = true;
553 break;
554 default:
555 break;
559 bool isLoop(const Stmt *S) {
560 return isa<ForStmt>(S) || isa<DoStmt>(S) || isa<WhileStmt>(S) ||
561 isa<CXXForRangeStmt>(S);
564 // Captures information from Extraction Zone
565 CapturedZoneInfo captureZoneInfo(const ExtractionZone &ExtZone) {
566 // We use the ASTVisitor instead of using the selection tree since we need to
567 // find references in the PostZone as well.
568 // FIXME: Check which statements we don't allow to extract.
569 class ExtractionZoneVisitor
570 : public clang::RecursiveASTVisitor<ExtractionZoneVisitor> {
571 public:
572 ExtractionZoneVisitor(const ExtractionZone &ExtZone) : ExtZone(ExtZone) {
573 TraverseDecl(const_cast<FunctionDecl *>(ExtZone.EnclosingFunction));
576 bool TraverseStmt(Stmt *S) {
577 if (!S)
578 return true;
579 bool IsRootStmt = ExtZone.isRootStmt(const_cast<const Stmt *>(S));
580 // If we are starting traversal of a RootStmt, we are somewhere inside
581 // ExtractionZone
582 if (IsRootStmt)
583 CurrentLocation = ZoneRelative::Inside;
584 addToLoopSwitchCounters(S, 1);
585 // Traverse using base class's TraverseStmt
586 RecursiveASTVisitor::TraverseStmt(S);
587 addToLoopSwitchCounters(S, -1);
588 // We set the current location as after since next stmt will either be a
589 // RootStmt (handled at the beginning) or after extractionZone
590 if (IsRootStmt)
591 CurrentLocation = ZoneRelative::After;
592 return true;
595 // Add Increment to CurNumberOf{Loops,Switch} if statement is
596 // {Loop,Switch} and inside Extraction Zone.
597 void addToLoopSwitchCounters(Stmt *S, int Increment) {
598 if (CurrentLocation != ZoneRelative::Inside)
599 return;
600 if (isLoop(S))
601 CurNumberOfNestedLoops += Increment;
602 else if (isa<SwitchStmt>(S))
603 CurNumberOfSwitch += Increment;
606 bool VisitDecl(Decl *D) {
607 Info.createDeclInfo(D, CurrentLocation);
608 return true;
611 bool VisitDeclRefExpr(DeclRefExpr *DRE) {
612 // Find the corresponding Decl and mark it's occurrence.
613 const Decl *D = DRE->getDecl();
614 auto *DeclInfo = Info.getDeclInfoFor(D);
615 // If no Decl was found, the Decl must be outside the enclosingFunc.
616 if (!DeclInfo)
617 DeclInfo = Info.createDeclInfo(D, ZoneRelative::OutsideFunc);
618 DeclInfo->markOccurence(CurrentLocation);
619 // FIXME: check if reference mutates the Decl being referred.
620 return true;
623 bool VisitReturnStmt(ReturnStmt *Return) {
624 if (CurrentLocation == ZoneRelative::Inside)
625 Info.HasReturnStmt = true;
626 return true;
629 bool VisitBreakStmt(BreakStmt *Break) {
630 // Control flow is broken if break statement is selected without any
631 // parent loop or switch statement.
632 if (CurrentLocation == ZoneRelative::Inside &&
633 !(CurNumberOfNestedLoops || CurNumberOfSwitch))
634 Info.BrokenControlFlow = true;
635 return true;
638 bool VisitContinueStmt(ContinueStmt *Continue) {
639 // Control flow is broken if Continue statement is selected without any
640 // parent loop
641 if (CurrentLocation == ZoneRelative::Inside && !CurNumberOfNestedLoops)
642 Info.BrokenControlFlow = true;
643 return true;
645 CapturedZoneInfo Info;
646 const ExtractionZone &ExtZone;
647 ZoneRelative CurrentLocation = ZoneRelative::Before;
648 // Number of {loop,switch} statements that are currently in the traversal
649 // stack inside Extraction Zone. Used to check for broken control flow.
650 unsigned CurNumberOfNestedLoops = 0;
651 unsigned CurNumberOfSwitch = 0;
653 ExtractionZoneVisitor Visitor(ExtZone);
654 CapturedZoneInfo Result = std::move(Visitor.Info);
655 Result.AlwaysReturns = alwaysReturns(ExtZone);
656 return Result;
659 // Adds parameters to ExtractedFunc.
660 // Returns true if able to find the parameters successfully and no hoisting
661 // needed.
662 // FIXME: Check if the declaration has a local/anonymous type
663 bool createParameters(NewFunction &ExtractedFunc,
664 const CapturedZoneInfo &CapturedInfo) {
665 for (const auto &KeyVal : CapturedInfo.DeclInfoMap) {
666 const auto &DeclInfo = KeyVal.second;
667 // If a Decl was Declared in zone and referenced in post zone, it
668 // needs to be hoisted (we bail out in that case).
669 // FIXME: Support Decl Hoisting.
670 if (DeclInfo.DeclaredIn == ZoneRelative::Inside &&
671 DeclInfo.IsReferencedInPostZone)
672 return false;
673 if (!DeclInfo.IsReferencedInZone)
674 continue; // no need to pass as parameter, not referenced
675 if (DeclInfo.DeclaredIn == ZoneRelative::Inside ||
676 DeclInfo.DeclaredIn == ZoneRelative::OutsideFunc)
677 continue; // no need to pass as parameter, still accessible.
678 // Parameter specific checks.
679 const ValueDecl *VD = dyn_cast_or_null<ValueDecl>(DeclInfo.TheDecl);
680 // Can't parameterise if the Decl isn't a ValueDecl or is a FunctionDecl
681 // (this includes the case of recursive call to EnclosingFunc in Zone).
682 if (!VD || isa<FunctionDecl>(DeclInfo.TheDecl))
683 return false;
684 // Parameter qualifiers are same as the Decl's qualifiers.
685 QualType TypeInfo = VD->getType().getNonReferenceType();
686 // FIXME: Need better qualifier checks: check mutated status for
687 // Decl(e.g. was it assigned, passed as nonconst argument, etc)
688 // FIXME: check if parameter will be a non l-value reference.
689 // FIXME: We don't want to always pass variables of types like int,
690 // pointers, etc by reference.
691 bool IsPassedByReference = true;
692 // We use the index of declaration as the ordering priority for parameters.
693 ExtractedFunc.Parameters.push_back({std::string(VD->getName()), TypeInfo,
694 IsPassedByReference,
695 DeclInfo.DeclIndex});
697 llvm::sort(ExtractedFunc.Parameters);
698 return true;
701 // Clangd uses open ranges while ExtractionSemicolonPolicy (in Clang Tooling)
702 // uses closed ranges. Generates the semicolon policy for the extraction and
703 // extends the ZoneRange if necessary.
704 tooling::ExtractionSemicolonPolicy
705 getSemicolonPolicy(ExtractionZone &ExtZone, const SourceManager &SM,
706 const LangOptions &LangOpts) {
707 // Get closed ZoneRange.
708 SourceRange FuncBodyRange = {ExtZone.ZoneRange.getBegin(),
709 ExtZone.ZoneRange.getEnd().getLocWithOffset(-1)};
710 auto SemicolonPolicy = tooling::ExtractionSemicolonPolicy::compute(
711 ExtZone.getLastRootStmt()->ASTNode.get<Stmt>(), FuncBodyRange, SM,
712 LangOpts);
713 // Update ZoneRange.
714 ExtZone.ZoneRange.setEnd(FuncBodyRange.getEnd().getLocWithOffset(1));
715 return SemicolonPolicy;
718 // Generate return type for ExtractedFunc. Return false if unable to do so.
719 bool generateReturnProperties(NewFunction &ExtractedFunc,
720 const FunctionDecl &EnclosingFunc,
721 const CapturedZoneInfo &CapturedInfo) {
722 // If the selected code always returns, we preserve those return statements.
723 // The return type should be the same as the enclosing function.
724 // (Others are possible if there are conversions, but this seems clearest).
725 if (CapturedInfo.HasReturnStmt) {
726 // If the return is conditional, neither replacing the code with
727 // `extracted()` nor `return extracted()` is correct.
728 if (!CapturedInfo.AlwaysReturns)
729 return false;
730 QualType Ret = EnclosingFunc.getReturnType();
731 // Once we support members, it'd be nice to support e.g. extracting a method
732 // of Foo<T> that returns T. But it's not clear when that's safe.
733 if (Ret->isDependentType())
734 return false;
735 ExtractedFunc.ReturnType = Ret;
736 return true;
738 // FIXME: Generate new return statement if needed.
739 ExtractedFunc.ReturnType = EnclosingFunc.getParentASTContext().VoidTy;
740 return true;
743 void captureMethodInfo(NewFunction &ExtractedFunc,
744 const CXXMethodDecl *Method) {
745 ExtractedFunc.Static = Method->isStatic();
746 ExtractedFunc.Const = Method->isConst();
747 ExtractedFunc.EnclosingClass = Method->getParent();
750 // FIXME: add support for adding other function return types besides void.
751 // FIXME: assign the value returned by non void extracted function.
752 llvm::Expected<NewFunction> getExtractedFunction(ExtractionZone &ExtZone,
753 const SourceManager &SM,
754 const LangOptions &LangOpts) {
755 CapturedZoneInfo CapturedInfo = captureZoneInfo(ExtZone);
756 // Bail out if any break of continue exists
757 if (CapturedInfo.BrokenControlFlow)
758 return error("Cannot extract break/continue without corresponding "
759 "loop/switch statement.");
760 NewFunction ExtractedFunc(getSemicolonPolicy(ExtZone, SM, LangOpts),
761 &LangOpts);
763 ExtractedFunc.SyntacticDC =
764 ExtZone.EnclosingFunction->getLexicalDeclContext();
765 ExtractedFunc.SemanticDC = ExtZone.EnclosingFunction->getDeclContext();
766 ExtractedFunc.DefinitionQualifier = ExtZone.EnclosingFunction->getQualifier();
767 ExtractedFunc.Constexpr = ExtZone.EnclosingFunction->getConstexprKind();
769 if (const auto *Method =
770 llvm::dyn_cast<CXXMethodDecl>(ExtZone.EnclosingFunction))
771 captureMethodInfo(ExtractedFunc, Method);
773 if (ExtZone.EnclosingFunction->isOutOfLine()) {
774 // FIXME: Put the extracted method in a private section if it's a class or
775 // maybe in an anonymous namespace
776 const auto *FirstOriginalDecl =
777 ExtZone.EnclosingFunction->getCanonicalDecl();
778 auto DeclPos =
779 toHalfOpenFileRange(SM, LangOpts, FirstOriginalDecl->getSourceRange());
780 if (!DeclPos)
781 return error("Declaration is inside a macro");
782 ExtractedFunc.ForwardDeclarationPoint = DeclPos->getBegin();
783 ExtractedFunc.ForwardDeclarationSyntacticDC = ExtractedFunc.SemanticDC;
786 ExtractedFunc.BodyRange = ExtZone.ZoneRange;
787 ExtractedFunc.DefinitionPoint = ExtZone.getInsertionPoint();
789 ExtractedFunc.CallerReturnsValue = CapturedInfo.AlwaysReturns;
790 if (!createParameters(ExtractedFunc, CapturedInfo) ||
791 !generateReturnProperties(ExtractedFunc, *ExtZone.EnclosingFunction,
792 CapturedInfo))
793 return error("Too complex to extract.");
794 return ExtractedFunc;
797 class ExtractFunction : public Tweak {
798 public:
799 const char *id() const final;
800 bool prepare(const Selection &Inputs) override;
801 Expected<Effect> apply(const Selection &Inputs) override;
802 std::string title() const override { return "Extract to function"; }
803 llvm::StringLiteral kind() const override {
804 return CodeAction::REFACTOR_KIND;
807 private:
808 ExtractionZone ExtZone;
811 REGISTER_TWEAK(ExtractFunction)
812 tooling::Replacement replaceWithFuncCall(const NewFunction &ExtractedFunc,
813 const SourceManager &SM,
814 const LangOptions &LangOpts) {
815 std::string FuncCall = ExtractedFunc.renderCall();
816 return tooling::Replacement(
817 SM, CharSourceRange(ExtractedFunc.BodyRange, false), FuncCall, LangOpts);
820 tooling::Replacement createFunctionDefinition(const NewFunction &ExtractedFunc,
821 const SourceManager &SM) {
822 FunctionDeclKind DeclKind = InlineDefinition;
823 if (ExtractedFunc.ForwardDeclarationPoint)
824 DeclKind = OutOfLineDefinition;
825 std::string FunctionDef = ExtractedFunc.renderDeclaration(
826 DeclKind, *ExtractedFunc.SemanticDC, *ExtractedFunc.SyntacticDC, SM);
828 return tooling::Replacement(SM, ExtractedFunc.DefinitionPoint, 0,
829 FunctionDef);
832 tooling::Replacement createForwardDeclaration(const NewFunction &ExtractedFunc,
833 const SourceManager &SM) {
834 std::string FunctionDecl = ExtractedFunc.renderDeclaration(
835 ForwardDeclaration, *ExtractedFunc.SemanticDC,
836 *ExtractedFunc.ForwardDeclarationSyntacticDC, SM);
837 SourceLocation DeclPoint = *ExtractedFunc.ForwardDeclarationPoint;
839 return tooling::Replacement(SM, DeclPoint, 0, FunctionDecl);
842 // Returns true if ExtZone contains any ReturnStmts.
843 bool hasReturnStmt(const ExtractionZone &ExtZone) {
844 class ReturnStmtVisitor
845 : public clang::RecursiveASTVisitor<ReturnStmtVisitor> {
846 public:
847 bool VisitReturnStmt(ReturnStmt *Return) {
848 Found = true;
849 return false; // We found the answer, abort the scan.
851 bool Found = false;
854 ReturnStmtVisitor V;
855 for (const Stmt *RootStmt : ExtZone.RootStmts) {
856 V.TraverseStmt(const_cast<Stmt *>(RootStmt));
857 if (V.Found)
858 break;
860 return V.Found;
863 bool ExtractFunction::prepare(const Selection &Inputs) {
864 const LangOptions &LangOpts = Inputs.AST->getLangOpts();
865 if (!LangOpts.CPlusPlus)
866 return false;
867 const Node *CommonAnc = Inputs.ASTSelection.commonAncestor();
868 const SourceManager &SM = Inputs.AST->getSourceManager();
869 auto MaybeExtZone = findExtractionZone(CommonAnc, SM, LangOpts);
870 if (!MaybeExtZone ||
871 (hasReturnStmt(*MaybeExtZone) && !alwaysReturns(*MaybeExtZone)))
872 return false;
874 // FIXME: Get rid of this check once we support hoisting.
875 if (MaybeExtZone->requiresHoisting(SM, Inputs.AST->getHeuristicResolver()))
876 return false;
878 ExtZone = std::move(*MaybeExtZone);
879 return true;
882 Expected<Tweak::Effect> ExtractFunction::apply(const Selection &Inputs) {
883 const SourceManager &SM = Inputs.AST->getSourceManager();
884 const LangOptions &LangOpts = Inputs.AST->getLangOpts();
885 auto ExtractedFunc = getExtractedFunction(ExtZone, SM, LangOpts);
886 // FIXME: Add more types of errors.
887 if (!ExtractedFunc)
888 return ExtractedFunc.takeError();
889 tooling::Replacements Edit;
890 if (auto Err = Edit.add(createFunctionDefinition(*ExtractedFunc, SM)))
891 return std::move(Err);
892 if (auto Err = Edit.add(replaceWithFuncCall(*ExtractedFunc, SM, LangOpts)))
893 return std::move(Err);
895 if (auto FwdLoc = ExtractedFunc->ForwardDeclarationPoint) {
896 // If the fwd-declaration goes in the same file, merge into Replacements.
897 // Otherwise it needs to be a separate file edit.
898 if (SM.isWrittenInSameFile(ExtractedFunc->DefinitionPoint, *FwdLoc)) {
899 if (auto Err = Edit.add(createForwardDeclaration(*ExtractedFunc, SM)))
900 return std::move(Err);
901 } else {
902 auto MultiFileEffect = Effect::mainFileEdit(SM, std::move(Edit));
903 if (!MultiFileEffect)
904 return MultiFileEffect.takeError();
906 tooling::Replacements OtherEdit(
907 createForwardDeclaration(*ExtractedFunc, SM));
908 if (auto PathAndEdit = Tweak::Effect::fileEdit(SM, SM.getFileID(*FwdLoc),
909 OtherEdit))
910 MultiFileEffect->ApplyEdits.try_emplace(PathAndEdit->first,
911 PathAndEdit->second);
912 else
913 return PathAndEdit.takeError();
914 return MultiFileEffect;
917 return Effect::mainFileEdit(SM, std::move(Edit));
920 } // namespace
921 } // namespace clangd
922 } // namespace clang