[clang][modules] Don't prevent translation of FW_Private includes when explicitly...
[llvm-project.git] / clang-tools-extra / clang-tidy / mpi / TypeMismatchCheck.cpp
blob5abe4f77d65984dc7b88b7ca41fb9955deaa49bb
1 //===--- TypeMismatchCheck.cpp - clang-tidy--------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "TypeMismatchCheck.h"
10 #include "clang/Lex/Lexer.h"
11 #include "clang/Tooling/FixIt.h"
12 #include "llvm/ADT/StringSet.h"
13 #include <map>
15 using namespace clang::ast_matchers;
17 namespace clang::tidy::mpi {
19 /// Check if a BuiltinType::Kind matches the MPI datatype.
20 ///
21 /// \param MultiMap datatype group
22 /// \param Kind buffer type kind
23 /// \param MPIDatatype name of the MPI datatype
24 ///
25 /// \returns true if the pair matches
26 static bool
27 isMPITypeMatching(const std::multimap<BuiltinType::Kind, StringRef> &MultiMap,
28 const BuiltinType::Kind Kind, StringRef MPIDatatype) {
29 auto ItPair = MultiMap.equal_range(Kind);
30 while (ItPair.first != ItPair.second) {
31 if (ItPair.first->second == MPIDatatype)
32 return true;
33 ++ItPair.first;
35 return false;
38 /// Check if the MPI datatype is a standard type.
39 ///
40 /// \param MPIDatatype name of the MPI datatype
41 ///
42 /// \returns true if the type is a standard type
43 static bool isStandardMPIDatatype(StringRef MPIDatatype) {
44 static llvm::StringSet<> AllTypes = {"MPI_C_BOOL",
45 "MPI_CHAR",
46 "MPI_SIGNED_CHAR",
47 "MPI_UNSIGNED_CHAR",
48 "MPI_WCHAR",
49 "MPI_INT",
50 "MPI_LONG",
51 "MPI_SHORT",
52 "MPI_LONG_LONG",
53 "MPI_LONG_LONG_INT",
54 "MPI_UNSIGNED",
55 "MPI_UNSIGNED_SHORT",
56 "MPI_UNSIGNED_LONG",
57 "MPI_UNSIGNED_LONG_LONG",
58 "MPI_FLOAT",
59 "MPI_DOUBLE",
60 "MPI_LONG_DOUBLE",
61 "MPI_C_COMPLEX",
62 "MPI_C_FLOAT_COMPLEX",
63 "MPI_C_DOUBLE_COMPLEX",
64 "MPI_C_LONG_DOUBLE_COMPLEX",
65 "MPI_INT8_T",
66 "MPI_INT16_T",
67 "MPI_INT32_T",
68 "MPI_INT64_T",
69 "MPI_UINT8_T",
70 "MPI_UINT16_T",
71 "MPI_UINT32_T",
72 "MPI_UINT64_T",
73 "MPI_CXX_BOOL",
74 "MPI_CXX_FLOAT_COMPLEX",
75 "MPI_CXX_DOUBLE_COMPLEX",
76 "MPI_CXX_LONG_DOUBLE_COMPLEX"};
78 return AllTypes.contains(MPIDatatype);
81 /// Check if a BuiltinType matches the MPI datatype.
82 ///
83 /// \param Builtin the builtin type
84 /// \param BufferTypeName buffer type name, gets assigned
85 /// \param MPIDatatype name of the MPI datatype
86 /// \param LO language options
87 ///
88 /// \returns true if the type matches
89 static bool isBuiltinTypeMatching(const BuiltinType *Builtin,
90 std::string &BufferTypeName,
91 StringRef MPIDatatype,
92 const LangOptions &LO) {
93 static std::multimap<BuiltinType::Kind, StringRef> BuiltinMatches = {
94 // On some systems like PPC or ARM, 'char' is unsigned by default which is
95 // why distinct signedness for the buffer and MPI type is tolerated.
96 {BuiltinType::SChar, "MPI_CHAR"},
97 {BuiltinType::SChar, "MPI_SIGNED_CHAR"},
98 {BuiltinType::SChar, "MPI_UNSIGNED_CHAR"},
99 {BuiltinType::Char_S, "MPI_CHAR"},
100 {BuiltinType::Char_S, "MPI_SIGNED_CHAR"},
101 {BuiltinType::Char_S, "MPI_UNSIGNED_CHAR"},
102 {BuiltinType::UChar, "MPI_CHAR"},
103 {BuiltinType::UChar, "MPI_SIGNED_CHAR"},
104 {BuiltinType::UChar, "MPI_UNSIGNED_CHAR"},
105 {BuiltinType::Char_U, "MPI_CHAR"},
106 {BuiltinType::Char_U, "MPI_SIGNED_CHAR"},
107 {BuiltinType::Char_U, "MPI_UNSIGNED_CHAR"},
108 {BuiltinType::WChar_S, "MPI_WCHAR"},
109 {BuiltinType::WChar_U, "MPI_WCHAR"},
110 {BuiltinType::Bool, "MPI_C_BOOL"},
111 {BuiltinType::Bool, "MPI_CXX_BOOL"},
112 {BuiltinType::Short, "MPI_SHORT"},
113 {BuiltinType::Int, "MPI_INT"},
114 {BuiltinType::Long, "MPI_LONG"},
115 {BuiltinType::LongLong, "MPI_LONG_LONG"},
116 {BuiltinType::LongLong, "MPI_LONG_LONG_INT"},
117 {BuiltinType::UShort, "MPI_UNSIGNED_SHORT"},
118 {BuiltinType::UInt, "MPI_UNSIGNED"},
119 {BuiltinType::ULong, "MPI_UNSIGNED_LONG"},
120 {BuiltinType::ULongLong, "MPI_UNSIGNED_LONG_LONG"},
121 {BuiltinType::Float, "MPI_FLOAT"},
122 {BuiltinType::Double, "MPI_DOUBLE"},
123 {BuiltinType::LongDouble, "MPI_LONG_DOUBLE"}};
125 if (!isMPITypeMatching(BuiltinMatches, Builtin->getKind(), MPIDatatype)) {
126 BufferTypeName = std::string(Builtin->getName(LO));
127 return false;
130 return true;
133 /// Check if a complex float/double/long double buffer type matches
134 /// the MPI datatype.
136 /// \param Complex buffer type
137 /// \param BufferTypeName buffer type name, gets assigned
138 /// \param MPIDatatype name of the MPI datatype
139 /// \param LO language options
141 /// \returns true if the type matches or the buffer type is unknown
142 static bool isCComplexTypeMatching(const ComplexType *const Complex,
143 std::string &BufferTypeName,
144 StringRef MPIDatatype,
145 const LangOptions &LO) {
146 static std::multimap<BuiltinType::Kind, StringRef> ComplexCMatches = {
147 {BuiltinType::Float, "MPI_C_COMPLEX"},
148 {BuiltinType::Float, "MPI_C_FLOAT_COMPLEX"},
149 {BuiltinType::Double, "MPI_C_DOUBLE_COMPLEX"},
150 {BuiltinType::LongDouble, "MPI_C_LONG_DOUBLE_COMPLEX"}};
152 const auto *Builtin =
153 Complex->getElementType().getTypePtr()->getAs<BuiltinType>();
155 if (Builtin &&
156 !isMPITypeMatching(ComplexCMatches, Builtin->getKind(), MPIDatatype)) {
157 BufferTypeName = (llvm::Twine(Builtin->getName(LO)) + " _Complex").str();
158 return false;
160 return true;
163 /// Check if a complex<float/double/long double> templated buffer type matches
164 /// the MPI datatype.
166 /// \param Template buffer type
167 /// \param BufferTypeName buffer type name, gets assigned
168 /// \param MPIDatatype name of the MPI datatype
169 /// \param LO language options
171 /// \returns true if the type matches or the buffer type is unknown
172 static bool
173 isCXXComplexTypeMatching(const TemplateSpecializationType *const Template,
174 std::string &BufferTypeName, StringRef MPIDatatype,
175 const LangOptions &LO) {
176 static std::multimap<BuiltinType::Kind, StringRef> ComplexCXXMatches = {
177 {BuiltinType::Float, "MPI_CXX_FLOAT_COMPLEX"},
178 {BuiltinType::Double, "MPI_CXX_DOUBLE_COMPLEX"},
179 {BuiltinType::LongDouble, "MPI_CXX_LONG_DOUBLE_COMPLEX"}};
181 if (Template->getAsCXXRecordDecl()->getName() != "complex")
182 return true;
184 const auto *Builtin = Template->template_arguments()[0]
185 .getAsType()
186 .getTypePtr()
187 ->getAs<BuiltinType>();
189 if (Builtin &&
190 !isMPITypeMatching(ComplexCXXMatches, Builtin->getKind(), MPIDatatype)) {
191 BufferTypeName =
192 (llvm::Twine("complex<") + Builtin->getName(LO) + ">").str();
193 return false;
196 return true;
199 /// Check if a fixed size width buffer type matches the MPI datatype.
201 /// \param Typedef buffer type
202 /// \param BufferTypeName buffer type name, gets assigned
203 /// \param MPIDatatype name of the MPI datatype
205 /// \returns true if the type matches or the buffer type is unknown
206 static bool isTypedefTypeMatching(const TypedefType *const Typedef,
207 std::string &BufferTypeName,
208 StringRef MPIDatatype) {
209 static llvm::StringMap<StringRef> FixedWidthMatches = {
210 {"int8_t", "MPI_INT8_T"}, {"int16_t", "MPI_INT16_T"},
211 {"int32_t", "MPI_INT32_T"}, {"int64_t", "MPI_INT64_T"},
212 {"uint8_t", "MPI_UINT8_T"}, {"uint16_t", "MPI_UINT16_T"},
213 {"uint32_t", "MPI_UINT32_T"}, {"uint64_t", "MPI_UINT64_T"}};
215 const auto It = FixedWidthMatches.find(Typedef->getDecl()->getName());
216 // Check if the typedef is known and not matching the MPI datatype.
217 if (It != FixedWidthMatches.end() && It->getValue() != MPIDatatype) {
218 BufferTypeName = std::string(Typedef->getDecl()->getName());
219 return false;
221 return true;
224 /// Get the unqualified, dereferenced type of an argument.
226 /// \param CE call expression
227 /// \param Idx argument index
229 /// \returns type of the argument
230 static const Type *argumentType(const CallExpr *const CE, const size_t Idx) {
231 const QualType QT = CE->getArg(Idx)->IgnoreImpCasts()->getType();
232 return QT.getTypePtr()->getPointeeOrArrayElementType();
235 void TypeMismatchCheck::registerMatchers(MatchFinder *Finder) {
236 Finder->addMatcher(callExpr().bind("CE"), this);
239 void TypeMismatchCheck::check(const MatchFinder::MatchResult &Result) {
240 const auto *const CE = Result.Nodes.getNodeAs<CallExpr>("CE");
241 if (!CE->getDirectCallee())
242 return;
244 if (!FuncClassifier)
245 FuncClassifier.emplace(*Result.Context);
247 const IdentifierInfo *Identifier = CE->getDirectCallee()->getIdentifier();
248 if (!Identifier || !FuncClassifier->isMPIType(Identifier))
249 return;
251 // These containers are used, to capture buffer, MPI datatype pairs.
252 SmallVector<const Type *, 1> BufferTypes;
253 SmallVector<const Expr *, 1> BufferExprs;
254 SmallVector<StringRef, 1> MPIDatatypes;
256 // Adds a buffer, MPI datatype pair of an MPI call expression to the
257 // containers. For buffers, the type and expression is captured.
258 auto AddPair = [&CE, &Result, &BufferTypes, &BufferExprs, &MPIDatatypes](
259 const size_t BufferIdx, const size_t DatatypeIdx) {
260 // Skip null pointer constants and in place 'operators'.
261 if (CE->getArg(BufferIdx)->isNullPointerConstant(
262 *Result.Context, Expr::NPC_ValueDependentIsNull) ||
263 tooling::fixit::getText(*CE->getArg(BufferIdx), *Result.Context) ==
264 "MPI_IN_PLACE")
265 return;
267 StringRef MPIDatatype =
268 tooling::fixit::getText(*CE->getArg(DatatypeIdx), *Result.Context);
270 const Type *ArgType = argumentType(CE, BufferIdx);
271 // Skip unknown MPI datatypes and void pointers.
272 if (!isStandardMPIDatatype(MPIDatatype) || ArgType->isVoidType())
273 return;
275 BufferTypes.push_back(ArgType);
276 BufferExprs.push_back(CE->getArg(BufferIdx));
277 MPIDatatypes.push_back(MPIDatatype);
280 // Collect all buffer, MPI datatype pairs for the inspected call expression.
281 if (FuncClassifier->isPointToPointType(Identifier)) {
282 AddPair(0, 2);
283 } else if (FuncClassifier->isCollectiveType(Identifier)) {
284 if (FuncClassifier->isReduceType(Identifier)) {
285 AddPair(0, 3);
286 AddPair(1, 3);
287 } else if (FuncClassifier->isScatterType(Identifier) ||
288 FuncClassifier->isGatherType(Identifier) ||
289 FuncClassifier->isAlltoallType(Identifier)) {
290 AddPair(0, 2);
291 AddPair(3, 5);
292 } else if (FuncClassifier->isBcastType(Identifier)) {
293 AddPair(0, 2);
296 checkArguments(BufferTypes, BufferExprs, MPIDatatypes, getLangOpts());
299 void TypeMismatchCheck::checkArguments(ArrayRef<const Type *> BufferTypes,
300 ArrayRef<const Expr *> BufferExprs,
301 ArrayRef<StringRef> MPIDatatypes,
302 const LangOptions &LO) {
303 std::string BufferTypeName;
305 for (size_t I = 0; I < MPIDatatypes.size(); ++I) {
306 const Type *const BT = BufferTypes[I];
307 bool Error = false;
309 if (const auto *Typedef = BT->getAs<TypedefType>()) {
310 Error = !isTypedefTypeMatching(Typedef, BufferTypeName, MPIDatatypes[I]);
311 } else if (const auto *Complex = BT->getAs<ComplexType>()) {
312 Error =
313 !isCComplexTypeMatching(Complex, BufferTypeName, MPIDatatypes[I], LO);
314 } else if (const auto *Template = BT->getAs<TemplateSpecializationType>()) {
315 Error = !isCXXComplexTypeMatching(Template, BufferTypeName,
316 MPIDatatypes[I], LO);
317 } else if (const auto *Builtin = BT->getAs<BuiltinType>()) {
318 Error =
319 !isBuiltinTypeMatching(Builtin, BufferTypeName, MPIDatatypes[I], LO);
322 if (Error) {
323 const auto Loc = BufferExprs[I]->getSourceRange().getBegin();
324 diag(Loc, "buffer type '%0' does not match the MPI datatype '%1'")
325 << BufferTypeName << MPIDatatypes[I];
330 void TypeMismatchCheck::onEndOfTranslationUnit() { FuncClassifier.reset(); }
331 } // namespace clang::tidy::mpi