[mlir] FunctionOpInterface: turn required attributes into interface methods (Reland)
[llvm-project.git] / mlir / lib / Dialect / Func / IR / FuncOps.cpp
blob7bb3663cc43bed593ce8c6aabd338596beddb63d
1 //===- FuncOps.cpp - Func Dialect Operations ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Dialect/Func/IR/FuncOps.h"
11 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
12 #include "mlir/IR/BlockAndValueMapping.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/BuiltinOps.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/FunctionImplementation.h"
17 #include "mlir/IR/Matchers.h"
18 #include "mlir/IR/OpImplementation.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/IR/TypeUtilities.h"
21 #include "mlir/IR/Value.h"
22 #include "mlir/Support/MathExtras.h"
23 #include "mlir/Transforms/InliningUtils.h"
24 #include "llvm/ADT/APFloat.h"
25 #include "llvm/ADT/MapVector.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/Support/FormatVariadic.h"
28 #include "llvm/Support/raw_ostream.h"
29 #include <numeric>
31 #include "mlir/Dialect/Func/IR/FuncOpsDialect.cpp.inc"
33 using namespace mlir;
34 using namespace mlir::func;
36 //===----------------------------------------------------------------------===//
37 // FuncDialect Interfaces
38 //===----------------------------------------------------------------------===//
39 namespace {
40 /// This class defines the interface for handling inlining with func operations.
41 struct FuncInlinerInterface : public DialectInlinerInterface {
42 using DialectInlinerInterface::DialectInlinerInterface;
44 //===--------------------------------------------------------------------===//
45 // Analysis Hooks
46 //===--------------------------------------------------------------------===//
48 /// All call operations can be inlined.
49 bool isLegalToInline(Operation *call, Operation *callable,
50 bool wouldBeCloned) const final {
51 return true;
54 /// All operations can be inlined.
55 bool isLegalToInline(Operation *, Region *, bool,
56 BlockAndValueMapping &) const final {
57 return true;
60 /// All functions can be inlined.
61 bool isLegalToInline(Region *, Region *, bool,
62 BlockAndValueMapping &) const final {
63 return true;
66 //===--------------------------------------------------------------------===//
67 // Transformation Hooks
68 //===--------------------------------------------------------------------===//
70 /// Handle the given inlined terminator by replacing it with a new operation
71 /// as necessary.
72 void handleTerminator(Operation *op, Block *newDest) const final {
73 // Only return needs to be handled here.
74 auto returnOp = dyn_cast<ReturnOp>(op);
75 if (!returnOp)
76 return;
78 // Replace the return with a branch to the dest.
79 OpBuilder builder(op);
80 builder.create<cf::BranchOp>(op->getLoc(), newDest, returnOp.getOperands());
81 op->erase();
84 /// Handle the given inlined terminator by replacing it with a new operation
85 /// as necessary.
86 void handleTerminator(Operation *op,
87 ArrayRef<Value> valuesToRepl) const final {
88 // Only return needs to be handled here.
89 auto returnOp = cast<ReturnOp>(op);
91 // Replace the values directly with the return operands.
92 assert(returnOp.getNumOperands() == valuesToRepl.size());
93 for (const auto &it : llvm::enumerate(returnOp.getOperands()))
94 valuesToRepl[it.index()].replaceAllUsesWith(it.value());
97 } // namespace
99 //===----------------------------------------------------------------------===//
100 // FuncDialect
101 //===----------------------------------------------------------------------===//
103 void FuncDialect::initialize() {
104 addOperations<
105 #define GET_OP_LIST
106 #include "mlir/Dialect/Func/IR/FuncOps.cpp.inc"
107 >();
108 addInterfaces<FuncInlinerInterface>();
111 /// Materialize a single constant operation from a given attribute value with
112 /// the desired resultant type.
113 Operation *FuncDialect::materializeConstant(OpBuilder &builder, Attribute value,
114 Type type, Location loc) {
115 if (ConstantOp::isBuildableWith(value, type))
116 return builder.create<ConstantOp>(loc, type,
117 value.cast<FlatSymbolRefAttr>());
118 return nullptr;
121 //===----------------------------------------------------------------------===//
122 // CallOp
123 //===----------------------------------------------------------------------===//
125 LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
126 // Check that the callee attribute was specified.
127 auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
128 if (!fnAttr)
129 return emitOpError("requires a 'callee' symbol reference attribute");
130 FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
131 if (!fn)
132 return emitOpError() << "'" << fnAttr.getValue()
133 << "' does not reference a valid function";
135 // Verify that the operand and result types match the callee.
136 auto fnType = fn.getFunctionType();
137 if (fnType.getNumInputs() != getNumOperands())
138 return emitOpError("incorrect number of operands for callee");
140 for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
141 if (getOperand(i).getType() != fnType.getInput(i))
142 return emitOpError("operand type mismatch: expected operand type ")
143 << fnType.getInput(i) << ", but provided "
144 << getOperand(i).getType() << " for operand number " << i;
146 if (fnType.getNumResults() != getNumResults())
147 return emitOpError("incorrect number of results for callee");
149 for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
150 if (getResult(i).getType() != fnType.getResult(i)) {
151 auto diag = emitOpError("result type mismatch at index ") << i;
152 diag.attachNote() << " op result types: " << getResultTypes();
153 diag.attachNote() << "function result types: " << fnType.getResults();
154 return diag;
157 return success();
160 FunctionType CallOp::getCalleeType() {
161 return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
164 //===----------------------------------------------------------------------===//
165 // CallIndirectOp
166 //===----------------------------------------------------------------------===//
168 /// Fold indirect calls that have a constant function as the callee operand.
169 LogicalResult CallIndirectOp::canonicalize(CallIndirectOp indirectCall,
170 PatternRewriter &rewriter) {
171 // Check that the callee is a constant callee.
172 SymbolRefAttr calledFn;
173 if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn)))
174 return failure();
176 // Replace with a direct call.
177 rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn,
178 indirectCall.getResultTypes(),
179 indirectCall.getArgOperands());
180 return success();
183 //===----------------------------------------------------------------------===//
184 // ConstantOp
185 //===----------------------------------------------------------------------===//
187 LogicalResult ConstantOp::verify() {
188 StringRef fnName = getValue();
189 Type type = getType();
191 // Try to find the referenced function.
192 auto fn = (*this)->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(fnName);
193 if (!fn)
194 return emitOpError() << "reference to undefined function '" << fnName
195 << "'";
197 // Check that the referenced function has the correct type.
198 if (fn.getFunctionType() != type)
199 return emitOpError("reference to function with mismatched type");
201 return success();
204 OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
205 assert(operands.empty() && "constant has no operands");
206 return getValueAttr();
209 void ConstantOp::getAsmResultNames(
210 function_ref<void(Value, StringRef)> setNameFn) {
211 setNameFn(getResult(), "f");
214 bool ConstantOp::isBuildableWith(Attribute value, Type type) {
215 return value.isa<FlatSymbolRefAttr>() && type.isa<FunctionType>();
218 //===----------------------------------------------------------------------===//
219 // FuncOp
220 //===----------------------------------------------------------------------===//
222 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
223 ArrayRef<NamedAttribute> attrs) {
224 OpBuilder builder(location->getContext());
225 OperationState state(location, getOperationName());
226 FuncOp::build(builder, state, name, type, attrs);
227 return cast<FuncOp>(Operation::create(state));
229 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
230 Operation::dialect_attr_range attrs) {
231 SmallVector<NamedAttribute, 8> attrRef(attrs);
232 return create(location, name, type, llvm::makeArrayRef(attrRef));
234 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
235 ArrayRef<NamedAttribute> attrs,
236 ArrayRef<DictionaryAttr> argAttrs) {
237 FuncOp func = create(location, name, type, attrs);
238 func.setAllArgAttrs(argAttrs);
239 return func;
242 void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
243 FunctionType type, ArrayRef<NamedAttribute> attrs,
244 ArrayRef<DictionaryAttr> argAttrs) {
245 state.addAttribute(SymbolTable::getSymbolAttrName(),
246 builder.getStringAttr(name));
247 state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
248 state.attributes.append(attrs.begin(), attrs.end());
249 state.addRegion();
251 if (argAttrs.empty())
252 return;
253 assert(type.getNumInputs() == argAttrs.size());
254 function_interface_impl::addArgAndResultAttrs(
255 builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
256 getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
259 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
260 auto buildFuncType =
261 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
262 function_interface_impl::VariadicFlag,
263 std::string &) { return builder.getFunctionType(argTypes, results); };
265 return function_interface_impl::parseFunctionOp(
266 parser, result, /*allowVariadic=*/false,
267 getFunctionTypeAttrName(result.name), buildFuncType,
268 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
271 void FuncOp::print(OpAsmPrinter &p) {
272 function_interface_impl::printFunctionOp(
273 p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
274 getArgAttrsAttrName(), getResAttrsAttrName());
277 /// Clone the internal blocks from this function into dest and all attributes
278 /// from this function to dest.
279 void FuncOp::cloneInto(FuncOp dest, BlockAndValueMapping &mapper) {
280 // Add the attributes of this function to dest.
281 llvm::MapVector<StringAttr, Attribute> newAttrMap;
282 for (const auto &attr : dest->getAttrs())
283 newAttrMap.insert({attr.getName(), attr.getValue()});
284 for (const auto &attr : (*this)->getAttrs())
285 newAttrMap.insert({attr.getName(), attr.getValue()});
287 auto newAttrs = llvm::to_vector(llvm::map_range(
288 newAttrMap, [](std::pair<StringAttr, Attribute> attrPair) {
289 return NamedAttribute(attrPair.first, attrPair.second);
290 }));
291 dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs));
293 // Clone the body.
294 getBody().cloneInto(&dest.getBody(), mapper);
297 /// Create a deep copy of this function and all of its blocks, remapping
298 /// any operands that use values outside of the function using the map that is
299 /// provided (leaving them alone if no entry is present). Replaces references
300 /// to cloned sub-values with the corresponding value that is copied, and adds
301 /// those mappings to the mapper.
302 FuncOp FuncOp::clone(BlockAndValueMapping &mapper) {
303 // Create the new function.
304 FuncOp newFunc = cast<FuncOp>(getOperation()->cloneWithoutRegions());
306 // If the function has a body, then the user might be deleting arguments to
307 // the function by specifying them in the mapper. If so, we don't add the
308 // argument to the input type vector.
309 if (!isExternal()) {
310 FunctionType oldType = getFunctionType();
312 unsigned oldNumArgs = oldType.getNumInputs();
313 SmallVector<Type, 4> newInputs;
314 newInputs.reserve(oldNumArgs);
315 for (unsigned i = 0; i != oldNumArgs; ++i)
316 if (!mapper.contains(getArgument(i)))
317 newInputs.push_back(oldType.getInput(i));
319 /// If any of the arguments were dropped, update the type and drop any
320 /// necessary argument attributes.
321 if (newInputs.size() != oldNumArgs) {
322 newFunc.setType(FunctionType::get(oldType.getContext(), newInputs,
323 oldType.getResults()));
325 if (ArrayAttr argAttrs = getAllArgAttrs()) {
326 SmallVector<Attribute> newArgAttrs;
327 newArgAttrs.reserve(newInputs.size());
328 for (unsigned i = 0; i != oldNumArgs; ++i)
329 if (!mapper.contains(getArgument(i)))
330 newArgAttrs.push_back(argAttrs[i]);
331 newFunc.setAllArgAttrs(newArgAttrs);
336 /// Clone the current function into the new one and return it.
337 cloneInto(newFunc, mapper);
338 return newFunc;
340 FuncOp FuncOp::clone() {
341 BlockAndValueMapping mapper;
342 return clone(mapper);
345 //===----------------------------------------------------------------------===//
346 // ReturnOp
347 //===----------------------------------------------------------------------===//
349 LogicalResult ReturnOp::verify() {
350 auto function = cast<FuncOp>((*this)->getParentOp());
352 // The operand number and types must match the function signature.
353 const auto &results = function.getFunctionType().getResults();
354 if (getNumOperands() != results.size())
355 return emitOpError("has ")
356 << getNumOperands() << " operands, but enclosing function (@"
357 << function.getName() << ") returns " << results.size();
359 for (unsigned i = 0, e = results.size(); i != e; ++i)
360 if (getOperand(i).getType() != results[i])
361 return emitError() << "type of return operand " << i << " ("
362 << getOperand(i).getType()
363 << ") doesn't match function result type ("
364 << results[i] << ")"
365 << " in function @" << function.getName();
367 return success();
370 //===----------------------------------------------------------------------===//
371 // TableGen'd op method definitions
372 //===----------------------------------------------------------------------===//
374 #define GET_OP_CLASSES
375 #include "mlir/Dialect/Func/IR/FuncOps.cpp.inc"