1 //===- FuncOps.cpp - Func Dialect Operations ------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Dialect/Func/IR/FuncOps.h"
11 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
12 #include "mlir/IR/BlockAndValueMapping.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/BuiltinOps.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/FunctionImplementation.h"
17 #include "mlir/IR/Matchers.h"
18 #include "mlir/IR/OpImplementation.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/IR/TypeUtilities.h"
21 #include "mlir/IR/Value.h"
22 #include "mlir/Support/MathExtras.h"
23 #include "mlir/Transforms/InliningUtils.h"
24 #include "llvm/ADT/APFloat.h"
25 #include "llvm/ADT/MapVector.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/Support/FormatVariadic.h"
28 #include "llvm/Support/raw_ostream.h"
31 #include "mlir/Dialect/Func/IR/FuncOpsDialect.cpp.inc"
34 using namespace mlir::func
;
36 //===----------------------------------------------------------------------===//
37 // FuncDialect Interfaces
38 //===----------------------------------------------------------------------===//
40 /// This class defines the interface for handling inlining with func operations.
41 struct FuncInlinerInterface
: public DialectInlinerInterface
{
42 using DialectInlinerInterface::DialectInlinerInterface
;
44 //===--------------------------------------------------------------------===//
46 //===--------------------------------------------------------------------===//
48 /// All call operations can be inlined.
49 bool isLegalToInline(Operation
*call
, Operation
*callable
,
50 bool wouldBeCloned
) const final
{
54 /// All operations can be inlined.
55 bool isLegalToInline(Operation
*, Region
*, bool,
56 BlockAndValueMapping
&) const final
{
60 /// All functions can be inlined.
61 bool isLegalToInline(Region
*, Region
*, bool,
62 BlockAndValueMapping
&) const final
{
66 //===--------------------------------------------------------------------===//
67 // Transformation Hooks
68 //===--------------------------------------------------------------------===//
70 /// Handle the given inlined terminator by replacing it with a new operation
72 void handleTerminator(Operation
*op
, Block
*newDest
) const final
{
73 // Only return needs to be handled here.
74 auto returnOp
= dyn_cast
<ReturnOp
>(op
);
78 // Replace the return with a branch to the dest.
79 OpBuilder
builder(op
);
80 builder
.create
<cf::BranchOp
>(op
->getLoc(), newDest
, returnOp
.getOperands());
84 /// Handle the given inlined terminator by replacing it with a new operation
86 void handleTerminator(Operation
*op
,
87 ArrayRef
<Value
> valuesToRepl
) const final
{
88 // Only return needs to be handled here.
89 auto returnOp
= cast
<ReturnOp
>(op
);
91 // Replace the values directly with the return operands.
92 assert(returnOp
.getNumOperands() == valuesToRepl
.size());
93 for (const auto &it
: llvm::enumerate(returnOp
.getOperands()))
94 valuesToRepl
[it
.index()].replaceAllUsesWith(it
.value());
99 //===----------------------------------------------------------------------===//
101 //===----------------------------------------------------------------------===//
103 void FuncDialect::initialize() {
106 #include "mlir/Dialect/Func/IR/FuncOps.cpp.inc"
108 addInterfaces
<FuncInlinerInterface
>();
111 /// Materialize a single constant operation from a given attribute value with
112 /// the desired resultant type.
113 Operation
*FuncDialect::materializeConstant(OpBuilder
&builder
, Attribute value
,
114 Type type
, Location loc
) {
115 if (ConstantOp::isBuildableWith(value
, type
))
116 return builder
.create
<ConstantOp
>(loc
, type
,
117 value
.cast
<FlatSymbolRefAttr
>());
121 //===----------------------------------------------------------------------===//
123 //===----------------------------------------------------------------------===//
125 LogicalResult
CallOp::verifySymbolUses(SymbolTableCollection
&symbolTable
) {
126 // Check that the callee attribute was specified.
127 auto fnAttr
= (*this)->getAttrOfType
<FlatSymbolRefAttr
>("callee");
129 return emitOpError("requires a 'callee' symbol reference attribute");
130 FuncOp fn
= symbolTable
.lookupNearestSymbolFrom
<FuncOp
>(*this, fnAttr
);
132 return emitOpError() << "'" << fnAttr
.getValue()
133 << "' does not reference a valid function";
135 // Verify that the operand and result types match the callee.
136 auto fnType
= fn
.getFunctionType();
137 if (fnType
.getNumInputs() != getNumOperands())
138 return emitOpError("incorrect number of operands for callee");
140 for (unsigned i
= 0, e
= fnType
.getNumInputs(); i
!= e
; ++i
)
141 if (getOperand(i
).getType() != fnType
.getInput(i
))
142 return emitOpError("operand type mismatch: expected operand type ")
143 << fnType
.getInput(i
) << ", but provided "
144 << getOperand(i
).getType() << " for operand number " << i
;
146 if (fnType
.getNumResults() != getNumResults())
147 return emitOpError("incorrect number of results for callee");
149 for (unsigned i
= 0, e
= fnType
.getNumResults(); i
!= e
; ++i
)
150 if (getResult(i
).getType() != fnType
.getResult(i
)) {
151 auto diag
= emitOpError("result type mismatch at index ") << i
;
152 diag
.attachNote() << " op result types: " << getResultTypes();
153 diag
.attachNote() << "function result types: " << fnType
.getResults();
160 FunctionType
CallOp::getCalleeType() {
161 return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
164 //===----------------------------------------------------------------------===//
166 //===----------------------------------------------------------------------===//
168 /// Fold indirect calls that have a constant function as the callee operand.
169 LogicalResult
CallIndirectOp::canonicalize(CallIndirectOp indirectCall
,
170 PatternRewriter
&rewriter
) {
171 // Check that the callee is a constant callee.
172 SymbolRefAttr calledFn
;
173 if (!matchPattern(indirectCall
.getCallee(), m_Constant(&calledFn
)))
176 // Replace with a direct call.
177 rewriter
.replaceOpWithNewOp
<CallOp
>(indirectCall
, calledFn
,
178 indirectCall
.getResultTypes(),
179 indirectCall
.getArgOperands());
183 //===----------------------------------------------------------------------===//
185 //===----------------------------------------------------------------------===//
187 LogicalResult
ConstantOp::verify() {
188 StringRef fnName
= getValue();
189 Type type
= getType();
191 // Try to find the referenced function.
192 auto fn
= (*this)->getParentOfType
<ModuleOp
>().lookupSymbol
<FuncOp
>(fnName
);
194 return emitOpError() << "reference to undefined function '" << fnName
197 // Check that the referenced function has the correct type.
198 if (fn
.getFunctionType() != type
)
199 return emitOpError("reference to function with mismatched type");
204 OpFoldResult
ConstantOp::fold(ArrayRef
<Attribute
> operands
) {
205 assert(operands
.empty() && "constant has no operands");
206 return getValueAttr();
209 void ConstantOp::getAsmResultNames(
210 function_ref
<void(Value
, StringRef
)> setNameFn
) {
211 setNameFn(getResult(), "f");
214 bool ConstantOp::isBuildableWith(Attribute value
, Type type
) {
215 return value
.isa
<FlatSymbolRefAttr
>() && type
.isa
<FunctionType
>();
218 //===----------------------------------------------------------------------===//
220 //===----------------------------------------------------------------------===//
222 FuncOp
FuncOp::create(Location location
, StringRef name
, FunctionType type
,
223 ArrayRef
<NamedAttribute
> attrs
) {
224 OpBuilder
builder(location
->getContext());
225 OperationState
state(location
, getOperationName());
226 FuncOp::build(builder
, state
, name
, type
, attrs
);
227 return cast
<FuncOp
>(Operation::create(state
));
229 FuncOp
FuncOp::create(Location location
, StringRef name
, FunctionType type
,
230 Operation::dialect_attr_range attrs
) {
231 SmallVector
<NamedAttribute
, 8> attrRef(attrs
);
232 return create(location
, name
, type
, llvm::makeArrayRef(attrRef
));
234 FuncOp
FuncOp::create(Location location
, StringRef name
, FunctionType type
,
235 ArrayRef
<NamedAttribute
> attrs
,
236 ArrayRef
<DictionaryAttr
> argAttrs
) {
237 FuncOp func
= create(location
, name
, type
, attrs
);
238 func
.setAllArgAttrs(argAttrs
);
242 void FuncOp::build(OpBuilder
&builder
, OperationState
&state
, StringRef name
,
243 FunctionType type
, ArrayRef
<NamedAttribute
> attrs
,
244 ArrayRef
<DictionaryAttr
> argAttrs
) {
245 state
.addAttribute(SymbolTable::getSymbolAttrName(),
246 builder
.getStringAttr(name
));
247 state
.addAttribute(getFunctionTypeAttrName(state
.name
), TypeAttr::get(type
));
248 state
.attributes
.append(attrs
.begin(), attrs
.end());
251 if (argAttrs
.empty())
253 assert(type
.getNumInputs() == argAttrs
.size());
254 function_interface_impl::addArgAndResultAttrs(
255 builder
, state
, argAttrs
, /*resultAttrs=*/std::nullopt
,
256 getArgAttrsAttrName(state
.name
), getResAttrsAttrName(state
.name
));
259 ParseResult
FuncOp::parse(OpAsmParser
&parser
, OperationState
&result
) {
261 [](Builder
&builder
, ArrayRef
<Type
> argTypes
, ArrayRef
<Type
> results
,
262 function_interface_impl::VariadicFlag
,
263 std::string
&) { return builder
.getFunctionType(argTypes
, results
); };
265 return function_interface_impl::parseFunctionOp(
266 parser
, result
, /*allowVariadic=*/false,
267 getFunctionTypeAttrName(result
.name
), buildFuncType
,
268 getArgAttrsAttrName(result
.name
), getResAttrsAttrName(result
.name
));
271 void FuncOp::print(OpAsmPrinter
&p
) {
272 function_interface_impl::printFunctionOp(
273 p
, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
274 getArgAttrsAttrName(), getResAttrsAttrName());
277 /// Clone the internal blocks from this function into dest and all attributes
278 /// from this function to dest.
279 void FuncOp::cloneInto(FuncOp dest
, BlockAndValueMapping
&mapper
) {
280 // Add the attributes of this function to dest.
281 llvm::MapVector
<StringAttr
, Attribute
> newAttrMap
;
282 for (const auto &attr
: dest
->getAttrs())
283 newAttrMap
.insert({attr
.getName(), attr
.getValue()});
284 for (const auto &attr
: (*this)->getAttrs())
285 newAttrMap
.insert({attr
.getName(), attr
.getValue()});
287 auto newAttrs
= llvm::to_vector(llvm::map_range(
288 newAttrMap
, [](std::pair
<StringAttr
, Attribute
> attrPair
) {
289 return NamedAttribute(attrPair
.first
, attrPair
.second
);
291 dest
->setAttrs(DictionaryAttr::get(getContext(), newAttrs
));
294 getBody().cloneInto(&dest
.getBody(), mapper
);
297 /// Create a deep copy of this function and all of its blocks, remapping
298 /// any operands that use values outside of the function using the map that is
299 /// provided (leaving them alone if no entry is present). Replaces references
300 /// to cloned sub-values with the corresponding value that is copied, and adds
301 /// those mappings to the mapper.
302 FuncOp
FuncOp::clone(BlockAndValueMapping
&mapper
) {
303 // Create the new function.
304 FuncOp newFunc
= cast
<FuncOp
>(getOperation()->cloneWithoutRegions());
306 // If the function has a body, then the user might be deleting arguments to
307 // the function by specifying them in the mapper. If so, we don't add the
308 // argument to the input type vector.
310 FunctionType oldType
= getFunctionType();
312 unsigned oldNumArgs
= oldType
.getNumInputs();
313 SmallVector
<Type
, 4> newInputs
;
314 newInputs
.reserve(oldNumArgs
);
315 for (unsigned i
= 0; i
!= oldNumArgs
; ++i
)
316 if (!mapper
.contains(getArgument(i
)))
317 newInputs
.push_back(oldType
.getInput(i
));
319 /// If any of the arguments were dropped, update the type and drop any
320 /// necessary argument attributes.
321 if (newInputs
.size() != oldNumArgs
) {
322 newFunc
.setType(FunctionType::get(oldType
.getContext(), newInputs
,
323 oldType
.getResults()));
325 if (ArrayAttr argAttrs
= getAllArgAttrs()) {
326 SmallVector
<Attribute
> newArgAttrs
;
327 newArgAttrs
.reserve(newInputs
.size());
328 for (unsigned i
= 0; i
!= oldNumArgs
; ++i
)
329 if (!mapper
.contains(getArgument(i
)))
330 newArgAttrs
.push_back(argAttrs
[i
]);
331 newFunc
.setAllArgAttrs(newArgAttrs
);
336 /// Clone the current function into the new one and return it.
337 cloneInto(newFunc
, mapper
);
340 FuncOp
FuncOp::clone() {
341 BlockAndValueMapping mapper
;
342 return clone(mapper
);
345 //===----------------------------------------------------------------------===//
347 //===----------------------------------------------------------------------===//
349 LogicalResult
ReturnOp::verify() {
350 auto function
= cast
<FuncOp
>((*this)->getParentOp());
352 // The operand number and types must match the function signature.
353 const auto &results
= function
.getFunctionType().getResults();
354 if (getNumOperands() != results
.size())
355 return emitOpError("has ")
356 << getNumOperands() << " operands, but enclosing function (@"
357 << function
.getName() << ") returns " << results
.size();
359 for (unsigned i
= 0, e
= results
.size(); i
!= e
; ++i
)
360 if (getOperand(i
).getType() != results
[i
])
361 return emitError() << "type of return operand " << i
<< " ("
362 << getOperand(i
).getType()
363 << ") doesn't match function result type ("
365 << " in function @" << function
.getName();
370 //===----------------------------------------------------------------------===//
371 // TableGen'd op method definitions
372 //===----------------------------------------------------------------------===//
374 #define GET_OP_CLASSES
375 #include "mlir/Dialect/Func/IR/FuncOps.cpp.inc"