1 //===- FunctionImplementation.cpp - Utilities for function-like ops -------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Interfaces/FunctionImplementation.h"
10 #include "mlir/IR/Builders.h"
11 #include "mlir/IR/SymbolTable.h"
12 #include "mlir/Interfaces/FunctionInterfaces.h"
17 parseFunctionArgumentList(OpAsmParser
&parser
, bool allowVariadic
,
18 SmallVectorImpl
<OpAsmParser::Argument
> &arguments
,
21 // Parse the function arguments. The argument list either has to consistently
22 // have ssa-id's followed by types, or just be a type list. It isn't ok to
23 // sometimes have SSA ID's and sometimes not.
26 return parser
.parseCommaSeparatedList(
27 OpAsmParser::Delimiter::Paren
, [&]() -> ParseResult
{
28 // Ellipsis must be at end of the list.
30 return parser
.emitError(
31 parser
.getCurrentLocation(),
32 "variadic arguments must be in the end of the argument list");
34 // Handle ellipsis as a special case.
35 if (allowVariadic
&& succeeded(parser
.parseOptionalEllipsis())) {
36 // This is a variadic designator.
38 return success(); // Stop parsing arguments.
40 // Parse argument name if present.
41 OpAsmParser::Argument argument
;
42 auto argPresent
= parser
.parseOptionalArgument(
43 argument
, /*allowType=*/true, /*allowAttrs=*/true);
44 if (argPresent
.has_value()) {
45 if (failed(argPresent
.value()))
46 return failure(); // Present but malformed.
48 // Reject this if the preceding argument was missing a name.
49 if (!arguments
.empty() && arguments
.back().ssaName
.name
.empty())
50 return parser
.emitError(argument
.ssaName
.location
,
51 "expected type instead of SSA identifier");
54 argument
.ssaName
.location
= parser
.getCurrentLocation();
55 // Otherwise we just have a type list without SSA names. Reject
56 // this if the preceding argument had a name.
57 if (!arguments
.empty() && !arguments
.back().ssaName
.name
.empty())
58 return parser
.emitError(argument
.ssaName
.location
,
59 "expected SSA identifier");
62 if (parser
.parseType(argument
.type
) ||
63 parser
.parseOptionalAttrDict(attrs
) ||
64 parser
.parseOptionalLocationSpecifier(argument
.sourceLoc
))
66 argument
.attrs
= attrs
.getDictionary(parser
.getContext());
68 arguments
.push_back(argument
);
73 /// Parse a function result list.
75 /// function-result-list ::= function-result-list-parens
76 /// | non-function-type
77 /// function-result-list-parens ::= `(` `)`
78 /// | `(` function-result-list-no-parens `)`
79 /// function-result-list-no-parens ::= function-result (`,` function-result)*
80 /// function-result ::= type attribute-dict?
83 parseFunctionResultList(OpAsmParser
&parser
, SmallVectorImpl
<Type
> &resultTypes
,
84 SmallVectorImpl
<DictionaryAttr
> &resultAttrs
) {
85 if (failed(parser
.parseOptionalLParen())) {
86 // We already know that there is no `(`, so parse a type.
87 // Because there is no `(`, it cannot be a function type.
89 if (parser
.parseType(ty
))
91 resultTypes
.push_back(ty
);
92 resultAttrs
.emplace_back();
96 // Special case for an empty set of parens.
97 if (succeeded(parser
.parseOptionalRParen()))
100 // Parse individual function results.
101 if (parser
.parseCommaSeparatedList([&]() -> ParseResult
{
102 resultTypes
.emplace_back();
103 resultAttrs
.emplace_back();
105 if (parser
.parseType(resultTypes
.back()) ||
106 parser
.parseOptionalAttrDict(attrs
))
108 resultAttrs
.back() = attrs
.getDictionary(parser
.getContext());
113 return parser
.parseRParen();
116 ParseResult
function_interface_impl::parseFunctionSignature(
117 OpAsmParser
&parser
, bool allowVariadic
,
118 SmallVectorImpl
<OpAsmParser::Argument
> &arguments
, bool &isVariadic
,
119 SmallVectorImpl
<Type
> &resultTypes
,
120 SmallVectorImpl
<DictionaryAttr
> &resultAttrs
) {
121 if (parseFunctionArgumentList(parser
, allowVariadic
, arguments
, isVariadic
))
123 if (succeeded(parser
.parseOptionalArrow()))
124 return parseFunctionResultList(parser
, resultTypes
, resultAttrs
);
128 void function_interface_impl::addArgAndResultAttrs(
129 Builder
&builder
, OperationState
&result
, ArrayRef
<DictionaryAttr
> argAttrs
,
130 ArrayRef
<DictionaryAttr
> resultAttrs
, StringAttr argAttrsName
,
131 StringAttr resAttrsName
) {
132 auto nonEmptyAttrsFn
= [](DictionaryAttr attrs
) {
133 return attrs
&& !attrs
.empty();
135 // Convert the specified array of dictionary attrs (which may have null
136 // entries) to an ArrayAttr of dictionaries.
137 auto getArrayAttr
= [&](ArrayRef
<DictionaryAttr
> dictAttrs
) {
138 SmallVector
<Attribute
> attrs
;
139 for (auto &dict
: dictAttrs
)
140 attrs
.push_back(dict
? dict
: builder
.getDictionaryAttr({}));
141 return builder
.getArrayAttr(attrs
);
144 // Add the attributes to the function arguments.
145 if (llvm::any_of(argAttrs
, nonEmptyAttrsFn
))
146 result
.addAttribute(argAttrsName
, getArrayAttr(argAttrs
));
148 // Add the attributes to the function results.
149 if (llvm::any_of(resultAttrs
, nonEmptyAttrsFn
))
150 result
.addAttribute(resAttrsName
, getArrayAttr(resultAttrs
));
153 void function_interface_impl::addArgAndResultAttrs(
154 Builder
&builder
, OperationState
&result
,
155 ArrayRef
<OpAsmParser::Argument
> args
, ArrayRef
<DictionaryAttr
> resultAttrs
,
156 StringAttr argAttrsName
, StringAttr resAttrsName
) {
157 SmallVector
<DictionaryAttr
> argAttrs
;
158 for (const auto &arg
: args
)
159 argAttrs
.push_back(arg
.attrs
);
160 addArgAndResultAttrs(builder
, result
, argAttrs
, resultAttrs
, argAttrsName
,
164 ParseResult
function_interface_impl::parseFunctionOp(
165 OpAsmParser
&parser
, OperationState
&result
, bool allowVariadic
,
166 StringAttr typeAttrName
, FuncTypeBuilder funcTypeBuilder
,
167 StringAttr argAttrsName
, StringAttr resAttrsName
) {
168 SmallVector
<OpAsmParser::Argument
> entryArgs
;
169 SmallVector
<DictionaryAttr
> resultAttrs
;
170 SmallVector
<Type
> resultTypes
;
171 auto &builder
= parser
.getBuilder();
174 (void)impl::parseOptionalVisibilityKeyword(parser
, result
.attributes
);
176 // Parse the name as a symbol.
178 if (parser
.parseSymbolName(nameAttr
, SymbolTable::getSymbolAttrName(),
182 // Parse the function signature.
183 SMLoc signatureLocation
= parser
.getCurrentLocation();
184 bool isVariadic
= false;
185 if (parseFunctionSignature(parser
, allowVariadic
, entryArgs
, isVariadic
,
186 resultTypes
, resultAttrs
))
189 std::string errorMessage
;
190 SmallVector
<Type
> argTypes
;
191 argTypes
.reserve(entryArgs
.size());
192 for (auto &arg
: entryArgs
)
193 argTypes
.push_back(arg
.type
);
194 Type type
= funcTypeBuilder(builder
, argTypes
, resultTypes
,
195 VariadicFlag(isVariadic
), errorMessage
);
197 return parser
.emitError(signatureLocation
)
198 << "failed to construct function type"
199 << (errorMessage
.empty() ? "" : ": ") << errorMessage
;
201 result
.addAttribute(typeAttrName
, TypeAttr::get(type
));
203 // If function attributes are present, parse them.
204 NamedAttrList parsedAttributes
;
205 SMLoc attributeDictLocation
= parser
.getCurrentLocation();
206 if (parser
.parseOptionalAttrDictWithKeyword(parsedAttributes
))
209 // Disallow attributes that are inferred from elsewhere in the attribute
211 for (StringRef disallowed
:
212 {SymbolTable::getVisibilityAttrName(), SymbolTable::getSymbolAttrName(),
213 typeAttrName
.getValue()}) {
214 if (parsedAttributes
.get(disallowed
))
215 return parser
.emitError(attributeDictLocation
, "'")
217 << "' is an inferred attribute and should not be specified in the "
218 "explicit attribute dictionary";
220 result
.attributes
.append(parsedAttributes
);
222 // Add the attributes to the function arguments.
223 assert(resultAttrs
.size() == resultTypes
.size());
224 addArgAndResultAttrs(builder
, result
, entryArgs
, resultAttrs
, argAttrsName
,
227 // Parse the optional function body. The printer will not print the body if
228 // its empty, so disallow parsing of empty body in the parser.
229 auto *body
= result
.addRegion();
230 SMLoc loc
= parser
.getCurrentLocation();
231 OptionalParseResult parseResult
=
232 parser
.parseOptionalRegion(*body
, entryArgs
,
233 /*enableNameShadowing=*/false);
234 if (parseResult
.has_value()) {
235 if (failed(*parseResult
))
237 // Function body was parsed, make sure its not empty.
239 return parser
.emitError(loc
, "expected non-empty function body");
244 /// Print a function result list. The provided `attrs` must either be null, or
245 /// contain a set of DictionaryAttrs of the same arity as `types`.
246 static void printFunctionResultList(OpAsmPrinter
&p
, ArrayRef
<Type
> types
,
248 assert(!types
.empty() && "Should not be called for empty result list.");
249 assert((!attrs
|| attrs
.size() == types
.size()) &&
250 "Invalid number of attributes.");
252 auto &os
= p
.getStream();
253 bool needsParens
= types
.size() > 1 || llvm::isa
<FunctionType
>(types
[0]) ||
254 (attrs
&& !llvm::cast
<DictionaryAttr
>(attrs
[0]).empty());
257 llvm::interleaveComma(llvm::seq
<size_t>(0, types
.size()), os
, [&](size_t i
) {
258 p
.printType(types
[i
]);
260 p
.printOptionalAttrDict(llvm::cast
<DictionaryAttr
>(attrs
[i
]).getValue());
266 void function_interface_impl::printFunctionSignature(
267 OpAsmPrinter
&p
, FunctionOpInterface op
, ArrayRef
<Type
> argTypes
,
268 bool isVariadic
, ArrayRef
<Type
> resultTypes
) {
269 Region
&body
= op
->getRegion(0);
270 bool isExternal
= body
.empty();
273 ArrayAttr argAttrs
= op
.getArgAttrsAttr();
274 for (unsigned i
= 0, e
= argTypes
.size(); i
< e
; ++i
) {
279 ArrayRef
<NamedAttribute
> attrs
;
281 attrs
= llvm::cast
<DictionaryAttr
>(argAttrs
[i
]).getValue();
282 p
.printRegionArgument(body
.getArgument(i
), attrs
);
284 p
.printType(argTypes
[i
]);
286 p
.printOptionalAttrDict(
287 llvm::cast
<DictionaryAttr
>(argAttrs
[i
]).getValue());
292 if (!argTypes
.empty())
299 if (!resultTypes
.empty()) {
300 p
.getStream() << " -> ";
301 auto resultAttrs
= op
.getResAttrsAttr();
302 printFunctionResultList(p
, resultTypes
, resultAttrs
);
306 void function_interface_impl::printFunctionAttributes(
307 OpAsmPrinter
&p
, Operation
*op
, ArrayRef
<StringRef
> elided
) {
308 // Print out function attributes, if present.
309 SmallVector
<StringRef
, 8> ignoredAttrs
= {SymbolTable::getSymbolAttrName()};
310 ignoredAttrs
.append(elided
.begin(), elided
.end());
312 p
.printOptionalAttrDictWithKeyword(op
->getAttrs(), ignoredAttrs
);
315 void function_interface_impl::printFunctionOp(
316 OpAsmPrinter
&p
, FunctionOpInterface op
, bool isVariadic
,
317 StringRef typeAttrName
, StringAttr argAttrsName
, StringAttr resAttrsName
) {
318 // Print the operation and the function name.
320 op
->getAttrOfType
<StringAttr
>(SymbolTable::getSymbolAttrName())
324 StringRef visibilityAttrName
= SymbolTable::getVisibilityAttrName();
325 if (auto visibility
= op
->getAttrOfType
<StringAttr
>(visibilityAttrName
))
326 p
<< visibility
.getValue() << ' ';
327 p
.printSymbolName(funcName
);
329 ArrayRef
<Type
> argTypes
= op
.getArgumentTypes();
330 ArrayRef
<Type
> resultTypes
= op
.getResultTypes();
331 printFunctionSignature(p
, op
, argTypes
, isVariadic
, resultTypes
);
332 printFunctionAttributes(
333 p
, op
, {visibilityAttrName
, typeAttrName
, argAttrsName
, resAttrsName
});
334 // Print the body if this is not an external function.
335 Region
&body
= op
->getRegion(0);
338 p
.printRegion(body
, /*printEntryBlockArgs=*/false,
339 /*printBlockTerminators=*/true);