1 //===- AttrTypeSubElements.cpp - Attr and Type SubElement Interfaces ------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/IR/Operation.h"
14 //===----------------------------------------------------------------------===//
16 //===----------------------------------------------------------------------===//
18 WalkResult
AttrTypeWalker::walkImpl(Attribute attr
, WalkOrder order
) {
19 return walkImpl(attr
, attrWalkFns
, order
);
21 WalkResult
AttrTypeWalker::walkImpl(Type type
, WalkOrder order
) {
22 return walkImpl(type
, typeWalkFns
, order
);
25 template <typename T
, typename WalkFns
>
26 WalkResult
AttrTypeWalker::walkImpl(T element
, WalkFns
&walkFns
,
28 // Check if we've already walk this element before.
29 auto key
= std::make_pair(element
.getAsOpaquePointer(), (int)order
);
31 visitedAttrTypes
.try_emplace(key
, WalkResult::advance());
35 // If we are walking in post order, walk the sub elements first.
36 if (order
== WalkOrder::PostOrder
) {
37 if (walkSubElements(element
, order
).wasInterrupted())
38 return visitedAttrTypes
[key
] = WalkResult::interrupt();
41 // Walk this element, bailing if skipped or interrupted.
42 for (auto &walkFn
: llvm::reverse(walkFns
)) {
43 WalkResult walkResult
= walkFn(element
);
44 if (walkResult
.wasInterrupted())
45 return visitedAttrTypes
[key
] = WalkResult::interrupt();
46 if (walkResult
.wasSkipped())
47 return WalkResult::advance();
50 // If we are walking in pre-order, walk the sub elements last.
51 if (order
== WalkOrder::PreOrder
) {
52 if (walkSubElements(element
, order
).wasInterrupted())
53 return WalkResult::interrupt();
55 return WalkResult::advance();
59 WalkResult
AttrTypeWalker::walkSubElements(T interface
, WalkOrder order
) {
60 WalkResult result
= WalkResult::advance();
61 auto walkFn
= [&](auto element
) {
62 if (element
&& !result
.wasInterrupted())
63 result
= walkImpl(element
, order
);
65 interface
.walkImmediateSubElements(walkFn
, walkFn
);
66 return result
.wasInterrupted() ? result
: WalkResult::advance();
69 //===----------------------------------------------------------------------===//
70 /// AttrTypeReplacerBase
71 //===----------------------------------------------------------------------===//
73 template <typename Concrete
>
74 void detail::AttrTypeReplacerBase
<Concrete
>::addReplacement(
75 ReplaceFn
<Attribute
> fn
) {
76 attrReplacementFns
.emplace_back(std::move(fn
));
79 template <typename Concrete
>
80 void detail::AttrTypeReplacerBase
<Concrete
>::addReplacement(
82 typeReplacementFns
.push_back(std::move(fn
));
85 template <typename Concrete
>
86 void detail::AttrTypeReplacerBase
<Concrete
>::replaceElementsIn(
87 Operation
*op
, bool replaceAttrs
, bool replaceLocs
, bool replaceTypes
) {
88 // Functor that replaces the given element if the new value is different,
89 // otherwise returns nullptr.
90 auto replaceIfDifferent
= [&](auto element
) {
91 auto replacement
= static_cast<Concrete
*>(this)->replace(element
);
92 return (replacement
&& replacement
!= element
) ? replacement
: nullptr;
95 // Update the attribute dictionary.
97 if (auto newAttrs
= replaceIfDifferent(op
->getAttrDictionary()))
98 op
->setAttrs(cast
<DictionaryAttr
>(newAttrs
));
101 // If we aren't updating locations or types, we're done.
102 if (!replaceTypes
&& !replaceLocs
)
105 // Update the location.
107 if (Attribute newLoc
= replaceIfDifferent(op
->getLoc()))
108 op
->setLoc(cast
<LocationAttr
>(newLoc
));
111 // Update the result types.
113 for (OpResult result
: op
->getResults())
114 if (Type newType
= replaceIfDifferent(result
.getType()))
115 result
.setType(newType
);
118 // Update any nested block arguments.
119 for (Region
®ion
: op
->getRegions()) {
120 for (Block
&block
: region
) {
121 for (BlockArgument
&arg
: block
.getArguments()) {
123 if (Attribute newLoc
= replaceIfDifferent(arg
.getLoc()))
124 arg
.setLoc(cast
<LocationAttr
>(newLoc
));
128 if (Type newType
= replaceIfDifferent(arg
.getType()))
129 arg
.setType(newType
);
136 template <typename Concrete
>
137 void detail::AttrTypeReplacerBase
<Concrete
>::recursivelyReplaceElementsIn(
138 Operation
*op
, bool replaceAttrs
, bool replaceLocs
, bool replaceTypes
) {
139 op
->walk([&](Operation
*nestedOp
) {
140 replaceElementsIn(nestedOp
, replaceAttrs
, replaceLocs
, replaceTypes
);
144 template <typename T
, typename Replacer
>
145 static void updateSubElementImpl(T element
, Replacer
&replacer
,
146 SmallVectorImpl
<T
> &newElements
,
147 FailureOr
<bool> &changed
) {
148 // Bail early if we failed at any point.
152 // Guard against potentially null inputs. We always map null to null.
154 newElements
.push_back(nullptr);
158 // Replace the element.
159 if (T result
= replacer
.replace(element
)) {
160 newElements
.push_back(result
);
161 if (result
!= element
)
168 template <typename T
, typename Replacer
>
169 static T
replaceSubElements(T interface
, Replacer
&replacer
) {
170 // Walk the current sub-elements, replacing them as necessary.
171 SmallVector
<Attribute
, 16> newAttrs
;
172 SmallVector
<Type
, 16> newTypes
;
173 FailureOr
<bool> changed
= false;
174 interface
.walkImmediateSubElements(
175 [&](Attribute element
) {
176 updateSubElementImpl(element
, replacer
, newAttrs
, changed
);
179 updateSubElementImpl(element
, replacer
, newTypes
, changed
);
184 // If any sub-elements changed, use the new elements during the replacement.
185 T result
= interface
;
187 result
= interface
.replaceImmediateSubElements(newAttrs
, newTypes
);
191 /// Shared implementation of replacing a given attribute or type element.
192 template <typename T
, typename ReplaceFns
, typename Replacer
>
193 static T
replaceElementImpl(T element
, ReplaceFns
&replaceFns
,
194 Replacer
&replacer
) {
196 WalkResult walkResult
= WalkResult::advance();
197 for (auto &replaceFn
: llvm::reverse(replaceFns
)) {
198 if (std::optional
<std::pair
<T
, WalkResult
>> newRes
= replaceFn(element
)) {
199 std::tie(result
, walkResult
) = *newRes
;
204 // If an error occurred, return nullptr to indicate failure.
205 if (walkResult
.wasInterrupted() || !result
) {
209 // Handle replacing sub-elements if this element is also a container.
210 if (!walkResult
.wasSkipped()) {
211 // Replace the sub elements of this element, bailing if we fail.
212 if (!(result
= replaceSubElements(result
, replacer
))) {
220 template <typename Concrete
>
221 Attribute
detail::AttrTypeReplacerBase
<Concrete
>::replaceBase(Attribute attr
) {
222 return replaceElementImpl(attr
, attrReplacementFns
,
223 *static_cast<Concrete
*>(this));
226 template <typename Concrete
>
227 Type
detail::AttrTypeReplacerBase
<Concrete
>::replaceBase(Type type
) {
228 return replaceElementImpl(type
, typeReplacementFns
,
229 *static_cast<Concrete
*>(this));
232 //===----------------------------------------------------------------------===//
234 //===----------------------------------------------------------------------===//
236 template class detail::AttrTypeReplacerBase
<AttrTypeReplacer
>;
238 template <typename T
>
239 T
AttrTypeReplacer::cachedReplaceImpl(T element
) {
240 const void *opaqueElement
= element
.getAsOpaquePointer();
241 auto [it
, inserted
] = cache
.try_emplace(opaqueElement
, opaqueElement
);
243 return T::getFromOpaquePointer(it
->second
);
245 T result
= replaceBase(element
);
247 cache
[opaqueElement
] = result
.getAsOpaquePointer();
251 Attribute
AttrTypeReplacer::replace(Attribute attr
) {
252 return cachedReplaceImpl(attr
);
255 Type
AttrTypeReplacer::replace(Type type
) { return cachedReplaceImpl(type
); }
257 //===----------------------------------------------------------------------===//
258 /// CyclicAttrTypeReplacer
259 //===----------------------------------------------------------------------===//
261 template class detail::AttrTypeReplacerBase
<CyclicAttrTypeReplacer
>;
263 CyclicAttrTypeReplacer::CyclicAttrTypeReplacer()
264 : cache([&](void *attr
) { return breakCycleImpl(attr
); }) {}
266 void CyclicAttrTypeReplacer::addCycleBreaker(CycleBreakerFn
<Attribute
> fn
) {
267 attrCycleBreakerFns
.emplace_back(std::move(fn
));
270 void CyclicAttrTypeReplacer::addCycleBreaker(CycleBreakerFn
<Type
> fn
) {
271 typeCycleBreakerFns
.emplace_back(std::move(fn
));
274 template <typename T
>
275 T
CyclicAttrTypeReplacer::cachedReplaceImpl(T element
) {
276 void *opaqueTaggedElement
= AttrOrType(element
).getOpaqueValue();
277 CyclicReplacerCache
<void *, const void *>::CacheEntry cacheEntry
=
278 cache
.lookupOrInit(opaqueTaggedElement
);
279 if (auto resultOpt
= cacheEntry
.get())
280 return T::getFromOpaquePointer(*resultOpt
);
282 T result
= replaceBase(element
);
284 cacheEntry
.resolve(result
.getAsOpaquePointer());
288 Attribute
CyclicAttrTypeReplacer::replace(Attribute attr
) {
289 return cachedReplaceImpl(attr
);
292 Type
CyclicAttrTypeReplacer::replace(Type type
) {
293 return cachedReplaceImpl(type
);
296 std::optional
<const void *>
297 CyclicAttrTypeReplacer::breakCycleImpl(void *element
) {
298 AttrOrType attrType
= AttrOrType::getFromOpaqueValue(element
);
299 if (auto attr
= dyn_cast
<Attribute
>(attrType
)) {
300 for (auto &cyclicReplaceFn
: llvm::reverse(attrCycleBreakerFns
)) {
301 if (std::optional
<Attribute
> newRes
= cyclicReplaceFn(attr
)) {
302 return newRes
->getAsOpaquePointer();
306 auto type
= dyn_cast
<Type
>(attrType
);
307 for (auto &cyclicReplaceFn
: llvm::reverse(typeCycleBreakerFns
)) {
308 if (std::optional
<Type
> newRes
= cyclicReplaceFn(type
)) {
309 return newRes
->getAsOpaquePointer();
316 //===----------------------------------------------------------------------===//
317 // AttrTypeImmediateSubElementWalker
318 //===----------------------------------------------------------------------===//
320 void AttrTypeImmediateSubElementWalker::walk(Attribute element
) {
322 walkAttrsFn(element
);
325 void AttrTypeImmediateSubElementWalker::walk(Type element
) {
327 walkTypesFn(element
);