1 //===- FunctionSupport.cpp - Utility types for function-like ops ----------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Interfaces/FunctionInterfaces.h"
13 //===----------------------------------------------------------------------===//
14 // Tablegen Interface Definitions
15 //===----------------------------------------------------------------------===//
17 #include "mlir/Interfaces/FunctionInterfaces.cpp.inc"
19 //===----------------------------------------------------------------------===//
20 // Function Arguments and Results.
21 //===----------------------------------------------------------------------===//
23 static bool isEmptyAttrDict(Attribute attr
) {
24 return llvm::cast
<DictionaryAttr
>(attr
).empty();
27 DictionaryAttr
function_interface_impl::getArgAttrDict(FunctionOpInterface op
,
29 ArrayAttr attrs
= op
.getArgAttrsAttr();
30 DictionaryAttr argAttrs
=
31 attrs
? llvm::cast
<DictionaryAttr
>(attrs
[index
]) : DictionaryAttr();
36 function_interface_impl::getResultAttrDict(FunctionOpInterface op
,
38 ArrayAttr attrs
= op
.getResAttrsAttr();
39 DictionaryAttr resAttrs
=
40 attrs
? llvm::cast
<DictionaryAttr
>(attrs
[index
]) : DictionaryAttr();
44 ArrayRef
<NamedAttribute
>
45 function_interface_impl::getArgAttrs(FunctionOpInterface op
, unsigned index
) {
46 auto argDict
= getArgAttrDict(op
, index
);
47 return argDict
? argDict
.getValue() : std::nullopt
;
50 ArrayRef
<NamedAttribute
>
51 function_interface_impl::getResultAttrs(FunctionOpInterface op
,
53 auto resultDict
= getResultAttrDict(op
, index
);
54 return resultDict
? resultDict
.getValue() : std::nullopt
;
57 /// Get either the argument or result attributes array.
59 static ArrayAttr
getArgResAttrs(FunctionOpInterface op
) {
61 return op
.getArgAttrsAttr();
63 return op
.getResAttrsAttr();
66 /// Set either the argument or result attributes array.
68 static void setArgResAttrs(FunctionOpInterface op
, ArrayAttr attrs
) {
70 op
.setArgAttrsAttr(attrs
);
72 op
.setResAttrsAttr(attrs
);
75 /// Erase either the argument or result attributes array.
77 static void removeArgResAttrs(FunctionOpInterface op
) {
79 op
.removeArgAttrsAttr();
81 op
.removeResAttrsAttr();
84 /// Set all of the argument or result attribute dictionaries for a function.
86 static void setAllArgResAttrDicts(FunctionOpInterface op
,
87 ArrayRef
<Attribute
> attrs
) {
88 if (llvm::all_of(attrs
, isEmptyAttrDict
))
89 removeArgResAttrs
<isArg
>(op
);
91 setArgResAttrs
<isArg
>(op
, ArrayAttr::get(op
->getContext(), attrs
));
94 void function_interface_impl::setAllArgAttrDicts(
95 FunctionOpInterface op
, ArrayRef
<DictionaryAttr
> attrs
) {
96 setAllArgAttrDicts(op
, ArrayRef
<Attribute
>(attrs
.data(), attrs
.size()));
99 void function_interface_impl::setAllArgAttrDicts(FunctionOpInterface op
,
100 ArrayRef
<Attribute
> attrs
) {
101 auto wrappedAttrs
= llvm::map_range(attrs
, [op
](Attribute attr
) -> Attribute
{
102 return !attr
? DictionaryAttr::get(op
->getContext()) : attr
;
104 setAllArgResAttrDicts
</*isArg=*/true>(op
, llvm::to_vector
<8>(wrappedAttrs
));
107 void function_interface_impl::setAllResultAttrDicts(
108 FunctionOpInterface op
, ArrayRef
<DictionaryAttr
> attrs
) {
109 setAllResultAttrDicts(op
, ArrayRef
<Attribute
>(attrs
.data(), attrs
.size()));
112 void function_interface_impl::setAllResultAttrDicts(FunctionOpInterface op
,
113 ArrayRef
<Attribute
> attrs
) {
114 auto wrappedAttrs
= llvm::map_range(attrs
, [op
](Attribute attr
) -> Attribute
{
115 return !attr
? DictionaryAttr::get(op
->getContext()) : attr
;
117 setAllArgResAttrDicts
</*isArg=*/false>(op
, llvm::to_vector
<8>(wrappedAttrs
));
120 /// Update the given index into an argument or result attribute dictionary.
121 template <bool isArg
>
122 static void setArgResAttrDict(FunctionOpInterface op
, unsigned numTotalIndices
,
123 unsigned index
, DictionaryAttr attrs
) {
124 ArrayAttr allAttrs
= getArgResAttrs
<isArg
>(op
);
129 // If this attribute is not empty, we need to create a new attribute array.
130 SmallVector
<Attribute
, 8> newAttrs(numTotalIndices
,
131 DictionaryAttr::get(op
->getContext()));
132 newAttrs
[index
] = attrs
;
133 setArgResAttrs
<isArg
>(op
, ArrayAttr::get(op
->getContext(), newAttrs
));
136 // Check to see if the attribute is different from what we already have.
137 if (allAttrs
[index
] == attrs
)
140 // If it is, check to see if the attribute array would now contain only empty
142 ArrayRef
<Attribute
> rawAttrArray
= allAttrs
.getValue();
144 llvm::all_of(rawAttrArray
.take_front(index
), isEmptyAttrDict
) &&
145 llvm::all_of(rawAttrArray
.drop_front(index
+ 1), isEmptyAttrDict
))
146 return removeArgResAttrs
<isArg
>(op
);
148 // Otherwise, create a new attribute array with the updated dictionary.
149 SmallVector
<Attribute
, 8> newAttrs(rawAttrArray
);
150 newAttrs
[index
] = attrs
;
151 setArgResAttrs
<isArg
>(op
, ArrayAttr::get(op
->getContext(), newAttrs
));
154 void function_interface_impl::setArgAttrs(FunctionOpInterface op
,
156 ArrayRef
<NamedAttribute
> attributes
) {
157 assert(index
< op
.getNumArguments() && "invalid argument number");
158 return setArgResAttrDict
</*isArg=*/true>(
159 op
, op
.getNumArguments(), index
,
160 DictionaryAttr::get(op
->getContext(), attributes
));
163 void function_interface_impl::setArgAttrs(FunctionOpInterface op
,
165 DictionaryAttr attributes
) {
166 return setArgResAttrDict
</*isArg=*/true>(
167 op
, op
.getNumArguments(), index
,
168 attributes
? attributes
: DictionaryAttr::get(op
->getContext()));
171 void function_interface_impl::setResultAttrs(
172 FunctionOpInterface op
, unsigned index
,
173 ArrayRef
<NamedAttribute
> attributes
) {
174 assert(index
< op
.getNumResults() && "invalid result number");
175 return setArgResAttrDict
</*isArg=*/false>(
176 op
, op
.getNumResults(), index
,
177 DictionaryAttr::get(op
->getContext(), attributes
));
180 void function_interface_impl::setResultAttrs(FunctionOpInterface op
,
182 DictionaryAttr attributes
) {
183 assert(index
< op
.getNumResults() && "invalid result number");
184 return setArgResAttrDict
</*isArg=*/false>(
185 op
, op
.getNumResults(), index
,
186 attributes
? attributes
: DictionaryAttr::get(op
->getContext()));
189 void function_interface_impl::insertFunctionArguments(
190 FunctionOpInterface op
, ArrayRef
<unsigned> argIndices
, TypeRange argTypes
,
191 ArrayRef
<DictionaryAttr
> argAttrs
, ArrayRef
<Location
> argLocs
,
192 unsigned originalNumArgs
, Type newType
) {
193 assert(argIndices
.size() == argTypes
.size());
194 assert(argIndices
.size() == argAttrs
.size() || argAttrs
.empty());
195 assert(argIndices
.size() == argLocs
.size());
196 if (argIndices
.empty())
199 // There are 3 things that need to be updated:
202 // - Block arguments of entry block.
203 Block
&entry
= op
->getRegion(0).front();
205 // Update the argument attributes of the function.
206 ArrayAttr oldArgAttrs
= op
.getArgAttrsAttr();
207 if (oldArgAttrs
|| !argAttrs
.empty()) {
208 SmallVector
<DictionaryAttr
, 4> newArgAttrs
;
209 newArgAttrs
.reserve(originalNumArgs
+ argIndices
.size());
211 auto migrate
= [&](unsigned untilIdx
) {
213 newArgAttrs
.resize(newArgAttrs
.size() + untilIdx
- oldIdx
);
215 auto oldArgAttrRange
= oldArgAttrs
.getAsRange
<DictionaryAttr
>();
216 newArgAttrs
.append(oldArgAttrRange
.begin() + oldIdx
,
217 oldArgAttrRange
.begin() + untilIdx
);
221 for (unsigned i
= 0, e
= argIndices
.size(); i
< e
; ++i
) {
222 migrate(argIndices
[i
]);
223 newArgAttrs
.push_back(argAttrs
.empty() ? DictionaryAttr
{} : argAttrs
[i
]);
225 migrate(originalNumArgs
);
226 setAllArgAttrDicts(op
, newArgAttrs
);
229 // Update the function type and any entry block arguments.
230 op
.setFunctionTypeAttr(TypeAttr::get(newType
));
231 for (unsigned i
= 0, e
= argIndices
.size(); i
< e
; ++i
)
232 entry
.insertArgument(argIndices
[i
] + i
, argTypes
[i
], argLocs
[i
]);
235 void function_interface_impl::insertFunctionResults(
236 FunctionOpInterface op
, ArrayRef
<unsigned> resultIndices
,
237 TypeRange resultTypes
, ArrayRef
<DictionaryAttr
> resultAttrs
,
238 unsigned originalNumResults
, Type newType
) {
239 assert(resultIndices
.size() == resultTypes
.size());
240 assert(resultIndices
.size() == resultAttrs
.size() || resultAttrs
.empty());
241 if (resultIndices
.empty())
244 // There are 2 things that need to be updated:
248 // Update the result attributes of the function.
249 ArrayAttr oldResultAttrs
= op
.getResAttrsAttr();
250 if (oldResultAttrs
|| !resultAttrs
.empty()) {
251 SmallVector
<DictionaryAttr
, 4> newResultAttrs
;
252 newResultAttrs
.reserve(originalNumResults
+ resultIndices
.size());
254 auto migrate
= [&](unsigned untilIdx
) {
255 if (!oldResultAttrs
) {
256 newResultAttrs
.resize(newResultAttrs
.size() + untilIdx
- oldIdx
);
258 auto oldResultAttrsRange
= oldResultAttrs
.getAsRange
<DictionaryAttr
>();
259 newResultAttrs
.append(oldResultAttrsRange
.begin() + oldIdx
,
260 oldResultAttrsRange
.begin() + untilIdx
);
264 for (unsigned i
= 0, e
= resultIndices
.size(); i
< e
; ++i
) {
265 migrate(resultIndices
[i
]);
266 newResultAttrs
.push_back(resultAttrs
.empty() ? DictionaryAttr
{}
269 migrate(originalNumResults
);
270 setAllResultAttrDicts(op
, newResultAttrs
);
273 // Update the function type.
274 op
.setFunctionTypeAttr(TypeAttr::get(newType
));
277 void function_interface_impl::eraseFunctionArguments(
278 FunctionOpInterface op
, const BitVector
&argIndices
, Type newType
) {
279 // There are 3 things that need to be updated:
282 // - Block arguments of entry block.
283 Block
&entry
= op
->getRegion(0).front();
285 // Update the argument attributes of the function.
286 if (ArrayAttr argAttrs
= op
.getArgAttrsAttr()) {
287 SmallVector
<DictionaryAttr
, 4> newArgAttrs
;
288 newArgAttrs
.reserve(argAttrs
.size());
289 for (unsigned i
= 0, e
= argIndices
.size(); i
< e
; ++i
)
291 newArgAttrs
.emplace_back(llvm::cast
<DictionaryAttr
>(argAttrs
[i
]));
292 setAllArgAttrDicts(op
, newArgAttrs
);
295 // Update the function type and any entry block arguments.
296 op
.setFunctionTypeAttr(TypeAttr::get(newType
));
297 entry
.eraseArguments(argIndices
);
300 void function_interface_impl::eraseFunctionResults(
301 FunctionOpInterface op
, const BitVector
&resultIndices
, Type newType
) {
302 // There are 2 things that need to be updated:
306 // Update the result attributes of the function.
307 if (ArrayAttr resAttrs
= op
.getResAttrsAttr()) {
308 SmallVector
<DictionaryAttr
, 4> newResultAttrs
;
309 newResultAttrs
.reserve(resAttrs
.size());
310 for (unsigned i
= 0, e
= resultIndices
.size(); i
< e
; ++i
)
311 if (!resultIndices
[i
])
312 newResultAttrs
.emplace_back(llvm::cast
<DictionaryAttr
>(resAttrs
[i
]));
313 setAllResultAttrDicts(op
, newResultAttrs
);
316 // Update the function type.
317 op
.setFunctionTypeAttr(TypeAttr::get(newType
));
320 //===----------------------------------------------------------------------===//
321 // Function type signature.
322 //===----------------------------------------------------------------------===//
324 void function_interface_impl::setFunctionType(FunctionOpInterface op
,
326 unsigned oldNumArgs
= op
.getNumArguments();
327 unsigned oldNumResults
= op
.getNumResults();
328 op
.setFunctionTypeAttr(TypeAttr::get(newType
));
329 unsigned newNumArgs
= op
.getNumArguments();
330 unsigned newNumResults
= op
.getNumResults();
332 // Functor used to update the argument and result attributes of the function.
333 auto emptyDict
= DictionaryAttr::get(op
.getContext());
334 auto updateAttrFn
= [&](auto isArg
, unsigned oldCount
, unsigned newCount
) {
335 constexpr bool isArgVal
= std::is_same_v
<decltype(isArg
), std::true_type
>;
337 if (oldCount
== newCount
)
339 // The new type has no arguments/results, just drop the attribute.
341 return removeArgResAttrs
<isArgVal
>(op
);
342 ArrayAttr attrs
= getArgResAttrs
<isArgVal
>(op
);
346 // The new type has less arguments/results, take the first N attributes.
347 if (newCount
< oldCount
)
348 return setAllArgResAttrDicts
<isArgVal
>(
349 op
, attrs
.getValue().take_front(newCount
));
351 // Otherwise, the new type has more arguments/results. Initialize the new
352 // arguments/results with empty dictionary attributes.
353 SmallVector
<Attribute
> newAttrs(attrs
.begin(), attrs
.end());
354 newAttrs
.resize(newCount
, emptyDict
);
355 setAllArgResAttrDicts
<isArgVal
>(op
, newAttrs
);
358 // Update the argument and result attributes.
359 updateAttrFn(std::true_type
{}, oldNumArgs
, newNumArgs
);
360 updateAttrFn(std::false_type
{}, oldNumResults
, newNumResults
);