[VPlan] Run recipe removal and simplification after optimizeForVFAndUF. (#125926)
[llvm-project.git] / mlir / lib / Dialect / PDLInterp / IR / PDLInterp.cpp
blob9b1f11d8352827aae195560f918e8532752f6e8b
1 //===- PDLInterp.cpp - PDL Interpreter Dialect ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
10 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
11 #include "mlir/IR/BuiltinTypes.h"
12 #include "mlir/IR/DialectImplementation.h"
13 #include "mlir/Interfaces/FunctionImplementation.h"
15 using namespace mlir;
16 using namespace mlir::pdl_interp;
18 #include "mlir/Dialect/PDLInterp/IR/PDLInterpOpsDialect.cpp.inc"
20 //===----------------------------------------------------------------------===//
21 // PDLInterp Dialect
22 //===----------------------------------------------------------------------===//
24 void PDLInterpDialect::initialize() {
25 addOperations<
26 #define GET_OP_LIST
27 #include "mlir/Dialect/PDLInterp/IR/PDLInterpOps.cpp.inc"
28 >();
31 template <typename OpT>
32 static LogicalResult verifySwitchOp(OpT op) {
33 // Verify that the number of case destinations matches the number of case
34 // values.
35 size_t numDests = op.getCases().size();
36 size_t numValues = op.getCaseValues().size();
37 if (numDests != numValues) {
38 return op.emitOpError(
39 "expected number of cases to match the number of case "
40 "values, got ")
41 << numDests << " but expected " << numValues;
43 return success();
46 //===----------------------------------------------------------------------===//
47 // pdl_interp::CreateOperationOp
48 //===----------------------------------------------------------------------===//
50 LogicalResult CreateOperationOp::verify() {
51 if (!getInferredResultTypes())
52 return success();
53 if (!getInputResultTypes().empty()) {
54 return emitOpError("with inferred results cannot also have "
55 "explicit result types");
57 OperationName opName(getName(), getContext());
58 if (!opName.hasInterface<InferTypeOpInterface>()) {
59 return emitOpError()
60 << "has inferred results, but the created operation '" << opName
61 << "' does not support result type inference (or is not "
62 "registered)";
64 return success();
67 static ParseResult parseCreateOperationOpAttributes(
68 OpAsmParser &p,
69 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &attrOperands,
70 ArrayAttr &attrNamesAttr) {
71 Builder &builder = p.getBuilder();
72 SmallVector<Attribute, 4> attrNames;
73 if (succeeded(p.parseOptionalLBrace())) {
74 auto parseOperands = [&]() {
75 StringAttr nameAttr;
76 OpAsmParser::UnresolvedOperand operand;
77 if (p.parseAttribute(nameAttr) || p.parseEqual() ||
78 p.parseOperand(operand))
79 return failure();
80 attrNames.push_back(nameAttr);
81 attrOperands.push_back(operand);
82 return success();
84 if (p.parseCommaSeparatedList(parseOperands) || p.parseRBrace())
85 return failure();
87 attrNamesAttr = builder.getArrayAttr(attrNames);
88 return success();
91 static void printCreateOperationOpAttributes(OpAsmPrinter &p,
92 CreateOperationOp op,
93 OperandRange attrArgs,
94 ArrayAttr attrNames) {
95 if (attrNames.empty())
96 return;
97 p << " {";
98 interleaveComma(llvm::seq<int>(0, attrNames.size()), p,
99 [&](int i) { p << attrNames[i] << " = " << attrArgs[i]; });
100 p << '}';
103 static ParseResult parseCreateOperationOpResults(
104 OpAsmParser &p,
105 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &resultOperands,
106 SmallVectorImpl<Type> &resultTypes, UnitAttr &inferredResultTypes) {
107 if (failed(p.parseOptionalArrow()))
108 return success();
110 // Handle the case of inferred results.
111 if (succeeded(p.parseOptionalLess())) {
112 if (p.parseKeyword("inferred") || p.parseGreater())
113 return failure();
114 inferredResultTypes = p.getBuilder().getUnitAttr();
115 return success();
118 // Otherwise, parse the explicit results.
119 return failure(p.parseLParen() || p.parseOperandList(resultOperands) ||
120 p.parseColonTypeList(resultTypes) || p.parseRParen());
123 static void printCreateOperationOpResults(OpAsmPrinter &p, CreateOperationOp op,
124 OperandRange resultOperands,
125 TypeRange resultTypes,
126 UnitAttr inferredResultTypes) {
127 // Handle the case of inferred results.
128 if (inferredResultTypes) {
129 p << " -> <inferred>";
130 return;
133 // Otherwise, handle the explicit results.
134 if (!resultTypes.empty())
135 p << " -> (" << resultOperands << " : " << resultTypes << ")";
138 //===----------------------------------------------------------------------===//
139 // pdl_interp::ForEachOp
140 //===----------------------------------------------------------------------===//
142 void ForEachOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
143 Value range, Block *successor, bool initLoop) {
144 build(builder, state, range, successor);
145 if (initLoop) {
146 // Create the block and the loop variable.
147 // FIXME: Allow passing in a proper location for the loop variable.
148 auto rangeType = llvm::cast<pdl::RangeType>(range.getType());
149 state.regions.front()->emplaceBlock();
150 state.regions.front()->addArgument(rangeType.getElementType(),
151 state.location);
155 ParseResult ForEachOp::parse(OpAsmParser &parser, OperationState &result) {
156 // Parse the loop variable followed by type.
157 OpAsmParser::Argument loopVariable;
158 OpAsmParser::UnresolvedOperand operandInfo;
159 if (parser.parseArgument(loopVariable, /*allowType=*/true) ||
160 parser.parseKeyword("in", " after loop variable") ||
161 // Parse the operand (value range).
162 parser.parseOperand(operandInfo))
163 return failure();
165 // Resolve the operand.
166 Type rangeType = pdl::RangeType::get(loopVariable.type);
167 if (parser.resolveOperand(operandInfo, rangeType, result.operands))
168 return failure();
170 // Parse the body region.
171 Region *body = result.addRegion();
172 Block *successor;
173 if (parser.parseRegion(*body, loopVariable) ||
174 parser.parseOptionalAttrDict(result.attributes) ||
175 // Parse the successor.
176 parser.parseArrow() || parser.parseSuccessor(successor))
177 return failure();
179 result.addSuccessors(successor);
180 return success();
183 void ForEachOp::print(OpAsmPrinter &p) {
184 BlockArgument arg = getLoopVariable();
185 p << ' ' << arg << " : " << arg.getType() << " in " << getValues() << ' ';
186 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
187 p.printOptionalAttrDict((*this)->getAttrs());
188 p << " -> ";
189 p.printSuccessor(getSuccessor());
192 LogicalResult ForEachOp::verify() {
193 // Verify that the operation has exactly one argument.
194 if (getRegion().getNumArguments() != 1)
195 return emitOpError("requires exactly one argument");
197 // Verify that the loop variable and the operand (value range)
198 // have compatible types.
199 BlockArgument arg = getLoopVariable();
200 Type rangeType = pdl::RangeType::get(arg.getType());
201 if (rangeType != getValues().getType())
202 return emitOpError("operand must be a range of loop variable type");
204 return success();
207 //===----------------------------------------------------------------------===//
208 // pdl_interp::FuncOp
209 //===----------------------------------------------------------------------===//
211 void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
212 FunctionType type, ArrayRef<NamedAttribute> attrs) {
213 buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs());
216 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
217 auto buildFuncType =
218 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
219 function_interface_impl::VariadicFlag,
220 std::string &) { return builder.getFunctionType(argTypes, results); };
222 return function_interface_impl::parseFunctionOp(
223 parser, result, /*allowVariadic=*/false,
224 getFunctionTypeAttrName(result.name), buildFuncType,
225 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
228 void FuncOp::print(OpAsmPrinter &p) {
229 function_interface_impl::printFunctionOp(
230 p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
231 getArgAttrsAttrName(), getResAttrsAttrName());
234 //===----------------------------------------------------------------------===//
235 // pdl_interp::GetValueTypeOp
236 //===----------------------------------------------------------------------===//
238 /// Given the result type of a `GetValueTypeOp`, return the expected input type.
239 static Type getGetValueTypeOpValueType(Type type) {
240 Type valueTy = pdl::ValueType::get(type.getContext());
241 return llvm::isa<pdl::RangeType>(type) ? pdl::RangeType::get(valueTy)
242 : valueTy;
245 //===----------------------------------------------------------------------===//
246 // pdl::CreateRangeOp
247 //===----------------------------------------------------------------------===//
249 static ParseResult parseRangeType(OpAsmParser &p, TypeRange argumentTypes,
250 Type &resultType) {
251 // If arguments were provided, infer the result type from the argument list.
252 if (!argumentTypes.empty()) {
253 resultType =
254 pdl::RangeType::get(pdl::getRangeElementTypeOrSelf(argumentTypes[0]));
255 return success();
257 // Otherwise, parse the type as a trailing type.
258 return p.parseColonType(resultType);
261 static void printRangeType(OpAsmPrinter &p, CreateRangeOp op,
262 TypeRange argumentTypes, Type resultType) {
263 if (argumentTypes.empty())
264 p << ": " << resultType;
267 LogicalResult CreateRangeOp::verify() {
268 Type elementType = getType().getElementType();
269 for (Type operandType : getOperandTypes()) {
270 Type operandElementType = pdl::getRangeElementTypeOrSelf(operandType);
271 if (operandElementType != elementType) {
272 return emitOpError("expected operand to have element type ")
273 << elementType << ", but got " << operandElementType;
276 return success();
279 //===----------------------------------------------------------------------===//
280 // pdl_interp::SwitchAttributeOp
281 //===----------------------------------------------------------------------===//
283 LogicalResult SwitchAttributeOp::verify() { return verifySwitchOp(*this); }
285 //===----------------------------------------------------------------------===//
286 // pdl_interp::SwitchOperandCountOp
287 //===----------------------------------------------------------------------===//
289 LogicalResult SwitchOperandCountOp::verify() { return verifySwitchOp(*this); }
291 //===----------------------------------------------------------------------===//
292 // pdl_interp::SwitchOperationNameOp
293 //===----------------------------------------------------------------------===//
295 LogicalResult SwitchOperationNameOp::verify() { return verifySwitchOp(*this); }
297 //===----------------------------------------------------------------------===//
298 // pdl_interp::SwitchResultCountOp
299 //===----------------------------------------------------------------------===//
301 LogicalResult SwitchResultCountOp::verify() { return verifySwitchOp(*this); }
303 //===----------------------------------------------------------------------===//
304 // pdl_interp::SwitchTypeOp
305 //===----------------------------------------------------------------------===//
307 LogicalResult SwitchTypeOp::verify() { return verifySwitchOp(*this); }
309 //===----------------------------------------------------------------------===//
310 // pdl_interp::SwitchTypesOp
311 //===----------------------------------------------------------------------===//
313 LogicalResult SwitchTypesOp::verify() { return verifySwitchOp(*this); }
315 //===----------------------------------------------------------------------===//
316 // TableGen Auto-Generated Op and Interface Definitions
317 //===----------------------------------------------------------------------===//
319 #define GET_OP_CLASSES
320 #include "mlir/Dialect/PDLInterp/IR/PDLInterpOps.cpp.inc"