[mlir] Use StringRef::{starts,ends}_with (NFC)
[llvm-project.git] / mlir / lib / TableGen / Class.cpp
blobf71d7e07ed49989705a4516fcab852fa4c24cce0
1 //===- Class.cpp - Helper classes for Op C++ code emission --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/TableGen/Class.h"
10 #include "mlir/TableGen/Format.h"
11 #include "llvm/ADT/Sequence.h"
12 #include "llvm/ADT/Twine.h"
13 #include "llvm/Support/Debug.h"
15 using namespace mlir;
16 using namespace mlir::tblgen;
18 /// Returns space to be emitted after the given C++ `type`. return "" if the
19 /// ends with '&' or '*', or is empty, else returns " ".
20 static StringRef getSpaceAfterType(StringRef type) {
21 return (type.empty() || type.ends_with("&") || type.ends_with("*")) ? ""
22 : " ";
25 //===----------------------------------------------------------------------===//
26 // MethodParameter definitions
27 //===----------------------------------------------------------------------===//
29 void MethodParameter::writeDeclTo(raw_indented_ostream &os) const {
30 if (optional)
31 os << "/*optional*/";
32 os << type << getSpaceAfterType(type) << name;
33 if (hasDefaultValue())
34 os << " = " << defaultValue;
37 void MethodParameter::writeDefTo(raw_indented_ostream &os) const {
38 if (optional)
39 os << "/*optional*/";
40 os << type << getSpaceAfterType(type) << name;
43 //===----------------------------------------------------------------------===//
44 // MethodParameters definitions
45 //===----------------------------------------------------------------------===//
47 void MethodParameters::writeDeclTo(raw_indented_ostream &os) const {
48 llvm::interleaveComma(parameters, os,
49 [&os](auto &param) { param.writeDeclTo(os); });
51 void MethodParameters::writeDefTo(raw_indented_ostream &os) const {
52 llvm::interleaveComma(parameters, os,
53 [&os](auto &param) { param.writeDefTo(os); });
56 bool MethodParameters::subsumes(const MethodParameters &other) const {
57 // These parameters do not subsume the others if there are fewer parameters
58 // or their types do not match.
59 if (parameters.size() < other.parameters.size())
60 return false;
61 if (!std::equal(
62 other.parameters.begin(), other.parameters.end(), parameters.begin(),
63 [](auto &lhs, auto &rhs) { return lhs.getType() == rhs.getType(); }))
64 return false;
66 // If all the common parameters have the same type, we can elide the other
67 // method if this method has the same number of parameters as other or if the
68 // first paramater after the common parameters has a default value (and, as
69 // required by C++, subsequent parameters will have default values too).
70 return parameters.size() == other.parameters.size() ||
71 parameters[other.parameters.size()].hasDefaultValue();
74 //===----------------------------------------------------------------------===//
75 // MethodSignature definitions
76 //===----------------------------------------------------------------------===//
78 bool MethodSignature::makesRedundant(const MethodSignature &other) const {
79 return methodName == other.methodName &&
80 parameters.subsumes(other.parameters);
83 void MethodSignature::writeDeclTo(raw_indented_ostream &os) const {
84 os << returnType << getSpaceAfterType(returnType) << methodName << "(";
85 parameters.writeDeclTo(os);
86 os << ")";
89 void MethodSignature::writeDefTo(raw_indented_ostream &os,
90 StringRef namePrefix) const {
91 os << returnType << getSpaceAfterType(returnType) << namePrefix
92 << (namePrefix.empty() ? "" : "::") << methodName << "(";
93 parameters.writeDefTo(os);
94 os << ")";
97 void MethodSignature::writeTemplateParamsTo(
98 mlir::raw_indented_ostream &os) const {
99 if (templateParams.empty())
100 return;
102 os << "template <";
103 llvm::interleaveComma(templateParams, os,
104 [&](StringRef param) { os << "typename " << param; });
105 os << ">\n";
108 //===----------------------------------------------------------------------===//
109 // MethodBody definitions
110 //===----------------------------------------------------------------------===//
112 MethodBody::MethodBody(bool declOnly)
113 : declOnly(declOnly), stringOs(body), os(stringOs) {}
115 void MethodBody::writeTo(raw_indented_ostream &os) const {
116 auto bodyRef = StringRef(body).drop_while([](char c) { return c == '\n'; });
117 os << bodyRef;
118 if (bodyRef.empty())
119 return;
120 if (bodyRef.back() != '\n')
121 os << "\n";
124 //===----------------------------------------------------------------------===//
125 // Method definitions
126 //===----------------------------------------------------------------------===//
128 void Method::writeDeclTo(raw_indented_ostream &os) const {
129 methodSignature.writeTemplateParamsTo(os);
130 if (deprecationMessage) {
131 os << "[[deprecated(\"";
132 os.write_escaped(*deprecationMessage);
133 os << "\")]]\n";
135 if (isStatic())
136 os << "static ";
137 if (properties & ConstexprValue)
138 os << "constexpr ";
139 methodSignature.writeDeclTo(os);
140 if (isConst())
141 os << " const";
142 if (!isInline()) {
143 os << ";\n";
144 return;
146 os << " {\n";
147 methodBody.writeTo(os);
148 os << "}\n\n";
151 void Method::writeDefTo(raw_indented_ostream &os, StringRef namePrefix) const {
152 // The method has no definition to write if it is declaration only or inline.
153 if (properties & Declaration || isInline())
154 return;
156 methodSignature.writeDefTo(os, namePrefix);
157 if (isConst())
158 os << " const";
159 os << " {\n";
160 methodBody.writeTo(os);
161 os << "}\n\n";
164 //===----------------------------------------------------------------------===//
165 // Constructor definitions
166 //===----------------------------------------------------------------------===//
168 void Constructor::writeDeclTo(raw_indented_ostream &os) const {
169 methodSignature.writeTemplateParamsTo(os);
170 if (properties & ConstexprValue)
171 os << "constexpr ";
172 methodSignature.writeDeclTo(os);
173 if (!isInline()) {
174 os << ";\n\n";
175 return;
177 os << ' ';
178 if (!initializers.empty())
179 os << ": ";
180 llvm::interleaveComma(initializers, os,
181 [&](auto &initializer) { initializer.writeTo(os); });
182 if (!initializers.empty())
183 os << ' ';
184 os << "{";
185 methodBody.writeTo(os);
186 os << "}\n\n";
189 void Constructor::writeDefTo(raw_indented_ostream &os,
190 StringRef namePrefix) const {
191 // The method has no definition to write if it is declaration only or inline.
192 if (properties & Declaration || isInline())
193 return;
195 methodSignature.writeDefTo(os, namePrefix);
196 os << ' ';
197 if (!initializers.empty())
198 os << ": ";
199 llvm::interleaveComma(initializers, os,
200 [&](auto &initializer) { initializer.writeTo(os); });
201 if (!initializers.empty())
202 os << ' ';
203 os << "{";
204 methodBody.writeTo(os);
205 os << "}\n\n";
208 void Constructor::MemberInitializer::writeTo(raw_indented_ostream &os) const {
209 os << name << '(' << value << ')';
212 //===----------------------------------------------------------------------===//
213 // Visibility definitions
214 //===----------------------------------------------------------------------===//
216 namespace mlir {
217 namespace tblgen {
218 raw_ostream &operator<<(raw_ostream &os, Visibility visibility) {
219 switch (visibility) {
220 case Visibility::Public:
221 return os << "public";
222 case Visibility::Protected:
223 return os << "protected";
224 case Visibility::Private:
225 return os << "private";
227 return os;
229 } // namespace tblgen
230 } // namespace mlir
232 //===----------------------------------------------------------------------===//
233 // ParentClass definitions
234 //===----------------------------------------------------------------------===//
236 void ParentClass::writeTo(raw_indented_ostream &os) const {
237 os << visibility << ' ' << name;
238 if (!templateParams.empty()) {
239 auto scope = os.scope("<", ">", /*indent=*/false);
240 llvm::interleaveComma(templateParams, os,
241 [&](auto &param) { os << param; });
245 //===----------------------------------------------------------------------===//
246 // UsingDeclaration definitions
247 //===----------------------------------------------------------------------===//
249 void UsingDeclaration::writeDeclTo(raw_indented_ostream &os) const {
250 if (!templateParams.empty()) {
251 os << "template <";
252 llvm::interleaveComma(templateParams, os, [&](StringRef paramName) {
253 os << "typename " << paramName;
255 os << ">\n";
257 os << "using " << name;
258 if (!value.empty())
259 os << " = " << value;
260 os << ";\n";
263 //===----------------------------------------------------------------------===//
264 // Field definitions
265 //===----------------------------------------------------------------------===//
267 void Field::writeDeclTo(raw_indented_ostream &os) const {
268 os << type << ' ' << name << ";\n";
271 //===----------------------------------------------------------------------===//
272 // VisibilityDeclaration definitions
273 //===----------------------------------------------------------------------===//
275 void VisibilityDeclaration::writeDeclTo(raw_indented_ostream &os) const {
276 os.unindent();
277 os << visibility << ":\n";
278 os.indent();
281 //===----------------------------------------------------------------------===//
282 // ExtraClassDeclaration definitions
283 //===----------------------------------------------------------------------===//
285 void ExtraClassDeclaration::writeDeclTo(raw_indented_ostream &os) const {
286 os.printReindented(extraClassDeclaration);
289 void ExtraClassDeclaration::writeDefTo(raw_indented_ostream &os,
290 StringRef namePrefix) const {
291 os.printReindented(extraClassDefinition);
294 //===----------------------------------------------------------------------===//
295 // Class definitions
296 //===----------------------------------------------------------------------===//
298 ParentClass &Class::addParent(ParentClass parent) {
299 parents.push_back(std::move(parent));
300 return parents.back();
303 void Class::writeDeclTo(raw_indented_ostream &os) const {
304 if (!templateParams.empty()) {
305 os << "template <";
306 llvm::interleaveComma(templateParams, os,
307 [&](StringRef param) { os << "typename " << param; });
308 os << ">\n";
311 // Declare the class.
312 os << (isStruct ? "struct" : "class") << ' ' << className << ' ';
314 // Declare the parent classes, if any.
315 if (!parents.empty()) {
316 os << ": ";
317 llvm::interleaveComma(parents, os,
318 [&](auto &parent) { parent.writeTo(os); });
319 os << ' ';
321 auto classScope = os.scope("{\n", "};\n", /*indent=*/true);
323 // Print all the class declarations.
324 for (auto &decl : declarations)
325 decl->writeDeclTo(os);
328 void Class::writeDefTo(raw_indented_ostream &os) const {
329 // Print all the definitions.
330 for (auto &decl : declarations)
331 decl->writeDefTo(os, className);
334 void Class::finalize() {
335 // Sort the methods by public and private. Remove them from the pending list
336 // of methods.
337 SmallVector<std::unique_ptr<Method>> publicMethods, privateMethods;
338 for (auto &method : methods) {
339 if (method->isPrivate())
340 privateMethods.push_back(std::move(method));
341 else
342 publicMethods.push_back(std::move(method));
344 methods.clear();
346 // If the last visibility declaration wasn't `public`, add one that is. Then,
347 // declare the public methods.
348 if (!publicMethods.empty() && getLastVisibilityDecl() != Visibility::Public)
349 declare<VisibilityDeclaration>(Visibility::Public);
350 for (auto &method : publicMethods)
351 declarations.push_back(std::move(method));
353 // If the last visibility declaration wasn't `private`, add one that is. Then,
354 // declare the private methods.
355 if (!privateMethods.empty() && getLastVisibilityDecl() != Visibility::Private)
356 declare<VisibilityDeclaration>(Visibility::Private);
357 for (auto &method : privateMethods)
358 declarations.push_back(std::move(method));
360 // All fields added to the pending list are private and declared at the bottom
361 // of the class. If the last visibility declaration wasn't `private`, add one
362 // that is, then declare the fields.
363 if (!fields.empty() && getLastVisibilityDecl() != Visibility::Private)
364 declare<VisibilityDeclaration>(Visibility::Private);
365 for (auto &field : fields)
366 declare<Field>(std::move(field));
367 fields.clear();
370 Visibility Class::getLastVisibilityDecl() const {
371 auto reverseDecls = llvm::reverse(declarations);
372 auto it = llvm::find_if(reverseDecls, [](auto &decl) {
373 return isa<VisibilityDeclaration>(decl);
375 return it == reverseDecls.end()
376 ? (isStruct ? Visibility::Public : Visibility::Private)
377 : cast<VisibilityDeclaration>(**it).getVisibility();
380 Method *insertAndPruneMethods(std::vector<std::unique_ptr<Method>> &methods,
381 std::unique_ptr<Method> newMethod) {
382 if (llvm::any_of(methods, [&](auto &method) {
383 return method->makesRedundant(*newMethod);
385 return nullptr;
387 llvm::erase_if(methods, [&](auto &method) {
388 return newMethod->makesRedundant(*method);
390 methods.push_back(std::move(newMethod));
391 return methods.back().get();
394 Method *Class::addMethodAndPrune(Method &&newMethod) {
395 return insertAndPruneMethods(methods,
396 std::make_unique<Method>(std::move(newMethod)));
399 Constructor *Class::addConstructorAndPrune(Constructor &&newCtor) {
400 return dyn_cast_or_null<Constructor>(insertAndPruneMethods(
401 methods, std::make_unique<Constructor>(std::move(newCtor))));