1 //===- PDLInterp.cpp - PDL Interpreter Dialect ------------------*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
10 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
11 #include "mlir/IR/BuiltinTypes.h"
12 #include "mlir/IR/DialectImplementation.h"
13 #include "mlir/Interfaces/FunctionImplementation.h"
16 using namespace mlir::pdl_interp
;
18 #include "mlir/Dialect/PDLInterp/IR/PDLInterpOpsDialect.cpp.inc"
20 //===----------------------------------------------------------------------===//
22 //===----------------------------------------------------------------------===//
24 void PDLInterpDialect::initialize() {
27 #include "mlir/Dialect/PDLInterp/IR/PDLInterpOps.cpp.inc"
31 template <typename OpT
>
32 static LogicalResult
verifySwitchOp(OpT op
) {
33 // Verify that the number of case destinations matches the number of case
35 size_t numDests
= op
.getCases().size();
36 size_t numValues
= op
.getCaseValues().size();
37 if (numDests
!= numValues
) {
38 return op
.emitOpError(
39 "expected number of cases to match the number of case "
41 << numDests
<< " but expected " << numValues
;
46 //===----------------------------------------------------------------------===//
47 // pdl_interp::CreateOperationOp
48 //===----------------------------------------------------------------------===//
50 LogicalResult
CreateOperationOp::verify() {
51 if (!getInferredResultTypes())
53 if (!getInputResultTypes().empty()) {
54 return emitOpError("with inferred results cannot also have "
55 "explicit result types");
57 OperationName
opName(getName(), getContext());
58 if (!opName
.hasInterface
<InferTypeOpInterface
>()) {
60 << "has inferred results, but the created operation '" << opName
61 << "' does not support result type inference (or is not "
67 static ParseResult
parseCreateOperationOpAttributes(
69 SmallVectorImpl
<OpAsmParser::UnresolvedOperand
> &attrOperands
,
70 ArrayAttr
&attrNamesAttr
) {
71 Builder
&builder
= p
.getBuilder();
72 SmallVector
<Attribute
, 4> attrNames
;
73 if (succeeded(p
.parseOptionalLBrace())) {
74 auto parseOperands
= [&]() {
76 OpAsmParser::UnresolvedOperand operand
;
77 if (p
.parseAttribute(nameAttr
) || p
.parseEqual() ||
78 p
.parseOperand(operand
))
80 attrNames
.push_back(nameAttr
);
81 attrOperands
.push_back(operand
);
84 if (p
.parseCommaSeparatedList(parseOperands
) || p
.parseRBrace())
87 attrNamesAttr
= builder
.getArrayAttr(attrNames
);
91 static void printCreateOperationOpAttributes(OpAsmPrinter
&p
,
93 OperandRange attrArgs
,
94 ArrayAttr attrNames
) {
95 if (attrNames
.empty())
98 interleaveComma(llvm::seq
<int>(0, attrNames
.size()), p
,
99 [&](int i
) { p
<< attrNames
[i
] << " = " << attrArgs
[i
]; });
103 static ParseResult
parseCreateOperationOpResults(
105 SmallVectorImpl
<OpAsmParser::UnresolvedOperand
> &resultOperands
,
106 SmallVectorImpl
<Type
> &resultTypes
, UnitAttr
&inferredResultTypes
) {
107 if (failed(p
.parseOptionalArrow()))
110 // Handle the case of inferred results.
111 if (succeeded(p
.parseOptionalLess())) {
112 if (p
.parseKeyword("inferred") || p
.parseGreater())
114 inferredResultTypes
= p
.getBuilder().getUnitAttr();
118 // Otherwise, parse the explicit results.
119 return failure(p
.parseLParen() || p
.parseOperandList(resultOperands
) ||
120 p
.parseColonTypeList(resultTypes
) || p
.parseRParen());
123 static void printCreateOperationOpResults(OpAsmPrinter
&p
, CreateOperationOp op
,
124 OperandRange resultOperands
,
125 TypeRange resultTypes
,
126 UnitAttr inferredResultTypes
) {
127 // Handle the case of inferred results.
128 if (inferredResultTypes
) {
129 p
<< " -> <inferred>";
133 // Otherwise, handle the explicit results.
134 if (!resultTypes
.empty())
135 p
<< " -> (" << resultOperands
<< " : " << resultTypes
<< ")";
138 //===----------------------------------------------------------------------===//
139 // pdl_interp::ForEachOp
140 //===----------------------------------------------------------------------===//
142 void ForEachOp::build(::mlir::OpBuilder
&builder
, ::mlir::OperationState
&state
,
143 Value range
, Block
*successor
, bool initLoop
) {
144 build(builder
, state
, range
, successor
);
146 // Create the block and the loop variable.
147 // FIXME: Allow passing in a proper location for the loop variable.
148 auto rangeType
= llvm::cast
<pdl::RangeType
>(range
.getType());
149 state
.regions
.front()->emplaceBlock();
150 state
.regions
.front()->addArgument(rangeType
.getElementType(),
155 ParseResult
ForEachOp::parse(OpAsmParser
&parser
, OperationState
&result
) {
156 // Parse the loop variable followed by type.
157 OpAsmParser::Argument loopVariable
;
158 OpAsmParser::UnresolvedOperand operandInfo
;
159 if (parser
.parseArgument(loopVariable
, /*allowType=*/true) ||
160 parser
.parseKeyword("in", " after loop variable") ||
161 // Parse the operand (value range).
162 parser
.parseOperand(operandInfo
))
165 // Resolve the operand.
166 Type rangeType
= pdl::RangeType::get(loopVariable
.type
);
167 if (parser
.resolveOperand(operandInfo
, rangeType
, result
.operands
))
170 // Parse the body region.
171 Region
*body
= result
.addRegion();
173 if (parser
.parseRegion(*body
, loopVariable
) ||
174 parser
.parseOptionalAttrDict(result
.attributes
) ||
175 // Parse the successor.
176 parser
.parseArrow() || parser
.parseSuccessor(successor
))
179 result
.addSuccessors(successor
);
183 void ForEachOp::print(OpAsmPrinter
&p
) {
184 BlockArgument arg
= getLoopVariable();
185 p
<< ' ' << arg
<< " : " << arg
.getType() << " in " << getValues() << ' ';
186 p
.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
187 p
.printOptionalAttrDict((*this)->getAttrs());
189 p
.printSuccessor(getSuccessor());
192 LogicalResult
ForEachOp::verify() {
193 // Verify that the operation has exactly one argument.
194 if (getRegion().getNumArguments() != 1)
195 return emitOpError("requires exactly one argument");
197 // Verify that the loop variable and the operand (value range)
198 // have compatible types.
199 BlockArgument arg
= getLoopVariable();
200 Type rangeType
= pdl::RangeType::get(arg
.getType());
201 if (rangeType
!= getValues().getType())
202 return emitOpError("operand must be a range of loop variable type");
207 //===----------------------------------------------------------------------===//
208 // pdl_interp::FuncOp
209 //===----------------------------------------------------------------------===//
211 void FuncOp::build(OpBuilder
&builder
, OperationState
&state
, StringRef name
,
212 FunctionType type
, ArrayRef
<NamedAttribute
> attrs
) {
213 buildWithEntryBlock(builder
, state
, name
, type
, attrs
, type
.getInputs());
216 ParseResult
FuncOp::parse(OpAsmParser
&parser
, OperationState
&result
) {
218 [](Builder
&builder
, ArrayRef
<Type
> argTypes
, ArrayRef
<Type
> results
,
219 function_interface_impl::VariadicFlag
,
220 std::string
&) { return builder
.getFunctionType(argTypes
, results
); };
222 return function_interface_impl::parseFunctionOp(
223 parser
, result
, /*allowVariadic=*/false,
224 getFunctionTypeAttrName(result
.name
), buildFuncType
,
225 getArgAttrsAttrName(result
.name
), getResAttrsAttrName(result
.name
));
228 void FuncOp::print(OpAsmPrinter
&p
) {
229 function_interface_impl::printFunctionOp(
230 p
, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
231 getArgAttrsAttrName(), getResAttrsAttrName());
234 //===----------------------------------------------------------------------===//
235 // pdl_interp::GetValueTypeOp
236 //===----------------------------------------------------------------------===//
238 /// Given the result type of a `GetValueTypeOp`, return the expected input type.
239 static Type
getGetValueTypeOpValueType(Type type
) {
240 Type valueTy
= pdl::ValueType::get(type
.getContext());
241 return llvm::isa
<pdl::RangeType
>(type
) ? pdl::RangeType::get(valueTy
)
245 //===----------------------------------------------------------------------===//
246 // pdl::CreateRangeOp
247 //===----------------------------------------------------------------------===//
249 static ParseResult
parseRangeType(OpAsmParser
&p
, TypeRange argumentTypes
,
251 // If arguments were provided, infer the result type from the argument list.
252 if (!argumentTypes
.empty()) {
254 pdl::RangeType::get(pdl::getRangeElementTypeOrSelf(argumentTypes
[0]));
257 // Otherwise, parse the type as a trailing type.
258 return p
.parseColonType(resultType
);
261 static void printRangeType(OpAsmPrinter
&p
, CreateRangeOp op
,
262 TypeRange argumentTypes
, Type resultType
) {
263 if (argumentTypes
.empty())
264 p
<< ": " << resultType
;
267 LogicalResult
CreateRangeOp::verify() {
268 Type elementType
= getType().getElementType();
269 for (Type operandType
: getOperandTypes()) {
270 Type operandElementType
= pdl::getRangeElementTypeOrSelf(operandType
);
271 if (operandElementType
!= elementType
) {
272 return emitOpError("expected operand to have element type ")
273 << elementType
<< ", but got " << operandElementType
;
279 //===----------------------------------------------------------------------===//
280 // pdl_interp::SwitchAttributeOp
281 //===----------------------------------------------------------------------===//
283 LogicalResult
SwitchAttributeOp::verify() { return verifySwitchOp(*this); }
285 //===----------------------------------------------------------------------===//
286 // pdl_interp::SwitchOperandCountOp
287 //===----------------------------------------------------------------------===//
289 LogicalResult
SwitchOperandCountOp::verify() { return verifySwitchOp(*this); }
291 //===----------------------------------------------------------------------===//
292 // pdl_interp::SwitchOperationNameOp
293 //===----------------------------------------------------------------------===//
295 LogicalResult
SwitchOperationNameOp::verify() { return verifySwitchOp(*this); }
297 //===----------------------------------------------------------------------===//
298 // pdl_interp::SwitchResultCountOp
299 //===----------------------------------------------------------------------===//
301 LogicalResult
SwitchResultCountOp::verify() { return verifySwitchOp(*this); }
303 //===----------------------------------------------------------------------===//
304 // pdl_interp::SwitchTypeOp
305 //===----------------------------------------------------------------------===//
307 LogicalResult
SwitchTypeOp::verify() { return verifySwitchOp(*this); }
309 //===----------------------------------------------------------------------===//
310 // pdl_interp::SwitchTypesOp
311 //===----------------------------------------------------------------------===//
313 LogicalResult
SwitchTypesOp::verify() { return verifySwitchOp(*this); }
315 //===----------------------------------------------------------------------===//
316 // TableGen Auto-Generated Op and Interface Definitions
317 //===----------------------------------------------------------------------===//
319 #define GET_OP_CLASSES
320 #include "mlir/Dialect/PDLInterp/IR/PDLInterpOps.cpp.inc"