1 //===- AttrTypeSubElements.cpp - Attr and Type SubElement Interfaces ------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/IR/Operation.h"
14 //===----------------------------------------------------------------------===//
16 //===----------------------------------------------------------------------===//
18 WalkResult
AttrTypeWalker::walkImpl(Attribute attr
, WalkOrder order
) {
19 return walkImpl(attr
, attrWalkFns
, order
);
21 WalkResult
AttrTypeWalker::walkImpl(Type type
, WalkOrder order
) {
22 return walkImpl(type
, typeWalkFns
, order
);
25 template <typename T
, typename WalkFns
>
26 WalkResult
AttrTypeWalker::walkImpl(T element
, WalkFns
&walkFns
,
28 // Check if we've already walk this element before.
29 auto key
= std::make_pair(element
.getAsOpaquePointer(), (int)order
);
30 auto it
= visitedAttrTypes
.find(key
);
31 if (it
!= visitedAttrTypes
.end())
33 visitedAttrTypes
.try_emplace(key
, WalkResult::advance());
35 // If we are walking in post order, walk the sub elements first.
36 if (order
== WalkOrder::PostOrder
) {
37 if (walkSubElements(element
, order
).wasInterrupted())
38 return visitedAttrTypes
[key
] = WalkResult::interrupt();
41 // Walk this element, bailing if skipped or interrupted.
42 for (auto &walkFn
: llvm::reverse(walkFns
)) {
43 WalkResult walkResult
= walkFn(element
);
44 if (walkResult
.wasInterrupted())
45 return visitedAttrTypes
[key
] = WalkResult::interrupt();
46 if (walkResult
.wasSkipped())
47 return WalkResult::advance();
50 // If we are walking in pre-order, walk the sub elements last.
51 if (order
== WalkOrder::PreOrder
) {
52 if (walkSubElements(element
, order
).wasInterrupted())
53 return WalkResult::interrupt();
55 return WalkResult::advance();
59 WalkResult
AttrTypeWalker::walkSubElements(T interface
, WalkOrder order
) {
60 WalkResult result
= WalkResult::advance();
61 auto walkFn
= [&](auto element
) {
62 if (element
&& !result
.wasInterrupted())
63 result
= walkImpl(element
, order
);
65 interface
.walkImmediateSubElements(walkFn
, walkFn
);
66 return result
.wasInterrupted() ? result
: WalkResult::advance();
69 //===----------------------------------------------------------------------===//
71 //===----------------------------------------------------------------------===//
73 void AttrTypeReplacer::addReplacement(ReplaceFn
<Attribute
> fn
) {
74 attrReplacementFns
.emplace_back(std::move(fn
));
76 void AttrTypeReplacer::addReplacement(ReplaceFn
<Type
> fn
) {
77 typeReplacementFns
.push_back(std::move(fn
));
80 void AttrTypeReplacer::replaceElementsIn(Operation
*op
, bool replaceAttrs
,
81 bool replaceLocs
, bool replaceTypes
) {
82 // Functor that replaces the given element if the new value is different,
83 // otherwise returns nullptr.
84 auto replaceIfDifferent
= [&](auto element
) {
85 auto replacement
= replace(element
);
86 return (replacement
&& replacement
!= element
) ? replacement
: nullptr;
89 // Update the attribute dictionary.
91 if (auto newAttrs
= replaceIfDifferent(op
->getAttrDictionary()))
92 op
->setAttrs(cast
<DictionaryAttr
>(newAttrs
));
95 // If we aren't updating locations or types, we're done.
96 if (!replaceTypes
&& !replaceLocs
)
99 // Update the location.
101 if (Attribute newLoc
= replaceIfDifferent(op
->getLoc()))
102 op
->setLoc(cast
<LocationAttr
>(newLoc
));
105 // Update the result types.
107 for (OpResult result
: op
->getResults())
108 if (Type newType
= replaceIfDifferent(result
.getType()))
109 result
.setType(newType
);
112 // Update any nested block arguments.
113 for (Region
®ion
: op
->getRegions()) {
114 for (Block
&block
: region
) {
115 for (BlockArgument
&arg
: block
.getArguments()) {
117 if (Attribute newLoc
= replaceIfDifferent(arg
.getLoc()))
118 arg
.setLoc(cast
<LocationAttr
>(newLoc
));
122 if (Type newType
= replaceIfDifferent(arg
.getType()))
123 arg
.setType(newType
);
130 void AttrTypeReplacer::recursivelyReplaceElementsIn(Operation
*op
,
134 op
->walk([&](Operation
*nestedOp
) {
135 replaceElementsIn(nestedOp
, replaceAttrs
, replaceLocs
, replaceTypes
);
139 template <typename T
>
140 static void updateSubElementImpl(T element
, AttrTypeReplacer
&replacer
,
141 SmallVectorImpl
<T
> &newElements
,
142 FailureOr
<bool> &changed
) {
143 // Bail early if we failed at any point.
147 // Guard against potentially null inputs. We always map null to null.
149 newElements
.push_back(nullptr);
153 // Replace the element.
154 if (T result
= replacer
.replace(element
)) {
155 newElements
.push_back(result
);
156 if (result
!= element
)
163 template <typename T
>
164 T
AttrTypeReplacer::replaceSubElements(T interface
) {
165 // Walk the current sub-elements, replacing them as necessary.
166 SmallVector
<Attribute
, 16> newAttrs
;
167 SmallVector
<Type
, 16> newTypes
;
168 FailureOr
<bool> changed
= false;
169 interface
.walkImmediateSubElements(
170 [&](Attribute element
) {
171 updateSubElementImpl(element
, *this, newAttrs
, changed
);
174 updateSubElementImpl(element
, *this, newTypes
, changed
);
179 // If any sub-elements changed, use the new elements during the replacement.
180 T result
= interface
;
182 result
= interface
.replaceImmediateSubElements(newAttrs
, newTypes
);
186 /// Shared implementation of replacing a given attribute or type element.
187 template <typename T
, typename ReplaceFns
>
188 T
AttrTypeReplacer::replaceImpl(T element
, ReplaceFns
&replaceFns
) {
189 const void *opaqueElement
= element
.getAsOpaquePointer();
190 auto [it
, inserted
] = attrTypeMap
.try_emplace(opaqueElement
, opaqueElement
);
192 return T::getFromOpaquePointer(it
->second
);
195 WalkResult walkResult
= WalkResult::advance();
196 for (auto &replaceFn
: llvm::reverse(replaceFns
)) {
197 if (std::optional
<std::pair
<T
, WalkResult
>> newRes
= replaceFn(element
)) {
198 std::tie(result
, walkResult
) = *newRes
;
203 // If an error occurred, return nullptr to indicate failure.
204 if (walkResult
.wasInterrupted() || !result
) {
205 attrTypeMap
[opaqueElement
] = nullptr;
209 // Handle replacing sub-elements if this element is also a container.
210 if (!walkResult
.wasSkipped()) {
211 // Replace the sub elements of this element, bailing if we fail.
212 if (!(result
= replaceSubElements(result
))) {
213 attrTypeMap
[opaqueElement
] = nullptr;
218 attrTypeMap
[opaqueElement
] = result
.getAsOpaquePointer();
222 Attribute
AttrTypeReplacer::replace(Attribute attr
) {
223 return replaceImpl(attr
, attrReplacementFns
);
226 Type
AttrTypeReplacer::replace(Type type
) {
227 return replaceImpl(type
, typeReplacementFns
);
230 //===----------------------------------------------------------------------===//
231 // AttrTypeImmediateSubElementWalker
232 //===----------------------------------------------------------------------===//
234 void AttrTypeImmediateSubElementWalker::walk(Attribute element
) {
236 walkAttrsFn(element
);
239 void AttrTypeImmediateSubElementWalker::walk(Type element
) {
241 walkTypesFn(element
);