1 //===--- InsertionPoint.cpp - Where should we add new code? ---------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "refactor/InsertionPoint.h"
10 #include "support/Logger.h"
11 #include "clang/AST/ASTContext.h"
12 #include "clang/AST/DeclCXX.h"
13 #include "clang/AST/DeclObjC.h"
14 #include "clang/AST/DeclTemplate.h"
15 #include "clang/Basic/SourceManager.h"
22 // Choose the decl to insert before, according to an anchor.
23 // Nullptr means insert at end of DC.
24 // std::nullopt means no valid place to insert.
25 std::optional
<const Decl
*> insertionDecl(const DeclContext
&DC
,
27 bool LastMatched
= false;
28 bool ReturnNext
= false;
29 for (const auto *D
: DC
.decls()) {
35 const Decl
*NonTemplate
= D
;
36 if (auto *TD
= llvm::dyn_cast
<TemplateDecl
>(D
))
37 NonTemplate
= TD
->getTemplatedDecl();
38 bool Matches
= A
.Match(NonTemplate
);
39 dlog(" {0} {1} {2}", Matches
, D
->getDeclKindName(), D
);
41 switch (A
.Direction
) {
43 if (Matches
&& !LastMatched
) {
44 // Special case: if "above" matches an access specifier, we actually
45 // want to insert below it!
46 if (llvm::isa
<AccessSpecDecl
>(D
)) {
54 if (LastMatched
&& !Matches
)
59 LastMatched
= Matches
;
61 if (ReturnNext
|| (LastMatched
&& A
.Direction
== Anchor::Below
))
66 SourceLocation
beginLoc(const Decl
&D
) {
67 auto Loc
= D
.getBeginLoc();
68 if (RawComment
*Comment
= D
.getASTContext().getRawCommentForDeclNoCache(&D
)) {
69 auto CommentLoc
= Comment
->getBeginLoc();
70 if (CommentLoc
.isValid() && Loc
.isValid() &&
71 D
.getASTContext().getSourceManager().isBeforeInTranslationUnit(
78 bool any(const Decl
*D
) { return true; }
80 SourceLocation
endLoc(const DeclContext
&DC
) {
81 const Decl
*D
= llvm::cast
<Decl
>(&DC
);
82 if (auto *OCD
= llvm::dyn_cast
<ObjCContainerDecl
>(D
))
83 return OCD
->getAtEndRange().getBegin();
84 return D
->getEndLoc();
87 AccessSpecifier
getAccessAtEnd(const CXXRecordDecl
&C
) {
88 AccessSpecifier Spec
=
89 (C
.getTagKind() == TagTypeKind::Class
? AS_private
: AS_public
);
90 for (const auto *D
: C
.decls())
91 if (const auto *ASD
= llvm::dyn_cast
<AccessSpecDecl
>(D
))
92 Spec
= ASD
->getAccess();
98 SourceLocation
insertionPoint(const DeclContext
&DC
,
99 llvm::ArrayRef
<Anchor
> Anchors
) {
100 dlog("Looking for insertion point in {0}", DC
.getDeclKindName());
101 for (const auto &A
: Anchors
) {
102 dlog(" anchor ({0})", A
.Direction
== Anchor::Above
? "above" : "below");
103 if (auto D
= insertionDecl(DC
, A
)) {
104 dlog(" anchor matched before {0}", *D
);
105 return *D
? beginLoc(**D
) : endLoc(DC
);
108 dlog("no anchor matched");
109 return SourceLocation();
112 llvm::Expected
<tooling::Replacement
>
113 insertDecl(llvm::StringRef Code
, const DeclContext
&DC
,
114 llvm::ArrayRef
<Anchor
> Anchors
) {
115 auto Loc
= insertionPoint(DC
, Anchors
);
116 // Fallback: insert at the end.
119 const auto &SM
= DC
.getParentASTContext().getSourceManager();
120 if (!SM
.isWrittenInSameFile(Loc
, cast
<Decl
>(DC
).getLocation()))
121 return error("{0} body in wrong file: {1}", DC
.getDeclKindName(),
122 Loc
.printToString(SM
));
123 return tooling::Replacement(SM
, Loc
, 0, Code
);
126 SourceLocation
insertionPoint(const CXXRecordDecl
&InClass
,
127 std::vector
<Anchor
> Anchors
,
128 AccessSpecifier Protection
) {
129 for (auto &A
: Anchors
)
130 A
.Match
= [Inner(std::move(A
.Match
)), Protection
](const Decl
*D
) {
131 return D
->getAccess() == Protection
&& Inner(D
);
133 return insertionPoint(InClass
, Anchors
);
136 llvm::Expected
<tooling::Replacement
> insertDecl(llvm::StringRef Code
,
137 const CXXRecordDecl
&InClass
,
138 std::vector
<Anchor
> Anchors
,
139 AccessSpecifier Protection
) {
140 // Fallback: insert at the bottom of the relevant access section.
141 Anchors
.push_back({any
, Anchor::Below
});
142 auto Loc
= insertionPoint(InClass
, std::move(Anchors
), Protection
);
143 std::string CodeBuffer
;
144 auto &SM
= InClass
.getASTContext().getSourceManager();
145 // Fallback: insert at the end of the class. Check if protection matches!
146 if (Loc
.isInvalid()) {
147 Loc
= InClass
.getBraceRange().getEnd();
148 if (Protection
!= getAccessAtEnd(InClass
)) {
149 CodeBuffer
= (getAccessSpelling(Protection
) + ":\n" + Code
).str();
153 if (!SM
.isWrittenInSameFile(Loc
, InClass
.getLocation()))
154 return error("Class body in wrong file: {0}", Loc
.printToString(SM
));
155 return tooling::Replacement(SM
, Loc
, 0, Code
);
158 } // namespace clangd