1 //===--- TransProtectedScope.cpp - Transformations to ARC mode ------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // Adds brackets in case statements that "contain" initialization of retaining
10 // variable, thus emitting the "switch case is in protected scope" error.
12 //===----------------------------------------------------------------------===//
14 #include "Internals.h"
15 #include "Transforms.h"
16 #include "clang/AST/ASTContext.h"
17 #include "clang/Basic/SourceManager.h"
18 #include "clang/Sema/SemaDiagnostic.h"
20 using namespace clang
;
21 using namespace arcmt
;
22 using namespace trans
;
26 class LocalRefsCollector
: public RecursiveASTVisitor
<LocalRefsCollector
> {
27 SmallVectorImpl
<DeclRefExpr
*> &Refs
;
30 LocalRefsCollector(SmallVectorImpl
<DeclRefExpr
*> &refs
)
33 bool VisitDeclRefExpr(DeclRefExpr
*E
) {
34 if (ValueDecl
*D
= E
->getDecl())
35 if (D
->getDeclContext()->getRedeclContext()->isFunctionOrMethod())
50 CaseInfo() : SC(nullptr), State(St_Unchecked
) {}
51 CaseInfo(SwitchCase
*S
, SourceRange Range
)
52 : SC(S
), Range(Range
), State(St_Unchecked
) {}
55 class CaseCollector
: public RecursiveASTVisitor
<CaseCollector
> {
57 SmallVectorImpl
<CaseInfo
> &Cases
;
60 CaseCollector(ParentMap
&PMap
, SmallVectorImpl
<CaseInfo
> &Cases
)
61 : PMap(PMap
), Cases(Cases
) { }
63 bool VisitSwitchStmt(SwitchStmt
*S
) {
64 SwitchCase
*Curr
= S
->getSwitchCaseList();
67 Stmt
*Parent
= getCaseParent(Curr
);
68 Curr
= Curr
->getNextSwitchCase();
69 // Make sure all case statements are in the same scope.
71 if (getCaseParent(Curr
) != Parent
)
73 Curr
= Curr
->getNextSwitchCase();
76 SourceLocation NextLoc
= S
->getEndLoc();
77 Curr
= S
->getSwitchCaseList();
78 // We iterate over case statements in reverse source-order.
81 CaseInfo(Curr
, SourceRange(Curr
->getBeginLoc(), NextLoc
)));
82 NextLoc
= Curr
->getBeginLoc();
83 Curr
= Curr
->getNextSwitchCase();
88 Stmt
*getCaseParent(SwitchCase
*S
) {
89 Stmt
*Parent
= PMap
.getParent(S
);
90 while (Parent
&& (isa
<SwitchCase
>(Parent
) || isa
<LabelStmt
>(Parent
)))
91 Parent
= PMap
.getParent(Parent
);
96 class ProtectedScopeFixer
{
99 SmallVector
<CaseInfo
, 16> Cases
;
100 SmallVector
<DeclRefExpr
*, 16> LocalRefs
;
103 ProtectedScopeFixer(BodyContext
&BodyCtx
)
104 : Pass(BodyCtx
.getMigrationContext().Pass
),
105 SM(Pass
.Ctx
.getSourceManager()) {
107 CaseCollector(BodyCtx
.getParentMap(), Cases
)
108 .TraverseStmt(BodyCtx
.getTopStmt());
109 LocalRefsCollector(LocalRefs
).TraverseStmt(BodyCtx
.getTopStmt());
111 SourceRange BodyRange
= BodyCtx
.getTopStmt()->getSourceRange();
112 const CapturedDiagList
&DiagList
= Pass
.getDiags();
113 // Copy the diagnostics so we don't have to worry about invaliding iterators
114 // from the diagnostic list.
115 SmallVector
<StoredDiagnostic
, 16> StoredDiags
;
116 StoredDiags
.append(DiagList
.begin(), DiagList
.end());
117 SmallVectorImpl
<StoredDiagnostic
>::iterator
118 I
= StoredDiags
.begin(), E
= StoredDiags
.end();
120 if (I
->getID() == diag::err_switch_into_protected_scope
&&
121 isInRange(I
->getLocation(), BodyRange
)) {
122 handleProtectedScopeError(I
, E
);
129 void handleProtectedScopeError(
130 SmallVectorImpl
<StoredDiagnostic
>::iterator
&DiagI
,
131 SmallVectorImpl
<StoredDiagnostic
>::iterator DiagE
){
132 Transaction
Trans(Pass
.TA
);
133 assert(DiagI
->getID() == diag::err_switch_into_protected_scope
);
134 SourceLocation ErrLoc
= DiagI
->getLocation();
135 bool handledAllNotes
= true;
137 for (; DiagI
!= DiagE
&& DiagI
->getLevel() == DiagnosticsEngine::Note
;
139 if (!handleProtectedNote(*DiagI
))
140 handledAllNotes
= false;
144 Pass
.TA
.clearDiagnostic(diag::err_switch_into_protected_scope
, ErrLoc
);
147 bool handleProtectedNote(const StoredDiagnostic
&Diag
) {
148 assert(Diag
.getLevel() == DiagnosticsEngine::Note
);
150 for (unsigned i
= 0; i
!= Cases
.size(); i
++) {
151 CaseInfo
&info
= Cases
[i
];
152 if (isInRange(Diag
.getLocation(), info
.Range
)) {
154 if (info
.State
== CaseInfo::St_Unchecked
)
156 assert(info
.State
!= CaseInfo::St_Unchecked
);
158 if (info
.State
== CaseInfo::St_Fixed
) {
159 Pass
.TA
.clearDiagnostic(Diag
.getID(), Diag
.getLocation());
169 void tryFixing(CaseInfo
&info
) {
170 assert(info
.State
== CaseInfo::St_Unchecked
);
171 if (hasVarReferencedOutside(info
)) {
172 info
.State
= CaseInfo::St_CannotFix
;
176 Pass
.TA
.insertAfterToken(info
.SC
->getColonLoc(), " {");
177 Pass
.TA
.insert(info
.Range
.getEnd(), "}\n");
178 info
.State
= CaseInfo::St_Fixed
;
181 bool hasVarReferencedOutside(CaseInfo
&info
) {
182 for (unsigned i
= 0, e
= LocalRefs
.size(); i
!= e
; ++i
) {
183 DeclRefExpr
*DRE
= LocalRefs
[i
];
184 if (isInRange(DRE
->getDecl()->getLocation(), info
.Range
) &&
185 !isInRange(DRE
->getLocation(), info
.Range
))
191 bool isInRange(SourceLocation Loc
, SourceRange R
) {
194 return !SM
.isBeforeInTranslationUnit(Loc
, R
.getBegin()) &&
195 SM
.isBeforeInTranslationUnit(Loc
, R
.getEnd());
199 } // anonymous namespace
201 void ProtectedScopeTraverser::traverseBody(BodyContext
&BodyCtx
) {
202 ProtectedScopeFixer
Fix(BodyCtx
);