Version 6.1.0.2, tag libreoffice-6.1.0.2
[LibreOffice.git] / compilerplugins / clang / flatten.cxx
blobdd116d7a4ea54f75542938c451d71c778f5eba80
1 /* -*- Mode: C++; tab-width: 4; indent-tabs-mode: nil; c-basic-offset: 4 -*- */
2 /*
3 * This file is part of the LibreOffice project.
5 * This Source Code Form is subject to the terms of the Mozilla Public
6 * License, v. 2.0. If a copy of the MPL was not distributed with this
7 * file, You can obtain one at http://mozilla.org/MPL/2.0/.
8 */
10 #include "plugin.hxx"
11 #include <cassert>
12 #include <string>
13 #include <iostream>
14 #include <fstream>
15 #include <set>
16 #include <stack>
18 /**
19 Look for places where we can flatten the control flow in a method by returning early.
21 namespace {
23 class Flatten:
24 public RecursiveASTVisitor<Flatten>, public loplugin::RewritePlugin
26 public:
27 explicit Flatten(loplugin::InstantiationData const & data):
28 RewritePlugin(data) {}
30 virtual void run() override
32 TraverseDecl(compiler.getASTContext().getTranslationUnitDecl());
35 bool TraverseIfStmt(IfStmt *);
36 bool TraverseCXXCatchStmt(CXXCatchStmt * );
37 bool TraverseCompoundStmt(CompoundStmt *);
38 bool TraverseFunctionDecl(FunctionDecl *);
39 bool TraverseCXXMethodDecl(CXXMethodDecl *);
40 bool TraverseCXXConstructorDecl(CXXConstructorDecl *);
41 bool TraverseCXXConversionDecl(CXXConversionDecl *);
42 bool TraverseCXXDestructorDecl(CXXDestructorDecl *);
43 bool VisitIfStmt(IfStmt const * );
44 private:
45 bool rewrite1(IfStmt const * );
46 bool rewrite2(IfStmt const * );
47 bool rewriteLargeIf(IfStmt const * );
48 SourceRange ignoreMacroExpansions(SourceRange range);
49 SourceRange extendOverComments(SourceRange range);
50 std::string getSourceAsString(SourceRange range);
51 std::string invertCondition(Expr const * condExpr, SourceRange conditionRange);
52 bool isLargeCompoundStmt(Stmt const *);
54 Stmt const * lastStmtInCompoundStmt = nullptr;
55 FunctionDecl const * functionDecl = nullptr;
56 CompoundStmt const * functionDeclBody = nullptr;
57 Stmt const * mElseBranch = nullptr;
60 static Stmt const * containsSingleThrowExpr(Stmt const * stmt)
62 if (auto compoundStmt = dyn_cast<CompoundStmt>(stmt)) {
63 if (compoundStmt->size() != 1)
64 return nullptr;
65 stmt = *compoundStmt->body_begin();
67 if (auto exprWithCleanups = dyn_cast<ExprWithCleanups>(stmt)) {
68 stmt = exprWithCleanups->getSubExpr();
70 return dyn_cast<CXXThrowExpr>(stmt);
73 static bool containsVarDecl(Stmt const * stmt)
75 if (auto compoundStmt = dyn_cast<CompoundStmt>(stmt)) {
76 for (auto i = compoundStmt->body_begin(); i != compoundStmt->body_end(); ++i) {
77 auto declStmt = dyn_cast<DeclStmt>(*i);
78 if (declStmt && isa<VarDecl>(*declStmt->decl_begin()))
79 return true;
81 return false;
83 auto declStmt = dyn_cast<DeclStmt>(stmt);
84 return declStmt && isa<VarDecl>(*declStmt->decl_begin());
87 bool Flatten::TraverseCXXCatchStmt(CXXCatchStmt* )
89 // ignore stuff inside catch statements, where doing a "if...else..throw" is more natural
90 return true;
93 bool Flatten::TraverseIfStmt(IfStmt * ifStmt)
95 if (!WalkUpFromIfStmt(ifStmt)) {
96 return false;
98 auto const saved = mElseBranch;
99 mElseBranch = ifStmt->getElse();
100 auto ret = true;
101 for (auto const sub: ifStmt->children()) {
102 if (!TraverseStmt(sub)) {
103 ret = false;
104 break;
107 mElseBranch = saved;
108 return ret;
111 bool Flatten::TraverseCompoundStmt(CompoundStmt * compoundStmt)
113 auto copy = lastStmtInCompoundStmt;
114 if (compoundStmt->size() > 0)
115 lastStmtInCompoundStmt = compoundStmt->body_back();
116 else
117 lastStmtInCompoundStmt = nullptr;
119 bool rv = RecursiveASTVisitor<Flatten>::TraverseCompoundStmt(compoundStmt);
121 lastStmtInCompoundStmt = copy;
122 return rv;
125 bool Flatten::TraverseFunctionDecl(FunctionDecl * fd)
127 auto copy1 = functionDeclBody;
128 auto copy2 = fd;
129 functionDeclBody = dyn_cast_or_null<CompoundStmt>(fd->getBody());
130 functionDecl = fd;
131 bool rv = RecursiveASTVisitor<Flatten>::TraverseFunctionDecl(fd);
132 functionDeclBody = copy1;
133 functionDecl = copy2;
134 return rv;
137 bool Flatten::TraverseCXXMethodDecl(CXXMethodDecl * fd)
139 auto copy1 = functionDeclBody;
140 auto copy2 = fd;
141 functionDeclBody = dyn_cast_or_null<CompoundStmt>(fd->getBody());
142 functionDecl = fd;
143 bool rv = RecursiveASTVisitor<Flatten>::TraverseCXXMethodDecl(fd);
144 functionDeclBody = copy1;
145 functionDecl = copy2;
146 return rv;
149 bool Flatten::TraverseCXXConstructorDecl(CXXConstructorDecl * fd)
151 auto copy1 = functionDeclBody;
152 auto copy2 = fd;
153 functionDeclBody = dyn_cast_or_null<CompoundStmt>(fd->getBody());
154 functionDecl = fd;
155 bool rv = RecursiveASTVisitor<Flatten>::TraverseCXXConstructorDecl(fd);
156 functionDeclBody = copy1;
157 functionDecl = copy2;
158 return rv;
161 bool Flatten::TraverseCXXConversionDecl(CXXConversionDecl * fd)
163 auto copy1 = functionDeclBody;
164 auto copy2 = fd;
165 functionDeclBody = dyn_cast_or_null<CompoundStmt>(fd->getBody());
166 functionDecl = fd;
167 bool rv = RecursiveASTVisitor<Flatten>::TraverseCXXConversionDecl(fd);
168 functionDeclBody = copy1;
169 functionDecl = copy2;
170 return rv;
173 bool Flatten::TraverseCXXDestructorDecl(CXXDestructorDecl * fd)
175 auto copy1 = functionDeclBody;
176 auto copy2 = fd;
177 functionDeclBody = dyn_cast_or_null<CompoundStmt>(fd->getBody());
178 functionDecl = fd;
179 bool rv = RecursiveASTVisitor<Flatten>::TraverseCXXDestructorDecl(fd);
180 functionDeclBody = copy1;
181 functionDecl = copy2;
182 return rv;
186 bool Flatten::VisitIfStmt(IfStmt const * ifStmt)
188 if (ignoreLocation(ifStmt))
189 return true;
191 // ignore if we are part of an if/then/else/if chain
192 if (ifStmt == mElseBranch || (ifStmt->getElse() && isa<IfStmt>(ifStmt->getElse())))
193 return true;
195 // look for a large if(){} block at the end of a function
196 if (!ifStmt->getElse()
197 && (functionDecl->getReturnType().isNull() || functionDecl->getReturnType()->isVoidType())
198 && functionDeclBody && functionDeclBody->size()
199 && functionDeclBody->body_back() == ifStmt
200 && isLargeCompoundStmt(ifStmt->getThen()))
202 if (!rewriteLargeIf(ifStmt))
204 report(
205 DiagnosticsEngine::Warning,
206 "large if statement at end of function, rather invert the condition and exit early, and flatten the function",
207 ifStmt->getLocStart())
208 << ifStmt->getSourceRange();
210 return true;
213 if (!ifStmt->getElse())
214 return true;
216 auto const thenThrowExpr = containsSingleThrowExpr(ifStmt->getThen());
217 auto const elseThrowExpr = containsSingleThrowExpr(ifStmt->getElse());
218 // If neither contains a throw, nothing to do; if both contain throws, no
219 // improvement:
220 if ((thenThrowExpr == nullptr) == (elseThrowExpr == nullptr)) {
221 return true;
224 if (containsPreprocessingConditionalInclusion(ifStmt->getSourceRange())) {
225 return true;
228 if (elseThrowExpr)
230 // if the "if" statement is not the last statement in its block, and it contains
231 // var decls in its then block, we cannot de-indent the then block without
232 // extending the lifetime of some variables, which may be problematic
233 if (ifStmt != lastStmtInCompoundStmt && containsVarDecl(ifStmt->getThen()))
234 return true;
236 if (!rewrite1(ifStmt))
238 report(
239 DiagnosticsEngine::Warning,
240 "unconditional throw in else branch, rather invert the condition, throw early, and flatten the normal case",
241 elseThrowExpr->getLocStart())
242 << elseThrowExpr->getSourceRange();
243 report(
244 DiagnosticsEngine::Note,
245 "if condition here",
246 ifStmt->getLocStart())
247 << ifStmt->getSourceRange();
250 if (thenThrowExpr)
252 // if the "if" statement is not the last statement in it's block, and it contains
253 // var decls in it's else block, we cannot de-indent the else block without
254 // extending the lifetime of some variables, which may be problematic
255 if (ifStmt != lastStmtInCompoundStmt && containsVarDecl(ifStmt->getElse()))
256 return true;
258 if (!rewrite2(ifStmt))
260 report(
261 DiagnosticsEngine::Warning,
262 "unconditional throw in then branch, just flatten the else",
263 thenThrowExpr->getLocStart())
264 << thenThrowExpr->getSourceRange();
267 return true;
270 static std::string stripOpenAndCloseBrace(std::string s);
271 static std::string stripTrailingEmptyLines(std::string s);
272 static std::string deindent(std::string const & s);
273 static std::vector<std::string> split(std::string s);
274 static bool startswith(std::string const & rStr, char const * pSubStr);
275 static int countLeadingSpaces(std::string const &);
276 static std::string padSpace(int iNoSpaces);
277 static void replace(std::string & s, std::string const & from, std::string const & to);
279 bool Flatten::rewrite1(IfStmt const * ifStmt)
281 if (!rewriter)
282 return false;
284 auto conditionRange = ignoreMacroExpansions(ifStmt->getCond()->getSourceRange());
285 if (!conditionRange.isValid()) {
286 return false;
288 auto thenRange = ignoreMacroExpansions(ifStmt->getThen()->getSourceRange());
289 if (!thenRange.isValid()) {
290 return false;
292 auto elseRange = ignoreMacroExpansions(ifStmt->getElse()->getSourceRange());
293 if (!elseRange.isValid()) {
294 return false;
296 SourceRange elseKeywordRange = ifStmt->getElseLoc();
298 thenRange = extendOverComments(thenRange);
299 elseRange = extendOverComments(elseRange);
300 elseKeywordRange = extendOverComments(elseKeywordRange);
302 // in adjusting the formatting I assume that "{" starts on a new line
304 std::string conditionString = invertCondition(ifStmt->getCond(), conditionRange);
306 std::string thenString = getSourceAsString(thenRange);
307 if (auto compoundStmt = dyn_cast<CompoundStmt>(ifStmt->getThen())) {
308 if (compoundStmt->getLBracLoc().isValid()) {
309 thenString = stripOpenAndCloseBrace(thenString);
312 thenString = deindent(thenString);
314 std::string elseString = getSourceAsString(elseRange);
316 if (!replaceText(elseRange, thenString)) {
317 return false;
319 if (!removeText(elseKeywordRange)) {
320 return false;
322 if (!replaceText(thenRange, elseString)) {
323 return false;
325 if (!replaceText(conditionRange, conditionString)) {
326 return false;
329 return true;
332 bool Flatten::rewrite2(IfStmt const * ifStmt)
334 if (!rewriter)
335 return false;
337 auto conditionRange = ignoreMacroExpansions(ifStmt->getCond()->getSourceRange());
338 if (!conditionRange.isValid()) {
339 return false;
341 auto thenRange = ignoreMacroExpansions(ifStmt->getThen()->getSourceRange());
342 if (!thenRange.isValid()) {
343 return false;
345 auto elseRange = ignoreMacroExpansions(ifStmt->getElse()->getSourceRange());
346 if (!elseRange.isValid()) {
347 return false;
349 SourceRange elseKeywordRange = ifStmt->getElseLoc();
351 elseRange = extendOverComments(elseRange);
352 elseKeywordRange = extendOverComments(elseKeywordRange);
354 // in adjusting the formatting I assume that "{" starts on a new line
356 std::string elseString = getSourceAsString(elseRange);
357 if (auto compoundStmt = dyn_cast<CompoundStmt>(ifStmt->getElse())) {
358 if (compoundStmt->getLBracLoc().isValid()) {
359 elseString = stripOpenAndCloseBrace(elseString);
362 elseString = deindent(elseString);
364 if (!replaceText(elseRange, elseString)) {
365 return false;
367 if (!removeText(elseKeywordRange)) {
368 return false;
371 return true;
374 bool Flatten::rewriteLargeIf(IfStmt const * ifStmt)
376 if (!rewriter)
377 return false;
379 auto conditionRange = ignoreMacroExpansions(ifStmt->getCond()->getSourceRange());
380 if (!conditionRange.isValid()) {
381 return false;
383 auto thenRange = ignoreMacroExpansions(ifStmt->getThen()->getSourceRange());
384 if (!thenRange.isValid()) {
385 return false;
388 thenRange = extendOverComments(thenRange);
390 // in adjusting the formatting I assume that "{" starts on a new line
392 std::string conditionString = invertCondition(ifStmt->getCond(), conditionRange);
394 std::string thenString = getSourceAsString(thenRange);
395 if (auto compoundStmt = dyn_cast<CompoundStmt>(ifStmt->getThen())) {
396 if (compoundStmt->getLBracLoc().isValid()) {
397 thenString = stripOpenAndCloseBrace(thenString);
400 int iNoSpaces = countLeadingSpaces(thenString);
401 thenString = padSpace(iNoSpaces) + "return;\n\n" + deindent(thenString);
402 thenString = stripTrailingEmptyLines(thenString);
404 if (!replaceText(thenRange, thenString)) {
405 return false;
407 if (!replaceText(conditionRange, conditionString)) {
408 return false;
411 return true;
414 std::string Flatten::invertCondition(Expr const * condExpr, SourceRange conditionRange)
416 std::string s = getSourceAsString(conditionRange);
418 condExpr = condExpr->IgnoreImpCasts();
420 if (auto exprWithCleanups = dyn_cast<ExprWithCleanups>(condExpr))
421 condExpr = exprWithCleanups->getSubExpr()->IgnoreImpCasts();
423 // an if statement will automatically invoke a bool-conversion method
424 if (auto memberCallExpr = dyn_cast<CXXMemberCallExpr>(condExpr))
426 if (isa<CXXConversionDecl>(memberCallExpr->getMethodDecl()))
427 condExpr = memberCallExpr->getImplicitObjectArgument()->IgnoreImpCasts();
430 if (auto unaryOp = dyn_cast<UnaryOperator>(condExpr))
432 if (unaryOp->getOpcode() != UO_LNot)
433 return "!(" + s + ")";
434 auto i = s.find("!");
435 assert (i != std::string::npos);
436 s = s.substr(i+1);
438 else if (auto binaryOp = dyn_cast<BinaryOperator>(condExpr))
440 switch (binaryOp->getOpcode())
442 case BO_LT: replace(s, "<", ">="); break;
443 case BO_GT: replace(s, ">", "<="); break;
444 case BO_LE: replace(s, "<=", ">"); break;
445 case BO_GE: replace(s, ">=", "<"); break;
446 case BO_EQ: replace(s, "==", "!="); break;
447 case BO_NE: replace(s, "!=", "=="); break;
448 default:
449 s = "!(" + s + ")";
452 else if (auto opCallExpr = dyn_cast<CXXOperatorCallExpr>(condExpr))
454 switch (opCallExpr->getOperator())
456 case OO_Less: replace(s, "<", ">="); break;
457 case OO_Greater: replace(s, ">", "<="); break;
458 case OO_LessEqual: replace(s, "<=", ">"); break;
459 case OO_GreaterEqual: replace(s, ">=", "<"); break;
460 case OO_EqualEqual: replace(s, "==", "!="); break;
461 case OO_ExclaimEqual: replace(s, "!=", "=="); break;
462 default:
463 s = "!(" + s + ")";
466 else if (isa<DeclRefExpr>(condExpr) || isa<CallExpr>(condExpr) || isa<MemberExpr>(condExpr))
467 s = "!" + s;
468 else
469 s = "!(" + s + ")";
470 return s;
473 std::string stripOpenAndCloseBrace(std::string s)
475 size_t i = s.find("{");
476 if (i == std::string::npos)
477 throw "did not find {";
479 ++i;
480 // strip to line end
481 while (s[i] == ' ')
482 ++i;
483 if (s[i] == '\n')
484 ++i;
485 s = s.substr(i);
487 i = s.rfind("}");
488 if (i == std::string::npos)
489 throw "did not find }";
490 --i;
491 while (s[i] == ' ')
492 --i;
493 s = s.substr(0,i);
494 return s;
497 std::string deindent(std::string const & s)
499 std::vector<std::string> lines = split(s);
500 std::string rv;
501 for (auto s : lines) {
502 if (startswith(s, " "))
503 rv += s.substr(4);
504 else
505 rv += s;
506 rv += "\n";
508 return rv;
511 std::vector<std::string> split(std::string s)
513 if (s.back() == '\n')
514 s = s.substr(0, s.size()-1);
515 size_t next = -1;
516 std::vector<std::string> rv;
519 size_t current = next + 1;
520 next = s.find_first_of( "\n", current );
521 rv.push_back(s.substr( current, next - current ));
523 while (next != std::string::npos);
524 return rv;
527 bool startswith(std::string const & rStr, char const * pSubStr)
529 return rStr.compare(0, strlen(pSubStr), pSubStr) == 0;
532 int countLeadingSpaces(std::string const & s)
534 int i = 0;
535 while (i < (int)s.length() && s[i] == ' ')
536 i++;
537 return i;
540 std::string padSpace(int iNoSpaces)
542 std::string s;
543 for (int i = 0; i < iNoSpaces; ++i)
544 s += " ";
545 return s;
548 std::string stripTrailingEmptyLines(std::string s)
550 while (s.back() == '\n')
551 s.resize(s.length() - 1);
552 return s;
555 void replace(std::string & s, std::string const & from, std::string const & to)
557 auto i = s.find(from);
558 assert (i != std::string::npos);
559 s.replace(i, from.length(), to);
560 // just in case we have something really weird, like the operator token is also present in the rest of the condition somehow
561 assert (s.find(from) == std::string::npos);
564 SourceRange Flatten::ignoreMacroExpansions(SourceRange range) {
565 while (compiler.getSourceManager().isMacroArgExpansion(range.getBegin())) {
566 range.setBegin(
567 compiler.getSourceManager().getImmediateMacroCallerLoc(
568 range.getBegin()));
570 if (range.getBegin().isMacroID()) {
571 SourceLocation loc;
572 if (Lexer::isAtStartOfMacroExpansion(
573 range.getBegin(), compiler.getSourceManager(),
574 compiler.getLangOpts(), &loc))
576 range.setBegin(loc);
579 while (compiler.getSourceManager().isMacroArgExpansion(range.getEnd())) {
580 range.setEnd(
581 compiler.getSourceManager().getImmediateMacroCallerLoc(
582 range.getEnd()));
584 if (range.getEnd().isMacroID()) {
585 SourceLocation loc;
586 if (Lexer::isAtEndOfMacroExpansion(
587 range.getEnd(), compiler.getSourceManager(),
588 compiler.getLangOpts(), &loc))
590 range.setEnd(loc);
593 return range.getBegin().isMacroID() || range.getEnd().isMacroID()
594 ? SourceRange() : range;
598 * Extend the SourceRange to include any leading and trailing whitespace, and any comments.
600 SourceRange Flatten::extendOverComments(SourceRange range)
602 SourceManager& SM = compiler.getSourceManager();
603 SourceLocation startLoc = range.getBegin();
604 SourceLocation endLoc = range.getEnd();
605 char const *p1 = SM.getCharacterData( startLoc );
606 char const *p2 = SM.getCharacterData( endLoc );
608 // scan backwards from the beginning to include any spaces on that line
609 while (*(p1-1) == ' ')
610 --p1;
611 startLoc = startLoc.getLocWithOffset(p1 - SM.getCharacterData( startLoc ));
613 // look for trailing ";"
614 while (*(p2+1) == ';')
615 ++p2;
616 // look for trailing " "
617 while (*(p2+1) == ' ')
618 ++p2;
619 // look for single line comments attached to the end of the statement
620 if (*(p2+1) == '/' && *(p2+2) == '/')
622 p2 += 2;
623 while (*(p2+1) && *(p2+1) != '\n')
624 ++p2;
625 if (*(p2+1) == '\n')
626 ++p2;
628 else
630 // make the source code we extract include any trailing "\n"
631 if (*(p2+1) == '\n')
632 ++p2;
634 endLoc = endLoc.getLocWithOffset(p2 - SM.getCharacterData( endLoc ));
636 return SourceRange(startLoc, endLoc);
639 std::string Flatten::getSourceAsString(SourceRange range)
641 SourceManager& SM = compiler.getSourceManager();
642 SourceLocation startLoc = range.getBegin();
643 SourceLocation endLoc = range.getEnd();
644 char const *p1 = SM.getCharacterData( startLoc );
645 char const *p2 = SM.getCharacterData( endLoc );
646 p2 += Lexer::MeasureTokenLength( endLoc, SM, compiler.getLangOpts());
647 return std::string( p1, p2 - p1);
650 bool Flatten::isLargeCompoundStmt(Stmt const * stmt)
652 auto stmtRange = stmt->getSourceRange();
653 std::string s = getSourceAsString(stmtRange);
654 return std::count(s.begin(), s.end(), '\n') > 10;
657 loplugin::Plugin::Registration< Flatten > X("flatten", false);
661 /* vim:set shiftwidth=4 softtabstop=4 expandtab: */