1 //===--- UnrollLoopsCheck.cpp - clang-tidy --------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "UnrollLoopsCheck.h"
10 #include "clang/AST/APValue.h"
11 #include "clang/AST/ASTContext.h"
12 #include "clang/AST/ASTTypeTraits.h"
13 #include "clang/AST/OperationKinds.h"
14 #include "clang/AST/ParentMapContext.h"
15 #include "clang/ASTMatchers/ASTMatchFinder.h"
18 using namespace clang::ast_matchers
;
20 namespace clang::tidy::altera
{
22 UnrollLoopsCheck::UnrollLoopsCheck(StringRef Name
, ClangTidyContext
*Context
)
23 : ClangTidyCheck(Name
, Context
),
24 MaxLoopIterations(Options
.get("MaxLoopIterations", 100U)) {}
26 void UnrollLoopsCheck::registerMatchers(MatchFinder
*Finder
) {
27 const auto HasLoopBound
= hasDescendant(
28 varDecl(matchesName("__end*"),
29 hasDescendant(integerLiteral().bind("cxx_loop_bound"))));
30 const auto CXXForRangeLoop
=
31 cxxForRangeStmt(anyOf(HasLoopBound
, unless(HasLoopBound
)));
32 const auto AnyLoop
= anyOf(forStmt(), whileStmt(), doStmt(), CXXForRangeLoop
);
34 stmt(AnyLoop
, unless(hasDescendant(stmt(AnyLoop
)))).bind("loop"), this);
37 void UnrollLoopsCheck::check(const MatchFinder::MatchResult
&Result
) {
38 const auto *Loop
= Result
.Nodes
.getNodeAs
<Stmt
>("loop");
39 const auto *CXXLoopBound
=
40 Result
.Nodes
.getNodeAs
<IntegerLiteral
>("cxx_loop_bound");
41 const ASTContext
*Context
= Result
.Context
;
42 switch (unrollType(Loop
, Result
.Context
)) {
44 diag(Loop
->getBeginLoc(),
45 "kernel performance could be improved by unrolling this loop with a "
46 "'#pragma unroll' directive");
48 case PartiallyUnrolled
:
49 // Loop already partially unrolled, do nothing.
52 if (hasKnownBounds(Loop
, CXXLoopBound
, Context
)) {
53 if (hasLargeNumIterations(Loop
, CXXLoopBound
, Context
)) {
54 diag(Loop
->getBeginLoc(),
55 "loop likely has a large number of iterations and thus "
56 "cannot be fully unrolled; to partially unroll this loop, use "
57 "the '#pragma unroll <num>' directive");
62 if (isa
<WhileStmt
, DoStmt
>(Loop
)) {
63 diag(Loop
->getBeginLoc(),
64 "full unrolling requested, but loop bounds may not be known; to "
65 "partially unroll this loop, use the '#pragma unroll <num>' "
70 diag(Loop
->getBeginLoc(),
71 "full unrolling requested, but loop bounds are not known; to "
72 "partially unroll this loop, use the '#pragma unroll <num>' "
78 enum UnrollLoopsCheck::UnrollType
79 UnrollLoopsCheck::unrollType(const Stmt
*Statement
, ASTContext
*Context
) {
80 const DynTypedNodeList Parents
= Context
->getParents
<Stmt
>(*Statement
);
81 for (const DynTypedNode
&Parent
: Parents
) {
82 const auto *ParentStmt
= Parent
.get
<AttributedStmt
>();
85 for (const Attr
*Attribute
: ParentStmt
->getAttrs()) {
86 const auto *LoopHint
= dyn_cast
<LoopHintAttr
>(Attribute
);
89 switch (LoopHint
->getState()) {
90 case LoopHintAttr::Numeric
:
91 return PartiallyUnrolled
;
92 case LoopHintAttr::Disable
:
94 case LoopHintAttr::Full
:
96 case LoopHintAttr::Enable
:
98 case LoopHintAttr::AssumeSafety
:
100 case LoopHintAttr::FixedWidth
:
102 case LoopHintAttr::ScalableWidth
:
110 bool UnrollLoopsCheck::hasKnownBounds(const Stmt
*Statement
,
111 const IntegerLiteral
*CXXLoopBound
,
112 const ASTContext
*Context
) {
113 if (isa
<CXXForRangeStmt
>(Statement
))
114 return CXXLoopBound
!= nullptr;
115 // Too many possibilities in a while statement, so always recommend partial
116 // unrolling for these.
117 if (isa
<WhileStmt
, DoStmt
>(Statement
))
119 // The last loop type is a for loop.
120 const auto *ForLoop
= cast
<ForStmt
>(Statement
);
121 const Stmt
*Initializer
= ForLoop
->getInit();
122 const Expr
*Conditional
= ForLoop
->getCond();
123 const Expr
*Increment
= ForLoop
->getInc();
124 if (!Initializer
|| !Conditional
|| !Increment
)
126 // If the loop variable value isn't known, loop bounds are unknown.
127 if (const auto *InitDeclStatement
= dyn_cast
<DeclStmt
>(Initializer
)) {
128 if (const auto *VariableDecl
=
129 dyn_cast
<VarDecl
>(InitDeclStatement
->getSingleDecl())) {
130 APValue
*Evaluation
= VariableDecl
->evaluateValue();
131 if (!Evaluation
|| !Evaluation
->hasValue())
135 // If increment is unary and not one of ++ and --, loop bounds are unknown.
136 if (const auto *Op
= dyn_cast
<UnaryOperator
>(Increment
))
137 if (!Op
->isIncrementDecrementOp())
140 if (const auto *BinaryOp
= dyn_cast
<BinaryOperator
>(Conditional
)) {
141 const Expr
*LHS
= BinaryOp
->getLHS();
142 const Expr
*RHS
= BinaryOp
->getRHS();
143 // If both sides are value dependent or constant, loop bounds are unknown.
144 return LHS
->isEvaluatable(*Context
) != RHS
->isEvaluatable(*Context
);
146 return false; // If it's not a binary operator, loop bounds are unknown.
149 const Expr
*UnrollLoopsCheck::getCondExpr(const Stmt
*Statement
) {
150 if (const auto *ForLoop
= dyn_cast
<ForStmt
>(Statement
))
151 return ForLoop
->getCond();
152 if (const auto *WhileLoop
= dyn_cast
<WhileStmt
>(Statement
))
153 return WhileLoop
->getCond();
154 if (const auto *DoWhileLoop
= dyn_cast
<DoStmt
>(Statement
))
155 return DoWhileLoop
->getCond();
156 if (const auto *CXXRangeLoop
= dyn_cast
<CXXForRangeStmt
>(Statement
))
157 return CXXRangeLoop
->getCond();
158 llvm_unreachable("Unknown loop");
161 bool UnrollLoopsCheck::hasLargeNumIterations(const Stmt
*Statement
,
162 const IntegerLiteral
*CXXLoopBound
,
163 const ASTContext
*Context
) {
164 // Because hasKnownBounds is called before this, if this is true, then
165 // CXXLoopBound is also matched.
166 if (isa
<CXXForRangeStmt
>(Statement
)) {
167 assert(CXXLoopBound
&& "CXX ranged for loop has no loop bound");
168 return exprHasLargeNumIterations(CXXLoopBound
, Context
);
170 const auto *ForLoop
= cast
<ForStmt
>(Statement
);
171 const Stmt
*Initializer
= ForLoop
->getInit();
172 const Expr
*Conditional
= ForLoop
->getCond();
173 const Expr
*Increment
= ForLoop
->getInc();
175 // If the loop variable value isn't known, we can't know the loop bounds.
176 if (const auto *InitDeclStatement
= dyn_cast
<DeclStmt
>(Initializer
)) {
177 if (const auto *VariableDecl
=
178 dyn_cast
<VarDecl
>(InitDeclStatement
->getSingleDecl())) {
179 APValue
*Evaluation
= VariableDecl
->evaluateValue();
180 if (!Evaluation
|| !Evaluation
->isInt())
182 InitValue
= Evaluation
->getInt().getExtValue();
187 const auto *BinaryOp
= cast
<BinaryOperator
>(Conditional
);
188 if (!extractValue(EndValue
, BinaryOp
, Context
))
191 double Iterations
= 0.0;
193 // If increment is unary and not one of ++, --, we can't know the loop bounds.
194 if (const auto *Op
= dyn_cast
<UnaryOperator
>(Increment
)) {
195 if (Op
->isIncrementOp())
196 Iterations
= EndValue
- InitValue
;
197 else if (Op
->isDecrementOp())
198 Iterations
= InitValue
- EndValue
;
200 llvm_unreachable("Unary operator neither increment nor decrement");
203 // If increment is binary and not one of +, -, *, /, we can't know the loop
205 if (const auto *Op
= dyn_cast
<BinaryOperator
>(Increment
)) {
206 int ConstantValue
= 0;
207 if (!extractValue(ConstantValue
, Op
, Context
))
209 switch (Op
->getOpcode()) {
211 Iterations
= ceil(float(EndValue
- InitValue
) / ConstantValue
);
214 Iterations
= ceil(float(InitValue
- EndValue
) / ConstantValue
);
217 Iterations
= 1 + (log((double)EndValue
) - log((double)InitValue
)) /
218 log((double)ConstantValue
);
221 Iterations
= 1 + (log((double)InitValue
) - log((double)EndValue
)) /
222 log((double)ConstantValue
);
225 // All other operators are not handled; assume large bounds.
229 return Iterations
> MaxLoopIterations
;
232 bool UnrollLoopsCheck::extractValue(int &Value
, const BinaryOperator
*Op
,
233 const ASTContext
*Context
) {
234 const Expr
*LHS
= Op
->getLHS();
235 const Expr
*RHS
= Op
->getRHS();
236 Expr::EvalResult Result
;
237 if (LHS
->isEvaluatable(*Context
))
238 LHS
->EvaluateAsRValue(Result
, *Context
);
239 else if (RHS
->isEvaluatable(*Context
))
240 RHS
->EvaluateAsRValue(Result
, *Context
);
242 return false; // Cannot evaluate either side.
243 if (!Result
.Val
.isInt())
244 return false; // Cannot check number of iterations, return false to be
246 Value
= Result
.Val
.getInt().getExtValue();
250 bool UnrollLoopsCheck::exprHasLargeNumIterations(const Expr
*Expression
,
251 const ASTContext
*Context
) const {
252 Expr::EvalResult Result
;
253 if (Expression
->EvaluateAsRValue(Result
, *Context
)) {
254 if (!Result
.Val
.isInt())
255 return false; // Cannot check number of iterations, return false to be
257 // The following assumes values go from 0 to Val in increments of 1.
258 return Result
.Val
.getInt() > MaxLoopIterations
;
260 // Cannot evaluate Expression as an r-value, so cannot check number of
265 void UnrollLoopsCheck::storeOptions(ClangTidyOptions::OptionMap
&Opts
) {
266 Options
.store(Opts
, "MaxLoopIterations", MaxLoopIterations
);
269 } // namespace clang::tidy::altera