1 //===-- IncludeFixer.cpp - Include inserter based on sema callbacks -------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "IncludeFixer.h"
10 #include "clang/Format/Format.h"
11 #include "clang/Frontend/CompilerInstance.h"
12 #include "clang/Lex/HeaderSearch.h"
13 #include "clang/Lex/Preprocessor.h"
14 #include "clang/Parse/ParseAST.h"
15 #include "clang/Sema/Sema.h"
16 #include "llvm/Support/Debug.h"
17 #include "llvm/Support/raw_ostream.h"
19 #define DEBUG_TYPE "clang-include-fixer"
21 using namespace clang
;
24 namespace include_fixer
{
26 /// Manages the parse, gathers include suggestions.
27 class Action
: public clang::ASTFrontendAction
{
29 explicit Action(SymbolIndexManager
&SymbolIndexMgr
, bool MinimizeIncludePaths
)
30 : SemaSource(new IncludeFixerSemaSource(SymbolIndexMgr
,
32 /*GenerateDiagnostics=*/false)) {}
34 std::unique_ptr
<clang::ASTConsumer
>
35 CreateASTConsumer(clang::CompilerInstance
&Compiler
,
36 StringRef InFile
) override
{
37 SemaSource
->setFilePath(InFile
);
38 return std::make_unique
<clang::ASTConsumer
>();
41 void ExecuteAction() override
{
42 clang::CompilerInstance
*Compiler
= &getCompilerInstance();
43 assert(!Compiler
->hasSema() && "CI already has Sema");
45 // Set up our hooks into sema and parse the AST.
46 if (hasCodeCompletionSupport() &&
47 !Compiler
->getFrontendOpts().CodeCompletionAt
.FileName
.empty())
48 Compiler
->createCodeCompletionConsumer();
50 clang::CodeCompleteConsumer
*CompletionConsumer
= nullptr;
51 if (Compiler
->hasCodeCompletionConsumer())
52 CompletionConsumer
= &Compiler
->getCodeCompletionConsumer();
54 Compiler
->createSema(getTranslationUnitKind(), CompletionConsumer
);
55 SemaSource
->setCompilerInstance(Compiler
);
56 Compiler
->getSema().addExternalSource(SemaSource
.get());
58 clang::ParseAST(Compiler
->getSema(), Compiler
->getFrontendOpts().ShowStats
,
59 Compiler
->getFrontendOpts().SkipFunctionBodies
);
63 getIncludeFixerContext(const clang::SourceManager
&SourceManager
,
64 clang::HeaderSearch
&HeaderSearch
) const {
65 return SemaSource
->getIncludeFixerContext(SourceManager
, HeaderSearch
,
66 SemaSource
->getMatchedSymbols());
70 IntrusiveRefCntPtr
<IncludeFixerSemaSource
> SemaSource
;
75 IncludeFixerActionFactory::IncludeFixerActionFactory(
76 SymbolIndexManager
&SymbolIndexMgr
,
77 std::vector
<IncludeFixerContext
> &Contexts
, StringRef StyleName
,
78 bool MinimizeIncludePaths
)
79 : SymbolIndexMgr(SymbolIndexMgr
), Contexts(Contexts
),
80 MinimizeIncludePaths(MinimizeIncludePaths
) {}
82 IncludeFixerActionFactory::~IncludeFixerActionFactory() = default;
84 bool IncludeFixerActionFactory::runInvocation(
85 std::shared_ptr
<clang::CompilerInvocation
> Invocation
,
86 clang::FileManager
*Files
,
87 std::shared_ptr
<clang::PCHContainerOperations
> PCHContainerOps
,
88 clang::DiagnosticConsumer
*Diagnostics
) {
89 assert(Invocation
->getFrontendOpts().Inputs
.size() == 1);
92 clang::CompilerInstance
Compiler(PCHContainerOps
);
93 Compiler
.setInvocation(std::move(Invocation
));
94 Compiler
.setFileManager(Files
);
96 // Create the compiler's actual diagnostics engine. We want to drop all
98 Compiler
.createDiagnostics(Files
->getVirtualFileSystem(),
99 new clang::IgnoringDiagConsumer
,
100 /*ShouldOwnClient=*/true);
101 Compiler
.createSourceManager(*Files
);
103 // We abort on fatal errors so don't let a large number of errors become
104 // fatal. A missing #include can cause thousands of errors.
105 Compiler
.getDiagnostics().setErrorLimit(0);
107 // Run the parser, gather missing includes.
108 auto ScopedToolAction
=
109 std::make_unique
<Action
>(SymbolIndexMgr
, MinimizeIncludePaths
);
110 Compiler
.ExecuteAction(*ScopedToolAction
);
112 Contexts
.push_back(ScopedToolAction
->getIncludeFixerContext(
113 Compiler
.getSourceManager(),
114 Compiler
.getPreprocessor().getHeaderSearchInfo()));
116 // Technically this should only return true if we're sure that we have a
117 // parseable file. We don't know that though. Only inform users of fatal
119 return !Compiler
.getDiagnostics().hasFatalErrorOccurred();
122 static bool addDiagnosticsForContext(TypoCorrection
&Correction
,
123 const IncludeFixerContext
&Context
,
124 StringRef Code
, SourceLocation StartOfFile
,
126 auto Reps
= createIncludeFixerReplacements(
127 Code
, Context
, format::getLLVMStyle(), /*AddQualifiers=*/false);
128 if (!Reps
|| Reps
->size() != 1)
131 unsigned DiagID
= Ctx
.getDiagnostics().getCustomDiagID(
132 DiagnosticsEngine::Note
, "Add '#include %0' to provide the missing "
133 "declaration [clang-include-fixer]");
135 // FIXME: Currently we only generate a diagnostic for the first header. Give
137 const tooling::Replacement
&Placed
= *Reps
->begin();
139 auto Begin
= StartOfFile
.getLocWithOffset(Placed
.getOffset());
140 auto End
= Begin
.getLocWithOffset(std::max(0, (int)Placed
.getLength() - 1));
141 PartialDiagnostic
PD(DiagID
, Ctx
.getDiagAllocator());
142 PD
<< Context
.getHeaderInfos().front().Header
143 << FixItHint::CreateReplacement(CharSourceRange::getCharRange(Begin
, End
),
144 Placed
.getReplacementText());
145 Correction
.addExtraDiagnostic(std::move(PD
));
149 /// Callback for incomplete types. If we encounter a forward declaration we
150 /// have the fully qualified name ready. Just query that.
151 bool IncludeFixerSemaSource::MaybeDiagnoseMissingCompleteType(
152 clang::SourceLocation Loc
, clang::QualType T
) {
153 // Ignore spurious callbacks from SFINAE contexts.
154 if (CI
->getSema().isSFINAEContext())
157 clang::ASTContext
&context
= CI
->getASTContext();
158 std::string QueryString
= QualType(T
->getUnqualifiedDesugaredType(), 0)
159 .getAsString(context
.getPrintingPolicy());
160 LLVM_DEBUG(llvm::dbgs() << "Query missing complete type '" << QueryString
162 // Pass an empty range here since we don't add qualifier in this case.
163 std::vector
<find_all_symbols::SymbolInfo
> MatchedSymbols
=
164 query(QueryString
, "", tooling::Range());
166 if (!MatchedSymbols
.empty() && GenerateDiagnostics
) {
167 TypoCorrection Correction
;
168 FileID FID
= CI
->getSourceManager().getFileID(Loc
);
169 StringRef Code
= CI
->getSourceManager().getBufferData(FID
);
170 SourceLocation StartOfFile
=
171 CI
->getSourceManager().getLocForStartOfFile(FID
);
172 addDiagnosticsForContext(
174 getIncludeFixerContext(CI
->getSourceManager(),
175 CI
->getPreprocessor().getHeaderSearchInfo(),
177 Code
, StartOfFile
, CI
->getASTContext());
178 for (const PartialDiagnostic
&PD
: Correction
.getExtraDiagnostics())
179 CI
->getSema().Diag(Loc
, PD
);
184 /// Callback for unknown identifiers. Try to piece together as much
185 /// qualification as we can get and do a query.
186 clang::TypoCorrection
IncludeFixerSemaSource::CorrectTypo(
187 const DeclarationNameInfo
&Typo
, int LookupKind
, Scope
*S
, CXXScopeSpec
*SS
,
188 CorrectionCandidateCallback
&CCC
, DeclContext
*MemberContext
,
189 bool EnteringContext
, const ObjCObjectPointerType
*OPT
) {
190 // Ignore spurious callbacks from SFINAE contexts.
191 if (CI
->getSema().isSFINAEContext())
192 return clang::TypoCorrection();
194 // We currently ignore the unidentified symbol which is not from the
197 // However, this is not always true due to templates in a non-self contained
198 // header, consider the case:
201 // template <typename T>
207 // // We need to add <bar.h> in test.cc instead of header.h.
211 // FIXME: Add the missing header to the header file where the symbol comes
213 if (!CI
->getSourceManager().isWrittenInMainFile(Typo
.getLoc()))
214 return clang::TypoCorrection();
216 std::string TypoScopeString
;
218 // FIXME: Currently we only use namespace contexts. Use other context
220 for (const auto *Context
= S
->getEntity(); Context
;
221 Context
= Context
->getParent()) {
222 if (const auto *ND
= dyn_cast
<NamespaceDecl
>(Context
)) {
223 if (!ND
->getName().empty())
224 TypoScopeString
= ND
->getNameAsString() + "::" + TypoScopeString
;
229 auto ExtendNestedNameSpecifier
= [this](CharSourceRange Range
) {
231 Lexer::getSourceText(Range
, CI
->getSourceManager(), CI
->getLangOpts());
233 // Skip forward until we find a character that's neither identifier nor
234 // colon. This is a bit of a hack around the fact that we will only get a
235 // single callback for a long nested name if a part of the beginning is
236 // unknown. For example:
238 // llvm::sys::path::parent_path(...)
242 // unknown, last callback
246 // With the extension we get the full nested name specifier including
248 // FIXME: Don't rely on source text.
249 const char *End
= Source
.end();
250 while (isAsciiIdentifierContinue(*End
) || *End
== ':')
253 return std::string(Source
.begin(), End
);
256 /// If we have a scope specification, use that to get more precise results.
257 std::string QueryString
;
258 tooling::Range SymbolRange
;
259 const auto &SM
= CI
->getSourceManager();
260 auto CreateToolingRange
= [&QueryString
, &SM
](SourceLocation BeginLoc
) {
261 return tooling::Range(SM
.getDecomposedLoc(BeginLoc
).second
,
264 if (SS
&& SS
->getRange().isValid()) {
265 auto Range
= CharSourceRange::getTokenRange(SS
->getRange().getBegin(),
268 QueryString
= ExtendNestedNameSpecifier(Range
);
269 SymbolRange
= CreateToolingRange(Range
.getBegin());
270 } else if (Typo
.getName().isIdentifier() && !Typo
.getLoc().isMacroID()) {
272 CharSourceRange::getTokenRange(Typo
.getBeginLoc(), Typo
.getEndLoc());
274 QueryString
= ExtendNestedNameSpecifier(Range
);
275 SymbolRange
= CreateToolingRange(Range
.getBegin());
277 QueryString
= Typo
.getAsString();
278 SymbolRange
= CreateToolingRange(Typo
.getLoc());
281 LLVM_DEBUG(llvm::dbgs() << "TypoScopeQualifiers: " << TypoScopeString
283 std::vector
<find_all_symbols::SymbolInfo
> MatchedSymbols
=
284 query(QueryString
, TypoScopeString
, SymbolRange
);
286 if (!MatchedSymbols
.empty() && GenerateDiagnostics
) {
287 TypoCorrection
Correction(Typo
.getName());
288 Correction
.setCorrectionRange(SS
, Typo
);
289 FileID FID
= SM
.getFileID(Typo
.getLoc());
290 StringRef Code
= SM
.getBufferData(FID
);
291 SourceLocation StartOfFile
= SM
.getLocForStartOfFile(FID
);
292 if (addDiagnosticsForContext(
293 Correction
, getIncludeFixerContext(
294 SM
, CI
->getPreprocessor().getHeaderSearchInfo(),
296 Code
, StartOfFile
, CI
->getASTContext()))
299 return TypoCorrection();
302 /// Get the minimal include for a given path.
303 std::string
IncludeFixerSemaSource::minimizeInclude(
304 StringRef Include
, const clang::SourceManager
&SourceManager
,
305 clang::HeaderSearch
&HeaderSearch
) const {
306 if (!MinimizeIncludePaths
)
307 return std::string(Include
);
309 // Get the FileEntry for the include.
310 StringRef StrippedInclude
= Include
.trim("\"<>");
312 SourceManager
.getFileManager().getOptionalFileRef(StrippedInclude
);
314 // If the file doesn't exist return the path from the database.
315 // FIXME: This should never happen.
317 return std::string(Include
);
319 bool IsAngled
= false;
320 std::string Suggestion
=
321 HeaderSearch
.suggestPathToFileForDiagnostics(*Entry
, "", &IsAngled
);
323 return IsAngled
? '<' + Suggestion
+ '>' : '"' + Suggestion
+ '"';
326 /// Get the include fixer context for the queried symbol.
327 IncludeFixerContext
IncludeFixerSemaSource::getIncludeFixerContext(
328 const clang::SourceManager
&SourceManager
,
329 clang::HeaderSearch
&HeaderSearch
,
330 ArrayRef
<find_all_symbols::SymbolInfo
> MatchedSymbols
) const {
331 std::vector
<find_all_symbols::SymbolInfo
> SymbolCandidates
;
332 for (const auto &Symbol
: MatchedSymbols
) {
333 std::string FilePath
= Symbol
.getFilePath().str();
334 std::string MinimizedFilePath
= minimizeInclude(
335 ((FilePath
[0] == '"' || FilePath
[0] == '<') ? FilePath
336 : "\"" + FilePath
+ "\""),
337 SourceManager
, HeaderSearch
);
338 SymbolCandidates
.emplace_back(Symbol
.getName(), Symbol
.getSymbolKind(),
339 MinimizedFilePath
, Symbol
.getContexts());
341 return IncludeFixerContext(FilePath
, QuerySymbolInfos
, SymbolCandidates
);
344 std::vector
<find_all_symbols::SymbolInfo
>
345 IncludeFixerSemaSource::query(StringRef Query
, StringRef ScopedQualifiers
,
346 tooling::Range Range
) {
347 assert(!Query
.empty() && "Empty query!");
349 // Save all instances of an unidentified symbol.
351 // We use conservative behavior for detecting the same unidentified symbol
352 // here. The symbols which have the same ScopedQualifier and RawIdentifier
353 // are considered equal. So that clang-include-fixer avoids false positives,
354 // and always adds missing qualifiers to correct symbols.
355 if (!GenerateDiagnostics
&& !QuerySymbolInfos
.empty()) {
356 if (ScopedQualifiers
== QuerySymbolInfos
.front().ScopedQualifiers
&&
357 Query
== QuerySymbolInfos
.front().RawIdentifier
) {
358 QuerySymbolInfos
.push_back(
359 {Query
.str(), std::string(ScopedQualifiers
), Range
});
364 LLVM_DEBUG(llvm::dbgs() << "Looking up '" << Query
<< "' at ");
365 LLVM_DEBUG(CI
->getSourceManager()
366 .getLocForStartOfFile(CI
->getSourceManager().getMainFileID())
367 .getLocWithOffset(Range
.getOffset())
368 .print(llvm::dbgs(), CI
->getSourceManager()));
369 LLVM_DEBUG(llvm::dbgs() << " ...");
370 llvm::StringRef FileName
= CI
->getSourceManager().getFilename(
371 CI
->getSourceManager().getLocForStartOfFile(
372 CI
->getSourceManager().getMainFileID()));
374 QuerySymbolInfos
.push_back(
375 {Query
.str(), std::string(ScopedQualifiers
), Range
});
377 // Query the symbol based on C++ name Lookup rules.
378 // Firstly, lookup the identifier with scoped namespace contexts;
379 // If that fails, falls back to look up the identifier directly.
387 // 1. lookup a::b::foo.
389 std::string QueryString
= ScopedQualifiers
.str() + Query
.str();
390 // It's unsafe to do nested search for the identifier with scoped namespace
391 // context, it might treat the identifier as a nested class of the scoped
393 std::vector
<find_all_symbols::SymbolInfo
> MatchedSymbols
=
394 SymbolIndexMgr
.search(QueryString
, /*IsNestedSearch=*/false, FileName
);
395 if (MatchedSymbols
.empty())
397 SymbolIndexMgr
.search(Query
, /*IsNestedSearch=*/true, FileName
);
398 LLVM_DEBUG(llvm::dbgs() << "Having found " << MatchedSymbols
.size()
400 // We store a copy of MatchedSymbols in a place where it's globally reachable.
401 // This is used by the standalone version of the tool.
402 this->MatchedSymbols
= MatchedSymbols
;
403 return MatchedSymbols
;
406 llvm::Expected
<tooling::Replacements
> createIncludeFixerReplacements(
407 StringRef Code
, const IncludeFixerContext
&Context
,
408 const clang::format::FormatStyle
&Style
, bool AddQualifiers
) {
409 if (Context
.getHeaderInfos().empty())
410 return tooling::Replacements();
411 StringRef FilePath
= Context
.getFilePath();
412 std::string IncludeName
=
413 "#include " + Context
.getHeaderInfos().front().Header
+ "\n";
414 // Create replacements for the new header.
415 clang::tooling::Replacements Insertions
;
417 Insertions
.add(tooling::Replacement(FilePath
, UINT_MAX
, 0, IncludeName
));
419 return std::move(Err
);
421 auto CleanReplaces
= cleanupAroundReplacements(Code
, Insertions
, Style
);
423 return CleanReplaces
;
425 auto Replaces
= std::move(*CleanReplaces
);
427 for (const auto &Info
: Context
.getQuerySymbolInfos()) {
428 // Ignore the empty range.
429 if (Info
.Range
.getLength() > 0) {
430 auto R
= tooling::Replacement(
431 {FilePath
, Info
.Range
.getOffset(), Info
.Range
.getLength(),
432 Context
.getHeaderInfos().front().QualifiedName
});
433 auto Err
= Replaces
.add(R
);
435 llvm::consumeError(std::move(Err
));
436 R
= tooling::Replacement(
437 R
.getFilePath(), Replaces
.getShiftedCodePosition(R
.getOffset()),
438 R
.getLength(), R
.getReplacementText());
439 Replaces
= Replaces
.merge(tooling::Replacements(R
));
444 return formatReplacements(Code
, Replaces
, Style
);
447 } // namespace include_fixer