1 //===- TestAttributes.cpp - MLIR Test Dialect Attributes --------*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file contains attributes defined by the TestDialect for testing various
12 //===----------------------------------------------------------------------===//
14 #include "TestAttributes.h"
15 #include "TestDialect.h"
16 #include "TestTypes.h"
17 #include "mlir/IR/Attributes.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/DialectImplementation.h"
20 #include "mlir/IR/ExtensibleDialect.h"
21 #include "mlir/IR/OpImplementation.h"
22 #include "mlir/IR/Types.h"
23 #include "llvm/ADT/APFloat.h"
24 #include "llvm/ADT/Hashing.h"
25 #include "llvm/ADT/StringExtras.h"
26 #include "llvm/ADT/TypeSwitch.h"
27 #include "llvm/ADT/bit.h"
28 #include "llvm/Support/ErrorHandling.h"
29 #include "llvm/Support/raw_ostream.h"
34 //===----------------------------------------------------------------------===//
36 //===----------------------------------------------------------------------===//
38 Attribute
CompoundAAttr::parse(AsmParser
&parser
, Type type
) {
41 SmallVector
<int, 4> arrayOfInts
;
42 if (parser
.parseLess() || parser
.parseInteger(widthOfSomething
) ||
43 parser
.parseComma() || parser
.parseType(oneType
) || parser
.parseComma() ||
44 parser
.parseLSquare())
48 while (!*parser
.parseOptionalInteger(intVal
)) {
49 arrayOfInts
.push_back(intVal
);
50 if (parser
.parseOptionalComma())
54 if (parser
.parseRSquare() || parser
.parseGreater())
56 return get(parser
.getContext(), widthOfSomething
, oneType
, arrayOfInts
);
59 void CompoundAAttr::print(AsmPrinter
&printer
) const {
60 printer
<< "<" << getWidthOfSomething() << ", " << getOneType() << ", [";
61 llvm::interleaveComma(getArrayOfInts(), printer
);
65 //===----------------------------------------------------------------------===//
67 //===----------------------------------------------------------------------===//
69 Attribute
TestDecimalShapeAttr::parse(AsmParser
&parser
, Type type
) {
70 if (parser
.parseLess()){
73 SmallVector
<int64_t> shape
;
74 if (parser
.parseOptionalGreater()) {
75 auto parseDecimal
= [&]() {
77 auto parseResult
= parser
.parseOptionalDecimalInteger(shape
.back());
78 if (!parseResult
.has_value() || failed(*parseResult
)) {
79 parser
.emitError(parser
.getCurrentLocation()) << "expected an integer";
84 if (failed(parseDecimal())) {
87 while (failed(parser
.parseOptionalGreater())) {
88 if (failed(parser
.parseXInDimensionList()) || failed(parseDecimal())) {
93 return get(parser
.getContext(), shape
);
96 void TestDecimalShapeAttr::print(AsmPrinter
&printer
) const {
98 llvm::interleave(getShape(), printer
, "x");
102 Attribute
TestI64ElementsAttr::parse(AsmParser
&parser
, Type type
) {
103 SmallVector
<uint64_t> elements
;
104 if (parser
.parseLess() || parser
.parseLSquare())
107 while (succeeded(*parser
.parseOptionalInteger(intVal
))) {
108 elements
.push_back(intVal
);
109 if (parser
.parseOptionalComma())
113 if (parser
.parseRSquare() || parser
.parseGreater())
115 return parser
.getChecked
<TestI64ElementsAttr
>(
116 parser
.getContext(), llvm::cast
<ShapedType
>(type
), elements
);
119 void TestI64ElementsAttr::print(AsmPrinter
&printer
) const {
121 llvm::interleaveComma(getElements(), printer
);
126 TestI64ElementsAttr::verify(function_ref
<InFlightDiagnostic()> emitError
,
127 ShapedType type
, ArrayRef
<uint64_t> elements
) {
128 if (type
.getNumElements() != static_cast<int64_t>(elements
.size())) {
130 << "number of elements does not match the provided shape type, got: "
131 << elements
.size() << ", but expected: " << type
.getNumElements();
133 if (type
.getRank() != 1 || !type
.getElementType().isSignlessInteger(64))
134 return emitError() << "expected single rank 64-bit shape type, but got: "
139 LogicalResult
TestAttrWithFormatAttr::verify(
140 function_ref
<InFlightDiagnostic()> emitError
, int64_t one
, std::string two
,
141 IntegerAttr three
, ArrayRef
<int> four
, uint64_t five
, ArrayRef
<int> six
,
142 ArrayRef
<AttrWithTypeBuilderAttr
> arrayOfAttrs
) {
143 if (four
.size() != static_cast<unsigned>(one
))
144 return emitError() << "expected 'one' to equal 'four.size()'";
148 //===----------------------------------------------------------------------===//
149 // Utility Functions for Generated Attributes
150 //===----------------------------------------------------------------------===//
152 static FailureOr
<SmallVector
<int>> parseIntArray(AsmParser
&parser
) {
153 SmallVector
<int> ints
;
154 if (parser
.parseLSquare() || parser
.parseCommaSeparatedList([&]() {
156 return parser
.parseInteger(ints
.back());
158 parser
.parseRSquare())
163 static void printIntArray(AsmPrinter
&printer
, ArrayRef
<int> ints
) {
165 llvm::interleaveComma(ints
, printer
);
169 //===----------------------------------------------------------------------===//
170 // TestSubElementsAccessAttr
171 //===----------------------------------------------------------------------===//
173 Attribute
TestSubElementsAccessAttr::parse(::mlir::AsmParser
&parser
,
175 Attribute first
, second
, third
;
176 if (parser
.parseLess() || parser
.parseAttribute(first
) ||
177 parser
.parseComma() || parser
.parseAttribute(second
) ||
178 parser
.parseComma() || parser
.parseAttribute(third
) ||
179 parser
.parseGreater()) {
182 return get(parser
.getContext(), first
, second
, third
);
185 void TestSubElementsAccessAttr::print(::mlir::AsmPrinter
&printer
) const {
186 printer
<< "<" << getFirst() << ", " << getSecond() << ", " << getThird()
190 //===----------------------------------------------------------------------===//
191 // TestExtern1DI64ElementsAttr
192 //===----------------------------------------------------------------------===//
194 ArrayRef
<uint64_t> TestExtern1DI64ElementsAttr::getElements() const {
195 if (auto *blob
= getHandle().getBlob())
196 return blob
->getDataAs
<uint64_t>();
200 //===----------------------------------------------------------------------===//
201 // TestCustomAnchorAttr
202 //===----------------------------------------------------------------------===//
204 static ParseResult
parseTrueFalse(AsmParser
&p
, std::optional
<int> &result
) {
206 if (p
.parseInteger(b
))
212 static void printTrueFalse(AsmPrinter
&p
, std::optional
<int> result
) {
213 p
<< (*result
? "true" : "false");
216 //===----------------------------------------------------------------------===//
217 // CopyCountAttr Implementation
218 //===----------------------------------------------------------------------===//
220 CopyCount::CopyCount(const CopyCount
&rhs
) : value(rhs
.value
) {
221 CopyCount::counter
++;
224 CopyCount
&CopyCount::operator=(const CopyCount
&rhs
) {
225 CopyCount::counter
++;
230 int CopyCount::counter
;
232 static bool operator==(const test::CopyCount
&lhs
, const test::CopyCount
&rhs
) {
233 return lhs
.value
== rhs
.value
;
236 llvm::raw_ostream
&test::operator<<(llvm::raw_ostream
&os
,
237 const test::CopyCount
&value
) {
238 return os
<< value
.value
;
242 struct mlir::FieldParser
<test::CopyCount
> {
243 static FailureOr
<test::CopyCount
> parse(AsmParser
&parser
) {
245 if (parser
.parseKeyword(value
))
247 return test::CopyCount(value
);
251 llvm::hash_code
hash_value(const test::CopyCount
©Count
) {
252 return llvm::hash_value(copyCount
.value
);
256 //===----------------------------------------------------------------------===//
257 // TestConditionalAliasAttr
258 //===----------------------------------------------------------------------===//
260 /// Attempt to parse the conditionally-aliased string attribute as a keyword or
261 /// string, else try to parse an alias.
262 static ParseResult
parseConditionalAlias(AsmParser
&p
, StringAttr
&value
) {
264 if (succeeded(p
.parseOptionalKeywordOrString(&str
))) {
265 value
= StringAttr::get(p
.getContext(), str
);
268 return p
.parseAttribute(value
);
271 /// Print the string attribute as an alias if it has one, otherwise print it as
272 /// a keyword if possible.
273 static void printConditionalAlias(AsmPrinter
&p
, StringAttr value
) {
274 if (succeeded(p
.printAlias(value
)))
276 p
.printKeywordOrString(value
);
279 //===----------------------------------------------------------------------===//
280 // Custom Float Attribute
281 //===----------------------------------------------------------------------===//
283 static void printCustomFloatAttr(AsmPrinter
&p
, StringAttr typeStrAttr
,
285 p
<< typeStrAttr
<< " : " << value
;
288 static ParseResult
parseCustomFloatAttr(AsmParser
&p
, StringAttr
&typeStrAttr
,
289 FailureOr
<APFloat
> &value
) {
292 if (p
.parseString(&str
))
295 typeStrAttr
= StringAttr::get(p
.getContext(), str
);
300 const llvm::fltSemantics
*semantics
;
302 semantics
= &llvm::APFloat::IEEEsingle();
303 else if (str
== "double")
304 semantics
= &llvm::APFloat::IEEEdouble();
305 else if (str
== "fp80")
306 semantics
= &llvm::APFloat::x87DoubleExtended();
308 return p
.emitError(p
.getCurrentLocation(), "unknown float type, expected "
309 "'float', 'double' or 'fp80'");
311 APFloat
parsedValue(0.0);
312 if (p
.parseFloat(*semantics
, parsedValue
))
315 value
.emplace(parsedValue
);
319 //===----------------------------------------------------------------------===//
320 // Tablegen Generated Definitions
321 //===----------------------------------------------------------------------===//
323 #include "TestAttrInterfaces.cpp.inc"
324 #include "TestOpEnums.cpp.inc"
325 #define GET_ATTRDEF_CLASSES
326 #include "TestAttrDefs.cpp.inc"
328 //===----------------------------------------------------------------------===//
329 // Dynamic Attributes
330 //===----------------------------------------------------------------------===//
332 /// Define a singleton dynamic attribute.
333 static std::unique_ptr
<DynamicAttrDefinition
>
334 getDynamicSingletonAttr(TestDialect
*testDialect
) {
335 return DynamicAttrDefinition::get(
336 "dynamic_singleton", testDialect
,
337 [](function_ref
<InFlightDiagnostic()> emitError
,
338 ArrayRef
<Attribute
> args
) {
340 emitError() << "expected 0 attribute arguments, but had "
348 /// Define a dynamic attribute representing a pair or attributes.
349 static std::unique_ptr
<DynamicAttrDefinition
>
350 getDynamicPairAttr(TestDialect
*testDialect
) {
351 return DynamicAttrDefinition::get(
352 "dynamic_pair", testDialect
,
353 [](function_ref
<InFlightDiagnostic()> emitError
,
354 ArrayRef
<Attribute
> args
) {
355 if (args
.size() != 2) {
356 emitError() << "expected 2 attribute arguments, but had "
364 static std::unique_ptr
<DynamicAttrDefinition
>
365 getDynamicCustomAssemblyFormatAttr(TestDialect
*testDialect
) {
366 auto verifier
= [](function_ref
<InFlightDiagnostic()> emitError
,
367 ArrayRef
<Attribute
> args
) {
368 if (args
.size() != 2) {
369 emitError() << "expected 2 attribute arguments, but had " << args
.size();
375 auto parser
= [](AsmParser
&parser
,
376 llvm::SmallVectorImpl
<Attribute
> &parsedParams
) {
377 Attribute leftAttr
, rightAttr
;
378 if (parser
.parseLess() || parser
.parseAttribute(leftAttr
) ||
379 parser
.parseColon() || parser
.parseAttribute(rightAttr
) ||
380 parser
.parseGreater())
382 parsedParams
.push_back(leftAttr
);
383 parsedParams
.push_back(rightAttr
);
387 auto printer
= [](AsmPrinter
&printer
, ArrayRef
<Attribute
> params
) {
388 printer
<< "<" << params
[0] << ":" << params
[1] << ">";
391 return DynamicAttrDefinition::get("dynamic_custom_assembly_format",
392 testDialect
, std::move(verifier
),
393 std::move(parser
), std::move(printer
));
396 //===----------------------------------------------------------------------===//
398 //===----------------------------------------------------------------------===//
400 void TestDialect::registerAttributes() {
402 #define GET_ATTRDEF_LIST
403 #include "TestAttrDefs.cpp.inc"
405 registerDynamicAttr(getDynamicSingletonAttr(this));
406 registerDynamicAttr(getDynamicPairAttr(this));
407 registerDynamicAttr(getDynamicCustomAssemblyFormatAttr(this));