1 //===--- UseConstraintsCheck.cpp - clang-tidy -----------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "UseConstraintsCheck.h"
10 #include "clang/AST/ASTContext.h"
11 #include "clang/ASTMatchers/ASTMatchFinder.h"
12 #include "clang/Lex/Lexer.h"
14 #include "../utils/LexerUtils.h"
19 using namespace clang::ast_matchers
;
21 namespace clang::tidy::modernize
{
24 TemplateSpecializationTypeLoc Loc
;
29 AST_MATCHER(FunctionDecl
, hasOtherDeclarations
) {
30 auto It
= Node
.redecls_begin();
31 auto EndIt
= Node
.redecls_end();
41 void UseConstraintsCheck::registerMatchers(MatchFinder
*Finder
) {
44 // Skip external libraries included as system headers
45 unless(isExpansionInSystemHeader()),
46 has(functionDecl(unless(hasOtherDeclarations()), isDefinition(),
47 hasReturnTypeLoc(typeLoc().bind("return")))
49 .bind("functionTemplate"),
53 static std::optional
<TemplateSpecializationTypeLoc
>
54 matchEnableIfSpecializationImplTypename(TypeLoc TheType
) {
55 if (const auto Dep
= TheType
.getAs
<DependentNameTypeLoc
>()) {
56 const IdentifierInfo
*Identifier
= Dep
.getTypePtr()->getIdentifier();
57 if (!Identifier
|| Identifier
->getName() != "type" ||
58 Dep
.getTypePtr()->getKeyword() != ElaboratedTypeKeyword::Typename
) {
61 TheType
= Dep
.getQualifierLoc().getTypeLoc();
66 if (const auto SpecializationLoc
=
67 TheType
.getAs
<TemplateSpecializationTypeLoc
>()) {
69 const auto *Specialization
=
70 dyn_cast
<TemplateSpecializationType
>(SpecializationLoc
.getTypePtr());
74 const TemplateDecl
*TD
=
75 Specialization
->getTemplateName().getAsTemplateDecl();
76 if (!TD
|| TD
->getName() != "enable_if")
79 int NumArgs
= SpecializationLoc
.getNumArgs();
80 if (NumArgs
!= 1 && NumArgs
!= 2)
83 return SpecializationLoc
;
88 static std::optional
<TemplateSpecializationTypeLoc
>
89 matchEnableIfSpecializationImplTrait(TypeLoc TheType
) {
90 if (const auto Elaborated
= TheType
.getAs
<ElaboratedTypeLoc
>())
91 TheType
= Elaborated
.getNamedTypeLoc();
93 if (const auto SpecializationLoc
=
94 TheType
.getAs
<TemplateSpecializationTypeLoc
>()) {
96 const auto *Specialization
=
97 dyn_cast
<TemplateSpecializationType
>(SpecializationLoc
.getTypePtr());
101 const TemplateDecl
*TD
=
102 Specialization
->getTemplateName().getAsTemplateDecl();
103 if (!TD
|| TD
->getName() != "enable_if_t")
106 if (!Specialization
->isTypeAlias())
109 if (const auto *AliasedType
=
110 dyn_cast
<DependentNameType
>(Specialization
->getAliasedType())) {
111 if (AliasedType
->getIdentifier()->getName() != "type" ||
112 AliasedType
->getKeyword() != ElaboratedTypeKeyword::Typename
) {
118 int NumArgs
= SpecializationLoc
.getNumArgs();
119 if (NumArgs
!= 1 && NumArgs
!= 2)
122 return SpecializationLoc
;
127 static std::optional
<TemplateSpecializationTypeLoc
>
128 matchEnableIfSpecializationImpl(TypeLoc TheType
) {
129 if (auto EnableIf
= matchEnableIfSpecializationImplTypename(TheType
))
131 return matchEnableIfSpecializationImplTrait(TheType
);
134 static std::optional
<EnableIfData
>
135 matchEnableIfSpecialization(TypeLoc TheType
) {
136 if (const auto Pointer
= TheType
.getAs
<PointerTypeLoc
>())
137 TheType
= Pointer
.getPointeeLoc();
138 else if (const auto Reference
= TheType
.getAs
<ReferenceTypeLoc
>())
139 TheType
= Reference
.getPointeeLoc();
140 if (const auto Qualified
= TheType
.getAs
<QualifiedTypeLoc
>())
141 TheType
= Qualified
.getUnqualifiedLoc();
143 if (auto EnableIf
= matchEnableIfSpecializationImpl(TheType
))
144 return EnableIfData
{std::move(*EnableIf
), TheType
};
148 static std::pair
<std::optional
<EnableIfData
>, const Decl
*>
149 matchTrailingTemplateParam(const FunctionTemplateDecl
*FunctionTemplate
) {
150 // For non-type trailing param, match very specifically
151 // 'template <..., enable_if_type<Condition, Type> = Default>' where
152 // enable_if_type is 'enable_if' or 'enable_if_t'. E.g., 'template <typename
153 // T, enable_if_t<is_same_v<T, bool>, int*> = nullptr>
155 // Otherwise, match a trailing default type arg.
156 // E.g., 'template <typename T, typename = enable_if_t<is_same_v<T, bool>>>'
158 const TemplateParameterList
*TemplateParams
=
159 FunctionTemplate
->getTemplateParameters();
160 if (TemplateParams
->size() == 0)
163 const NamedDecl
*LastParam
=
164 TemplateParams
->getParam(TemplateParams
->size() - 1);
165 if (const auto *LastTemplateParam
=
166 dyn_cast
<NonTypeTemplateParmDecl
>(LastParam
)) {
168 if (!LastTemplateParam
->hasDefaultArgument() ||
169 !LastTemplateParam
->getName().empty())
172 return {matchEnableIfSpecialization(
173 LastTemplateParam
->getTypeSourceInfo()->getTypeLoc()),
176 if (const auto *LastTemplateParam
=
177 dyn_cast
<TemplateTypeParmDecl
>(LastParam
)) {
178 if (LastTemplateParam
->hasDefaultArgument() &&
179 LastTemplateParam
->getIdentifier() == nullptr) {
181 matchEnableIfSpecialization(LastTemplateParam
->getDefaultArgument()
190 template <typename T
>
191 static SourceLocation
getRAngleFileLoc(const SourceManager
&SM
,
193 // getFileLoc handles the case where the RAngle loc is part of a synthesized
194 // '>>', which ends up allocating a 'scratch space' buffer in the source
196 return SM
.getFileLoc(Element
.getRAngleLoc());
200 getConditionRange(ASTContext
&Context
,
201 const TemplateSpecializationTypeLoc
&EnableIf
) {
202 // TemplateArgumentLoc's SourceRange End is the location of the last token
203 // (per UnqualifiedId docs). E.g., in `enable_if<AAA && BBB>`, the End
204 // location will be the first 'B' in 'BBB'.
205 const LangOptions
&LangOpts
= Context
.getLangOpts();
206 const SourceManager
&SM
= Context
.getSourceManager();
207 if (EnableIf
.getNumArgs() > 1) {
208 TemplateArgumentLoc NextArg
= EnableIf
.getArgLoc(1);
209 return {EnableIf
.getLAngleLoc().getLocWithOffset(1),
210 utils::lexer::findPreviousTokenKind(
211 NextArg
.getSourceRange().getBegin(), SM
, LangOpts
, tok::comma
)};
214 return {EnableIf
.getLAngleLoc().getLocWithOffset(1),
215 getRAngleFileLoc(SM
, EnableIf
)};
218 static SourceRange
getTypeRange(ASTContext
&Context
,
219 const TemplateSpecializationTypeLoc
&EnableIf
) {
220 TemplateArgumentLoc Arg
= EnableIf
.getArgLoc(1);
221 const LangOptions
&LangOpts
= Context
.getLangOpts();
222 const SourceManager
&SM
= Context
.getSourceManager();
223 return {utils::lexer::findPreviousTokenKind(Arg
.getSourceRange().getBegin(),
224 SM
, LangOpts
, tok::comma
)
225 .getLocWithOffset(1),
226 getRAngleFileLoc(SM
, EnableIf
)};
229 // Returns the original source text of the second argument of a call to
230 // enable_if_t. E.g., in enable_if_t<Condition, TheType>, this function
231 // returns 'TheType'.
232 static std::optional
<StringRef
>
233 getTypeText(ASTContext
&Context
,
234 const TemplateSpecializationTypeLoc
&EnableIf
) {
235 if (EnableIf
.getNumArgs() > 1) {
236 const LangOptions
&LangOpts
= Context
.getLangOpts();
237 const SourceManager
&SM
= Context
.getSourceManager();
238 bool Invalid
= false;
239 StringRef Text
= Lexer::getSourceText(CharSourceRange::getCharRange(
240 getTypeRange(Context
, EnableIf
)),
241 SM
, LangOpts
, &Invalid
)
252 static std::optional
<SourceLocation
>
253 findInsertionForConstraint(const FunctionDecl
*Function
, ASTContext
&Context
) {
254 SourceManager
&SM
= Context
.getSourceManager();
255 const LangOptions
&LangOpts
= Context
.getLangOpts();
257 if (const auto *Constructor
= dyn_cast
<CXXConstructorDecl
>(Function
)) {
258 for (const CXXCtorInitializer
*Init
: Constructor
->inits()) {
259 if (Init
->getSourceOrder() == 0)
260 return utils::lexer::findPreviousTokenKind(Init
->getSourceLocation(),
261 SM
, LangOpts
, tok::colon
);
263 if (!Constructor
->inits().empty())
266 if (Function
->isDeleted()) {
267 SourceLocation FunctionEnd
= Function
->getSourceRange().getEnd();
268 return utils::lexer::findNextAnyTokenKind(FunctionEnd
, SM
, LangOpts
,
269 tok::equal
, tok::equal
);
271 const Stmt
*Body
= Function
->getBody();
275 return Body
->getBeginLoc();
278 bool isPrimaryExpression(const Expr
*Expression
) {
279 // This function is an incomplete approximation of checking whether
280 // an Expr is a primary expression. In particular, if this function
281 // returns true, the expression is a primary expression. The converse
282 // is not necessarily true.
284 if (const auto *Cast
= dyn_cast
<ImplicitCastExpr
>(Expression
))
285 Expression
= Cast
->getSubExprAsWritten();
286 if (isa
<ParenExpr
, DependentScopeDeclRefExpr
>(Expression
))
292 // Return the original source text of an enable_if_t condition, i.e., the
293 // first template argument). For example, in
294 // 'enable_if_t<FirstCondition || SecondCondition, AType>', the text
295 // the text 'FirstCondition || SecondCondition' is returned.
296 static std::optional
<std::string
> getConditionText(const Expr
*ConditionExpr
,
297 SourceRange ConditionRange
,
298 ASTContext
&Context
) {
299 SourceManager
&SM
= Context
.getSourceManager();
300 const LangOptions
&LangOpts
= Context
.getLangOpts();
302 SourceLocation PrevTokenLoc
= ConditionRange
.getEnd();
303 if (PrevTokenLoc
.isInvalid())
306 const bool SkipComments
= false;
308 std::tie(PrevToken
, PrevTokenLoc
) = utils::lexer::getPreviousTokenAndStart(
309 PrevTokenLoc
, SM
, LangOpts
, SkipComments
);
310 bool EndsWithDoubleSlash
=
311 PrevToken
.is(tok::comment
) &&
312 Lexer::getSourceText(CharSourceRange::getCharRange(
313 PrevTokenLoc
, PrevTokenLoc
.getLocWithOffset(2)),
314 SM
, LangOpts
) == "//";
316 bool Invalid
= false;
317 llvm::StringRef ConditionText
= Lexer::getSourceText(
318 CharSourceRange::getCharRange(ConditionRange
), SM
, LangOpts
, &Invalid
);
322 auto AddParens
= [&](llvm::StringRef Text
) -> std::string
{
323 if (isPrimaryExpression(ConditionExpr
))
325 return "(" + Text
.str() + ")";
328 if (EndsWithDoubleSlash
)
329 return AddParens(ConditionText
);
330 return AddParens(ConditionText
.trim());
333 // Handle functions that return enable_if_t, e.g.,
335 // enable_if_t<Condition, ReturnType> function();
337 // Return a vector of FixItHints if the code can be replaced with
338 // a C++20 requires clause. In the example above, returns FixItHints
341 // ReturnType function() requires Condition {}
342 static std::vector
<FixItHint
> handleReturnType(const FunctionDecl
*Function
,
343 const TypeLoc
&ReturnType
,
344 const EnableIfData
&EnableIf
,
345 ASTContext
&Context
) {
346 TemplateArgumentLoc EnableCondition
= EnableIf
.Loc
.getArgLoc(0);
348 SourceRange ConditionRange
= getConditionRange(Context
, EnableIf
.Loc
);
350 std::optional
<std::string
> ConditionText
= getConditionText(
351 EnableCondition
.getSourceExpression(), ConditionRange
, Context
);
355 std::optional
<StringRef
> TypeText
= getTypeText(Context
, EnableIf
.Loc
);
359 SmallVector
<const Expr
*, 3> ExistingConstraints
;
360 Function
->getAssociatedConstraints(ExistingConstraints
);
361 if (!ExistingConstraints
.empty()) {
362 // FIXME - Support adding new constraints to existing ones. Do we need to
363 // consider subsumption?
367 std::optional
<SourceLocation
> ConstraintInsertionLoc
=
368 findInsertionForConstraint(Function
, Context
);
369 if (!ConstraintInsertionLoc
)
372 std::vector
<FixItHint
> FixIts
;
373 FixIts
.push_back(FixItHint::CreateReplacement(
374 CharSourceRange::getTokenRange(EnableIf
.Outer
.getSourceRange()),
376 FixIts
.push_back(FixItHint::CreateInsertion(
377 *ConstraintInsertionLoc
, "requires " + *ConditionText
+ " "));
381 // Handle enable_if_t in a trailing template parameter, e.g.,
382 // template <..., enable_if_t<Condition, Type> = Type{}>
383 // ReturnType function();
385 // Return a vector of FixItHints if the code can be replaced with
386 // a C++20 requires clause. In the example above, returns FixItHints
389 // ReturnType function() requires Condition {}
390 static std::vector
<FixItHint
>
391 handleTrailingTemplateType(const FunctionTemplateDecl
*FunctionTemplate
,
392 const FunctionDecl
*Function
,
393 const Decl
*LastTemplateParam
,
394 const EnableIfData
&EnableIf
, ASTContext
&Context
) {
395 SourceManager
&SM
= Context
.getSourceManager();
396 const LangOptions
&LangOpts
= Context
.getLangOpts();
398 TemplateArgumentLoc EnableCondition
= EnableIf
.Loc
.getArgLoc(0);
400 SourceRange ConditionRange
= getConditionRange(Context
, EnableIf
.Loc
);
402 std::optional
<std::string
> ConditionText
= getConditionText(
403 EnableCondition
.getSourceExpression(), ConditionRange
, Context
);
407 SmallVector
<const Expr
*, 3> ExistingConstraints
;
408 Function
->getAssociatedConstraints(ExistingConstraints
);
409 if (!ExistingConstraints
.empty()) {
410 // FIXME - Support adding new constraints to existing ones. Do we need to
411 // consider subsumption?
415 SourceRange RemovalRange
;
416 const TemplateParameterList
*TemplateParams
=
417 FunctionTemplate
->getTemplateParameters();
418 if (!TemplateParams
|| TemplateParams
->size() == 0)
421 if (TemplateParams
->size() == 1) {
423 SourceRange(TemplateParams
->getTemplateLoc(),
424 getRAngleFileLoc(SM
, *TemplateParams
).getLocWithOffset(1));
427 SourceRange(utils::lexer::findPreviousTokenKind(
428 LastTemplateParam
->getSourceRange().getBegin(), SM
,
429 LangOpts
, tok::comma
),
430 getRAngleFileLoc(SM
, *TemplateParams
));
433 std::optional
<SourceLocation
> ConstraintInsertionLoc
=
434 findInsertionForConstraint(Function
, Context
);
435 if (!ConstraintInsertionLoc
)
438 std::vector
<FixItHint
> FixIts
;
440 FixItHint::CreateRemoval(CharSourceRange::getCharRange(RemovalRange
)));
441 FixIts
.push_back(FixItHint::CreateInsertion(
442 *ConstraintInsertionLoc
, "requires " + *ConditionText
+ " "));
446 void UseConstraintsCheck::check(const MatchFinder::MatchResult
&Result
) {
447 const auto *FunctionTemplate
=
448 Result
.Nodes
.getNodeAs
<FunctionTemplateDecl
>("functionTemplate");
449 const auto *Function
= Result
.Nodes
.getNodeAs
<FunctionDecl
>("function");
450 const auto *ReturnType
= Result
.Nodes
.getNodeAs
<TypeLoc
>("return");
451 if (!FunctionTemplate
|| !Function
|| !ReturnType
)
456 // Case 1. Return type of function
459 // enable_if_t<Condition, ReturnType>::type function() {}
461 // Case 2. Trailing template parameter
463 // template <..., enable_if_t<Condition, Type> = Type{}>
464 // ReturnType function() {}
468 // template <..., typename = enable_if_t<Condition, void>>
469 // ReturnType function() {}
472 // Case 1. Return type of function
473 if (auto EnableIf
= matchEnableIfSpecialization(*ReturnType
)) {
474 diag(ReturnType
->getBeginLoc(),
475 "use C++20 requires constraints instead of enable_if")
476 << handleReturnType(Function
, *ReturnType
, *EnableIf
, *Result
.Context
);
480 // Case 2. Trailing template parameter
481 if (auto [EnableIf
, LastTemplateParam
] =
482 matchTrailingTemplateParam(FunctionTemplate
);
483 EnableIf
&& LastTemplateParam
) {
484 diag(LastTemplateParam
->getSourceRange().getBegin(),
485 "use C++20 requires constraints instead of enable_if")
486 << handleTrailingTemplateType(FunctionTemplate
, Function
,
487 LastTemplateParam
, *EnableIf
,
493 } // namespace clang::tidy::modernize