1 //===- TestOpsSyntax.cpp - Operations for testing syntax ------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "TestOpsSyntax.h"
10 #include "TestDialect.h"
12 #include "mlir/IR/OpImplementation.h"
13 #include "llvm/Support/Base64.h"
18 //===----------------------------------------------------------------------===//
19 // Test Format* operations
20 //===----------------------------------------------------------------------===//
22 //===----------------------------------------------------------------------===//
25 static ParseResult
parseCustomOptionalOperand(
27 std::optional
<OpAsmParser::UnresolvedOperand
> &optOperand
) {
28 if (succeeded(parser
.parseOptionalLParen())) {
30 if (parser
.parseOperand(*optOperand
) || parser
.parseRParen())
36 static ParseResult
parseCustomDirectiveOperands(
37 OpAsmParser
&parser
, OpAsmParser::UnresolvedOperand
&operand
,
38 std::optional
<OpAsmParser::UnresolvedOperand
> &optOperand
,
39 SmallVectorImpl
<OpAsmParser::UnresolvedOperand
> &varOperands
) {
40 if (parser
.parseOperand(operand
))
42 if (succeeded(parser
.parseOptionalComma())) {
44 if (parser
.parseOperand(*optOperand
))
47 if (parser
.parseArrow() || parser
.parseLParen() ||
48 parser
.parseOperandList(varOperands
) || parser
.parseRParen())
53 parseCustomDirectiveResults(OpAsmParser
&parser
, Type
&operandType
,
55 SmallVectorImpl
<Type
> &varOperandTypes
) {
56 if (parser
.parseColon())
59 if (parser
.parseType(operandType
))
61 if (succeeded(parser
.parseOptionalComma())) {
62 if (parser
.parseType(optOperandType
))
65 if (parser
.parseArrow() || parser
.parseLParen() ||
66 parser
.parseTypeList(varOperandTypes
) || parser
.parseRParen())
71 parseCustomDirectiveWithTypeRefs(OpAsmParser
&parser
, Type operandType
,
73 const SmallVectorImpl
<Type
> &varOperandTypes
) {
74 if (parser
.parseKeyword("type_refs_capture"))
77 Type operandType2
, optOperandType2
;
78 SmallVector
<Type
, 1> varOperandTypes2
;
79 if (parseCustomDirectiveResults(parser
, operandType2
, optOperandType2
,
83 if (operandType
!= operandType2
|| optOperandType
!= optOperandType2
||
84 varOperandTypes
!= varOperandTypes2
)
89 static ParseResult
parseCustomDirectiveOperandsAndTypes(
90 OpAsmParser
&parser
, OpAsmParser::UnresolvedOperand
&operand
,
91 std::optional
<OpAsmParser::UnresolvedOperand
> &optOperand
,
92 SmallVectorImpl
<OpAsmParser::UnresolvedOperand
> &varOperands
,
93 Type
&operandType
, Type
&optOperandType
,
94 SmallVectorImpl
<Type
> &varOperandTypes
) {
95 if (parseCustomDirectiveOperands(parser
, operand
, optOperand
, varOperands
) ||
96 parseCustomDirectiveResults(parser
, operandType
, optOperandType
,
101 static ParseResult
parseCustomDirectiveRegions(
102 OpAsmParser
&parser
, Region
®ion
,
103 SmallVectorImpl
<std::unique_ptr
<Region
>> &varRegions
) {
104 if (parser
.parseRegion(region
))
106 if (failed(parser
.parseOptionalComma()))
108 std::unique_ptr
<Region
> varRegion
= std::make_unique
<Region
>();
109 if (parser
.parseRegion(*varRegion
))
111 varRegions
.emplace_back(std::move(varRegion
));
115 parseCustomDirectiveSuccessors(OpAsmParser
&parser
, Block
*&successor
,
116 SmallVectorImpl
<Block
*> &varSuccessors
) {
117 if (parser
.parseSuccessor(successor
))
119 if (failed(parser
.parseOptionalComma()))
122 if (parser
.parseSuccessor(varSuccessor
))
124 varSuccessors
.append(2, varSuccessor
);
127 static ParseResult
parseCustomDirectiveAttributes(OpAsmParser
&parser
,
129 IntegerAttr
&optAttr
) {
130 if (parser
.parseAttribute(attr
))
132 if (succeeded(parser
.parseOptionalComma())) {
133 if (parser
.parseAttribute(optAttr
))
138 static ParseResult
parseCustomDirectiveSpacing(OpAsmParser
&parser
,
139 mlir::StringAttr
&attr
) {
140 return parser
.parseAttribute(attr
);
142 static ParseResult
parseCustomDirectiveAttrDict(OpAsmParser
&parser
,
143 NamedAttrList
&attrs
) {
144 return parser
.parseOptionalAttrDict(attrs
);
146 static ParseResult
parseCustomDirectiveOptionalOperandRef(
148 std::optional
<OpAsmParser::UnresolvedOperand
> &optOperand
) {
149 int64_t operandCount
= 0;
150 if (parser
.parseInteger(operandCount
))
152 bool expectedOptionalOperand
= operandCount
== 0;
153 return success(expectedOptionalOperand
!= optOperand
.has_value());
156 //===----------------------------------------------------------------------===//
159 static void printCustomOptionalOperand(OpAsmPrinter
&printer
, Operation
*,
162 printer
<< "(" << optOperand
<< ") ";
165 static void printCustomDirectiveOperands(OpAsmPrinter
&printer
, Operation
*,
166 Value operand
, Value optOperand
,
167 OperandRange varOperands
) {
170 printer
<< ", " << optOperand
;
171 printer
<< " -> (" << varOperands
<< ")";
173 static void printCustomDirectiveResults(OpAsmPrinter
&printer
, Operation
*,
174 Type operandType
, Type optOperandType
,
175 TypeRange varOperandTypes
) {
176 printer
<< " : " << operandType
;
178 printer
<< ", " << optOperandType
;
179 printer
<< " -> (" << varOperandTypes
<< ")";
181 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter
&printer
,
182 Operation
*op
, Type operandType
,
184 TypeRange varOperandTypes
) {
185 printer
<< " type_refs_capture ";
186 printCustomDirectiveResults(printer
, op
, operandType
, optOperandType
,
189 static void printCustomDirectiveOperandsAndTypes(
190 OpAsmPrinter
&printer
, Operation
*op
, Value operand
, Value optOperand
,
191 OperandRange varOperands
, Type operandType
, Type optOperandType
,
192 TypeRange varOperandTypes
) {
193 printCustomDirectiveOperands(printer
, op
, operand
, optOperand
, varOperands
);
194 printCustomDirectiveResults(printer
, op
, operandType
, optOperandType
,
197 static void printCustomDirectiveRegions(OpAsmPrinter
&printer
, Operation
*,
199 MutableArrayRef
<Region
> varRegions
) {
200 printer
.printRegion(region
);
201 if (!varRegions
.empty()) {
203 for (Region
®ion
: varRegions
)
204 printer
.printRegion(region
);
207 static void printCustomDirectiveSuccessors(OpAsmPrinter
&printer
, Operation
*,
209 SuccessorRange varSuccessors
) {
210 printer
<< successor
;
211 if (!varSuccessors
.empty())
212 printer
<< ", " << varSuccessors
.front();
214 static void printCustomDirectiveAttributes(OpAsmPrinter
&printer
, Operation
*,
216 Attribute optAttribute
) {
217 printer
<< attribute
;
219 printer
<< ", " << optAttribute
;
221 static void printCustomDirectiveSpacing(OpAsmPrinter
&printer
, Operation
*op
,
222 Attribute attribute
) {
223 printer
<< attribute
;
225 static void printCustomDirectiveAttrDict(OpAsmPrinter
&printer
, Operation
*op
,
226 DictionaryAttr attrs
) {
227 printer
.printOptionalAttrDict(attrs
.getValue());
230 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter
&printer
,
233 printer
<< (optOperand
? "1" : "0");
235 //===----------------------------------------------------------------------===//
237 //===----------------------------------------------------------------------===//
239 ParseResult
ParseIntegerLiteralOp::parse(OpAsmParser
&parser
,
240 OperationState
&result
) {
241 if (parser
.parseOptionalColon())
244 if (parser
.parseInteger(numResults
))
247 IndexType type
= parser
.getBuilder().getIndexType();
248 for (unsigned i
= 0; i
< numResults
; ++i
)
249 result
.addTypes(type
);
253 void ParseIntegerLiteralOp::print(OpAsmPrinter
&p
) {
254 if (unsigned numResults
= getNumResults())
255 p
<< " : " << numResults
;
258 ParseResult
ParseWrappedKeywordOp::parse(OpAsmParser
&parser
,
259 OperationState
&result
) {
261 if (parser
.parseKeyword(&keyword
))
263 result
.addAttribute("keyword", parser
.getBuilder().getStringAttr(keyword
));
267 void ParseWrappedKeywordOp::print(OpAsmPrinter
&p
) { p
<< " " << getKeyword(); }
269 ParseResult
ParseB64BytesOp::parse(OpAsmParser
&parser
,
270 OperationState
&result
) {
271 std::vector
<char> bytes
;
272 if (parser
.parseBase64Bytes(&bytes
))
274 result
.addAttribute("b64", parser
.getBuilder().getStringAttr(
275 StringRef(&bytes
.front(), bytes
.size())));
279 void ParseB64BytesOp::print(OpAsmPrinter
&p
) {
280 p
<< " \"" << llvm::encodeBase64(getB64()) << "\"";
283 ::llvm::LogicalResult
FormatInferType2Op::inferReturnTypes(
284 ::mlir::MLIRContext
*context
, ::std::optional
<::mlir::Location
> location
,
285 ::mlir::ValueRange operands
, ::mlir::DictionaryAttr attributes
,
286 OpaqueProperties properties
, ::mlir::RegionRange regions
,
287 ::llvm::SmallVectorImpl
<::mlir::Type
> &inferredReturnTypes
) {
288 inferredReturnTypes
.assign({::mlir::IntegerType::get(context
, 16)});
289 return ::mlir::success();
292 //===----------------------------------------------------------------------===//
293 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`.
295 ParseResult
WrappingRegionOp::parse(OpAsmParser
&parser
,
296 OperationState
&result
) {
297 if (parser
.parseKeyword("wraps"))
300 // Parse the wrapped op in a region
301 Region
&body
= *result
.addRegion();
302 body
.push_back(new Block
);
303 Block
&block
= body
.back();
304 Operation
*wrappedOp
= parser
.parseGenericOperation(&block
, block
.begin());
308 // Create a return terminator in the inner region, pass as operand to the
309 // terminator the returned values from the wrapped operation.
310 SmallVector
<Value
, 8> returnOperands(wrappedOp
->getResults());
311 OpBuilder
builder(parser
.getContext());
312 builder
.setInsertionPointToEnd(&block
);
313 builder
.create
<TestReturnOp
>(wrappedOp
->getLoc(), returnOperands
);
315 // Get the results type for the wrapping op from the terminator operands.
316 Operation
&returnOp
= body
.back().back();
317 result
.types
.append(returnOp
.operand_type_begin(),
318 returnOp
.operand_type_end());
320 // Use the location of the wrapped op for the "test.wrapping_region" op.
321 result
.location
= wrappedOp
->getLoc();
326 void WrappingRegionOp::print(OpAsmPrinter
&p
) {
328 p
.printGenericOp(&getRegion().front().front());
331 //===----------------------------------------------------------------------===//
332 // Test PrettyPrintedRegionOp - exercising the following parser APIs
333 // parseGenericOperationAfterOpName
334 // parseCustomOperationName
335 //===----------------------------------------------------------------------===//
337 ParseResult
PrettyPrintedRegionOp::parse(OpAsmParser
&parser
,
338 OperationState
&result
) {
340 SMLoc loc
= parser
.getCurrentLocation();
341 Location currLocation
= parser
.getEncodedSourceLoc(loc
);
343 // Parse the operands.
344 SmallVector
<OpAsmParser::UnresolvedOperand
, 2> operands
;
345 if (parser
.parseOperandList(operands
))
348 // Check if we are parsing the pretty-printed version
349 // test.pretty_printed_region start <inner-op> end : <functional-type>
350 // Else fallback to parsing the "non pretty-printed" version.
351 if (!succeeded(parser
.parseOptionalKeyword("start")))
352 return parser
.parseGenericOperationAfterOpName(result
,
353 llvm::ArrayRef(operands
));
355 FailureOr
<OperationName
> parseOpNameInfo
= parser
.parseCustomOperationName();
356 if (failed(parseOpNameInfo
))
359 StringAttr innerOpName
= parseOpNameInfo
->getIdentifier();
361 FunctionType opFntype
;
362 std::optional
<Location
> explicitLoc
;
363 if (parser
.parseKeyword("end") || parser
.parseColon() ||
364 parser
.parseType(opFntype
) ||
365 parser
.parseOptionalLocationSpecifier(explicitLoc
))
368 // If location of the op is explicitly provided, then use it; Else use
369 // the parser's current location.
370 Location opLoc
= explicitLoc
.value_or(currLocation
);
372 // Derive the SSA-values for op's operands.
373 if (parser
.resolveOperands(operands
, opFntype
.getInputs(), loc
,
377 // Add a region for op.
378 Region
®ion
= *result
.addRegion();
380 // Create a basic-block inside op's region.
381 Block
&block
= region
.emplaceBlock();
383 // Create and insert an "inner-op" operation in the block.
384 // Just for testing purposes, we can assume that inner op is a binary op with
385 // result and operand types all same as the test-op's first operand.
386 Type innerOpType
= opFntype
.getInput(0);
387 Value lhs
= block
.addArgument(innerOpType
, opLoc
);
388 Value rhs
= block
.addArgument(innerOpType
, opLoc
);
390 OpBuilder
builder(parser
.getBuilder().getContext());
391 builder
.setInsertionPointToStart(&block
);
394 builder
.create(opLoc
, innerOpName
, /*operands=*/{lhs
, rhs
}, innerOpType
);
396 // Insert a return statement in the block returning the inner-op's result.
397 builder
.create
<TestReturnOp
>(innerOp
->getLoc(), innerOp
->getResults());
399 // Populate the op operation-state with result-type and location.
400 result
.addTypes(opFntype
.getResults());
401 result
.location
= innerOp
->getLoc();
406 void PrettyPrintedRegionOp::print(OpAsmPrinter
&p
) {
408 p
.printOperands(getOperands());
410 Operation
&innerOp
= getRegion().front().front();
411 // Assuming that region has a single non-terminator inner-op, if the inner-op
412 // meets some criteria (which in this case is a simple one based on the name
413 // of inner-op), then we can print the entire region in a succinct way.
414 // Here we assume that the prototype of "test.special.op" can be trivially
415 // derived while parsing it back.
416 if (innerOp
.getName().getStringRef() == "test.special.op") {
417 p
<< " start test.special.op end";
420 p
.printRegion(getRegion());
425 p
.printFunctionalType(*this);
428 //===----------------------------------------------------------------------===//
429 // Test PolyForOp - parse list of region arguments.
430 //===----------------------------------------------------------------------===//
432 ParseResult
PolyForOp::parse(OpAsmParser
&parser
, OperationState
&result
) {
433 SmallVector
<OpAsmParser::Argument
, 4> ivsInfo
;
434 // Parse list of region arguments without a delimiter.
435 if (parser
.parseArgumentList(ivsInfo
, OpAsmParser::Delimiter::None
))
438 // Parse the body region.
439 Region
*body
= result
.addRegion();
440 for (auto &iv
: ivsInfo
)
441 iv
.type
= parser
.getBuilder().getIndexType();
442 return parser
.parseRegion(*body
, ivsInfo
);
445 void PolyForOp::print(OpAsmPrinter
&p
) {
447 llvm::interleaveComma(getRegion().getArguments(), p
, [&](auto arg
) {
448 p
.printRegionArgument(arg
, /*argAttrs =*/{}, /*omitType=*/true);
451 p
.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
454 void PolyForOp::getAsmBlockArgumentNames(Region
®ion
,
455 OpAsmSetValueNameFn setNameFn
) {
456 auto arrayAttr
= getOperation()->getAttrOfType
<ArrayAttr
>("arg_names");
459 auto args
= getRegion().front().getArguments();
460 auto e
= std::min(arrayAttr
.size(), args
.size());
461 for (unsigned i
= 0; i
< e
; ++i
) {
462 if (auto strAttr
= dyn_cast
<StringAttr
>(arrayAttr
[i
]))
463 setNameFn(args
[i
], strAttr
.getValue());
467 //===----------------------------------------------------------------------===//
468 // TestAttrWithLoc - parse/printOptionalLocationSpecifier
469 //===----------------------------------------------------------------------===//
471 static ParseResult
parseOptionalLoc(OpAsmParser
&p
, Attribute
&loc
) {
472 std::optional
<Location
> result
;
473 SMLoc sourceLoc
= p
.getCurrentLocation();
474 if (p
.parseOptionalLocationSpecifier(result
))
479 loc
= p
.getEncodedSourceLoc(sourceLoc
);
483 static void printOptionalLoc(OpAsmPrinter
&p
, Operation
*op
, Attribute loc
) {
484 p
.printOptionalLocationSpecifier(cast
<LocationAttr
>(loc
));
487 #define GET_OP_CLASSES
488 #include "TestOpsSyntax.cpp.inc"
490 void TestDialect::registerOpsSyntax() {
493 #include "TestOpsSyntax.cpp.inc"