[clang][modules] Don't prevent translation of FW_Private includes when explicitly...
[llvm-project.git] / mlir / lib / Interfaces / FunctionInterfaces.cpp
blob8b6f7110c2cf0afa13d7ad6fb76dca557cb4ae53
1 //===- FunctionSupport.cpp - Utility types for function-like ops ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Interfaces/FunctionInterfaces.h"
11 using namespace mlir;
13 //===----------------------------------------------------------------------===//
14 // Tablegen Interface Definitions
15 //===----------------------------------------------------------------------===//
17 #include "mlir/Interfaces/FunctionInterfaces.cpp.inc"
19 //===----------------------------------------------------------------------===//
20 // Function Arguments and Results.
21 //===----------------------------------------------------------------------===//
23 static bool isEmptyAttrDict(Attribute attr) {
24 return llvm::cast<DictionaryAttr>(attr).empty();
27 DictionaryAttr function_interface_impl::getArgAttrDict(FunctionOpInterface op,
28 unsigned index) {
29 ArrayAttr attrs = op.getArgAttrsAttr();
30 DictionaryAttr argAttrs =
31 attrs ? llvm::cast<DictionaryAttr>(attrs[index]) : DictionaryAttr();
32 return argAttrs;
35 DictionaryAttr
36 function_interface_impl::getResultAttrDict(FunctionOpInterface op,
37 unsigned index) {
38 ArrayAttr attrs = op.getResAttrsAttr();
39 DictionaryAttr resAttrs =
40 attrs ? llvm::cast<DictionaryAttr>(attrs[index]) : DictionaryAttr();
41 return resAttrs;
44 ArrayRef<NamedAttribute>
45 function_interface_impl::getArgAttrs(FunctionOpInterface op, unsigned index) {
46 auto argDict = getArgAttrDict(op, index);
47 return argDict ? argDict.getValue() : std::nullopt;
50 ArrayRef<NamedAttribute>
51 function_interface_impl::getResultAttrs(FunctionOpInterface op,
52 unsigned index) {
53 auto resultDict = getResultAttrDict(op, index);
54 return resultDict ? resultDict.getValue() : std::nullopt;
57 /// Get either the argument or result attributes array.
58 template <bool isArg>
59 static ArrayAttr getArgResAttrs(FunctionOpInterface op) {
60 if constexpr (isArg)
61 return op.getArgAttrsAttr();
62 else
63 return op.getResAttrsAttr();
66 /// Set either the argument or result attributes array.
67 template <bool isArg>
68 static void setArgResAttrs(FunctionOpInterface op, ArrayAttr attrs) {
69 if constexpr (isArg)
70 op.setArgAttrsAttr(attrs);
71 else
72 op.setResAttrsAttr(attrs);
75 /// Erase either the argument or result attributes array.
76 template <bool isArg>
77 static void removeArgResAttrs(FunctionOpInterface op) {
78 if constexpr (isArg)
79 op.removeArgAttrsAttr();
80 else
81 op.removeResAttrsAttr();
84 /// Set all of the argument or result attribute dictionaries for a function.
85 template <bool isArg>
86 static void setAllArgResAttrDicts(FunctionOpInterface op,
87 ArrayRef<Attribute> attrs) {
88 if (llvm::all_of(attrs, isEmptyAttrDict))
89 removeArgResAttrs<isArg>(op);
90 else
91 setArgResAttrs<isArg>(op, ArrayAttr::get(op->getContext(), attrs));
94 void function_interface_impl::setAllArgAttrDicts(
95 FunctionOpInterface op, ArrayRef<DictionaryAttr> attrs) {
96 setAllArgAttrDicts(op, ArrayRef<Attribute>(attrs.data(), attrs.size()));
99 void function_interface_impl::setAllArgAttrDicts(FunctionOpInterface op,
100 ArrayRef<Attribute> attrs) {
101 auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute {
102 return !attr ? DictionaryAttr::get(op->getContext()) : attr;
104 setAllArgResAttrDicts</*isArg=*/true>(op, llvm::to_vector<8>(wrappedAttrs));
107 void function_interface_impl::setAllResultAttrDicts(
108 FunctionOpInterface op, ArrayRef<DictionaryAttr> attrs) {
109 setAllResultAttrDicts(op, ArrayRef<Attribute>(attrs.data(), attrs.size()));
112 void function_interface_impl::setAllResultAttrDicts(FunctionOpInterface op,
113 ArrayRef<Attribute> attrs) {
114 auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute {
115 return !attr ? DictionaryAttr::get(op->getContext()) : attr;
117 setAllArgResAttrDicts</*isArg=*/false>(op, llvm::to_vector<8>(wrappedAttrs));
120 /// Update the given index into an argument or result attribute dictionary.
121 template <bool isArg>
122 static void setArgResAttrDict(FunctionOpInterface op, unsigned numTotalIndices,
123 unsigned index, DictionaryAttr attrs) {
124 ArrayAttr allAttrs = getArgResAttrs<isArg>(op);
125 if (!allAttrs) {
126 if (attrs.empty())
127 return;
129 // If this attribute is not empty, we need to create a new attribute array.
130 SmallVector<Attribute, 8> newAttrs(numTotalIndices,
131 DictionaryAttr::get(op->getContext()));
132 newAttrs[index] = attrs;
133 setArgResAttrs<isArg>(op, ArrayAttr::get(op->getContext(), newAttrs));
134 return;
136 // Check to see if the attribute is different from what we already have.
137 if (allAttrs[index] == attrs)
138 return;
140 // If it is, check to see if the attribute array would now contain only empty
141 // dictionaries.
142 ArrayRef<Attribute> rawAttrArray = allAttrs.getValue();
143 if (attrs.empty() &&
144 llvm::all_of(rawAttrArray.take_front(index), isEmptyAttrDict) &&
145 llvm::all_of(rawAttrArray.drop_front(index + 1), isEmptyAttrDict))
146 return removeArgResAttrs<isArg>(op);
148 // Otherwise, create a new attribute array with the updated dictionary.
149 SmallVector<Attribute, 8> newAttrs(rawAttrArray.begin(), rawAttrArray.end());
150 newAttrs[index] = attrs;
151 setArgResAttrs<isArg>(op, ArrayAttr::get(op->getContext(), newAttrs));
154 void function_interface_impl::setArgAttrs(FunctionOpInterface op,
155 unsigned index,
156 ArrayRef<NamedAttribute> attributes) {
157 assert(index < op.getNumArguments() && "invalid argument number");
158 return setArgResAttrDict</*isArg=*/true>(
159 op, op.getNumArguments(), index,
160 DictionaryAttr::get(op->getContext(), attributes));
163 void function_interface_impl::setArgAttrs(FunctionOpInterface op,
164 unsigned index,
165 DictionaryAttr attributes) {
166 return setArgResAttrDict</*isArg=*/true>(
167 op, op.getNumArguments(), index,
168 attributes ? attributes : DictionaryAttr::get(op->getContext()));
171 void function_interface_impl::setResultAttrs(
172 FunctionOpInterface op, unsigned index,
173 ArrayRef<NamedAttribute> attributes) {
174 assert(index < op.getNumResults() && "invalid result number");
175 return setArgResAttrDict</*isArg=*/false>(
176 op, op.getNumResults(), index,
177 DictionaryAttr::get(op->getContext(), attributes));
180 void function_interface_impl::setResultAttrs(FunctionOpInterface op,
181 unsigned index,
182 DictionaryAttr attributes) {
183 assert(index < op.getNumResults() && "invalid result number");
184 return setArgResAttrDict</*isArg=*/false>(
185 op, op.getNumResults(), index,
186 attributes ? attributes : DictionaryAttr::get(op->getContext()));
189 void function_interface_impl::insertFunctionArguments(
190 FunctionOpInterface op, ArrayRef<unsigned> argIndices, TypeRange argTypes,
191 ArrayRef<DictionaryAttr> argAttrs, ArrayRef<Location> argLocs,
192 unsigned originalNumArgs, Type newType) {
193 assert(argIndices.size() == argTypes.size());
194 assert(argIndices.size() == argAttrs.size() || argAttrs.empty());
195 assert(argIndices.size() == argLocs.size());
196 if (argIndices.empty())
197 return;
199 // There are 3 things that need to be updated:
200 // - Function type.
201 // - Arg attrs.
202 // - Block arguments of entry block.
203 Block &entry = op->getRegion(0).front();
205 // Update the argument attributes of the function.
206 ArrayAttr oldArgAttrs = op.getArgAttrsAttr();
207 if (oldArgAttrs || !argAttrs.empty()) {
208 SmallVector<DictionaryAttr, 4> newArgAttrs;
209 newArgAttrs.reserve(originalNumArgs + argIndices.size());
210 unsigned oldIdx = 0;
211 auto migrate = [&](unsigned untilIdx) {
212 if (!oldArgAttrs) {
213 newArgAttrs.resize(newArgAttrs.size() + untilIdx - oldIdx);
214 } else {
215 auto oldArgAttrRange = oldArgAttrs.getAsRange<DictionaryAttr>();
216 newArgAttrs.append(oldArgAttrRange.begin() + oldIdx,
217 oldArgAttrRange.begin() + untilIdx);
219 oldIdx = untilIdx;
221 for (unsigned i = 0, e = argIndices.size(); i < e; ++i) {
222 migrate(argIndices[i]);
223 newArgAttrs.push_back(argAttrs.empty() ? DictionaryAttr{} : argAttrs[i]);
225 migrate(originalNumArgs);
226 setAllArgAttrDicts(op, newArgAttrs);
229 // Update the function type and any entry block arguments.
230 op.setFunctionTypeAttr(TypeAttr::get(newType));
231 for (unsigned i = 0, e = argIndices.size(); i < e; ++i)
232 entry.insertArgument(argIndices[i] + i, argTypes[i], argLocs[i]);
235 void function_interface_impl::insertFunctionResults(
236 FunctionOpInterface op, ArrayRef<unsigned> resultIndices,
237 TypeRange resultTypes, ArrayRef<DictionaryAttr> resultAttrs,
238 unsigned originalNumResults, Type newType) {
239 assert(resultIndices.size() == resultTypes.size());
240 assert(resultIndices.size() == resultAttrs.size() || resultAttrs.empty());
241 if (resultIndices.empty())
242 return;
244 // There are 2 things that need to be updated:
245 // - Function type.
246 // - Result attrs.
248 // Update the result attributes of the function.
249 ArrayAttr oldResultAttrs = op.getResAttrsAttr();
250 if (oldResultAttrs || !resultAttrs.empty()) {
251 SmallVector<DictionaryAttr, 4> newResultAttrs;
252 newResultAttrs.reserve(originalNumResults + resultIndices.size());
253 unsigned oldIdx = 0;
254 auto migrate = [&](unsigned untilIdx) {
255 if (!oldResultAttrs) {
256 newResultAttrs.resize(newResultAttrs.size() + untilIdx - oldIdx);
257 } else {
258 auto oldResultAttrsRange = oldResultAttrs.getAsRange<DictionaryAttr>();
259 newResultAttrs.append(oldResultAttrsRange.begin() + oldIdx,
260 oldResultAttrsRange.begin() + untilIdx);
262 oldIdx = untilIdx;
264 for (unsigned i = 0, e = resultIndices.size(); i < e; ++i) {
265 migrate(resultIndices[i]);
266 newResultAttrs.push_back(resultAttrs.empty() ? DictionaryAttr{}
267 : resultAttrs[i]);
269 migrate(originalNumResults);
270 setAllResultAttrDicts(op, newResultAttrs);
273 // Update the function type.
274 op.setFunctionTypeAttr(TypeAttr::get(newType));
277 void function_interface_impl::eraseFunctionArguments(
278 FunctionOpInterface op, const BitVector &argIndices, Type newType) {
279 // There are 3 things that need to be updated:
280 // - Function type.
281 // - Arg attrs.
282 // - Block arguments of entry block.
283 Block &entry = op->getRegion(0).front();
285 // Update the argument attributes of the function.
286 if (ArrayAttr argAttrs = op.getArgAttrsAttr()) {
287 SmallVector<DictionaryAttr, 4> newArgAttrs;
288 newArgAttrs.reserve(argAttrs.size());
289 for (unsigned i = 0, e = argIndices.size(); i < e; ++i)
290 if (!argIndices[i])
291 newArgAttrs.emplace_back(llvm::cast<DictionaryAttr>(argAttrs[i]));
292 setAllArgAttrDicts(op, newArgAttrs);
295 // Update the function type and any entry block arguments.
296 op.setFunctionTypeAttr(TypeAttr::get(newType));
297 entry.eraseArguments(argIndices);
300 void function_interface_impl::eraseFunctionResults(
301 FunctionOpInterface op, const BitVector &resultIndices, Type newType) {
302 // There are 2 things that need to be updated:
303 // - Function type.
304 // - Result attrs.
306 // Update the result attributes of the function.
307 if (ArrayAttr resAttrs = op.getResAttrsAttr()) {
308 SmallVector<DictionaryAttr, 4> newResultAttrs;
309 newResultAttrs.reserve(resAttrs.size());
310 for (unsigned i = 0, e = resultIndices.size(); i < e; ++i)
311 if (!resultIndices[i])
312 newResultAttrs.emplace_back(llvm::cast<DictionaryAttr>(resAttrs[i]));
313 setAllResultAttrDicts(op, newResultAttrs);
316 // Update the function type.
317 op.setFunctionTypeAttr(TypeAttr::get(newType));
320 //===----------------------------------------------------------------------===//
321 // Function type signature.
322 //===----------------------------------------------------------------------===//
324 void function_interface_impl::setFunctionType(FunctionOpInterface op,
325 Type newType) {
326 unsigned oldNumArgs = op.getNumArguments();
327 unsigned oldNumResults = op.getNumResults();
328 op.setFunctionTypeAttr(TypeAttr::get(newType));
329 unsigned newNumArgs = op.getNumArguments();
330 unsigned newNumResults = op.getNumResults();
332 // Functor used to update the argument and result attributes of the function.
333 auto emptyDict = DictionaryAttr::get(op.getContext());
334 auto updateAttrFn = [&](auto isArg, unsigned oldCount, unsigned newCount) {
335 constexpr bool isArgVal = std::is_same_v<decltype(isArg), std::true_type>;
337 if (oldCount == newCount)
338 return;
339 // The new type has no arguments/results, just drop the attribute.
340 if (newCount == 0)
341 return removeArgResAttrs<isArgVal>(op);
342 ArrayAttr attrs = getArgResAttrs<isArgVal>(op);
343 if (!attrs)
344 return;
346 // The new type has less arguments/results, take the first N attributes.
347 if (newCount < oldCount)
348 return setAllArgResAttrDicts<isArgVal>(
349 op, attrs.getValue().take_front(newCount));
351 // Otherwise, the new type has more arguments/results. Initialize the new
352 // arguments/results with empty dictionary attributes.
353 SmallVector<Attribute> newAttrs(attrs.begin(), attrs.end());
354 newAttrs.resize(newCount, emptyDict);
355 setAllArgResAttrDicts<isArgVal>(op, newAttrs);
358 // Update the argument and result attributes.
359 updateAttrFn(std::true_type{}, oldNumArgs, newNumArgs);
360 updateAttrFn(std::false_type{}, oldNumResults, newNumResults);