1 //===--- LoopUnrolling.cpp - Unroll loops -----------------------*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 /// This file contains functions which are used to decide if a loop worth to be
10 /// unrolled. Moreover, these functions manages the stack of loop which is
11 /// tracked by the ProgramState.
13 //===----------------------------------------------------------------------===//
15 #include "clang/ASTMatchers/ASTMatchers.h"
16 #include "clang/ASTMatchers/ASTMatchFinder.h"
17 #include "clang/StaticAnalyzer/Core/PathSensitive/CallEvent.h"
18 #include "clang/StaticAnalyzer/Core/PathSensitive/CheckerContext.h"
19 #include "clang/StaticAnalyzer/Core/PathSensitive/LoopUnrolling.h"
22 using namespace clang
;
24 using namespace clang::ast_matchers
;
26 static const int MAXIMUM_STEP_UNROLLED
= 128;
31 enum Kind
{ Normal
, Unrolled
} K
;
33 const LocationContext
*LCtx
;
35 LoopState(Kind InK
, const Stmt
*S
, const LocationContext
*L
, unsigned N
)
36 : K(InK
), LoopStmt(S
), LCtx(L
), maxStep(N
) {}
39 static LoopState
getNormal(const Stmt
*S
, const LocationContext
*L
,
41 return LoopState(Normal
, S
, L
, N
);
43 static LoopState
getUnrolled(const Stmt
*S
, const LocationContext
*L
,
45 return LoopState(Unrolled
, S
, L
, N
);
47 bool isUnrolled() const { return K
== Unrolled
; }
48 unsigned getMaxStep() const { return maxStep
; }
49 const Stmt
*getLoopStmt() const { return LoopStmt
; }
50 const LocationContext
*getLocationContext() const { return LCtx
; }
51 bool operator==(const LoopState
&X
) const {
52 return K
== X
.K
&& LoopStmt
== X
.LoopStmt
;
54 void Profile(llvm::FoldingSetNodeID
&ID
) const {
56 ID
.AddPointer(LoopStmt
);
58 ID
.AddInteger(maxStep
);
63 // The tracked stack of loops. The stack indicates that which loops the
64 // simulated element contained by. The loops are marked depending if we decided
66 // TODO: The loop stack should not need to be in the program state since it is
67 // lexical in nature. Instead, the stack of loops should be tracked in the
69 REGISTER_LIST_WITH_PROGRAMSTATE(LoopStack
, LoopState
)
74 static bool isLoopStmt(const Stmt
*S
) {
75 return isa_and_nonnull
<ForStmt
, WhileStmt
, DoStmt
>(S
);
78 ProgramStateRef
processLoopEnd(const Stmt
*LoopStmt
, ProgramStateRef State
) {
79 auto LS
= State
->get
<LoopStack
>();
80 if (!LS
.isEmpty() && LS
.getHead().getLoopStmt() == LoopStmt
)
81 State
= State
->set
<LoopStack
>(LS
.getTail());
85 static internal::Matcher
<Stmt
> simpleCondition(StringRef BindName
,
87 return binaryOperator(
88 anyOf(hasOperatorName("<"), hasOperatorName(">"),
89 hasOperatorName("<="), hasOperatorName(">="),
90 hasOperatorName("!=")),
91 hasEitherOperand(ignoringParenImpCasts(
92 declRefExpr(to(varDecl(hasType(isInteger())).bind(BindName
)))
95 ignoringParenImpCasts(integerLiteral().bind("boundNum"))))
96 .bind("conditionOperator");
99 static internal::Matcher
<Stmt
>
100 changeIntBoundNode(internal::Matcher
<Decl
> VarNodeMatcher
) {
102 unaryOperator(anyOf(hasOperatorName("--"), hasOperatorName("++")),
103 hasUnaryOperand(ignoringParenImpCasts(
104 declRefExpr(to(varDecl(VarNodeMatcher
)))))),
105 binaryOperator(isAssignmentOperator(),
106 hasLHS(ignoringParenImpCasts(
107 declRefExpr(to(varDecl(VarNodeMatcher
)))))));
110 static internal::Matcher
<Stmt
>
111 callByRef(internal::Matcher
<Decl
> VarNodeMatcher
) {
112 return callExpr(forEachArgumentWithParam(
113 declRefExpr(to(varDecl(VarNodeMatcher
))),
114 parmVarDecl(hasType(references(qualType(unless(isConstQualified())))))));
117 static internal::Matcher
<Stmt
>
118 assignedToRef(internal::Matcher
<Decl
> VarNodeMatcher
) {
119 return declStmt(hasDescendant(varDecl(
120 allOf(hasType(referenceType()),
121 hasInitializer(anyOf(
122 initListExpr(has(declRefExpr(to(varDecl(VarNodeMatcher
))))),
123 declRefExpr(to(varDecl(VarNodeMatcher
)))))))));
126 static internal::Matcher
<Stmt
>
127 getAddrTo(internal::Matcher
<Decl
> VarNodeMatcher
) {
128 return unaryOperator(
129 hasOperatorName("&"),
130 hasUnaryOperand(declRefExpr(hasDeclaration(VarNodeMatcher
))));
133 static internal::Matcher
<Stmt
> hasSuspiciousStmt(StringRef NodeName
) {
134 return hasDescendant(stmt(
135 anyOf(gotoStmt(), switchStmt(), returnStmt(),
136 // Escaping and not known mutation of the loop counter is handled
137 // by exclusion of assigning and address-of operators and
138 // pass-by-ref function calls on the loop counter from the body.
139 changeIntBoundNode(equalsBoundNode(std::string(NodeName
))),
140 callByRef(equalsBoundNode(std::string(NodeName
))),
141 getAddrTo(equalsBoundNode(std::string(NodeName
))),
142 assignedToRef(equalsBoundNode(std::string(NodeName
))))));
145 static internal::Matcher
<Stmt
> forLoopMatcher() {
147 hasCondition(simpleCondition("initVarName", "initVarRef")),
148 // Initialization should match the form: 'int i = 6' or 'i = 42'.
150 anyOf(declStmt(hasSingleDecl(
151 varDecl(allOf(hasInitializer(ignoringParenImpCasts(
152 integerLiteral().bind("initNum"))),
153 equalsBoundNode("initVarName"))))),
154 binaryOperator(hasLHS(declRefExpr(to(varDecl(
155 equalsBoundNode("initVarName"))))),
156 hasRHS(ignoringParenImpCasts(
157 integerLiteral().bind("initNum")))))),
158 // Incrementation should be a simple increment or decrement
160 hasIncrement(unaryOperator(
161 anyOf(hasOperatorName("++"), hasOperatorName("--")),
162 hasUnaryOperand(declRefExpr(
163 to(varDecl(allOf(equalsBoundNode("initVarName"),
164 hasType(isInteger())))))))),
165 unless(hasBody(hasSuspiciousStmt("initVarName"))))
169 static bool isCapturedByReference(ExplodedNode
*N
, const DeclRefExpr
*DR
) {
171 // Get the lambda CXXRecordDecl
172 assert(DR
->refersToEnclosingVariableOrCapture());
173 const LocationContext
*LocCtxt
= N
->getLocationContext();
174 const Decl
*D
= LocCtxt
->getDecl();
175 const auto *MD
= cast
<CXXMethodDecl
>(D
);
176 assert(MD
&& MD
->getParent()->isLambda() &&
177 "Captured variable should only be seen while evaluating a lambda");
178 const CXXRecordDecl
*LambdaCXXRec
= MD
->getParent();
180 // Lookup the fields of the lambda
181 llvm::DenseMap
<const ValueDecl
*, FieldDecl
*> LambdaCaptureFields
;
182 FieldDecl
*LambdaThisCaptureField
;
183 LambdaCXXRec
->getCaptureFields(LambdaCaptureFields
, LambdaThisCaptureField
);
185 // Check if the counter is captured by reference
186 const VarDecl
*VD
= cast
<VarDecl
>(DR
->getDecl()->getCanonicalDecl());
188 const FieldDecl
*FD
= LambdaCaptureFields
[VD
];
189 assert(FD
&& "Captured variable without a corresponding field");
190 return FD
->getType()->isReferenceType();
193 // A loop counter is considered escaped if:
194 // case 1: It is a global variable.
195 // case 2: It is a reference parameter or a reference capture.
196 // case 3: It is assigned to a non-const reference variable or parameter.
197 // case 4: Has its address taken.
198 static bool isPossiblyEscaped(ExplodedNode
*N
, const DeclRefExpr
*DR
) {
199 const VarDecl
*VD
= cast
<VarDecl
>(DR
->getDecl()->getCanonicalDecl());
202 if (VD
->hasGlobalStorage())
205 const bool IsRefParamOrCapture
=
206 isa
<ParmVarDecl
>(VD
) || DR
->refersToEnclosingVariableOrCapture();
208 if ((DR
->refersToEnclosingVariableOrCapture() &&
209 isCapturedByReference(N
, DR
)) ||
210 (IsRefParamOrCapture
&& VD
->getType()->isReferenceType()))
213 while (!N
->pred_empty()) {
214 // FIXME: getStmtForDiagnostics() does nasty things in order to provide
215 // a valid statement for body farms, do we need this behavior here?
216 const Stmt
*S
= N
->getStmtForDiagnostics();
218 N
= N
->getFirstPred();
222 if (const DeclStmt
*DS
= dyn_cast
<DeclStmt
>(S
)) {
223 for (const Decl
*D
: DS
->decls()) {
224 // Once we reach the declaration of the VD we can return.
225 if (D
->getCanonicalDecl() == VD
)
229 // Check the usage of the pass-by-ref function calls and adress-of operator
230 // on VD and reference initialized by VD.
232 N
->getLocationContext()->getAnalysisDeclContext()->getASTContext();
235 match(stmt(anyOf(callByRef(equalsNode(VD
)), getAddrTo(equalsNode(VD
)),
236 assignedToRef(equalsNode(VD
)))),
241 N
= N
->getFirstPred();
244 // Reference parameter and reference capture will not be found.
245 if (IsRefParamOrCapture
)
248 llvm_unreachable("Reached root without finding the declaration of VD");
251 bool shouldCompletelyUnroll(const Stmt
*LoopStmt
, ASTContext
&ASTCtx
,
252 ExplodedNode
*Pred
, unsigned &maxStep
) {
254 if (!isLoopStmt(LoopStmt
))
257 // TODO: Match the cases where the bound is not a concrete literal but an
258 // integer with known value
259 auto Matches
= match(forLoopMatcher(), *LoopStmt
, ASTCtx
);
263 const auto *CounterVarRef
= Matches
[0].getNodeAs
<DeclRefExpr
>("initVarRef");
264 llvm::APInt BoundNum
=
265 Matches
[0].getNodeAs
<IntegerLiteral
>("boundNum")->getValue();
266 llvm::APInt InitNum
=
267 Matches
[0].getNodeAs
<IntegerLiteral
>("initNum")->getValue();
268 auto CondOp
= Matches
[0].getNodeAs
<BinaryOperator
>("conditionOperator");
269 if (InitNum
.getBitWidth() != BoundNum
.getBitWidth()) {
270 InitNum
= InitNum
.zext(BoundNum
.getBitWidth());
271 BoundNum
= BoundNum
.zext(InitNum
.getBitWidth());
274 if (CondOp
->getOpcode() == BO_GE
|| CondOp
->getOpcode() == BO_LE
)
275 maxStep
= (BoundNum
- InitNum
+ 1).abs().getZExtValue();
277 maxStep
= (BoundNum
- InitNum
).abs().getZExtValue();
279 // Check if the counter of the loop is not escaped before.
280 return !isPossiblyEscaped(Pred
, CounterVarRef
);
283 bool madeNewBranch(ExplodedNode
*N
, const Stmt
*LoopStmt
) {
284 const Stmt
*S
= nullptr;
285 while (!N
->pred_empty()) {
286 if (N
->succ_size() > 1)
289 ProgramPoint P
= N
->getLocation();
290 if (std::optional
<BlockEntrance
> BE
= P
.getAs
<BlockEntrance
>())
291 S
= BE
->getBlock()->getTerminatorStmt();
296 N
= N
->getFirstPred();
299 llvm_unreachable("Reached root without encountering the previous step");
302 // updateLoopStack is called on every basic block, therefore it needs to be fast
303 ProgramStateRef
updateLoopStack(const Stmt
*LoopStmt
, ASTContext
&ASTCtx
,
304 ExplodedNode
*Pred
, unsigned maxVisitOnPath
) {
305 auto State
= Pred
->getState();
306 auto LCtx
= Pred
->getLocationContext();
308 if (!isLoopStmt(LoopStmt
))
311 auto LS
= State
->get
<LoopStack
>();
312 if (!LS
.isEmpty() && LoopStmt
== LS
.getHead().getLoopStmt() &&
313 LCtx
== LS
.getHead().getLocationContext()) {
314 if (LS
.getHead().isUnrolled() && madeNewBranch(Pred
, LoopStmt
)) {
315 State
= State
->set
<LoopStack
>(LS
.getTail());
316 State
= State
->add
<LoopStack
>(
317 LoopState::getNormal(LoopStmt
, LCtx
, maxVisitOnPath
));
322 if (!shouldCompletelyUnroll(LoopStmt
, ASTCtx
, Pred
, maxStep
)) {
323 State
= State
->add
<LoopStack
>(
324 LoopState::getNormal(LoopStmt
, LCtx
, maxVisitOnPath
));
328 unsigned outerStep
= (LS
.isEmpty() ? 1 : LS
.getHead().getMaxStep());
330 unsigned innerMaxStep
= maxStep
* outerStep
;
331 if (innerMaxStep
> MAXIMUM_STEP_UNROLLED
)
332 State
= State
->add
<LoopStack
>(
333 LoopState::getNormal(LoopStmt
, LCtx
, maxVisitOnPath
));
335 State
= State
->add
<LoopStack
>(
336 LoopState::getUnrolled(LoopStmt
, LCtx
, innerMaxStep
));
340 bool isUnrolledState(ProgramStateRef State
) {
341 auto LS
= State
->get
<LoopStack
>();
342 if (LS
.isEmpty() || !LS
.getHead().isUnrolled())