1 //===--- TransBlockObjCVariable.cpp - Transformations to ARC mode ---------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // rewriteBlockObjCVariable:
11 // Adding __block to an obj-c variable could be either because the variable
12 // is used for output storage or the user wanted to break a retain cycle.
13 // This transformation checks whether a reference of the variable for the block
14 // is actually needed (it is assigned to or its address is taken) or not.
15 // If the reference is not needed it will assume __block was added to break a
16 // cycle so it will remove '__block' and add __weak/__unsafe_unretained.
20 // bar(^ { [x cake]; });
23 // bar(^ { [x cake]; });
25 //===----------------------------------------------------------------------===//
27 #include "Transforms.h"
28 #include "Internals.h"
29 #include "clang/AST/ASTContext.h"
30 #include "clang/AST/Attr.h"
31 #include "clang/Basic/SourceManager.h"
33 using namespace clang
;
34 using namespace arcmt
;
35 using namespace trans
;
39 class RootBlockObjCVarRewriter
:
40 public RecursiveASTVisitor
<RootBlockObjCVarRewriter
> {
41 llvm::DenseSet
<VarDecl
*> &VarsToChange
;
43 class BlockVarChecker
: public RecursiveASTVisitor
<BlockVarChecker
> {
46 typedef RecursiveASTVisitor
<BlockVarChecker
> base
;
48 BlockVarChecker(VarDecl
*var
) : Var(var
) { }
50 bool TraverseImplicitCastExpr(ImplicitCastExpr
*castE
) {
52 ref
= dyn_cast
<DeclRefExpr
>(castE
->getSubExpr())) {
53 if (ref
->getDecl() == Var
) {
54 if (castE
->getCastKind() == CK_LValueToRValue
)
55 return true; // Using the value of the variable.
56 if (castE
->getCastKind() == CK_NoOp
&& castE
->isLValue() &&
57 Var
->getASTContext().getLangOpts().CPlusPlus
)
58 return true; // Binding to const C++ reference.
62 return base::TraverseImplicitCastExpr(castE
);
65 bool VisitDeclRefExpr(DeclRefExpr
*E
) {
66 if (E
->getDecl() == Var
)
67 return false; // The reference of the variable, and not just its value,
74 RootBlockObjCVarRewriter(llvm::DenseSet
<VarDecl
*> &VarsToChange
)
75 : VarsToChange(VarsToChange
) { }
77 bool VisitBlockDecl(BlockDecl
*block
) {
78 SmallVector
<VarDecl
*, 4> BlockVars
;
80 for (const auto &I
: block
->captures()) {
81 VarDecl
*var
= I
.getVariable();
83 var
->getType()->isObjCObjectPointerType() &&
84 isImplicitStrong(var
->getType())) {
85 BlockVars
.push_back(var
);
89 for (unsigned i
= 0, e
= BlockVars
.size(); i
!= e
; ++i
) {
90 VarDecl
*var
= BlockVars
[i
];
92 BlockVarChecker
checker(var
);
93 bool onlyValueOfVarIsNeeded
= checker
.TraverseStmt(block
->getBody());
94 if (onlyValueOfVarIsNeeded
)
95 VarsToChange
.insert(var
);
97 VarsToChange
.erase(var
);
104 bool isImplicitStrong(QualType ty
) {
105 if (isa
<AttributedType
>(ty
.getTypePtr()))
107 return ty
.getLocalQualifiers().getObjCLifetime() == Qualifiers::OCL_Strong
;
111 class BlockObjCVarRewriter
: public RecursiveASTVisitor
<BlockObjCVarRewriter
> {
112 llvm::DenseSet
<VarDecl
*> &VarsToChange
;
115 BlockObjCVarRewriter(llvm::DenseSet
<VarDecl
*> &VarsToChange
)
116 : VarsToChange(VarsToChange
) { }
118 bool TraverseBlockDecl(BlockDecl
*block
) {
119 RootBlockObjCVarRewriter(VarsToChange
).TraverseDecl(block
);
124 } // anonymous namespace
126 void BlockObjCVariableTraverser::traverseBody(BodyContext
&BodyCtx
) {
127 MigrationPass
&Pass
= BodyCtx
.getMigrationContext().Pass
;
128 llvm::DenseSet
<VarDecl
*> VarsToChange
;
130 BlockObjCVarRewriter
trans(VarsToChange
);
131 trans
.TraverseStmt(BodyCtx
.getTopStmt());
133 for (llvm::DenseSet
<VarDecl
*>::iterator
134 I
= VarsToChange
.begin(), E
= VarsToChange
.end(); I
!= E
; ++I
) {
136 BlocksAttr
*attr
= var
->getAttr
<BlocksAttr
>();
139 bool useWeak
= canApplyWeak(Pass
.Ctx
, var
->getType());
140 SourceManager
&SM
= Pass
.Ctx
.getSourceManager();
141 Transaction
Trans(Pass
.TA
);
142 Pass
.TA
.replaceText(SM
.getExpansionLoc(attr
->getLocation()),
144 useWeak
? "__weak" : "__unsafe_unretained");