1 //===--- Quality.cpp ---------------------------------------------*- C++-*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
11 #include "ASTSignals.h"
12 #include "FileDistance.h"
13 #include "SourceCode.h"
14 #include "index/Symbol.h"
15 #include "clang/AST/ASTContext.h"
16 #include "clang/AST/Decl.h"
17 #include "clang/AST/DeclCXX.h"
18 #include "clang/AST/DeclTemplate.h"
19 #include "clang/AST/DeclVisitor.h"
20 #include "clang/Basic/SourceManager.h"
21 #include "clang/Sema/CodeCompleteConsumer.h"
22 #include "llvm/ADT/StringRef.h"
23 #include "llvm/Support/Casting.h"
24 #include "llvm/Support/FormatVariadic.h"
25 #include "llvm/Support/MathExtras.h"
26 #include "llvm/Support/raw_ostream.h"
34 static bool hasDeclInMainFile(const Decl
&D
) {
35 auto &SourceMgr
= D
.getASTContext().getSourceManager();
36 for (auto *Redecl
: D
.redecls()) {
37 if (isInsideMainFile(Redecl
->getLocation(), SourceMgr
))
43 static bool hasUsingDeclInMainFile(const CodeCompletionResult
&R
) {
44 const auto &Context
= R
.Declaration
->getASTContext();
45 const auto &SourceMgr
= Context
.getSourceManager();
47 if (isInsideMainFile(R
.ShadowDecl
->getLocation(), SourceMgr
))
53 static SymbolQualitySignals::SymbolCategory
categorize(const NamedDecl
&ND
) {
54 if (const auto *FD
= dyn_cast
<FunctionDecl
>(&ND
)) {
55 if (FD
->isOverloadedOperator())
56 return SymbolQualitySignals::Operator
;
59 : public ConstDeclVisitor
<Switch
, SymbolQualitySignals::SymbolCategory
> {
61 #define MAP(DeclType, Category) \
62 SymbolQualitySignals::SymbolCategory Visit##DeclType(const DeclType *) { \
63 return SymbolQualitySignals::Category; \
65 MAP(NamespaceDecl
, Namespace
);
66 MAP(NamespaceAliasDecl
, Namespace
);
68 MAP(TypeAliasTemplateDecl
, Type
);
69 MAP(ClassTemplateDecl
, Type
);
70 MAP(CXXConstructorDecl
, Constructor
);
71 MAP(CXXDestructorDecl
, Destructor
);
72 MAP(ValueDecl
, Variable
);
73 MAP(VarTemplateDecl
, Variable
);
74 MAP(FunctionDecl
, Function
);
75 MAP(FunctionTemplateDecl
, Function
);
79 return Switch().Visit(&ND
);
82 static SymbolQualitySignals::SymbolCategory
83 categorize(const CodeCompletionResult
&R
) {
85 return categorize(*R
.Declaration
);
86 if (R
.Kind
== CodeCompletionResult::RK_Macro
)
87 return SymbolQualitySignals::Macro
;
88 // Everything else is a keyword or a pattern. Patterns are mostly keywords
89 // too, except a few which we recognize by cursor kind.
90 switch (R
.CursorKind
) {
91 case CXCursor_CXXMethod
:
92 return SymbolQualitySignals::Function
;
93 case CXCursor_ModuleImportDecl
:
94 return SymbolQualitySignals::Namespace
;
95 case CXCursor_MacroDefinition
:
96 return SymbolQualitySignals::Macro
;
97 case CXCursor_TypeRef
:
98 return SymbolQualitySignals::Type
;
99 case CXCursor_MemberRef
:
100 return SymbolQualitySignals::Variable
;
101 case CXCursor_Constructor
:
102 return SymbolQualitySignals::Constructor
;
104 return SymbolQualitySignals::Keyword
;
108 static SymbolQualitySignals::SymbolCategory
109 categorize(const index::SymbolInfo
&D
) {
111 case index::SymbolKind::Namespace
:
112 case index::SymbolKind::NamespaceAlias
:
113 return SymbolQualitySignals::Namespace
;
114 case index::SymbolKind::Macro
:
115 return SymbolQualitySignals::Macro
;
116 case index::SymbolKind::Enum
:
117 case index::SymbolKind::Struct
:
118 case index::SymbolKind::Class
:
119 case index::SymbolKind::Protocol
:
120 case index::SymbolKind::Extension
:
121 case index::SymbolKind::Union
:
122 case index::SymbolKind::TypeAlias
:
123 case index::SymbolKind::TemplateTypeParm
:
124 case index::SymbolKind::TemplateTemplateParm
:
125 case index::SymbolKind::Concept
:
126 return SymbolQualitySignals::Type
;
127 case index::SymbolKind::Function
:
128 case index::SymbolKind::ClassMethod
:
129 case index::SymbolKind::InstanceMethod
:
130 case index::SymbolKind::StaticMethod
:
131 case index::SymbolKind::InstanceProperty
:
132 case index::SymbolKind::ClassProperty
:
133 case index::SymbolKind::StaticProperty
:
134 case index::SymbolKind::ConversionFunction
:
135 return SymbolQualitySignals::Function
;
136 case index::SymbolKind::Destructor
:
137 return SymbolQualitySignals::Destructor
;
138 case index::SymbolKind::Constructor
:
139 return SymbolQualitySignals::Constructor
;
140 case index::SymbolKind::Variable
:
141 case index::SymbolKind::Field
:
142 case index::SymbolKind::EnumConstant
:
143 case index::SymbolKind::Parameter
:
144 case index::SymbolKind::NonTypeTemplateParm
:
145 return SymbolQualitySignals::Variable
;
146 case index::SymbolKind::Using
:
147 case index::SymbolKind::Module
:
148 case index::SymbolKind::Unknown
:
149 return SymbolQualitySignals::Unknown
;
151 llvm_unreachable("Unknown index::SymbolKind");
154 static bool isInstanceMember(const NamedDecl
*ND
) {
157 if (const auto *TP
= dyn_cast
<FunctionTemplateDecl
>(ND
))
158 ND
= TP
->TemplateDecl::getTemplatedDecl();
159 if (const auto *CM
= dyn_cast
<CXXMethodDecl
>(ND
))
160 return !CM
->isStatic();
161 return isa
<FieldDecl
>(ND
); // Note that static fields are VarDecl.
164 static bool isInstanceMember(const index::SymbolInfo
&D
) {
166 case index::SymbolKind::InstanceMethod
:
167 case index::SymbolKind::InstanceProperty
:
168 case index::SymbolKind::Field
:
175 void SymbolQualitySignals::merge(const CodeCompletionResult
&SemaCCResult
) {
176 Deprecated
|= (SemaCCResult
.Availability
== CXAvailability_Deprecated
);
177 Category
= categorize(SemaCCResult
);
179 if (SemaCCResult
.Declaration
) {
180 ImplementationDetail
|= isImplementationDetail(SemaCCResult
.Declaration
);
181 if (auto *ID
= SemaCCResult
.Declaration
->getIdentifier())
182 ReservedName
= ReservedName
|| isReservedName(ID
->getName());
183 } else if (SemaCCResult
.Kind
== CodeCompletionResult::RK_Macro
)
185 ReservedName
|| isReservedName(SemaCCResult
.Macro
->getName());
188 void SymbolQualitySignals::merge(const Symbol
&IndexResult
) {
189 Deprecated
|= (IndexResult
.Flags
& Symbol::Deprecated
);
190 ImplementationDetail
|= (IndexResult
.Flags
& Symbol::ImplementationDetail
);
191 References
= std::max(IndexResult
.References
, References
);
192 Category
= categorize(IndexResult
.SymInfo
);
193 ReservedName
= ReservedName
|| isReservedName(IndexResult
.Name
);
196 float SymbolQualitySignals::evaluateHeuristics() const {
199 // This avoids a sharp gradient for tail symbols, and also neatly avoids the
200 // question of whether 0 references means a bad symbol or missing data.
201 if (References
>= 10) {
202 // Use a sigmoid style boosting function, which flats out nicely for large
203 // numbers (e.g. 2.58 for 1M references).
204 // The following boosting function is equivalent to:
207 // boost = f * sigmoid(m * std::log(References)) - 0.5 * f + 0.59
208 // Sample data points: (10, 1.00), (100, 1.41), (1000, 1.82),
209 // (10K, 2.21), (100K, 2.58), (1M, 2.94)
210 float S
= std::pow(References
, -0.06);
211 Score
*= 6.0 * (1 - S
) / (1 + S
) + 0.59;
218 if (ImplementationDetail
)
222 case Keyword
: // Often relevant, but misses most signals.
223 Score
*= 4; // FIXME: important keywords should have specific boosts.
238 case Constructor
: // No boost constructors so they are after class types.
246 llvm::raw_ostream
&operator<<(llvm::raw_ostream
&OS
,
247 const SymbolQualitySignals
&S
) {
248 OS
<< llvm::formatv("=== Symbol quality: {0}\n", S
.evaluateHeuristics());
249 OS
<< llvm::formatv("\tReferences: {0}\n", S
.References
);
250 OS
<< llvm::formatv("\tDeprecated: {0}\n", S
.Deprecated
);
251 OS
<< llvm::formatv("\tReserved name: {0}\n", S
.ReservedName
);
252 OS
<< llvm::formatv("\tImplementation detail: {0}\n", S
.ImplementationDetail
);
253 OS
<< llvm::formatv("\tCategory: {0}\n", static_cast<int>(S
.Category
));
257 static SymbolRelevanceSignals::AccessibleScope
258 computeScope(const NamedDecl
*D
) {
259 // Injected "Foo" within the class "Foo" has file scope, not class scope.
260 const DeclContext
*DC
= D
->getDeclContext();
261 if (auto *R
= dyn_cast_or_null
<RecordDecl
>(D
))
262 if (R
->isInjectedClassName())
263 DC
= DC
->getParent();
264 // Class constructor should have the same scope as the class.
265 if (isa
<CXXConstructorDecl
>(D
))
266 DC
= DC
->getParent();
267 bool InClass
= false;
268 for (; !DC
->isFileContext(); DC
= DC
->getParent()) {
269 if (DC
->isFunctionOrMethod())
270 return SymbolRelevanceSignals::FunctionScope
;
271 InClass
= InClass
|| DC
->isRecord();
274 return SymbolRelevanceSignals::ClassScope
;
275 // ExternalLinkage threshold could be tweaked, e.g. module-visible as global.
276 // Avoid caching linkage if it may change after enclosing code completion.
277 if (hasUnstableLinkage(D
) || llvm::to_underlying(D
->getLinkageInternal()) <
278 llvm::to_underlying(Linkage::External
))
279 return SymbolRelevanceSignals::FileScope
;
280 return SymbolRelevanceSignals::GlobalScope
;
283 void SymbolRelevanceSignals::merge(const Symbol
&IndexResult
) {
284 SymbolURI
= IndexResult
.CanonicalDeclaration
.FileURI
;
285 SymbolScope
= IndexResult
.Scope
;
286 IsInstanceMember
|= isInstanceMember(IndexResult
.SymInfo
);
287 if (!(IndexResult
.Flags
& Symbol::VisibleOutsideFile
)) {
288 Scope
= AccessibleScope::FileScope
;
290 if (MainFileSignals
) {
292 std::max(MainFileRefs
,
293 MainFileSignals
->ReferencedSymbols
.lookup(IndexResult
.ID
));
295 std::max(ScopeRefsInFile
,
296 MainFileSignals
->RelatedNamespaces
.lookup(IndexResult
.Scope
));
300 void SymbolRelevanceSignals::computeASTSignals(
301 const CodeCompletionResult
&SemaResult
) {
302 if (!MainFileSignals
)
304 if ((SemaResult
.Kind
!= CodeCompletionResult::RK_Declaration
) &&
305 (SemaResult
.Kind
!= CodeCompletionResult::RK_Pattern
))
307 if (const NamedDecl
*ND
= SemaResult
.getDeclaration()) {
308 if (hasUnstableLinkage(ND
))
310 auto ID
= getSymbolID(ND
);
314 std::max(MainFileRefs
, MainFileSignals
->ReferencedSymbols
.lookup(ID
));
315 if (const auto *NSD
= dyn_cast
<NamespaceDecl
>(ND
->getDeclContext())) {
316 if (NSD
->isAnonymousNamespace())
318 std::string Scope
= printNamespaceScope(*NSD
);
320 ScopeRefsInFile
= std::max(
321 ScopeRefsInFile
, MainFileSignals
->RelatedNamespaces
.lookup(Scope
));
326 void SymbolRelevanceSignals::merge(const CodeCompletionResult
&SemaCCResult
) {
327 if (SemaCCResult
.Availability
== CXAvailability_NotAvailable
||
328 SemaCCResult
.Availability
== CXAvailability_NotAccessible
)
331 if (SemaCCResult
.Declaration
) {
332 SemaSaysInScope
= true;
333 // We boost things that have decls in the main file. We give a fixed score
334 // for all other declarations in sema as they are already included in the
336 float DeclProximity
= (hasDeclInMainFile(*SemaCCResult
.Declaration
) ||
337 hasUsingDeclInMainFile(SemaCCResult
))
340 SemaFileProximityScore
= std::max(DeclProximity
, SemaFileProximityScore
);
341 IsInstanceMember
|= isInstanceMember(SemaCCResult
.Declaration
);
342 InBaseClass
|= SemaCCResult
.InBaseClass
;
345 computeASTSignals(SemaCCResult
);
346 // Declarations are scoped, others (like macros) are assumed global.
347 if (SemaCCResult
.Declaration
)
348 Scope
= std::min(Scope
, computeScope(SemaCCResult
.Declaration
));
350 NeedsFixIts
= !SemaCCResult
.FixIts
.empty();
353 static float fileProximityScore(unsigned FileDistance
) {
355 // FileDistance = [0, 1, 2, 3, 4, .., FileDistance::Unreachable]
356 // Score = [1, 0.82, 0.67, 0.55, 0.45, .., 0]
357 if (FileDistance
== FileDistance::Unreachable
)
359 // Assume approximately default options are used for sensible scoring.
360 return std::exp(FileDistance
* -0.4f
/ FileDistanceOptions().UpCost
);
363 static float scopeProximityScore(unsigned ScopeDistance
) {
365 // ScopeDistance = [0, 1, 2, 3, 4, 5, 6, 7, .., FileDistance::Unreachable]
366 // Score = [2.0, 1.55, 1.2, 0.93, 0.72, 0.65, 0.65, 0.65, .., 0.6]
367 if (ScopeDistance
== FileDistance::Unreachable
)
369 return std::max(0.65, 2.0 * std::pow(0.6, ScopeDistance
/ 2.0));
372 static std::optional
<llvm::StringRef
>
373 wordMatching(llvm::StringRef Name
, const llvm::StringSet
<> *ContextWords
) {
375 for (const auto &Word
: ContextWords
->keys())
376 if (Name
.contains_insensitive(Word
))
381 SymbolRelevanceSignals::DerivedSignals
382 SymbolRelevanceSignals::calculateDerivedSignals() const {
383 DerivedSignals Derived
;
384 Derived
.NameMatchesContext
= wordMatching(Name
, ContextWords
).has_value();
385 Derived
.FileProximityDistance
= !FileProximityMatch
|| SymbolURI
.empty()
386 ? FileDistance::Unreachable
387 : FileProximityMatch
->distance(SymbolURI
);
388 if (ScopeProximityMatch
) {
389 // For global symbol, the distance is 0.
390 Derived
.ScopeProximityDistance
=
391 SymbolScope
? ScopeProximityMatch
->distance(*SymbolScope
) : 0;
396 float SymbolRelevanceSignals::evaluateHeuristics() const {
397 DerivedSignals Derived
= calculateDerivedSignals();
405 // File proximity scores are [0,1] and we translate them into a multiplier in
406 // the range from 1 to 3.
407 Score
*= 1 + 2 * std::max(fileProximityScore(Derived
.FileProximityDistance
),
408 SemaFileProximityScore
);
410 if (ScopeProximityMatch
)
411 // Use a constant scope boost for sema results, as scopes of sema results
412 // can be tricky (e.g. class/function scope). Set to the max boost as we
413 // don't load top-level symbols from the preamble and sema results are
414 // always in the accessible scope.
415 Score
*= SemaSaysInScope
417 : scopeProximityScore(Derived
.ScopeProximityDistance
);
419 if (Derived
.NameMatchesContext
)
422 // Symbols like local variables may only be referenced within their scope.
423 // Conversely if we're in that scope, it's likely we'll reference them.
424 if (Query
== CodeComplete
) {
425 // The narrower the scope where a symbol is visible, the more likely it is
426 // to be relevant when it is available.
441 // For non-completion queries, the wider the scope where a symbol is
442 // visible, the more likely it is to be relevant.
450 // TODO: Handle other scopes as we start to use them for index results.
455 if (TypeMatchesPreferred
)
458 // Penalize non-instance members when they are accessed via a class instance.
459 if (!IsInstanceMember
&&
460 (Context
== CodeCompletionContext::CCC_DotMemberAccess
||
461 Context
== CodeCompletionContext::CCC_ArrowMemberAccess
)) {
468 // Penalize for FixIts.
472 // Use a sigmoid style boosting function similar to `References`, which flats
473 // out nicely for large values. This avoids a sharp gradient for heavily
474 // referenced symbols. Use smaller gradient for ScopeRefsInFile since ideally
475 // MainFileRefs <= ScopeRefsInFile.
476 if (MainFileRefs
>= 2) {
477 // E.g.: (2, 1.12), (9, 2.0), (48, 3.0).
478 float S
= std::pow(MainFileRefs
, -0.11);
479 Score
*= 11.0 * (1 - S
) / (1 + S
) + 0.7;
481 if (ScopeRefsInFile
>= 2) {
482 // E.g.: (2, 1.04), (14, 2.0), (109, 3.0), (400, 3.6).
483 float S
= std::pow(ScopeRefsInFile
, -0.10);
484 Score
*= 10.0 * (1 - S
) / (1 + S
) + 0.7;
490 llvm::raw_ostream
&operator<<(llvm::raw_ostream
&OS
,
491 const SymbolRelevanceSignals
&S
) {
492 OS
<< llvm::formatv("=== Symbol relevance: {0}\n", S
.evaluateHeuristics());
493 OS
<< llvm::formatv("\tName: {0}\n", S
.Name
);
494 OS
<< llvm::formatv("\tName match: {0}\n", S
.NameMatch
);
497 "\tMatching context word: {0}\n",
498 wordMatching(S
.Name
, S
.ContextWords
).value_or("<none>"));
499 OS
<< llvm::formatv("\tForbidden: {0}\n", S
.Forbidden
);
500 OS
<< llvm::formatv("\tNeedsFixIts: {0}\n", S
.NeedsFixIts
);
501 OS
<< llvm::formatv("\tIsInstanceMember: {0}\n", S
.IsInstanceMember
);
502 OS
<< llvm::formatv("\tInBaseClass: {0}\n", S
.InBaseClass
);
503 OS
<< llvm::formatv("\tContext: {0}\n", getCompletionKindString(S
.Context
));
504 OS
<< llvm::formatv("\tQuery type: {0}\n", static_cast<int>(S
.Query
));
505 OS
<< llvm::formatv("\tScope: {0}\n", static_cast<int>(S
.Scope
));
507 OS
<< llvm::formatv("\tSymbol URI: {0}\n", S
.SymbolURI
);
508 OS
<< llvm::formatv("\tSymbol scope: {0}\n",
509 S
.SymbolScope
? *S
.SymbolScope
: "<None>");
511 SymbolRelevanceSignals::DerivedSignals Derived
= S
.calculateDerivedSignals();
512 if (S
.FileProximityMatch
) {
513 unsigned Score
= fileProximityScore(Derived
.FileProximityDistance
);
514 OS
<< llvm::formatv("\tIndex URI proximity: {0} (distance={1})\n", Score
,
515 Derived
.FileProximityDistance
);
517 OS
<< llvm::formatv("\tSema file proximity: {0}\n", S
.SemaFileProximityScore
);
519 OS
<< llvm::formatv("\tSema says in scope: {0}\n", S
.SemaSaysInScope
);
520 if (S
.ScopeProximityMatch
)
521 OS
<< llvm::formatv("\tIndex scope boost: {0}\n",
522 scopeProximityScore(Derived
.ScopeProximityDistance
));
525 "\tType matched preferred: {0} (Context type: {1}, Symbol type: {2}\n",
526 S
.TypeMatchesPreferred
, S
.HadContextType
, S
.HadSymbolType
);
531 float evaluateSymbolAndRelevance(float SymbolQuality
, float SymbolRelevance
) {
532 return SymbolQuality
* SymbolRelevance
;
535 // Produces an integer that sorts in the same order as F.
536 // That is: a < b <==> encodeFloat(a) < encodeFloat(b).
537 static uint32_t encodeFloat(float F
) {
538 static_assert(std::numeric_limits
<float>::is_iec559
);
539 constexpr uint32_t TopBit
= ~(~uint32_t{0} >> 1);
541 // Get the bits of the float. Endianness is the same as for integers.
542 uint32_t U
= llvm::bit_cast
<uint32_t>(F
);
543 // IEEE 754 floats compare like sign-magnitude integers.
544 if (U
& TopBit
) // Negative float.
545 return 0 - U
; // Map onto the low half of integers, order reversed.
546 return U
+ TopBit
; // Positive floats map onto the high half of integers.
549 std::string
sortText(float Score
, llvm::StringRef Name
) {
550 // We convert -Score to an integer, and hex-encode for readability.
551 // Example: [0.5, "foo"] -> "41000000foo"
553 llvm::raw_string_ostream
OS(S
);
554 llvm::write_hex(OS
, encodeFloat(-Score
), llvm::HexPrintStyle::Lower
,
555 /*Width=*/2 * sizeof(Score
));
560 llvm::raw_ostream
&operator<<(llvm::raw_ostream
&OS
,
561 const SignatureQualitySignals
&S
) {
562 OS
<< llvm::formatv("=== Signature Quality:\n");
563 OS
<< llvm::formatv("\tNumber of parameters: {0}\n", S
.NumberOfParameters
);
564 OS
<< llvm::formatv("\tNumber of optional parameters: {0}\n",
565 S
.NumberOfOptionalParameters
);
566 OS
<< llvm::formatv("\tKind: {0}\n", S
.Kind
);
570 } // namespace clangd