[mlir] FunctionOpInterface: turn required attributes into interface methods (Reland)
[llvm-project.git] / mlir / lib / Dialect / Async / IR / Async.cpp
blob54acc373018c065ca30b3b7a3c5b4baf025a951a
1 //===- Async.cpp - MLIR Async Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Dialect/Async/IR/Async.h"
11 #include "mlir/IR/BlockAndValueMapping.h"
12 #include "mlir/IR/DialectImplementation.h"
13 #include "mlir/IR/FunctionImplementation.h"
14 #include "llvm/ADT/MapVector.h"
15 #include "llvm/ADT/TypeSwitch.h"
17 using namespace mlir;
18 using namespace mlir::async;
20 #include "mlir/Dialect/Async/IR/AsyncOpsDialect.cpp.inc"
22 constexpr StringRef AsyncDialect::kAllowedToBlockAttrName;
24 void AsyncDialect::initialize() {
25 addOperations<
26 #define GET_OP_LIST
27 #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
28 >();
29 addTypes<
30 #define GET_TYPEDEF_LIST
31 #include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc"
32 >();
35 //===----------------------------------------------------------------------===//
36 // YieldOp
37 //===----------------------------------------------------------------------===//
39 LogicalResult YieldOp::verify() {
40 // Get the underlying value types from async values returned from the
41 // parent `async.execute` operation.
42 auto executeOp = (*this)->getParentOfType<ExecuteOp>();
43 auto types =
44 llvm::map_range(executeOp.getBodyResults(), [](const OpResult &result) {
45 return result.getType().cast<ValueType>().getValueType();
46 });
48 if (getOperandTypes() != types)
49 return emitOpError("operand types do not match the types returned from "
50 "the parent ExecuteOp");
52 return success();
55 MutableOperandRange
56 YieldOp::getMutableSuccessorOperands(Optional<unsigned> index) {
57 return getOperandsMutable();
60 //===----------------------------------------------------------------------===//
61 /// ExecuteOp
62 //===----------------------------------------------------------------------===//
64 constexpr char kOperandSegmentSizesAttr[] = "operand_segment_sizes";
66 OperandRange ExecuteOp::getSuccessorEntryOperands(Optional<unsigned> index) {
67 assert(index && *index == 0 && "invalid region index");
68 return getBodyOperands();
71 bool ExecuteOp::areTypesCompatible(Type lhs, Type rhs) {
72 const auto getValueOrTokenType = [](Type type) {
73 if (auto value = type.dyn_cast<ValueType>())
74 return value.getValueType();
75 return type;
77 return getValueOrTokenType(lhs) == getValueOrTokenType(rhs);
80 void ExecuteOp::getSuccessorRegions(Optional<unsigned> index,
81 ArrayRef<Attribute>,
82 SmallVectorImpl<RegionSuccessor> &regions) {
83 // The `body` region branch back to the parent operation.
84 if (index) {
85 assert(*index == 0 && "invalid region index");
86 regions.push_back(RegionSuccessor(getBodyResults()));
87 return;
90 // Otherwise the successor is the body region.
91 regions.push_back(
92 RegionSuccessor(&getBodyRegion(), getBodyRegion().getArguments()));
95 void ExecuteOp::build(OpBuilder &builder, OperationState &result,
96 TypeRange resultTypes, ValueRange dependencies,
97 ValueRange operands, BodyBuilderFn bodyBuilder) {
99 result.addOperands(dependencies);
100 result.addOperands(operands);
102 // Add derived `operand_segment_sizes` attribute based on parsed operands.
103 int32_t numDependencies = dependencies.size();
104 int32_t numOperands = operands.size();
105 auto operandSegmentSizes =
106 builder.getDenseI32ArrayAttr({numDependencies, numOperands});
107 result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes);
109 // First result is always a token, and then `resultTypes` wrapped into
110 // `async.value`.
111 result.addTypes({TokenType::get(result.getContext())});
112 for (Type type : resultTypes)
113 result.addTypes(ValueType::get(type));
115 // Add a body region with block arguments as unwrapped async value operands.
116 Region *bodyRegion = result.addRegion();
117 bodyRegion->push_back(new Block);
118 Block &bodyBlock = bodyRegion->front();
119 for (Value operand : operands) {
120 auto valueType = operand.getType().dyn_cast<ValueType>();
121 bodyBlock.addArgument(valueType ? valueType.getValueType()
122 : operand.getType(),
123 operand.getLoc());
126 // Create the default terminator if the builder is not provided and if the
127 // expected result is empty. Otherwise, leave this to the caller
128 // because we don't know which values to return from the execute op.
129 if (resultTypes.empty() && !bodyBuilder) {
130 OpBuilder::InsertionGuard guard(builder);
131 builder.setInsertionPointToStart(&bodyBlock);
132 builder.create<async::YieldOp>(result.location, ValueRange());
133 } else if (bodyBuilder) {
134 OpBuilder::InsertionGuard guard(builder);
135 builder.setInsertionPointToStart(&bodyBlock);
136 bodyBuilder(builder, result.location, bodyBlock.getArguments());
140 void ExecuteOp::print(OpAsmPrinter &p) {
141 // [%tokens,...]
142 if (!getDependencies().empty())
143 p << " [" << getDependencies() << "]";
145 // (%value as %unwrapped: !async.value<!arg.type>, ...)
146 if (!getBodyOperands().empty()) {
147 p << " (";
148 Block *entry = getBodyRegion().empty() ? nullptr : &getBodyRegion().front();
149 llvm::interleaveComma(
150 getBodyOperands(), p, [&, n = 0](Value operand) mutable {
151 Value argument = entry ? entry->getArgument(n++) : Value();
152 p << operand << " as " << argument << ": " << operand.getType();
154 p << ")";
157 // -> (!async.value<!return.type>, ...)
158 p.printOptionalArrowTypeList(llvm::drop_begin(getResultTypes()));
159 p.printOptionalAttrDictWithKeyword((*this)->getAttrs(),
160 {kOperandSegmentSizesAttr});
161 p << ' ';
162 p.printRegion(getBodyRegion(), /*printEntryBlockArgs=*/false);
165 ParseResult ExecuteOp::parse(OpAsmParser &parser, OperationState &result) {
166 MLIRContext *ctx = result.getContext();
168 // Sizes of parsed variadic operands, will be updated below after parsing.
169 int32_t numDependencies = 0;
171 auto tokenTy = TokenType::get(ctx);
173 // Parse dependency tokens.
174 if (succeeded(parser.parseOptionalLSquare())) {
175 SmallVector<OpAsmParser::UnresolvedOperand, 4> tokenArgs;
176 if (parser.parseOperandList(tokenArgs) ||
177 parser.resolveOperands(tokenArgs, tokenTy, result.operands) ||
178 parser.parseRSquare())
179 return failure();
181 numDependencies = tokenArgs.size();
184 // Parse async value operands (%value as %unwrapped : !async.value<!type>).
185 SmallVector<OpAsmParser::UnresolvedOperand, 4> valueArgs;
186 SmallVector<OpAsmParser::Argument, 4> unwrappedArgs;
187 SmallVector<Type, 4> valueTypes;
189 // Parse a single instance of `%value as %unwrapped : !async.value<!type>`.
190 auto parseAsyncValueArg = [&]() -> ParseResult {
191 if (parser.parseOperand(valueArgs.emplace_back()) ||
192 parser.parseKeyword("as") ||
193 parser.parseArgument(unwrappedArgs.emplace_back()) ||
194 parser.parseColonType(valueTypes.emplace_back()))
195 return failure();
197 auto valueTy = valueTypes.back().dyn_cast<ValueType>();
198 unwrappedArgs.back().type = valueTy ? valueTy.getValueType() : Type();
199 return success();
202 auto argsLoc = parser.getCurrentLocation();
203 if (parser.parseCommaSeparatedList(OpAsmParser::Delimiter::OptionalParen,
204 parseAsyncValueArg) ||
205 parser.resolveOperands(valueArgs, valueTypes, argsLoc, result.operands))
206 return failure();
208 int32_t numOperands = valueArgs.size();
210 // Add derived `operand_segment_sizes` attribute based on parsed operands.
211 auto operandSegmentSizes =
212 parser.getBuilder().getDenseI32ArrayAttr({numDependencies, numOperands});
213 result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes);
215 // Parse the types of results returned from the async execute op.
216 SmallVector<Type, 4> resultTypes;
217 NamedAttrList attrs;
218 if (parser.parseOptionalArrowTypeList(resultTypes) ||
219 // Async execute first result is always a completion token.
220 parser.addTypeToList(tokenTy, result.types) ||
221 parser.addTypesToList(resultTypes, result.types) ||
222 // Parse operation attributes.
223 parser.parseOptionalAttrDictWithKeyword(attrs))
224 return failure();
226 result.addAttributes(attrs);
228 // Parse asynchronous region.
229 Region *body = result.addRegion();
230 return parser.parseRegion(*body, /*arguments=*/unwrappedArgs);
233 LogicalResult ExecuteOp::verifyRegions() {
234 // Unwrap async.execute value operands types.
235 auto unwrappedTypes = llvm::map_range(getBodyOperands(), [](Value operand) {
236 return operand.getType().cast<ValueType>().getValueType();
239 // Verify that unwrapped argument types matches the body region arguments.
240 if (getBodyRegion().getArgumentTypes() != unwrappedTypes)
241 return emitOpError("async body region argument types do not match the "
242 "execute operation arguments types");
244 return success();
247 //===----------------------------------------------------------------------===//
248 /// CreateGroupOp
249 //===----------------------------------------------------------------------===//
251 LogicalResult CreateGroupOp::canonicalize(CreateGroupOp op,
252 PatternRewriter &rewriter) {
253 // Find all `await_all` users of the group.
254 llvm::SmallVector<AwaitAllOp> awaitAllUsers;
256 auto isAwaitAll = [&](Operation *op) -> bool {
257 if (AwaitAllOp awaitAll = dyn_cast<AwaitAllOp>(op)) {
258 awaitAllUsers.push_back(awaitAll);
259 return true;
261 return false;
264 // Check if all users of the group are `await_all` operations.
265 if (!llvm::all_of(op->getUsers(), isAwaitAll))
266 return failure();
268 // If group is only awaited without adding anything to it, we can safely erase
269 // the create operation and all users.
270 for (AwaitAllOp awaitAll : awaitAllUsers)
271 rewriter.eraseOp(awaitAll);
272 rewriter.eraseOp(op);
274 return success();
277 //===----------------------------------------------------------------------===//
278 /// AwaitOp
279 //===----------------------------------------------------------------------===//
281 void AwaitOp::build(OpBuilder &builder, OperationState &result, Value operand,
282 ArrayRef<NamedAttribute> attrs) {
283 result.addOperands({operand});
284 result.attributes.append(attrs.begin(), attrs.end());
286 // Add unwrapped async.value type to the returned values types.
287 if (auto valueType = operand.getType().dyn_cast<ValueType>())
288 result.addTypes(valueType.getValueType());
291 static ParseResult parseAwaitResultType(OpAsmParser &parser, Type &operandType,
292 Type &resultType) {
293 if (parser.parseType(operandType))
294 return failure();
296 // Add unwrapped async.value type to the returned values types.
297 if (auto valueType = operandType.dyn_cast<ValueType>())
298 resultType = valueType.getValueType();
300 return success();
303 static void printAwaitResultType(OpAsmPrinter &p, Operation *op,
304 Type operandType, Type resultType) {
305 p << operandType;
308 LogicalResult AwaitOp::verify() {
309 Type argType = getOperand().getType();
311 // Awaiting on a token does not have any results.
312 if (argType.isa<TokenType>() && !getResultTypes().empty())
313 return emitOpError("awaiting on a token must have empty result");
315 // Awaiting on a value unwraps the async value type.
316 if (auto value = argType.dyn_cast<ValueType>()) {
317 if (*getResultType() != value.getValueType())
318 return emitOpError() << "result type " << *getResultType()
319 << " does not match async value type "
320 << value.getValueType();
323 return success();
326 //===----------------------------------------------------------------------===//
327 // FuncOp
328 //===----------------------------------------------------------------------===//
330 void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
331 FunctionType type, ArrayRef<NamedAttribute> attrs,
332 ArrayRef<DictionaryAttr> argAttrs) {
333 state.addAttribute(SymbolTable::getSymbolAttrName(),
334 builder.getStringAttr(name));
335 state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
337 state.attributes.append(attrs.begin(), attrs.end());
338 state.addRegion();
340 if (argAttrs.empty())
341 return;
342 assert(type.getNumInputs() == argAttrs.size());
343 function_interface_impl::addArgAndResultAttrs(
344 builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
345 getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
348 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
349 auto buildFuncType =
350 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
351 function_interface_impl::VariadicFlag,
352 std::string &) { return builder.getFunctionType(argTypes, results); };
354 return function_interface_impl::parseFunctionOp(
355 parser, result, /*allowVariadic=*/false,
356 getFunctionTypeAttrName(result.name), buildFuncType,
357 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
360 void FuncOp::print(OpAsmPrinter &p) {
361 function_interface_impl::printFunctionOp(
362 p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
363 getArgAttrsAttrName(), getResAttrsAttrName());
366 /// Check that the result type of async.func is not void and must be
367 /// some async token or async values.
368 LogicalResult FuncOp::verify() {
369 auto resultTypes = getResultTypes();
370 if (resultTypes.empty())
371 return emitOpError()
372 << "result is expected to be at least of size 1, but got "
373 << resultTypes.size();
375 for (unsigned i = 0, e = resultTypes.size(); i != e; ++i) {
376 auto type = resultTypes[i];
377 if (!type.isa<TokenType>() && !type.isa<ValueType>())
378 return emitOpError() << "result type must be async value type or async "
379 "token type, but got "
380 << type;
381 // We only allow AsyncToken appear as the first return value
382 if (type.isa<TokenType>() && i != 0) {
383 return emitOpError()
384 << " results' (optional) async token type is expected "
385 "to appear as the 1st return value, but got "
386 << i + 1;
390 return success();
393 //===----------------------------------------------------------------------===//
394 /// CallOp
395 //===----------------------------------------------------------------------===//
397 LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
398 // Check that the callee attribute was specified.
399 auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
400 if (!fnAttr)
401 return emitOpError("requires a 'callee' symbol reference attribute");
402 FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
403 if (!fn)
404 return emitOpError() << "'" << fnAttr.getValue()
405 << "' does not reference a valid async function";
407 // Verify that the operand and result types match the callee.
408 auto fnType = fn.getFunctionType();
409 if (fnType.getNumInputs() != getNumOperands())
410 return emitOpError("incorrect number of operands for callee");
412 for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
413 if (getOperand(i).getType() != fnType.getInput(i))
414 return emitOpError("operand type mismatch: expected operand type ")
415 << fnType.getInput(i) << ", but provided "
416 << getOperand(i).getType() << " for operand number " << i;
418 if (fnType.getNumResults() != getNumResults())
419 return emitOpError("incorrect number of results for callee");
421 for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
422 if (getResult(i).getType() != fnType.getResult(i)) {
423 auto diag = emitOpError("result type mismatch at index ") << i;
424 diag.attachNote() << " op result types: " << getResultTypes();
425 diag.attachNote() << "function result types: " << fnType.getResults();
426 return diag;
429 return success();
432 FunctionType CallOp::getCalleeType() {
433 return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
436 //===----------------------------------------------------------------------===//
437 /// ReturnOp
438 //===----------------------------------------------------------------------===//
440 LogicalResult ReturnOp::verify() {
441 auto funcOp = (*this)->getParentOfType<FuncOp>();
442 ArrayRef<Type> resultTypes = funcOp.isStateful()
443 ? funcOp.getResultTypes().drop_front()
444 : funcOp.getResultTypes();
445 // Get the underlying value types from async types returned from the
446 // parent `async.func` operation.
447 auto types = llvm::map_range(resultTypes, [](const Type &result) {
448 return result.cast<ValueType>().getValueType();
451 if (getOperandTypes() != types)
452 return emitOpError("operand types do not match the types returned from "
453 "the parent FuncOp");
455 return success();
458 //===----------------------------------------------------------------------===//
459 // TableGen'd op method definitions
460 //===----------------------------------------------------------------------===//
462 #define GET_OP_CLASSES
463 #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
465 //===----------------------------------------------------------------------===//
466 // TableGen'd type method definitions
467 //===----------------------------------------------------------------------===//
469 #define GET_TYPEDEF_CLASSES
470 #include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc"
472 void ValueType::print(AsmPrinter &printer) const {
473 printer << "<";
474 printer.printType(getValueType());
475 printer << '>';
478 Type ValueType::parse(mlir::AsmParser &parser) {
479 Type ty;
480 if (parser.parseLess() || parser.parseType(ty) || parser.parseGreater()) {
481 parser.emitError(parser.getNameLoc(), "failed to parse async value type");
482 return Type();
484 return ValueType::get(ty);