1 //===-- IncludeFixer.cpp - Include inserter based on sema callbacks -------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "IncludeFixer.h"
10 #include "clang/Format/Format.h"
11 #include "clang/Frontend/CompilerInstance.h"
12 #include "clang/Lex/HeaderSearch.h"
13 #include "clang/Lex/Preprocessor.h"
14 #include "clang/Parse/ParseAST.h"
15 #include "clang/Sema/Sema.h"
16 #include "llvm/Support/Debug.h"
17 #include "llvm/Support/raw_ostream.h"
19 #define DEBUG_TYPE "clang-include-fixer"
21 using namespace clang
;
24 namespace include_fixer
{
26 /// Manages the parse, gathers include suggestions.
27 class Action
: public clang::ASTFrontendAction
{
29 explicit Action(SymbolIndexManager
&SymbolIndexMgr
, bool MinimizeIncludePaths
)
30 : SemaSource(new IncludeFixerSemaSource(SymbolIndexMgr
,
32 /*GenerateDiagnostics=*/false)) {}
34 std::unique_ptr
<clang::ASTConsumer
>
35 CreateASTConsumer(clang::CompilerInstance
&Compiler
,
36 StringRef InFile
) override
{
37 SemaSource
->setFilePath(InFile
);
38 return std::make_unique
<clang::ASTConsumer
>();
41 void ExecuteAction() override
{
42 clang::CompilerInstance
*Compiler
= &getCompilerInstance();
43 assert(!Compiler
->hasSema() && "CI already has Sema");
45 // Set up our hooks into sema and parse the AST.
46 if (hasCodeCompletionSupport() &&
47 !Compiler
->getFrontendOpts().CodeCompletionAt
.FileName
.empty())
48 Compiler
->createCodeCompletionConsumer();
50 clang::CodeCompleteConsumer
*CompletionConsumer
= nullptr;
51 if (Compiler
->hasCodeCompletionConsumer())
52 CompletionConsumer
= &Compiler
->getCodeCompletionConsumer();
54 Compiler
->createSema(getTranslationUnitKind(), CompletionConsumer
);
55 SemaSource
->setCompilerInstance(Compiler
);
56 Compiler
->getSema().addExternalSource(SemaSource
.get());
58 clang::ParseAST(Compiler
->getSema(), Compiler
->getFrontendOpts().ShowStats
,
59 Compiler
->getFrontendOpts().SkipFunctionBodies
);
63 getIncludeFixerContext(const clang::SourceManager
&SourceManager
,
64 clang::HeaderSearch
&HeaderSearch
) const {
65 return SemaSource
->getIncludeFixerContext(SourceManager
, HeaderSearch
,
66 SemaSource
->getMatchedSymbols());
70 IntrusiveRefCntPtr
<IncludeFixerSemaSource
> SemaSource
;
75 IncludeFixerActionFactory::IncludeFixerActionFactory(
76 SymbolIndexManager
&SymbolIndexMgr
,
77 std::vector
<IncludeFixerContext
> &Contexts
, StringRef StyleName
,
78 bool MinimizeIncludePaths
)
79 : SymbolIndexMgr(SymbolIndexMgr
), Contexts(Contexts
),
80 MinimizeIncludePaths(MinimizeIncludePaths
) {}
82 IncludeFixerActionFactory::~IncludeFixerActionFactory() = default;
84 bool IncludeFixerActionFactory::runInvocation(
85 std::shared_ptr
<clang::CompilerInvocation
> Invocation
,
86 clang::FileManager
*Files
,
87 std::shared_ptr
<clang::PCHContainerOperations
> PCHContainerOps
,
88 clang::DiagnosticConsumer
*Diagnostics
) {
89 assert(Invocation
->getFrontendOpts().Inputs
.size() == 1);
92 clang::CompilerInstance
Compiler(PCHContainerOps
);
93 Compiler
.setInvocation(std::move(Invocation
));
94 Compiler
.setFileManager(Files
);
96 // Create the compiler's actual diagnostics engine. We want to drop all
98 Compiler
.createDiagnostics(new clang::IgnoringDiagConsumer
,
99 /*ShouldOwnClient=*/true);
100 Compiler
.createSourceManager(*Files
);
102 // We abort on fatal errors so don't let a large number of errors become
103 // fatal. A missing #include can cause thousands of errors.
104 Compiler
.getDiagnostics().setErrorLimit(0);
106 // Run the parser, gather missing includes.
107 auto ScopedToolAction
=
108 std::make_unique
<Action
>(SymbolIndexMgr
, MinimizeIncludePaths
);
109 Compiler
.ExecuteAction(*ScopedToolAction
);
111 Contexts
.push_back(ScopedToolAction
->getIncludeFixerContext(
112 Compiler
.getSourceManager(),
113 Compiler
.getPreprocessor().getHeaderSearchInfo()));
115 // Technically this should only return true if we're sure that we have a
116 // parseable file. We don't know that though. Only inform users of fatal
118 return !Compiler
.getDiagnostics().hasFatalErrorOccurred();
121 static bool addDiagnosticsForContext(TypoCorrection
&Correction
,
122 const IncludeFixerContext
&Context
,
123 StringRef Code
, SourceLocation StartOfFile
,
125 auto Reps
= createIncludeFixerReplacements(
126 Code
, Context
, format::getLLVMStyle(), /*AddQualifiers=*/false);
127 if (!Reps
|| Reps
->size() != 1)
130 unsigned DiagID
= Ctx
.getDiagnostics().getCustomDiagID(
131 DiagnosticsEngine::Note
, "Add '#include %0' to provide the missing "
132 "declaration [clang-include-fixer]");
134 // FIXME: Currently we only generate a diagnostic for the first header. Give
136 const tooling::Replacement
&Placed
= *Reps
->begin();
138 auto Begin
= StartOfFile
.getLocWithOffset(Placed
.getOffset());
139 auto End
= Begin
.getLocWithOffset(std::max(0, (int)Placed
.getLength() - 1));
140 PartialDiagnostic
PD(DiagID
, Ctx
.getDiagAllocator());
141 PD
<< Context
.getHeaderInfos().front().Header
142 << FixItHint::CreateReplacement(CharSourceRange::getCharRange(Begin
, End
),
143 Placed
.getReplacementText());
144 Correction
.addExtraDiagnostic(std::move(PD
));
148 /// Callback for incomplete types. If we encounter a forward declaration we
149 /// have the fully qualified name ready. Just query that.
150 bool IncludeFixerSemaSource::MaybeDiagnoseMissingCompleteType(
151 clang::SourceLocation Loc
, clang::QualType T
) {
152 // Ignore spurious callbacks from SFINAE contexts.
153 if (CI
->getSema().isSFINAEContext())
156 clang::ASTContext
&context
= CI
->getASTContext();
157 std::string QueryString
= QualType(T
->getUnqualifiedDesugaredType(), 0)
158 .getAsString(context
.getPrintingPolicy());
159 LLVM_DEBUG(llvm::dbgs() << "Query missing complete type '" << QueryString
161 // Pass an empty range here since we don't add qualifier in this case.
162 std::vector
<find_all_symbols::SymbolInfo
> MatchedSymbols
=
163 query(QueryString
, "", tooling::Range());
165 if (!MatchedSymbols
.empty() && GenerateDiagnostics
) {
166 TypoCorrection Correction
;
167 FileID FID
= CI
->getSourceManager().getFileID(Loc
);
168 StringRef Code
= CI
->getSourceManager().getBufferData(FID
);
169 SourceLocation StartOfFile
=
170 CI
->getSourceManager().getLocForStartOfFile(FID
);
171 addDiagnosticsForContext(
173 getIncludeFixerContext(CI
->getSourceManager(),
174 CI
->getPreprocessor().getHeaderSearchInfo(),
176 Code
, StartOfFile
, CI
->getASTContext());
177 for (const PartialDiagnostic
&PD
: Correction
.getExtraDiagnostics())
178 CI
->getSema().Diag(Loc
, PD
);
183 /// Callback for unknown identifiers. Try to piece together as much
184 /// qualification as we can get and do a query.
185 clang::TypoCorrection
IncludeFixerSemaSource::CorrectTypo(
186 const DeclarationNameInfo
&Typo
, int LookupKind
, Scope
*S
, CXXScopeSpec
*SS
,
187 CorrectionCandidateCallback
&CCC
, DeclContext
*MemberContext
,
188 bool EnteringContext
, const ObjCObjectPointerType
*OPT
) {
189 // Ignore spurious callbacks from SFINAE contexts.
190 if (CI
->getSema().isSFINAEContext())
191 return clang::TypoCorrection();
193 // We currently ignore the unidentified symbol which is not from the
196 // However, this is not always true due to templates in a non-self contained
197 // header, consider the case:
200 // template <typename T>
206 // // We need to add <bar.h> in test.cc instead of header.h.
210 // FIXME: Add the missing header to the header file where the symbol comes
212 if (!CI
->getSourceManager().isWrittenInMainFile(Typo
.getLoc()))
213 return clang::TypoCorrection();
215 std::string TypoScopeString
;
217 // FIXME: Currently we only use namespace contexts. Use other context
219 for (const auto *Context
= S
->getEntity(); Context
;
220 Context
= Context
->getParent()) {
221 if (const auto *ND
= dyn_cast
<NamespaceDecl
>(Context
)) {
222 if (!ND
->getName().empty())
223 TypoScopeString
= ND
->getNameAsString() + "::" + TypoScopeString
;
228 auto ExtendNestedNameSpecifier
= [this](CharSourceRange Range
) {
230 Lexer::getSourceText(Range
, CI
->getSourceManager(), CI
->getLangOpts());
232 // Skip forward until we find a character that's neither identifier nor
233 // colon. This is a bit of a hack around the fact that we will only get a
234 // single callback for a long nested name if a part of the beginning is
235 // unknown. For example:
237 // llvm::sys::path::parent_path(...)
241 // unknown, last callback
245 // With the extension we get the full nested name specifier including
247 // FIXME: Don't rely on source text.
248 const char *End
= Source
.end();
249 while (isAsciiIdentifierContinue(*End
) || *End
== ':')
252 return std::string(Source
.begin(), End
);
255 /// If we have a scope specification, use that to get more precise results.
256 std::string QueryString
;
257 tooling::Range SymbolRange
;
258 const auto &SM
= CI
->getSourceManager();
259 auto CreateToolingRange
= [&QueryString
, &SM
](SourceLocation BeginLoc
) {
260 return tooling::Range(SM
.getDecomposedLoc(BeginLoc
).second
,
263 if (SS
&& SS
->getRange().isValid()) {
264 auto Range
= CharSourceRange::getTokenRange(SS
->getRange().getBegin(),
267 QueryString
= ExtendNestedNameSpecifier(Range
);
268 SymbolRange
= CreateToolingRange(Range
.getBegin());
269 } else if (Typo
.getName().isIdentifier() && !Typo
.getLoc().isMacroID()) {
271 CharSourceRange::getTokenRange(Typo
.getBeginLoc(), Typo
.getEndLoc());
273 QueryString
= ExtendNestedNameSpecifier(Range
);
274 SymbolRange
= CreateToolingRange(Range
.getBegin());
276 QueryString
= Typo
.getAsString();
277 SymbolRange
= CreateToolingRange(Typo
.getLoc());
280 LLVM_DEBUG(llvm::dbgs() << "TypoScopeQualifiers: " << TypoScopeString
282 std::vector
<find_all_symbols::SymbolInfo
> MatchedSymbols
=
283 query(QueryString
, TypoScopeString
, SymbolRange
);
285 if (!MatchedSymbols
.empty() && GenerateDiagnostics
) {
286 TypoCorrection
Correction(Typo
.getName());
287 Correction
.setCorrectionRange(SS
, Typo
);
288 FileID FID
= SM
.getFileID(Typo
.getLoc());
289 StringRef Code
= SM
.getBufferData(FID
);
290 SourceLocation StartOfFile
= SM
.getLocForStartOfFile(FID
);
291 if (addDiagnosticsForContext(
292 Correction
, getIncludeFixerContext(
293 SM
, CI
->getPreprocessor().getHeaderSearchInfo(),
295 Code
, StartOfFile
, CI
->getASTContext()))
298 return TypoCorrection();
301 /// Get the minimal include for a given path.
302 std::string
IncludeFixerSemaSource::minimizeInclude(
303 StringRef Include
, const clang::SourceManager
&SourceManager
,
304 clang::HeaderSearch
&HeaderSearch
) const {
305 if (!MinimizeIncludePaths
)
306 return std::string(Include
);
308 // Get the FileEntry for the include.
309 StringRef StrippedInclude
= Include
.trim("\"<>");
311 SourceManager
.getFileManager().getOptionalFileRef(StrippedInclude
);
313 // If the file doesn't exist return the path from the database.
314 // FIXME: This should never happen.
316 return std::string(Include
);
318 bool IsAngled
= false;
319 std::string Suggestion
=
320 HeaderSearch
.suggestPathToFileForDiagnostics(*Entry
, "", &IsAngled
);
322 return IsAngled
? '<' + Suggestion
+ '>' : '"' + Suggestion
+ '"';
325 /// Get the include fixer context for the queried symbol.
326 IncludeFixerContext
IncludeFixerSemaSource::getIncludeFixerContext(
327 const clang::SourceManager
&SourceManager
,
328 clang::HeaderSearch
&HeaderSearch
,
329 ArrayRef
<find_all_symbols::SymbolInfo
> MatchedSymbols
) const {
330 std::vector
<find_all_symbols::SymbolInfo
> SymbolCandidates
;
331 for (const auto &Symbol
: MatchedSymbols
) {
332 std::string FilePath
= Symbol
.getFilePath().str();
333 std::string MinimizedFilePath
= minimizeInclude(
334 ((FilePath
[0] == '"' || FilePath
[0] == '<') ? FilePath
335 : "\"" + FilePath
+ "\""),
336 SourceManager
, HeaderSearch
);
337 SymbolCandidates
.emplace_back(Symbol
.getName(), Symbol
.getSymbolKind(),
338 MinimizedFilePath
, Symbol
.getContexts());
340 return IncludeFixerContext(FilePath
, QuerySymbolInfos
, SymbolCandidates
);
343 std::vector
<find_all_symbols::SymbolInfo
>
344 IncludeFixerSemaSource::query(StringRef Query
, StringRef ScopedQualifiers
,
345 tooling::Range Range
) {
346 assert(!Query
.empty() && "Empty query!");
348 // Save all instances of an unidentified symbol.
350 // We use conservative behavior for detecting the same unidentified symbol
351 // here. The symbols which have the same ScopedQualifier and RawIdentifier
352 // are considered equal. So that clang-include-fixer avoids false positives,
353 // and always adds missing qualifiers to correct symbols.
354 if (!GenerateDiagnostics
&& !QuerySymbolInfos
.empty()) {
355 if (ScopedQualifiers
== QuerySymbolInfos
.front().ScopedQualifiers
&&
356 Query
== QuerySymbolInfos
.front().RawIdentifier
) {
357 QuerySymbolInfos
.push_back(
358 {Query
.str(), std::string(ScopedQualifiers
), Range
});
363 LLVM_DEBUG(llvm::dbgs() << "Looking up '" << Query
<< "' at ");
364 LLVM_DEBUG(CI
->getSourceManager()
365 .getLocForStartOfFile(CI
->getSourceManager().getMainFileID())
366 .getLocWithOffset(Range
.getOffset())
367 .print(llvm::dbgs(), CI
->getSourceManager()));
368 LLVM_DEBUG(llvm::dbgs() << " ...");
369 llvm::StringRef FileName
= CI
->getSourceManager().getFilename(
370 CI
->getSourceManager().getLocForStartOfFile(
371 CI
->getSourceManager().getMainFileID()));
373 QuerySymbolInfos
.push_back(
374 {Query
.str(), std::string(ScopedQualifiers
), Range
});
376 // Query the symbol based on C++ name Lookup rules.
377 // Firstly, lookup the identifier with scoped namespace contexts;
378 // If that fails, falls back to look up the identifier directly.
386 // 1. lookup a::b::foo.
388 std::string QueryString
= ScopedQualifiers
.str() + Query
.str();
389 // It's unsafe to do nested search for the identifier with scoped namespace
390 // context, it might treat the identifier as a nested class of the scoped
392 std::vector
<find_all_symbols::SymbolInfo
> MatchedSymbols
=
393 SymbolIndexMgr
.search(QueryString
, /*IsNestedSearch=*/false, FileName
);
394 if (MatchedSymbols
.empty())
396 SymbolIndexMgr
.search(Query
, /*IsNestedSearch=*/true, FileName
);
397 LLVM_DEBUG(llvm::dbgs() << "Having found " << MatchedSymbols
.size()
399 // We store a copy of MatchedSymbols in a place where it's globally reachable.
400 // This is used by the standalone version of the tool.
401 this->MatchedSymbols
= MatchedSymbols
;
402 return MatchedSymbols
;
405 llvm::Expected
<tooling::Replacements
> createIncludeFixerReplacements(
406 StringRef Code
, const IncludeFixerContext
&Context
,
407 const clang::format::FormatStyle
&Style
, bool AddQualifiers
) {
408 if (Context
.getHeaderInfos().empty())
409 return tooling::Replacements();
410 StringRef FilePath
= Context
.getFilePath();
411 std::string IncludeName
=
412 "#include " + Context
.getHeaderInfos().front().Header
+ "\n";
413 // Create replacements for the new header.
414 clang::tooling::Replacements Insertions
;
416 Insertions
.add(tooling::Replacement(FilePath
, UINT_MAX
, 0, IncludeName
));
418 return std::move(Err
);
420 auto CleanReplaces
= cleanupAroundReplacements(Code
, Insertions
, Style
);
422 return CleanReplaces
;
424 auto Replaces
= std::move(*CleanReplaces
);
426 for (const auto &Info
: Context
.getQuerySymbolInfos()) {
427 // Ignore the empty range.
428 if (Info
.Range
.getLength() > 0) {
429 auto R
= tooling::Replacement(
430 {FilePath
, Info
.Range
.getOffset(), Info
.Range
.getLength(),
431 Context
.getHeaderInfos().front().QualifiedName
});
432 auto Err
= Replaces
.add(R
);
434 llvm::consumeError(std::move(Err
));
435 R
= tooling::Replacement(
436 R
.getFilePath(), Replaces
.getShiftedCodePosition(R
.getOffset()),
437 R
.getLength(), R
.getReplacementText());
438 Replaces
= Replaces
.merge(tooling::Replacements(R
));
443 return formatReplacements(Code
, Replaces
, Style
);
446 } // namespace include_fixer