1 //===- TestFormatUtils.cpp - MLIR Test Dialect Assembly Format Utilities --===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "TestFormatUtils.h"
10 #include "mlir/IR/Builders.h"
15 //===----------------------------------------------------------------------===//
16 // CustomDirectiveOperands
17 //===----------------------------------------------------------------------===//
19 ParseResult
test::parseCustomDirectiveOperands(
20 OpAsmParser
&parser
, OpAsmParser::UnresolvedOperand
&operand
,
21 std::optional
<OpAsmParser::UnresolvedOperand
> &optOperand
,
22 SmallVectorImpl
<OpAsmParser::UnresolvedOperand
> &varOperands
) {
23 if (parser
.parseOperand(operand
))
25 if (succeeded(parser
.parseOptionalComma())) {
27 if (parser
.parseOperand(*optOperand
))
30 if (parser
.parseArrow() || parser
.parseLParen() ||
31 parser
.parseOperandList(varOperands
) || parser
.parseRParen())
36 void test::printCustomDirectiveOperands(OpAsmPrinter
&printer
, Operation
*,
37 Value operand
, Value optOperand
,
38 OperandRange varOperands
) {
41 printer
<< ", " << optOperand
;
42 printer
<< " -> (" << varOperands
<< ")";
45 //===----------------------------------------------------------------------===//
46 // CustomDirectiveResults
47 //===----------------------------------------------------------------------===//
50 test::parseCustomDirectiveResults(OpAsmParser
&parser
, Type
&operandType
,
52 SmallVectorImpl
<Type
> &varOperandTypes
) {
53 if (parser
.parseColon())
56 if (parser
.parseType(operandType
))
58 if (succeeded(parser
.parseOptionalComma()))
59 if (parser
.parseType(optOperandType
))
61 if (parser
.parseArrow() || parser
.parseLParen() ||
62 parser
.parseTypeList(varOperandTypes
) || parser
.parseRParen())
67 void test::printCustomDirectiveResults(OpAsmPrinter
&printer
, Operation
*,
68 Type operandType
, Type optOperandType
,
69 TypeRange varOperandTypes
) {
70 printer
<< " : " << operandType
;
72 printer
<< ", " << optOperandType
;
73 printer
<< " -> (" << varOperandTypes
<< ")";
76 //===----------------------------------------------------------------------===//
77 // CustomDirectiveWithTypeRefs
78 //===----------------------------------------------------------------------===//
80 ParseResult
test::parseCustomDirectiveWithTypeRefs(
81 OpAsmParser
&parser
, Type operandType
, Type optOperandType
,
82 const SmallVectorImpl
<Type
> &varOperandTypes
) {
83 if (parser
.parseKeyword("type_refs_capture"))
86 Type operandType2
, optOperandType2
;
87 SmallVector
<Type
, 1> varOperandTypes2
;
88 if (parseCustomDirectiveResults(parser
, operandType2
, optOperandType2
,
92 if (operandType
!= operandType2
|| optOperandType
!= optOperandType2
||
93 varOperandTypes
!= varOperandTypes2
)
99 void test::printCustomDirectiveWithTypeRefs(OpAsmPrinter
&printer
,
100 Operation
*op
, Type operandType
,
102 TypeRange varOperandTypes
) {
103 printer
<< " type_refs_capture ";
104 printCustomDirectiveResults(printer
, op
, operandType
, optOperandType
,
108 //===----------------------------------------------------------------------===//
109 // CustomDirectiveOperandsAndTypes
110 //===----------------------------------------------------------------------===//
112 ParseResult
test::parseCustomDirectiveOperandsAndTypes(
113 OpAsmParser
&parser
, OpAsmParser::UnresolvedOperand
&operand
,
114 std::optional
<OpAsmParser::UnresolvedOperand
> &optOperand
,
115 SmallVectorImpl
<OpAsmParser::UnresolvedOperand
> &varOperands
,
116 Type
&operandType
, Type
&optOperandType
,
117 SmallVectorImpl
<Type
> &varOperandTypes
) {
118 if (parseCustomDirectiveOperands(parser
, operand
, optOperand
, varOperands
) ||
119 parseCustomDirectiveResults(parser
, operandType
, optOperandType
,
125 void test::printCustomDirectiveOperandsAndTypes(
126 OpAsmPrinter
&printer
, Operation
*op
, Value operand
, Value optOperand
,
127 OperandRange varOperands
, Type operandType
, Type optOperandType
,
128 TypeRange varOperandTypes
) {
129 printCustomDirectiveOperands(printer
, op
, operand
, optOperand
, varOperands
);
130 printCustomDirectiveResults(printer
, op
, operandType
, optOperandType
,
134 //===----------------------------------------------------------------------===//
135 // CustomDirectiveRegions
136 //===----------------------------------------------------------------------===//
138 ParseResult
test::parseCustomDirectiveRegions(
139 OpAsmParser
&parser
, Region
®ion
,
140 SmallVectorImpl
<std::unique_ptr
<Region
>> &varRegions
) {
141 if (parser
.parseRegion(region
))
143 if (failed(parser
.parseOptionalComma()))
145 std::unique_ptr
<Region
> varRegion
= std::make_unique
<Region
>();
146 if (parser
.parseRegion(*varRegion
))
148 varRegions
.emplace_back(std::move(varRegion
));
152 void test::printCustomDirectiveRegions(OpAsmPrinter
&printer
, Operation
*,
154 MutableArrayRef
<Region
> varRegions
) {
155 printer
.printRegion(region
);
156 if (!varRegions
.empty()) {
158 for (Region
®ion
: varRegions
)
159 printer
.printRegion(region
);
163 //===----------------------------------------------------------------------===//
164 // CustomDirectiveSuccessors
165 //===----------------------------------------------------------------------===//
168 test::parseCustomDirectiveSuccessors(OpAsmParser
&parser
, Block
*&successor
,
169 SmallVectorImpl
<Block
*> &varSuccessors
) {
170 if (parser
.parseSuccessor(successor
))
172 if (failed(parser
.parseOptionalComma()))
175 if (parser
.parseSuccessor(varSuccessor
))
177 varSuccessors
.append(2, varSuccessor
);
181 void test::printCustomDirectiveSuccessors(OpAsmPrinter
&printer
, Operation
*,
183 SuccessorRange varSuccessors
) {
184 printer
<< successor
;
185 if (!varSuccessors
.empty())
186 printer
<< ", " << varSuccessors
.front();
189 //===----------------------------------------------------------------------===//
190 // CustomDirectiveAttributes
191 //===----------------------------------------------------------------------===//
193 ParseResult
test::parseCustomDirectiveAttributes(OpAsmParser
&parser
,
195 IntegerAttr
&optAttr
) {
196 if (parser
.parseAttribute(attr
))
198 if (succeeded(parser
.parseOptionalComma())) {
199 if (parser
.parseAttribute(optAttr
))
205 void test::printCustomDirectiveAttributes(OpAsmPrinter
&printer
, Operation
*,
207 Attribute optAttribute
) {
208 printer
<< attribute
;
210 printer
<< ", " << optAttribute
;
213 //===----------------------------------------------------------------------===//
214 // CustomDirectiveAttrDict
215 //===----------------------------------------------------------------------===//
217 ParseResult
test::parseCustomDirectiveAttrDict(OpAsmParser
&parser
,
218 NamedAttrList
&attrs
) {
219 return parser
.parseOptionalAttrDict(attrs
);
222 void test::printCustomDirectiveAttrDict(OpAsmPrinter
&printer
, Operation
*op
,
223 DictionaryAttr attrs
) {
224 printer
.printOptionalAttrDict(attrs
.getValue());
227 //===----------------------------------------------------------------------===//
228 // CustomDirectiveOptionalOperandRef
229 //===----------------------------------------------------------------------===//
231 ParseResult
test::parseCustomDirectiveOptionalOperandRef(
233 std::optional
<OpAsmParser::UnresolvedOperand
> &optOperand
) {
234 int64_t operandCount
= 0;
235 if (parser
.parseInteger(operandCount
))
237 bool expectedOptionalOperand
= operandCount
== 0;
238 return success(expectedOptionalOperand
!= !!optOperand
);
241 void test::printCustomDirectiveOptionalOperandRef(OpAsmPrinter
&printer
,
244 printer
<< (optOperand
? "1" : "0");
247 //===----------------------------------------------------------------------===//
248 // CustomDirectiveOptionalOperand
249 //===----------------------------------------------------------------------===//
251 ParseResult
test::parseCustomOptionalOperand(
253 std::optional
<OpAsmParser::UnresolvedOperand
> &optOperand
) {
254 if (succeeded(parser
.parseOptionalLParen())) {
255 optOperand
.emplace();
256 if (parser
.parseOperand(*optOperand
) || parser
.parseRParen())
262 void test::printCustomOptionalOperand(OpAsmPrinter
&printer
, Operation
*,
265 printer
<< "(" << optOperand
<< ") ";
268 //===----------------------------------------------------------------------===//
269 // CustomDirectiveSwitchCases
270 //===----------------------------------------------------------------------===//
273 test::parseSwitchCases(OpAsmParser
&p
, DenseI64ArrayAttr
&cases
,
274 SmallVectorImpl
<std::unique_ptr
<Region
>> &caseRegions
) {
275 SmallVector
<int64_t> caseValues
;
276 while (succeeded(p
.parseOptionalKeyword("case"))) {
278 Region
®ion
= *caseRegions
.emplace_back(std::make_unique
<Region
>());
279 if (p
.parseInteger(value
) || p
.parseRegion(region
, /*arguments=*/{}))
281 caseValues
.push_back(value
);
283 cases
= p
.getBuilder().getDenseI64ArrayAttr(caseValues
);
287 void test::printSwitchCases(OpAsmPrinter
&p
, Operation
*op
,
288 DenseI64ArrayAttr cases
, RegionRange caseRegions
) {
289 for (auto [value
, region
] : llvm::zip(cases
.asArrayRef(), caseRegions
)) {
291 p
<< "case " << value
<< ' ';
292 p
.printRegion(*region
, /*printEntryBlockArgs=*/false);
296 //===----------------------------------------------------------------------===//
297 // CustomUsingPropertyInCustom
298 //===----------------------------------------------------------------------===//
300 bool test::parseUsingPropertyInCustom(OpAsmParser
&parser
,
301 SmallVector
<int64_t> &value
) {
302 auto elemParser
= [&]() {
304 if (failed(parser
.parseInteger(v
)))
309 return failed(parser
.parseCommaSeparatedList(OpAsmParser::Delimiter::Square
,
313 void test::printUsingPropertyInCustom(OpAsmPrinter
&printer
, Operation
*op
,
314 ArrayRef
<int64_t> value
) {
315 printer
<< '[' << value
<< ']';
318 //===----------------------------------------------------------------------===//
319 // CustomDirectiveIntProperty
320 //===----------------------------------------------------------------------===//
322 bool test::parseIntProperty(OpAsmParser
&parser
, int64_t &value
) {
323 return failed(parser
.parseInteger(value
));
326 void test::printIntProperty(OpAsmPrinter
&printer
, Operation
*op
,
331 //===----------------------------------------------------------------------===//
332 // CustomDirectiveSumProperty
333 //===----------------------------------------------------------------------===//
335 bool test::parseSumProperty(OpAsmParser
&parser
, int64_t &second
,
338 auto loc
= parser
.getCurrentLocation();
339 if (parser
.parseInteger(second
) || parser
.parseEqual() ||
340 parser
.parseInteger(sum
))
342 if (sum
!= second
+ first
) {
343 parser
.emitError(loc
, "Expected sum to equal first + second");
349 void test::printSumProperty(OpAsmPrinter
&printer
, Operation
*op
,
350 int64_t second
, int64_t first
) {
351 printer
<< second
<< " = " << (second
+ first
);
354 //===----------------------------------------------------------------------===//
355 // CustomDirectiveOptionalCustomParser
356 //===----------------------------------------------------------------------===//
358 OptionalParseResult
test::parseOptionalCustomParser(AsmParser
&p
,
359 IntegerAttr
&result
) {
360 if (succeeded(p
.parseOptionalKeyword("foo")))
361 return p
.parseAttribute(result
);
365 void test::printOptionalCustomParser(AsmPrinter
&p
, Operation
*,
366 IntegerAttr result
) {
368 p
.printAttribute(result
);
371 //===----------------------------------------------------------------------===//
372 // CustomDirectiveAttrElideType
373 //===----------------------------------------------------------------------===//
375 ParseResult
test::parseAttrElideType(AsmParser
&parser
, TypeAttr type
,
377 return parser
.parseAttribute(attr
, type
.getValue());
380 void test::printAttrElideType(AsmPrinter
&printer
, Operation
*op
, TypeAttr type
,
382 printer
.printAttributeWithoutType(attr
);