Fix GCC build problem with 288f05f related to SmallVector. (#116958)
[llvm-project.git] / mlir / test / lib / Dialect / Test / TestOpsSyntax.cpp
blob664951f2a11bbea3fdf67439e4f84692361c3880
1 //===- TestOpsSyntax.cpp - Operations for testing syntax ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "TestOpsSyntax.h"
10 #include "TestDialect.h"
11 #include "TestOps.h"
12 #include "mlir/IR/OpImplementation.h"
13 #include "llvm/Support/Base64.h"
15 using namespace mlir;
16 using namespace test;
18 //===----------------------------------------------------------------------===//
19 // Test Format* operations
20 //===----------------------------------------------------------------------===//
22 //===----------------------------------------------------------------------===//
23 // Parsing
25 static ParseResult parseCustomOptionalOperand(
26 OpAsmParser &parser,
27 std::optional<OpAsmParser::UnresolvedOperand> &optOperand) {
28 if (succeeded(parser.parseOptionalLParen())) {
29 optOperand.emplace();
30 if (parser.parseOperand(*optOperand) || parser.parseRParen())
31 return failure();
33 return success();
36 static ParseResult parseCustomDirectiveOperands(
37 OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand,
38 std::optional<OpAsmParser::UnresolvedOperand> &optOperand,
39 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands) {
40 if (parser.parseOperand(operand))
41 return failure();
42 if (succeeded(parser.parseOptionalComma())) {
43 optOperand.emplace();
44 if (parser.parseOperand(*optOperand))
45 return failure();
47 if (parser.parseArrow() || parser.parseLParen() ||
48 parser.parseOperandList(varOperands) || parser.parseRParen())
49 return failure();
50 return success();
52 static ParseResult
53 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType,
54 Type &optOperandType,
55 SmallVectorImpl<Type> &varOperandTypes) {
56 if (parser.parseColon())
57 return failure();
59 if (parser.parseType(operandType))
60 return failure();
61 if (succeeded(parser.parseOptionalComma())) {
62 if (parser.parseType(optOperandType))
63 return failure();
65 if (parser.parseArrow() || parser.parseLParen() ||
66 parser.parseTypeList(varOperandTypes) || parser.parseRParen())
67 return failure();
68 return success();
70 static ParseResult
71 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType,
72 Type optOperandType,
73 const SmallVectorImpl<Type> &varOperandTypes) {
74 if (parser.parseKeyword("type_refs_capture"))
75 return failure();
77 Type operandType2, optOperandType2;
78 SmallVector<Type, 1> varOperandTypes2;
79 if (parseCustomDirectiveResults(parser, operandType2, optOperandType2,
80 varOperandTypes2))
81 return failure();
83 if (operandType != operandType2 || optOperandType != optOperandType2 ||
84 varOperandTypes != varOperandTypes2)
85 return failure();
87 return success();
89 static ParseResult parseCustomDirectiveOperandsAndTypes(
90 OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand,
91 std::optional<OpAsmParser::UnresolvedOperand> &optOperand,
92 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands,
93 Type &operandType, Type &optOperandType,
94 SmallVectorImpl<Type> &varOperandTypes) {
95 if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) ||
96 parseCustomDirectiveResults(parser, operandType, optOperandType,
97 varOperandTypes))
98 return failure();
99 return success();
101 static ParseResult parseCustomDirectiveRegions(
102 OpAsmParser &parser, Region &region,
103 SmallVectorImpl<std::unique_ptr<Region>> &varRegions) {
104 if (parser.parseRegion(region))
105 return failure();
106 if (failed(parser.parseOptionalComma()))
107 return success();
108 std::unique_ptr<Region> varRegion = std::make_unique<Region>();
109 if (parser.parseRegion(*varRegion))
110 return failure();
111 varRegions.emplace_back(std::move(varRegion));
112 return success();
114 static ParseResult
115 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor,
116 SmallVectorImpl<Block *> &varSuccessors) {
117 if (parser.parseSuccessor(successor))
118 return failure();
119 if (failed(parser.parseOptionalComma()))
120 return success();
121 Block *varSuccessor;
122 if (parser.parseSuccessor(varSuccessor))
123 return failure();
124 varSuccessors.append(2, varSuccessor);
125 return success();
127 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser,
128 IntegerAttr &attr,
129 IntegerAttr &optAttr) {
130 if (parser.parseAttribute(attr))
131 return failure();
132 if (succeeded(parser.parseOptionalComma())) {
133 if (parser.parseAttribute(optAttr))
134 return failure();
136 return success();
138 static ParseResult parseCustomDirectiveSpacing(OpAsmParser &parser,
139 mlir::StringAttr &attr) {
140 return parser.parseAttribute(attr);
142 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser,
143 NamedAttrList &attrs) {
144 return parser.parseOptionalAttrDict(attrs);
146 static ParseResult parseCustomDirectiveOptionalOperandRef(
147 OpAsmParser &parser,
148 std::optional<OpAsmParser::UnresolvedOperand> &optOperand) {
149 int64_t operandCount = 0;
150 if (parser.parseInteger(operandCount))
151 return failure();
152 bool expectedOptionalOperand = operandCount == 0;
153 return success(expectedOptionalOperand != optOperand.has_value());
156 //===----------------------------------------------------------------------===//
157 // Printing
159 static void printCustomOptionalOperand(OpAsmPrinter &printer, Operation *,
160 Value optOperand) {
161 if (optOperand)
162 printer << "(" << optOperand << ") ";
165 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *,
166 Value operand, Value optOperand,
167 OperandRange varOperands) {
168 printer << operand;
169 if (optOperand)
170 printer << ", " << optOperand;
171 printer << " -> (" << varOperands << ")";
173 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *,
174 Type operandType, Type optOperandType,
175 TypeRange varOperandTypes) {
176 printer << " : " << operandType;
177 if (optOperandType)
178 printer << ", " << optOperandType;
179 printer << " -> (" << varOperandTypes << ")";
181 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer,
182 Operation *op, Type operandType,
183 Type optOperandType,
184 TypeRange varOperandTypes) {
185 printer << " type_refs_capture ";
186 printCustomDirectiveResults(printer, op, operandType, optOperandType,
187 varOperandTypes);
189 static void printCustomDirectiveOperandsAndTypes(
190 OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand,
191 OperandRange varOperands, Type operandType, Type optOperandType,
192 TypeRange varOperandTypes) {
193 printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands);
194 printCustomDirectiveResults(printer, op, operandType, optOperandType,
195 varOperandTypes);
197 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *,
198 Region &region,
199 MutableArrayRef<Region> varRegions) {
200 printer.printRegion(region);
201 if (!varRegions.empty()) {
202 printer << ", ";
203 for (Region &region : varRegions)
204 printer.printRegion(region);
207 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *,
208 Block *successor,
209 SuccessorRange varSuccessors) {
210 printer << successor;
211 if (!varSuccessors.empty())
212 printer << ", " << varSuccessors.front();
214 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *,
215 Attribute attribute,
216 Attribute optAttribute) {
217 printer << attribute;
218 if (optAttribute)
219 printer << ", " << optAttribute;
221 static void printCustomDirectiveSpacing(OpAsmPrinter &printer, Operation *op,
222 Attribute attribute) {
223 printer << attribute;
225 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op,
226 DictionaryAttr attrs) {
227 printer.printOptionalAttrDict(attrs.getValue());
230 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer,
231 Operation *op,
232 Value optOperand) {
233 printer << (optOperand ? "1" : "0");
235 //===----------------------------------------------------------------------===//
236 // Test parser.
237 //===----------------------------------------------------------------------===//
239 ParseResult ParseIntegerLiteralOp::parse(OpAsmParser &parser,
240 OperationState &result) {
241 if (parser.parseOptionalColon())
242 return success();
243 uint64_t numResults;
244 if (parser.parseInteger(numResults))
245 return failure();
247 IndexType type = parser.getBuilder().getIndexType();
248 for (unsigned i = 0; i < numResults; ++i)
249 result.addTypes(type);
250 return success();
253 void ParseIntegerLiteralOp::print(OpAsmPrinter &p) {
254 if (unsigned numResults = getNumResults())
255 p << " : " << numResults;
258 ParseResult ParseWrappedKeywordOp::parse(OpAsmParser &parser,
259 OperationState &result) {
260 StringRef keyword;
261 if (parser.parseKeyword(&keyword))
262 return failure();
263 result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword));
264 return success();
267 void ParseWrappedKeywordOp::print(OpAsmPrinter &p) { p << " " << getKeyword(); }
269 ParseResult ParseB64BytesOp::parse(OpAsmParser &parser,
270 OperationState &result) {
271 std::vector<char> bytes;
272 if (parser.parseBase64Bytes(&bytes))
273 return failure();
274 result.addAttribute("b64", parser.getBuilder().getStringAttr(
275 StringRef(&bytes.front(), bytes.size())));
276 return success();
279 void ParseB64BytesOp::print(OpAsmPrinter &p) {
280 p << " \"" << llvm::encodeBase64(getB64()) << "\"";
283 ::llvm::LogicalResult FormatInferType2Op::inferReturnTypes(
284 ::mlir::MLIRContext *context, ::std::optional<::mlir::Location> location,
285 ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes,
286 OpaqueProperties properties, ::mlir::RegionRange regions,
287 ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
288 inferredReturnTypes.assign({::mlir::IntegerType::get(context, 16)});
289 return ::mlir::success();
292 //===----------------------------------------------------------------------===//
293 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`.
295 ParseResult WrappingRegionOp::parse(OpAsmParser &parser,
296 OperationState &result) {
297 if (parser.parseKeyword("wraps"))
298 return failure();
300 // Parse the wrapped op in a region
301 Region &body = *result.addRegion();
302 body.push_back(new Block);
303 Block &block = body.back();
304 Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin());
305 if (!wrappedOp)
306 return failure();
308 // Create a return terminator in the inner region, pass as operand to the
309 // terminator the returned values from the wrapped operation.
310 SmallVector<Value, 8> returnOperands(wrappedOp->getResults());
311 OpBuilder builder(parser.getContext());
312 builder.setInsertionPointToEnd(&block);
313 builder.create<TestReturnOp>(wrappedOp->getLoc(), returnOperands);
315 // Get the results type for the wrapping op from the terminator operands.
316 Operation &returnOp = body.back().back();
317 result.types.append(returnOp.operand_type_begin(),
318 returnOp.operand_type_end());
320 // Use the location of the wrapped op for the "test.wrapping_region" op.
321 result.location = wrappedOp->getLoc();
323 return success();
326 void WrappingRegionOp::print(OpAsmPrinter &p) {
327 p << " wraps ";
328 p.printGenericOp(&getRegion().front().front());
331 //===----------------------------------------------------------------------===//
332 // Test PrettyPrintedRegionOp - exercising the following parser APIs
333 // parseGenericOperationAfterOpName
334 // parseCustomOperationName
335 //===----------------------------------------------------------------------===//
337 ParseResult PrettyPrintedRegionOp::parse(OpAsmParser &parser,
338 OperationState &result) {
340 SMLoc loc = parser.getCurrentLocation();
341 Location currLocation = parser.getEncodedSourceLoc(loc);
343 // Parse the operands.
344 SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
345 if (parser.parseOperandList(operands))
346 return failure();
348 // Check if we are parsing the pretty-printed version
349 // test.pretty_printed_region start <inner-op> end : <functional-type>
350 // Else fallback to parsing the "non pretty-printed" version.
351 if (!succeeded(parser.parseOptionalKeyword("start")))
352 return parser.parseGenericOperationAfterOpName(result,
353 llvm::ArrayRef(operands));
355 FailureOr<OperationName> parseOpNameInfo = parser.parseCustomOperationName();
356 if (failed(parseOpNameInfo))
357 return failure();
359 StringAttr innerOpName = parseOpNameInfo->getIdentifier();
361 FunctionType opFntype;
362 std::optional<Location> explicitLoc;
363 if (parser.parseKeyword("end") || parser.parseColon() ||
364 parser.parseType(opFntype) ||
365 parser.parseOptionalLocationSpecifier(explicitLoc))
366 return failure();
368 // If location of the op is explicitly provided, then use it; Else use
369 // the parser's current location.
370 Location opLoc = explicitLoc.value_or(currLocation);
372 // Derive the SSA-values for op's operands.
373 if (parser.resolveOperands(operands, opFntype.getInputs(), loc,
374 result.operands))
375 return failure();
377 // Add a region for op.
378 Region &region = *result.addRegion();
380 // Create a basic-block inside op's region.
381 Block &block = region.emplaceBlock();
383 // Create and insert an "inner-op" operation in the block.
384 // Just for testing purposes, we can assume that inner op is a binary op with
385 // result and operand types all same as the test-op's first operand.
386 Type innerOpType = opFntype.getInput(0);
387 Value lhs = block.addArgument(innerOpType, opLoc);
388 Value rhs = block.addArgument(innerOpType, opLoc);
390 OpBuilder builder(parser.getBuilder().getContext());
391 builder.setInsertionPointToStart(&block);
393 Operation *innerOp =
394 builder.create(opLoc, innerOpName, /*operands=*/{lhs, rhs}, innerOpType);
396 // Insert a return statement in the block returning the inner-op's result.
397 builder.create<TestReturnOp>(innerOp->getLoc(), innerOp->getResults());
399 // Populate the op operation-state with result-type and location.
400 result.addTypes(opFntype.getResults());
401 result.location = innerOp->getLoc();
403 return success();
406 void PrettyPrintedRegionOp::print(OpAsmPrinter &p) {
407 p << ' ';
408 p.printOperands(getOperands());
410 Operation &innerOp = getRegion().front().front();
411 // Assuming that region has a single non-terminator inner-op, if the inner-op
412 // meets some criteria (which in this case is a simple one based on the name
413 // of inner-op), then we can print the entire region in a succinct way.
414 // Here we assume that the prototype of "test.special.op" can be trivially
415 // derived while parsing it back.
416 if (innerOp.getName().getStringRef() == "test.special.op") {
417 p << " start test.special.op end";
418 } else {
419 p << " (";
420 p.printRegion(getRegion());
421 p << ")";
424 p << " : ";
425 p.printFunctionalType(*this);
428 //===----------------------------------------------------------------------===//
429 // Test PolyForOp - parse list of region arguments.
430 //===----------------------------------------------------------------------===//
432 ParseResult PolyForOp::parse(OpAsmParser &parser, OperationState &result) {
433 SmallVector<OpAsmParser::Argument, 4> ivsInfo;
434 // Parse list of region arguments without a delimiter.
435 if (parser.parseArgumentList(ivsInfo, OpAsmParser::Delimiter::None))
436 return failure();
438 // Parse the body region.
439 Region *body = result.addRegion();
440 for (auto &iv : ivsInfo)
441 iv.type = parser.getBuilder().getIndexType();
442 return parser.parseRegion(*body, ivsInfo);
445 void PolyForOp::print(OpAsmPrinter &p) {
446 p << " ";
447 llvm::interleaveComma(getRegion().getArguments(), p, [&](auto arg) {
448 p.printRegionArgument(arg, /*argAttrs =*/{}, /*omitType=*/true);
450 p << " ";
451 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
454 void PolyForOp::getAsmBlockArgumentNames(Region &region,
455 OpAsmSetValueNameFn setNameFn) {
456 auto arrayAttr = getOperation()->getAttrOfType<ArrayAttr>("arg_names");
457 if (!arrayAttr)
458 return;
459 auto args = getRegion().front().getArguments();
460 auto e = std::min(arrayAttr.size(), args.size());
461 for (unsigned i = 0; i < e; ++i) {
462 if (auto strAttr = dyn_cast<StringAttr>(arrayAttr[i]))
463 setNameFn(args[i], strAttr.getValue());
467 //===----------------------------------------------------------------------===//
468 // TestAttrWithLoc - parse/printOptionalLocationSpecifier
469 //===----------------------------------------------------------------------===//
471 static ParseResult parseOptionalLoc(OpAsmParser &p, Attribute &loc) {
472 std::optional<Location> result;
473 SMLoc sourceLoc = p.getCurrentLocation();
474 if (p.parseOptionalLocationSpecifier(result))
475 return failure();
476 if (result)
477 loc = *result;
478 else
479 loc = p.getEncodedSourceLoc(sourceLoc);
480 return success();
483 static void printOptionalLoc(OpAsmPrinter &p, Operation *op, Attribute loc) {
484 p.printOptionalLocationSpecifier(cast<LocationAttr>(loc));
487 #define GET_OP_CLASSES
488 #include "TestOpsSyntax.cpp.inc"
490 void TestDialect::registerOpsSyntax() {
491 addOperations<
492 #define GET_OP_LIST
493 #include "TestOpsSyntax.cpp.inc"
494 >();