Fix GCC build problem with 288f05f related to SmallVector. (#116958)
[llvm-project.git] / mlir / test / lib / Dialect / Test / TestAttributes.cpp
blobe09ea1090616482d3567dee77a2453fdbbeeac51
1 //===- TestAttributes.cpp - MLIR Test Dialect Attributes --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains attributes defined by the TestDialect for testing various
10 // features of MLIR.
12 //===----------------------------------------------------------------------===//
14 #include "TestAttributes.h"
15 #include "TestDialect.h"
16 #include "TestTypes.h"
17 #include "mlir/IR/Attributes.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/DialectImplementation.h"
20 #include "mlir/IR/ExtensibleDialect.h"
21 #include "mlir/IR/OpImplementation.h"
22 #include "mlir/IR/Types.h"
23 #include "llvm/ADT/APFloat.h"
24 #include "llvm/ADT/Hashing.h"
25 #include "llvm/ADT/StringExtras.h"
26 #include "llvm/ADT/TypeSwitch.h"
27 #include "llvm/ADT/bit.h"
28 #include "llvm/Support/ErrorHandling.h"
29 #include "llvm/Support/raw_ostream.h"
31 using namespace mlir;
32 using namespace test;
34 //===----------------------------------------------------------------------===//
35 // CompoundAAttr
36 //===----------------------------------------------------------------------===//
38 Attribute CompoundAAttr::parse(AsmParser &parser, Type type) {
39 int widthOfSomething;
40 Type oneType;
41 SmallVector<int, 4> arrayOfInts;
42 if (parser.parseLess() || parser.parseInteger(widthOfSomething) ||
43 parser.parseComma() || parser.parseType(oneType) || parser.parseComma() ||
44 parser.parseLSquare())
45 return Attribute();
47 int intVal;
48 while (!*parser.parseOptionalInteger(intVal)) {
49 arrayOfInts.push_back(intVal);
50 if (parser.parseOptionalComma())
51 break;
54 if (parser.parseRSquare() || parser.parseGreater())
55 return Attribute();
56 return get(parser.getContext(), widthOfSomething, oneType, arrayOfInts);
59 void CompoundAAttr::print(AsmPrinter &printer) const {
60 printer << "<" << getWidthOfSomething() << ", " << getOneType() << ", [";
61 llvm::interleaveComma(getArrayOfInts(), printer);
62 printer << "]>";
65 //===----------------------------------------------------------------------===//
66 // CompoundAAttr
67 //===----------------------------------------------------------------------===//
69 Attribute TestDecimalShapeAttr::parse(AsmParser &parser, Type type) {
70 if (parser.parseLess()){
71 return Attribute();
73 SmallVector<int64_t> shape;
74 if (parser.parseOptionalGreater()) {
75 auto parseDecimal = [&]() {
76 shape.emplace_back();
77 auto parseResult = parser.parseOptionalDecimalInteger(shape.back());
78 if (!parseResult.has_value() || failed(*parseResult)) {
79 parser.emitError(parser.getCurrentLocation()) << "expected an integer";
80 return failure();
82 return success();
84 if (failed(parseDecimal())) {
85 return Attribute();
87 while (failed(parser.parseOptionalGreater())) {
88 if (failed(parser.parseXInDimensionList()) || failed(parseDecimal())) {
89 return Attribute();
93 return get(parser.getContext(), shape);
96 void TestDecimalShapeAttr::print(AsmPrinter &printer) const {
97 printer << "<";
98 llvm::interleave(getShape(), printer, "x");
99 printer << ">";
102 Attribute TestI64ElementsAttr::parse(AsmParser &parser, Type type) {
103 SmallVector<uint64_t> elements;
104 if (parser.parseLess() || parser.parseLSquare())
105 return Attribute();
106 uint64_t intVal;
107 while (succeeded(*parser.parseOptionalInteger(intVal))) {
108 elements.push_back(intVal);
109 if (parser.parseOptionalComma())
110 break;
113 if (parser.parseRSquare() || parser.parseGreater())
114 return Attribute();
115 return parser.getChecked<TestI64ElementsAttr>(
116 parser.getContext(), llvm::cast<ShapedType>(type), elements);
119 void TestI64ElementsAttr::print(AsmPrinter &printer) const {
120 printer << "<[";
121 llvm::interleaveComma(getElements(), printer);
122 printer << "]>";
125 LogicalResult
126 TestI64ElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError,
127 ShapedType type, ArrayRef<uint64_t> elements) {
128 if (type.getNumElements() != static_cast<int64_t>(elements.size())) {
129 return emitError()
130 << "number of elements does not match the provided shape type, got: "
131 << elements.size() << ", but expected: " << type.getNumElements();
133 if (type.getRank() != 1 || !type.getElementType().isSignlessInteger(64))
134 return emitError() << "expected single rank 64-bit shape type, but got: "
135 << type;
136 return success();
139 LogicalResult TestAttrWithFormatAttr::verify(
140 function_ref<InFlightDiagnostic()> emitError, int64_t one, std::string two,
141 IntegerAttr three, ArrayRef<int> four, uint64_t five, ArrayRef<int> six,
142 ArrayRef<AttrWithTypeBuilderAttr> arrayOfAttrs) {
143 if (four.size() != static_cast<unsigned>(one))
144 return emitError() << "expected 'one' to equal 'four.size()'";
145 return success();
148 //===----------------------------------------------------------------------===//
149 // Utility Functions for Generated Attributes
150 //===----------------------------------------------------------------------===//
152 static FailureOr<SmallVector<int>> parseIntArray(AsmParser &parser) {
153 SmallVector<int> ints;
154 if (parser.parseLSquare() || parser.parseCommaSeparatedList([&]() {
155 ints.push_back(0);
156 return parser.parseInteger(ints.back());
157 }) ||
158 parser.parseRSquare())
159 return failure();
160 return ints;
163 static void printIntArray(AsmPrinter &printer, ArrayRef<int> ints) {
164 printer << '[';
165 llvm::interleaveComma(ints, printer);
166 printer << ']';
169 //===----------------------------------------------------------------------===//
170 // TestSubElementsAccessAttr
171 //===----------------------------------------------------------------------===//
173 Attribute TestSubElementsAccessAttr::parse(::mlir::AsmParser &parser,
174 ::mlir::Type type) {
175 Attribute first, second, third;
176 if (parser.parseLess() || parser.parseAttribute(first) ||
177 parser.parseComma() || parser.parseAttribute(second) ||
178 parser.parseComma() || parser.parseAttribute(third) ||
179 parser.parseGreater()) {
180 return {};
182 return get(parser.getContext(), first, second, third);
185 void TestSubElementsAccessAttr::print(::mlir::AsmPrinter &printer) const {
186 printer << "<" << getFirst() << ", " << getSecond() << ", " << getThird()
187 << ">";
190 //===----------------------------------------------------------------------===//
191 // TestExtern1DI64ElementsAttr
192 //===----------------------------------------------------------------------===//
194 ArrayRef<uint64_t> TestExtern1DI64ElementsAttr::getElements() const {
195 if (auto *blob = getHandle().getBlob())
196 return blob->getDataAs<uint64_t>();
197 return std::nullopt;
200 //===----------------------------------------------------------------------===//
201 // TestCustomAnchorAttr
202 //===----------------------------------------------------------------------===//
204 static ParseResult parseTrueFalse(AsmParser &p, std::optional<int> &result) {
205 bool b;
206 if (p.parseInteger(b))
207 return failure();
208 result = b;
209 return success();
212 static void printTrueFalse(AsmPrinter &p, std::optional<int> result) {
213 p << (*result ? "true" : "false");
216 //===----------------------------------------------------------------------===//
217 // CopyCountAttr Implementation
218 //===----------------------------------------------------------------------===//
220 CopyCount::CopyCount(const CopyCount &rhs) : value(rhs.value) {
221 CopyCount::counter++;
224 CopyCount &CopyCount::operator=(const CopyCount &rhs) {
225 CopyCount::counter++;
226 value = rhs.value;
227 return *this;
230 int CopyCount::counter;
232 static bool operator==(const test::CopyCount &lhs, const test::CopyCount &rhs) {
233 return lhs.value == rhs.value;
236 llvm::raw_ostream &test::operator<<(llvm::raw_ostream &os,
237 const test::CopyCount &value) {
238 return os << value.value;
241 template <>
242 struct mlir::FieldParser<test::CopyCount> {
243 static FailureOr<test::CopyCount> parse(AsmParser &parser) {
244 std::string value;
245 if (parser.parseKeyword(value))
246 return failure();
247 return test::CopyCount(value);
250 namespace test {
251 llvm::hash_code hash_value(const test::CopyCount &copyCount) {
252 return llvm::hash_value(copyCount.value);
254 } // namespace test
256 //===----------------------------------------------------------------------===//
257 // TestConditionalAliasAttr
258 //===----------------------------------------------------------------------===//
260 /// Attempt to parse the conditionally-aliased string attribute as a keyword or
261 /// string, else try to parse an alias.
262 static ParseResult parseConditionalAlias(AsmParser &p, StringAttr &value) {
263 std::string str;
264 if (succeeded(p.parseOptionalKeywordOrString(&str))) {
265 value = StringAttr::get(p.getContext(), str);
266 return success();
268 return p.parseAttribute(value);
271 /// Print the string attribute as an alias if it has one, otherwise print it as
272 /// a keyword if possible.
273 static void printConditionalAlias(AsmPrinter &p, StringAttr value) {
274 if (succeeded(p.printAlias(value)))
275 return;
276 p.printKeywordOrString(value);
279 //===----------------------------------------------------------------------===//
280 // Custom Float Attribute
281 //===----------------------------------------------------------------------===//
283 static void printCustomFloatAttr(AsmPrinter &p, StringAttr typeStrAttr,
284 APFloat value) {
285 p << typeStrAttr << " : " << value;
288 static ParseResult parseCustomFloatAttr(AsmParser &p, StringAttr &typeStrAttr,
289 FailureOr<APFloat> &value) {
291 std::string str;
292 if (p.parseString(&str))
293 return failure();
295 typeStrAttr = StringAttr::get(p.getContext(), str);
297 if (p.parseColon())
298 return failure();
300 const llvm::fltSemantics *semantics;
301 if (str == "float")
302 semantics = &llvm::APFloat::IEEEsingle();
303 else if (str == "double")
304 semantics = &llvm::APFloat::IEEEdouble();
305 else if (str == "fp80")
306 semantics = &llvm::APFloat::x87DoubleExtended();
307 else
308 return p.emitError(p.getCurrentLocation(), "unknown float type, expected "
309 "'float', 'double' or 'fp80'");
311 APFloat parsedValue(0.0);
312 if (p.parseFloat(*semantics, parsedValue))
313 return failure();
315 value.emplace(parsedValue);
316 return success();
319 //===----------------------------------------------------------------------===//
320 // Tablegen Generated Definitions
321 //===----------------------------------------------------------------------===//
323 #include "TestAttrInterfaces.cpp.inc"
324 #include "TestOpEnums.cpp.inc"
325 #define GET_ATTRDEF_CLASSES
326 #include "TestAttrDefs.cpp.inc"
328 //===----------------------------------------------------------------------===//
329 // Dynamic Attributes
330 //===----------------------------------------------------------------------===//
332 /// Define a singleton dynamic attribute.
333 static std::unique_ptr<DynamicAttrDefinition>
334 getDynamicSingletonAttr(TestDialect *testDialect) {
335 return DynamicAttrDefinition::get(
336 "dynamic_singleton", testDialect,
337 [](function_ref<InFlightDiagnostic()> emitError,
338 ArrayRef<Attribute> args) {
339 if (!args.empty()) {
340 emitError() << "expected 0 attribute arguments, but had "
341 << args.size();
342 return failure();
344 return success();
348 /// Define a dynamic attribute representing a pair or attributes.
349 static std::unique_ptr<DynamicAttrDefinition>
350 getDynamicPairAttr(TestDialect *testDialect) {
351 return DynamicAttrDefinition::get(
352 "dynamic_pair", testDialect,
353 [](function_ref<InFlightDiagnostic()> emitError,
354 ArrayRef<Attribute> args) {
355 if (args.size() != 2) {
356 emitError() << "expected 2 attribute arguments, but had "
357 << args.size();
358 return failure();
360 return success();
364 static std::unique_ptr<DynamicAttrDefinition>
365 getDynamicCustomAssemblyFormatAttr(TestDialect *testDialect) {
366 auto verifier = [](function_ref<InFlightDiagnostic()> emitError,
367 ArrayRef<Attribute> args) {
368 if (args.size() != 2) {
369 emitError() << "expected 2 attribute arguments, but had " << args.size();
370 return failure();
372 return success();
375 auto parser = [](AsmParser &parser,
376 llvm::SmallVectorImpl<Attribute> &parsedParams) {
377 Attribute leftAttr, rightAttr;
378 if (parser.parseLess() || parser.parseAttribute(leftAttr) ||
379 parser.parseColon() || parser.parseAttribute(rightAttr) ||
380 parser.parseGreater())
381 return failure();
382 parsedParams.push_back(leftAttr);
383 parsedParams.push_back(rightAttr);
384 return success();
387 auto printer = [](AsmPrinter &printer, ArrayRef<Attribute> params) {
388 printer << "<" << params[0] << ":" << params[1] << ">";
391 return DynamicAttrDefinition::get("dynamic_custom_assembly_format",
392 testDialect, std::move(verifier),
393 std::move(parser), std::move(printer));
396 //===----------------------------------------------------------------------===//
397 // TestDialect
398 //===----------------------------------------------------------------------===//
400 void TestDialect::registerAttributes() {
401 addAttributes<
402 #define GET_ATTRDEF_LIST
403 #include "TestAttrDefs.cpp.inc"
404 >();
405 registerDynamicAttr(getDynamicSingletonAttr(this));
406 registerDynamicAttr(getDynamicPairAttr(this));
407 registerDynamicAttr(getDynamicCustomAssemblyFormatAttr(this));