1 //===--- TransAutoreleasePool.cpp - Transformations to ARC mode -----------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // rewriteAutoreleasePool:
11 // Calls to NSAutoreleasePools will be rewritten as an @autorelease scope.
13 // NSAutoreleasePool *pool = [[NSAutoreleasePool alloc] init];
21 // An NSAutoreleasePool will not be touched if:
22 // - There is not a corresponding -release/-drain in the same scope
23 // - Not all references of the NSAutoreleasePool variable can be removed
24 // - There is a variable that is declared inside the intended @autorelease scope
25 // which is also used outside it.
27 //===----------------------------------------------------------------------===//
29 #include "Transforms.h"
30 #include "Internals.h"
31 #include "clang/AST/ASTContext.h"
32 #include "clang/Basic/SourceManager.h"
33 #include "clang/Sema/SemaDiagnostic.h"
36 using namespace clang
;
37 using namespace arcmt
;
38 using namespace trans
;
42 class ReleaseCollector
: public RecursiveASTVisitor
<ReleaseCollector
> {
44 SmallVectorImpl
<ObjCMessageExpr
*> &Releases
;
47 ReleaseCollector(Decl
*D
, SmallVectorImpl
<ObjCMessageExpr
*> &releases
)
48 : Dcl(D
), Releases(releases
) { }
50 bool VisitObjCMessageExpr(ObjCMessageExpr
*E
) {
51 if (!E
->isInstanceMessage())
53 if (E
->getMethodFamily() != OMF_release
)
55 Expr
*instance
= E
->getInstanceReceiver()->IgnoreParenCasts();
56 if (DeclRefExpr
*DE
= dyn_cast
<DeclRefExpr
>(instance
)) {
57 if (DE
->getDecl() == Dcl
)
58 Releases
.push_back(E
);
68 class AutoreleasePoolRewriter
69 : public RecursiveASTVisitor
<AutoreleasePoolRewriter
> {
71 AutoreleasePoolRewriter(MigrationPass
&pass
)
72 : Body(nullptr), Pass(pass
) {
73 PoolII
= &pass
.Ctx
.Idents
.get("NSAutoreleasePool");
74 DrainSel
= pass
.Ctx
.Selectors
.getNullarySelector(
75 &pass
.Ctx
.Idents
.get("drain"));
78 void transformBody(Stmt
*body
, Decl
*ParentD
) {
83 ~AutoreleasePoolRewriter() {
84 SmallVector
<VarDecl
*, 8> VarsToHandle
;
86 for (std::map
<VarDecl
*, PoolVarInfo
>::iterator
87 I
= PoolVars
.begin(), E
= PoolVars
.end(); I
!= E
; ++I
) {
88 VarDecl
*var
= I
->first
;
89 PoolVarInfo
&info
= I
->second
;
91 // Check that we can handle/rewrite all references of the pool.
93 clearRefsIn(info
.Dcl
, info
.Refs
);
94 for (SmallVectorImpl
<PoolScope
>::iterator
95 scpI
= info
.Scopes
.begin(),
96 scpE
= info
.Scopes
.end(); scpI
!= scpE
; ++scpI
) {
97 PoolScope
&scope
= *scpI
;
98 clearRefsIn(*scope
.Begin
, info
.Refs
);
99 clearRefsIn(*scope
.End
, info
.Refs
);
100 clearRefsIn(scope
.Releases
.begin(), scope
.Releases
.end(), info
.Refs
);
103 // Even if one reference is not handled we will not do anything about that
105 if (info
.Refs
.empty())
106 VarsToHandle
.push_back(var
);
109 for (unsigned i
= 0, e
= VarsToHandle
.size(); i
!= e
; ++i
) {
110 PoolVarInfo
&info
= PoolVars
[VarsToHandle
[i
]];
112 Transaction
Trans(Pass
.TA
);
114 clearUnavailableDiags(info
.Dcl
);
115 Pass
.TA
.removeStmt(info
.Dcl
);
117 // Add "@autoreleasepool { }"
118 for (SmallVectorImpl
<PoolScope
>::iterator
119 scpI
= info
.Scopes
.begin(),
120 scpE
= info
.Scopes
.end(); scpI
!= scpE
; ++scpI
) {
121 PoolScope
&scope
= *scpI
;
122 clearUnavailableDiags(*scope
.Begin
);
123 clearUnavailableDiags(*scope
.End
);
124 if (scope
.IsFollowedBySimpleReturnStmt
) {
125 // Include the return in the scope.
126 Pass
.TA
.replaceStmt(*scope
.Begin
, "@autoreleasepool {");
127 Pass
.TA
.removeStmt(*scope
.End
);
128 Stmt::child_iterator retI
= scope
.End
;
130 SourceLocation afterSemi
=
131 findLocationAfterSemi((*retI
)->getEndLoc(), Pass
.Ctx
);
132 assert(afterSemi
.isValid() &&
133 "Didn't we check before setting IsFollowedBySimpleReturnStmt "
135 Pass
.TA
.insertAfterToken(afterSemi
, "\n}");
136 Pass
.TA
.increaseIndentation(
137 SourceRange(scope
.getIndentedRange().getBegin(),
138 (*retI
)->getEndLoc()),
139 scope
.CompoundParent
->getBeginLoc());
141 Pass
.TA
.replaceStmt(*scope
.Begin
, "@autoreleasepool {");
142 Pass
.TA
.replaceStmt(*scope
.End
, "}");
143 Pass
.TA
.increaseIndentation(scope
.getIndentedRange(),
144 scope
.CompoundParent
->getBeginLoc());
148 // Remove rest of pool var references.
149 for (SmallVectorImpl
<PoolScope
>::iterator
150 scpI
= info
.Scopes
.begin(),
151 scpE
= info
.Scopes
.end(); scpI
!= scpE
; ++scpI
) {
152 PoolScope
&scope
= *scpI
;
153 for (SmallVectorImpl
<ObjCMessageExpr
*>::iterator
154 relI
= scope
.Releases
.begin(),
155 relE
= scope
.Releases
.end(); relI
!= relE
; ++relI
) {
156 clearUnavailableDiags(*relI
);
157 Pass
.TA
.removeStmt(*relI
);
163 bool VisitCompoundStmt(CompoundStmt
*S
) {
164 SmallVector
<PoolScope
, 4> Scopes
;
166 for (Stmt::child_iterator
167 I
= S
->body_begin(), E
= S
->body_end(); I
!= E
; ++I
) {
168 Stmt
*child
= getEssential(*I
);
169 if (DeclStmt
*DclS
= dyn_cast
<DeclStmt
>(child
)) {
170 if (DclS
->isSingleDecl()) {
171 if (VarDecl
*VD
= dyn_cast
<VarDecl
>(DclS
->getSingleDecl())) {
172 if (isNSAutoreleasePool(VD
->getType())) {
173 PoolVarInfo
&info
= PoolVars
[VD
];
175 collectRefs(VD
, S
, info
.Refs
);
176 // Does this statement follow the pattern:
177 // NSAutoreleasePool * pool = [NSAutoreleasePool new];
178 if (isPoolCreation(VD
->getInit())) {
179 Scopes
.push_back(PoolScope());
180 Scopes
.back().PoolVar
= VD
;
181 Scopes
.back().CompoundParent
= S
;
182 Scopes
.back().Begin
= I
;
187 } else if (BinaryOperator
*bop
= dyn_cast
<BinaryOperator
>(child
)) {
188 if (DeclRefExpr
*dref
= dyn_cast
<DeclRefExpr
>(bop
->getLHS())) {
189 if (VarDecl
*VD
= dyn_cast
<VarDecl
>(dref
->getDecl())) {
190 // Does this statement follow the pattern:
191 // pool = [NSAutoreleasePool new];
192 if (isNSAutoreleasePool(VD
->getType()) &&
193 isPoolCreation(bop
->getRHS())) {
194 Scopes
.push_back(PoolScope());
195 Scopes
.back().PoolVar
= VD
;
196 Scopes
.back().CompoundParent
= S
;
197 Scopes
.back().Begin
= I
;
206 if (isPoolDrain(Scopes
.back().PoolVar
, child
)) {
207 PoolScope
&scope
= Scopes
.back();
209 handlePoolScope(scope
, S
);
217 void clearUnavailableDiags(Stmt
*S
) {
219 Pass
.TA
.clearDiagnostic(diag::err_unavailable
,
220 diag::err_unavailable_message
,
221 S
->getSourceRange());
226 CompoundStmt
*CompoundParent
;
227 Stmt::child_iterator Begin
;
228 Stmt::child_iterator End
;
229 bool IsFollowedBySimpleReturnStmt
;
230 SmallVector
<ObjCMessageExpr
*, 4> Releases
;
233 : PoolVar(nullptr), CompoundParent(nullptr),
234 IsFollowedBySimpleReturnStmt(false) {}
236 SourceRange
getIndentedRange() const {
237 Stmt::child_iterator rangeS
= Begin
;
240 return SourceRange();
241 Stmt::child_iterator rangeE
= Begin
;
242 for (Stmt::child_iterator I
= rangeS
; I
!= End
; ++I
)
244 return SourceRange((*rangeS
)->getBeginLoc(), (*rangeE
)->getEndLoc());
248 class NameReferenceChecker
: public RecursiveASTVisitor
<NameReferenceChecker
>{
250 SourceRange ScopeRange
;
251 SourceLocation
&referenceLoc
, &declarationLoc
;
254 NameReferenceChecker(ASTContext
&ctx
, PoolScope
&scope
,
255 SourceLocation
&referenceLoc
,
256 SourceLocation
&declarationLoc
)
257 : Ctx(ctx
), referenceLoc(referenceLoc
),
258 declarationLoc(declarationLoc
) {
259 ScopeRange
= SourceRange((*scope
.Begin
)->getBeginLoc(),
260 (*scope
.End
)->getBeginLoc());
263 bool VisitDeclRefExpr(DeclRefExpr
*E
) {
264 return checkRef(E
->getLocation(), E
->getDecl()->getLocation());
267 bool VisitTypedefTypeLoc(TypedefTypeLoc TL
) {
268 return checkRef(TL
.getBeginLoc(), TL
.getTypedefNameDecl()->getLocation());
271 bool VisitTagTypeLoc(TagTypeLoc TL
) {
272 return checkRef(TL
.getBeginLoc(), TL
.getDecl()->getLocation());
276 bool checkRef(SourceLocation refLoc
, SourceLocation declLoc
) {
277 if (isInScope(declLoc
)) {
278 referenceLoc
= refLoc
;
279 declarationLoc
= declLoc
;
285 bool isInScope(SourceLocation loc
) {
289 SourceManager
&SM
= Ctx
.getSourceManager();
290 if (SM
.isBeforeInTranslationUnit(loc
, ScopeRange
.getBegin()))
292 return SM
.isBeforeInTranslationUnit(loc
, ScopeRange
.getEnd());
296 void handlePoolScope(PoolScope
&scope
, CompoundStmt
*compoundS
) {
297 // Check that all names declared inside the scope are not used
298 // outside the scope.
300 bool nameUsedOutsideScope
= false;
301 SourceLocation referenceLoc
, declarationLoc
;
302 Stmt::child_iterator SI
= scope
.End
, SE
= compoundS
->body_end();
304 // Check if the autoreleasepool scope is followed by a simple return
305 // statement, in which case we will include the return in the scope.
307 if (ReturnStmt
*retS
= dyn_cast
<ReturnStmt
>(*SI
))
308 if ((retS
->getRetValue() == nullptr ||
309 isa
<DeclRefExpr
>(retS
->getRetValue()->IgnoreParenCasts())) &&
310 findLocationAfterSemi(retS
->getEndLoc(), Pass
.Ctx
).isValid()) {
311 scope
.IsFollowedBySimpleReturnStmt
= true;
312 ++SI
; // the return will be included in scope, don't check it.
315 for (; SI
!= SE
; ++SI
) {
316 nameUsedOutsideScope
= !NameReferenceChecker(Pass
.Ctx
, scope
,
318 declarationLoc
).TraverseStmt(*SI
);
319 if (nameUsedOutsideScope
)
323 // If not all references were cleared it means some variables/typenames/etc
324 // declared inside the pool scope are used outside of it.
325 // We won't try to rewrite the pool.
326 if (nameUsedOutsideScope
) {
327 Pass
.TA
.reportError("a name is referenced outside the "
328 "NSAutoreleasePool scope that it was declared in", referenceLoc
);
329 Pass
.TA
.reportNote("name declared here", declarationLoc
);
330 Pass
.TA
.reportNote("intended @autoreleasepool scope begins here",
331 (*scope
.Begin
)->getBeginLoc());
332 Pass
.TA
.reportNote("intended @autoreleasepool scope ends here",
333 (*scope
.End
)->getBeginLoc());
338 // Collect all releases of the pool; they will be removed.
340 ReleaseCollector
releaseColl(scope
.PoolVar
, scope
.Releases
);
341 Stmt::child_iterator I
= scope
.Begin
;
343 for (; I
!= scope
.End
; ++I
)
344 releaseColl
.TraverseStmt(*I
);
347 PoolVars
[scope
.PoolVar
].Scopes
.push_back(scope
);
350 bool isPoolCreation(Expr
*E
) {
351 if (!E
) return false;
353 ObjCMessageExpr
*ME
= dyn_cast
<ObjCMessageExpr
>(E
);
354 if (!ME
) return false;
355 if (ME
->getMethodFamily() == OMF_new
&&
356 ME
->getReceiverKind() == ObjCMessageExpr::Class
&&
357 isNSAutoreleasePool(ME
->getReceiverInterface()))
359 if (ME
->getReceiverKind() == ObjCMessageExpr::Instance
&&
360 ME
->getMethodFamily() == OMF_init
) {
361 Expr
*rec
= getEssential(ME
->getInstanceReceiver());
362 if (ObjCMessageExpr
*recME
= dyn_cast_or_null
<ObjCMessageExpr
>(rec
)) {
363 if (recME
->getMethodFamily() == OMF_alloc
&&
364 recME
->getReceiverKind() == ObjCMessageExpr::Class
&&
365 isNSAutoreleasePool(recME
->getReceiverInterface()))
373 bool isPoolDrain(VarDecl
*poolVar
, Stmt
*S
) {
374 if (!S
) return false;
376 ObjCMessageExpr
*ME
= dyn_cast
<ObjCMessageExpr
>(S
);
377 if (!ME
) return false;
378 if (ME
->getReceiverKind() == ObjCMessageExpr::Instance
) {
379 Expr
*rec
= getEssential(ME
->getInstanceReceiver());
380 if (DeclRefExpr
*dref
= dyn_cast
<DeclRefExpr
>(rec
))
381 if (dref
->getDecl() == poolVar
)
382 return ME
->getMethodFamily() == OMF_release
||
383 ME
->getSelector() == DrainSel
;
389 bool isNSAutoreleasePool(ObjCInterfaceDecl
*IDecl
) {
390 return IDecl
&& IDecl
->getIdentifier() == PoolII
;
393 bool isNSAutoreleasePool(QualType Ty
) {
394 QualType pointee
= Ty
->getPointeeType();
395 if (pointee
.isNull())
397 if (const ObjCInterfaceType
*interT
= pointee
->getAs
<ObjCInterfaceType
>())
398 return isNSAutoreleasePool(interT
->getDecl());
402 static Expr
*getEssential(Expr
*E
) {
403 return cast
<Expr
>(getEssential((Stmt
*)E
));
405 static Stmt
*getEssential(Stmt
*S
) {
406 if (FullExpr
*FE
= dyn_cast
<FullExpr
>(S
))
407 S
= FE
->getSubExpr();
408 if (Expr
*E
= dyn_cast
<Expr
>(S
))
409 S
= E
->IgnoreParenCasts();
416 IdentifierInfo
*PoolII
;
422 SmallVector
<PoolScope
, 2> Scopes
;
424 PoolVarInfo() : Dcl(nullptr) { }
427 std::map
<VarDecl
*, PoolVarInfo
> PoolVars
;
430 } // anonymous namespace
432 void trans::rewriteAutoreleasePool(MigrationPass
&pass
) {
433 BodyTransform
<AutoreleasePoolRewriter
> trans(pass
);
434 trans
.TraverseDecl(pass
.Ctx
.getTranslationUnitDecl());