1 //===- Async.cpp - MLIR Async Operations ----------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Dialect/Async/IR/Async.h"
11 #include "mlir/IR/BlockAndValueMapping.h"
12 #include "mlir/IR/DialectImplementation.h"
13 #include "mlir/IR/FunctionImplementation.h"
14 #include "llvm/ADT/MapVector.h"
15 #include "llvm/ADT/TypeSwitch.h"
18 using namespace mlir::async
;
20 #include "mlir/Dialect/Async/IR/AsyncOpsDialect.cpp.inc"
22 constexpr StringRef
AsyncDialect::kAllowedToBlockAttrName
;
24 void AsyncDialect::initialize() {
27 #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
30 #define GET_TYPEDEF_LIST
31 #include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc"
35 //===----------------------------------------------------------------------===//
37 //===----------------------------------------------------------------------===//
39 LogicalResult
YieldOp::verify() {
40 // Get the underlying value types from async values returned from the
41 // parent `async.execute` operation.
42 auto executeOp
= (*this)->getParentOfType
<ExecuteOp
>();
44 llvm::map_range(executeOp
.getBodyResults(), [](const OpResult
&result
) {
45 return result
.getType().cast
<ValueType
>().getValueType();
48 if (getOperandTypes() != types
)
49 return emitOpError("operand types do not match the types returned from "
50 "the parent ExecuteOp");
56 YieldOp::getMutableSuccessorOperands(Optional
<unsigned> index
) {
57 return getOperandsMutable();
60 //===----------------------------------------------------------------------===//
62 //===----------------------------------------------------------------------===//
64 constexpr char kOperandSegmentSizesAttr
[] = "operand_segment_sizes";
66 OperandRange
ExecuteOp::getSuccessorEntryOperands(Optional
<unsigned> index
) {
67 assert(index
&& *index
== 0 && "invalid region index");
68 return getBodyOperands();
71 bool ExecuteOp::areTypesCompatible(Type lhs
, Type rhs
) {
72 const auto getValueOrTokenType
= [](Type type
) {
73 if (auto value
= type
.dyn_cast
<ValueType
>())
74 return value
.getValueType();
77 return getValueOrTokenType(lhs
) == getValueOrTokenType(rhs
);
80 void ExecuteOp::getSuccessorRegions(Optional
<unsigned> index
,
82 SmallVectorImpl
<RegionSuccessor
> ®ions
) {
83 // The `body` region branch back to the parent operation.
85 assert(*index
== 0 && "invalid region index");
86 regions
.push_back(RegionSuccessor(getBodyResults()));
90 // Otherwise the successor is the body region.
92 RegionSuccessor(&getBodyRegion(), getBodyRegion().getArguments()));
95 void ExecuteOp::build(OpBuilder
&builder
, OperationState
&result
,
96 TypeRange resultTypes
, ValueRange dependencies
,
97 ValueRange operands
, BodyBuilderFn bodyBuilder
) {
99 result
.addOperands(dependencies
);
100 result
.addOperands(operands
);
102 // Add derived `operand_segment_sizes` attribute based on parsed operands.
103 int32_t numDependencies
= dependencies
.size();
104 int32_t numOperands
= operands
.size();
105 auto operandSegmentSizes
=
106 builder
.getDenseI32ArrayAttr({numDependencies
, numOperands
});
107 result
.addAttribute(kOperandSegmentSizesAttr
, operandSegmentSizes
);
109 // First result is always a token, and then `resultTypes` wrapped into
111 result
.addTypes({TokenType::get(result
.getContext())});
112 for (Type type
: resultTypes
)
113 result
.addTypes(ValueType::get(type
));
115 // Add a body region with block arguments as unwrapped async value operands.
116 Region
*bodyRegion
= result
.addRegion();
117 bodyRegion
->push_back(new Block
);
118 Block
&bodyBlock
= bodyRegion
->front();
119 for (Value operand
: operands
) {
120 auto valueType
= operand
.getType().dyn_cast
<ValueType
>();
121 bodyBlock
.addArgument(valueType
? valueType
.getValueType()
126 // Create the default terminator if the builder is not provided and if the
127 // expected result is empty. Otherwise, leave this to the caller
128 // because we don't know which values to return from the execute op.
129 if (resultTypes
.empty() && !bodyBuilder
) {
130 OpBuilder::InsertionGuard
guard(builder
);
131 builder
.setInsertionPointToStart(&bodyBlock
);
132 builder
.create
<async::YieldOp
>(result
.location
, ValueRange());
133 } else if (bodyBuilder
) {
134 OpBuilder::InsertionGuard
guard(builder
);
135 builder
.setInsertionPointToStart(&bodyBlock
);
136 bodyBuilder(builder
, result
.location
, bodyBlock
.getArguments());
140 void ExecuteOp::print(OpAsmPrinter
&p
) {
142 if (!getDependencies().empty())
143 p
<< " [" << getDependencies() << "]";
145 // (%value as %unwrapped: !async.value<!arg.type>, ...)
146 if (!getBodyOperands().empty()) {
148 Block
*entry
= getBodyRegion().empty() ? nullptr : &getBodyRegion().front();
149 llvm::interleaveComma(
150 getBodyOperands(), p
, [&, n
= 0](Value operand
) mutable {
151 Value argument
= entry
? entry
->getArgument(n
++) : Value();
152 p
<< operand
<< " as " << argument
<< ": " << operand
.getType();
157 // -> (!async.value<!return.type>, ...)
158 p
.printOptionalArrowTypeList(llvm::drop_begin(getResultTypes()));
159 p
.printOptionalAttrDictWithKeyword((*this)->getAttrs(),
160 {kOperandSegmentSizesAttr
});
162 p
.printRegion(getBodyRegion(), /*printEntryBlockArgs=*/false);
165 ParseResult
ExecuteOp::parse(OpAsmParser
&parser
, OperationState
&result
) {
166 MLIRContext
*ctx
= result
.getContext();
168 // Sizes of parsed variadic operands, will be updated below after parsing.
169 int32_t numDependencies
= 0;
171 auto tokenTy
= TokenType::get(ctx
);
173 // Parse dependency tokens.
174 if (succeeded(parser
.parseOptionalLSquare())) {
175 SmallVector
<OpAsmParser::UnresolvedOperand
, 4> tokenArgs
;
176 if (parser
.parseOperandList(tokenArgs
) ||
177 parser
.resolveOperands(tokenArgs
, tokenTy
, result
.operands
) ||
178 parser
.parseRSquare())
181 numDependencies
= tokenArgs
.size();
184 // Parse async value operands (%value as %unwrapped : !async.value<!type>).
185 SmallVector
<OpAsmParser::UnresolvedOperand
, 4> valueArgs
;
186 SmallVector
<OpAsmParser::Argument
, 4> unwrappedArgs
;
187 SmallVector
<Type
, 4> valueTypes
;
189 // Parse a single instance of `%value as %unwrapped : !async.value<!type>`.
190 auto parseAsyncValueArg
= [&]() -> ParseResult
{
191 if (parser
.parseOperand(valueArgs
.emplace_back()) ||
192 parser
.parseKeyword("as") ||
193 parser
.parseArgument(unwrappedArgs
.emplace_back()) ||
194 parser
.parseColonType(valueTypes
.emplace_back()))
197 auto valueTy
= valueTypes
.back().dyn_cast
<ValueType
>();
198 unwrappedArgs
.back().type
= valueTy
? valueTy
.getValueType() : Type();
202 auto argsLoc
= parser
.getCurrentLocation();
203 if (parser
.parseCommaSeparatedList(OpAsmParser::Delimiter::OptionalParen
,
204 parseAsyncValueArg
) ||
205 parser
.resolveOperands(valueArgs
, valueTypes
, argsLoc
, result
.operands
))
208 int32_t numOperands
= valueArgs
.size();
210 // Add derived `operand_segment_sizes` attribute based on parsed operands.
211 auto operandSegmentSizes
=
212 parser
.getBuilder().getDenseI32ArrayAttr({numDependencies
, numOperands
});
213 result
.addAttribute(kOperandSegmentSizesAttr
, operandSegmentSizes
);
215 // Parse the types of results returned from the async execute op.
216 SmallVector
<Type
, 4> resultTypes
;
218 if (parser
.parseOptionalArrowTypeList(resultTypes
) ||
219 // Async execute first result is always a completion token.
220 parser
.addTypeToList(tokenTy
, result
.types
) ||
221 parser
.addTypesToList(resultTypes
, result
.types
) ||
222 // Parse operation attributes.
223 parser
.parseOptionalAttrDictWithKeyword(attrs
))
226 result
.addAttributes(attrs
);
228 // Parse asynchronous region.
229 Region
*body
= result
.addRegion();
230 return parser
.parseRegion(*body
, /*arguments=*/unwrappedArgs
);
233 LogicalResult
ExecuteOp::verifyRegions() {
234 // Unwrap async.execute value operands types.
235 auto unwrappedTypes
= llvm::map_range(getBodyOperands(), [](Value operand
) {
236 return operand
.getType().cast
<ValueType
>().getValueType();
239 // Verify that unwrapped argument types matches the body region arguments.
240 if (getBodyRegion().getArgumentTypes() != unwrappedTypes
)
241 return emitOpError("async body region argument types do not match the "
242 "execute operation arguments types");
247 //===----------------------------------------------------------------------===//
249 //===----------------------------------------------------------------------===//
251 LogicalResult
CreateGroupOp::canonicalize(CreateGroupOp op
,
252 PatternRewriter
&rewriter
) {
253 // Find all `await_all` users of the group.
254 llvm::SmallVector
<AwaitAllOp
> awaitAllUsers
;
256 auto isAwaitAll
= [&](Operation
*op
) -> bool {
257 if (AwaitAllOp awaitAll
= dyn_cast
<AwaitAllOp
>(op
)) {
258 awaitAllUsers
.push_back(awaitAll
);
264 // Check if all users of the group are `await_all` operations.
265 if (!llvm::all_of(op
->getUsers(), isAwaitAll
))
268 // If group is only awaited without adding anything to it, we can safely erase
269 // the create operation and all users.
270 for (AwaitAllOp awaitAll
: awaitAllUsers
)
271 rewriter
.eraseOp(awaitAll
);
272 rewriter
.eraseOp(op
);
277 //===----------------------------------------------------------------------===//
279 //===----------------------------------------------------------------------===//
281 void AwaitOp::build(OpBuilder
&builder
, OperationState
&result
, Value operand
,
282 ArrayRef
<NamedAttribute
> attrs
) {
283 result
.addOperands({operand
});
284 result
.attributes
.append(attrs
.begin(), attrs
.end());
286 // Add unwrapped async.value type to the returned values types.
287 if (auto valueType
= operand
.getType().dyn_cast
<ValueType
>())
288 result
.addTypes(valueType
.getValueType());
291 static ParseResult
parseAwaitResultType(OpAsmParser
&parser
, Type
&operandType
,
293 if (parser
.parseType(operandType
))
296 // Add unwrapped async.value type to the returned values types.
297 if (auto valueType
= operandType
.dyn_cast
<ValueType
>())
298 resultType
= valueType
.getValueType();
303 static void printAwaitResultType(OpAsmPrinter
&p
, Operation
*op
,
304 Type operandType
, Type resultType
) {
308 LogicalResult
AwaitOp::verify() {
309 Type argType
= getOperand().getType();
311 // Awaiting on a token does not have any results.
312 if (argType
.isa
<TokenType
>() && !getResultTypes().empty())
313 return emitOpError("awaiting on a token must have empty result");
315 // Awaiting on a value unwraps the async value type.
316 if (auto value
= argType
.dyn_cast
<ValueType
>()) {
317 if (*getResultType() != value
.getValueType())
318 return emitOpError() << "result type " << *getResultType()
319 << " does not match async value type "
320 << value
.getValueType();
326 //===----------------------------------------------------------------------===//
328 //===----------------------------------------------------------------------===//
330 void FuncOp::build(OpBuilder
&builder
, OperationState
&state
, StringRef name
,
331 FunctionType type
, ArrayRef
<NamedAttribute
> attrs
,
332 ArrayRef
<DictionaryAttr
> argAttrs
) {
333 state
.addAttribute(SymbolTable::getSymbolAttrName(),
334 builder
.getStringAttr(name
));
335 state
.addAttribute(getFunctionTypeAttrName(state
.name
), TypeAttr::get(type
));
337 state
.attributes
.append(attrs
.begin(), attrs
.end());
340 if (argAttrs
.empty())
342 assert(type
.getNumInputs() == argAttrs
.size());
343 function_interface_impl::addArgAndResultAttrs(
344 builder
, state
, argAttrs
, /*resultAttrs=*/std::nullopt
,
345 getArgAttrsAttrName(state
.name
), getResAttrsAttrName(state
.name
));
348 ParseResult
FuncOp::parse(OpAsmParser
&parser
, OperationState
&result
) {
350 [](Builder
&builder
, ArrayRef
<Type
> argTypes
, ArrayRef
<Type
> results
,
351 function_interface_impl::VariadicFlag
,
352 std::string
&) { return builder
.getFunctionType(argTypes
, results
); };
354 return function_interface_impl::parseFunctionOp(
355 parser
, result
, /*allowVariadic=*/false,
356 getFunctionTypeAttrName(result
.name
), buildFuncType
,
357 getArgAttrsAttrName(result
.name
), getResAttrsAttrName(result
.name
));
360 void FuncOp::print(OpAsmPrinter
&p
) {
361 function_interface_impl::printFunctionOp(
362 p
, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
363 getArgAttrsAttrName(), getResAttrsAttrName());
366 /// Check that the result type of async.func is not void and must be
367 /// some async token or async values.
368 LogicalResult
FuncOp::verify() {
369 auto resultTypes
= getResultTypes();
370 if (resultTypes
.empty())
372 << "result is expected to be at least of size 1, but got "
373 << resultTypes
.size();
375 for (unsigned i
= 0, e
= resultTypes
.size(); i
!= e
; ++i
) {
376 auto type
= resultTypes
[i
];
377 if (!type
.isa
<TokenType
>() && !type
.isa
<ValueType
>())
378 return emitOpError() << "result type must be async value type or async "
379 "token type, but got "
381 // We only allow AsyncToken appear as the first return value
382 if (type
.isa
<TokenType
>() && i
!= 0) {
384 << " results' (optional) async token type is expected "
385 "to appear as the 1st return value, but got "
393 //===----------------------------------------------------------------------===//
395 //===----------------------------------------------------------------------===//
397 LogicalResult
CallOp::verifySymbolUses(SymbolTableCollection
&symbolTable
) {
398 // Check that the callee attribute was specified.
399 auto fnAttr
= (*this)->getAttrOfType
<FlatSymbolRefAttr
>("callee");
401 return emitOpError("requires a 'callee' symbol reference attribute");
402 FuncOp fn
= symbolTable
.lookupNearestSymbolFrom
<FuncOp
>(*this, fnAttr
);
404 return emitOpError() << "'" << fnAttr
.getValue()
405 << "' does not reference a valid async function";
407 // Verify that the operand and result types match the callee.
408 auto fnType
= fn
.getFunctionType();
409 if (fnType
.getNumInputs() != getNumOperands())
410 return emitOpError("incorrect number of operands for callee");
412 for (unsigned i
= 0, e
= fnType
.getNumInputs(); i
!= e
; ++i
)
413 if (getOperand(i
).getType() != fnType
.getInput(i
))
414 return emitOpError("operand type mismatch: expected operand type ")
415 << fnType
.getInput(i
) << ", but provided "
416 << getOperand(i
).getType() << " for operand number " << i
;
418 if (fnType
.getNumResults() != getNumResults())
419 return emitOpError("incorrect number of results for callee");
421 for (unsigned i
= 0, e
= fnType
.getNumResults(); i
!= e
; ++i
)
422 if (getResult(i
).getType() != fnType
.getResult(i
)) {
423 auto diag
= emitOpError("result type mismatch at index ") << i
;
424 diag
.attachNote() << " op result types: " << getResultTypes();
425 diag
.attachNote() << "function result types: " << fnType
.getResults();
432 FunctionType
CallOp::getCalleeType() {
433 return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
436 //===----------------------------------------------------------------------===//
438 //===----------------------------------------------------------------------===//
440 LogicalResult
ReturnOp::verify() {
441 auto funcOp
= (*this)->getParentOfType
<FuncOp
>();
442 ArrayRef
<Type
> resultTypes
= funcOp
.isStateful()
443 ? funcOp
.getResultTypes().drop_front()
444 : funcOp
.getResultTypes();
445 // Get the underlying value types from async types returned from the
446 // parent `async.func` operation.
447 auto types
= llvm::map_range(resultTypes
, [](const Type
&result
) {
448 return result
.cast
<ValueType
>().getValueType();
451 if (getOperandTypes() != types
)
452 return emitOpError("operand types do not match the types returned from "
453 "the parent FuncOp");
458 //===----------------------------------------------------------------------===//
459 // TableGen'd op method definitions
460 //===----------------------------------------------------------------------===//
462 #define GET_OP_CLASSES
463 #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
465 //===----------------------------------------------------------------------===//
466 // TableGen'd type method definitions
467 //===----------------------------------------------------------------------===//
469 #define GET_TYPEDEF_CLASSES
470 #include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc"
472 void ValueType::print(AsmPrinter
&printer
) const {
474 printer
.printType(getValueType());
478 Type
ValueType::parse(mlir::AsmParser
&parser
) {
480 if (parser
.parseLess() || parser
.parseType(ty
) || parser
.parseGreater()) {
481 parser
.emitError(parser
.getNameLoc(), "failed to parse async value type");
484 return ValueType::get(ty
);