1 //===--- InsertionPoint.cpp - Where should we add new code? ---------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "refactor/InsertionPoint.h"
10 #include "support/Logger.h"
11 #include "clang/AST/ASTContext.h"
12 #include "clang/AST/DeclCXX.h"
13 #include "clang/AST/DeclObjC.h"
14 #include "clang/AST/DeclTemplate.h"
15 #include "clang/Basic/SourceManager.h"
22 // Choose the decl to insert before, according to an anchor.
23 // Nullptr means insert at end of DC.
24 // std::nullopt means no valid place to insert.
25 std::optional
<const Decl
*> insertionDecl(const DeclContext
&DC
,
27 bool LastMatched
= false;
28 bool ReturnNext
= false;
29 for (const auto *D
: DC
.decls()) {
35 const Decl
*NonTemplate
= D
;
36 if (auto *TD
= llvm::dyn_cast
<TemplateDecl
>(D
))
37 NonTemplate
= TD
->getTemplatedDecl();
38 bool Matches
= A
.Match(NonTemplate
);
39 dlog(" {0} {1} {2}", Matches
, D
->getDeclKindName(), D
);
41 switch (A
.Direction
) {
43 if (Matches
&& !LastMatched
) {
44 // Special case: if "above" matches an access specifier, we actually
45 // want to insert below it!
46 if (llvm::isa
<AccessSpecDecl
>(D
)) {
54 if (LastMatched
&& !Matches
)
59 LastMatched
= Matches
;
61 if (ReturnNext
|| (LastMatched
&& A
.Direction
== Anchor::Below
))
66 SourceLocation
beginLoc(const Decl
&D
) {
67 auto Loc
= D
.getBeginLoc();
68 if (RawComment
*Comment
= D
.getASTContext().getRawCommentForDeclNoCache(&D
)) {
69 auto CommentLoc
= Comment
->getBeginLoc();
70 if (CommentLoc
.isValid() && Loc
.isValid() &&
71 D
.getASTContext().getSourceManager().isBeforeInTranslationUnit(
78 bool any(const Decl
*D
) { return true; }
80 SourceLocation
endLoc(const DeclContext
&DC
) {
81 const Decl
*D
= llvm::cast
<Decl
>(&DC
);
82 if (auto *OCD
= llvm::dyn_cast
<ObjCContainerDecl
>(D
))
83 return OCD
->getAtEndRange().getBegin();
84 return D
->getEndLoc();
87 AccessSpecifier
getAccessAtEnd(const CXXRecordDecl
&C
) {
88 AccessSpecifier Spec
= (C
.getTagKind() == TTK_Class
? AS_private
: AS_public
);
89 for (const auto *D
: C
.decls())
90 if (const auto *ASD
= llvm::dyn_cast
<AccessSpecDecl
>(D
))
91 Spec
= ASD
->getAccess();
97 SourceLocation
insertionPoint(const DeclContext
&DC
,
98 llvm::ArrayRef
<Anchor
> Anchors
) {
99 dlog("Looking for insertion point in {0}", DC
.getDeclKindName());
100 for (const auto &A
: Anchors
) {
101 dlog(" anchor ({0})", A
.Direction
== Anchor::Above
? "above" : "below");
102 if (auto D
= insertionDecl(DC
, A
)) {
103 dlog(" anchor matched before {0}", *D
);
104 return *D
? beginLoc(**D
) : endLoc(DC
);
107 dlog("no anchor matched");
108 return SourceLocation();
111 llvm::Expected
<tooling::Replacement
>
112 insertDecl(llvm::StringRef Code
, const DeclContext
&DC
,
113 llvm::ArrayRef
<Anchor
> Anchors
) {
114 auto Loc
= insertionPoint(DC
, Anchors
);
115 // Fallback: insert at the end.
118 const auto &SM
= DC
.getParentASTContext().getSourceManager();
119 if (!SM
.isWrittenInSameFile(Loc
, cast
<Decl
>(DC
).getLocation()))
120 return error("{0} body in wrong file: {1}", DC
.getDeclKindName(),
121 Loc
.printToString(SM
));
122 return tooling::Replacement(SM
, Loc
, 0, Code
);
125 SourceLocation
insertionPoint(const CXXRecordDecl
&InClass
,
126 std::vector
<Anchor
> Anchors
,
127 AccessSpecifier Protection
) {
128 for (auto &A
: Anchors
)
129 A
.Match
= [Inner(std::move(A
.Match
)), Protection
](const Decl
*D
) {
130 return D
->getAccess() == Protection
&& Inner(D
);
132 return insertionPoint(InClass
, Anchors
);
135 llvm::Expected
<tooling::Replacement
> insertDecl(llvm::StringRef Code
,
136 const CXXRecordDecl
&InClass
,
137 std::vector
<Anchor
> Anchors
,
138 AccessSpecifier Protection
) {
139 // Fallback: insert at the bottom of the relevant access section.
140 Anchors
.push_back({any
, Anchor::Below
});
141 auto Loc
= insertionPoint(InClass
, std::move(Anchors
), Protection
);
142 std::string CodeBuffer
;
143 auto &SM
= InClass
.getASTContext().getSourceManager();
144 // Fallback: insert at the end of the class. Check if protection matches!
145 if (Loc
.isInvalid()) {
146 Loc
= InClass
.getBraceRange().getEnd();
147 if (Protection
!= getAccessAtEnd(InClass
)) {
148 CodeBuffer
= (getAccessSpelling(Protection
) + ":\n" + Code
).str();
152 if (!SM
.isWrittenInSameFile(Loc
, InClass
.getLocation()))
153 return error("Class body in wrong file: {0}", Loc
.printToString(SM
));
154 return tooling::Replacement(SM
, Loc
, 0, Code
);
157 } // namespace clangd