1 //===- OperationSupport.cpp -----------------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file contains out-of-line implementations of the support types that
10 // Operation and related classes build on top of.
12 //===----------------------------------------------------------------------===//
14 #include "mlir/IR/OperationSupport.h"
15 #include "mlir/IR/BuiltinAttributes.h"
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/OpDefinition.h"
18 #include "llvm/ADT/BitVector.h"
19 #include "llvm/Support/SHA1.h"
25 //===----------------------------------------------------------------------===//
27 //===----------------------------------------------------------------------===//
29 NamedAttrList::NamedAttrList(ArrayRef
<NamedAttribute
> attributes
) {
30 assign(attributes
.begin(), attributes
.end());
33 NamedAttrList::NamedAttrList(DictionaryAttr attributes
)
34 : NamedAttrList(attributes
? attributes
.getValue()
35 : ArrayRef
<NamedAttribute
>()) {
36 dictionarySorted
.setPointerAndInt(attributes
, true);
39 NamedAttrList::NamedAttrList(const_iterator inStart
, const_iterator inEnd
) {
40 assign(inStart
, inEnd
);
43 ArrayRef
<NamedAttribute
> NamedAttrList::getAttrs() const { return attrs
; }
45 std::optional
<NamedAttribute
> NamedAttrList::findDuplicate() const {
46 std::optional
<NamedAttribute
> duplicate
=
47 DictionaryAttr::findDuplicate(attrs
, isSorted());
48 // DictionaryAttr::findDuplicate will sort the list, so reset the sorted
51 dictionarySorted
.setPointerAndInt(nullptr, true);
55 DictionaryAttr
NamedAttrList::getDictionary(MLIRContext
*context
) const {
57 DictionaryAttr::sortInPlace(attrs
);
58 dictionarySorted
.setPointerAndInt(nullptr, true);
60 if (!dictionarySorted
.getPointer())
61 dictionarySorted
.setPointer(DictionaryAttr::getWithSorted(context
, attrs
));
62 return llvm::cast
<DictionaryAttr
>(dictionarySorted
.getPointer());
65 /// Add an attribute with the specified name.
66 void NamedAttrList::append(StringRef name
, Attribute attr
) {
67 append(StringAttr::get(attr
.getContext(), name
), attr
);
70 /// Replaces the attributes with new list of attributes.
71 void NamedAttrList::assign(const_iterator inStart
, const_iterator inEnd
) {
72 DictionaryAttr::sort(ArrayRef
<NamedAttribute
>{inStart
, inEnd
}, attrs
);
73 dictionarySorted
.setPointerAndInt(nullptr, true);
76 void NamedAttrList::push_back(NamedAttribute newAttribute
) {
78 dictionarySorted
.setInt(attrs
.empty() || attrs
.back() < newAttribute
);
79 dictionarySorted
.setPointer(nullptr);
80 attrs
.push_back(newAttribute
);
83 /// Return the specified attribute if present, null otherwise.
84 Attribute
NamedAttrList::get(StringRef name
) const {
85 auto it
= findAttr(*this, name
);
86 return it
.second
? it
.first
->getValue() : Attribute();
88 Attribute
NamedAttrList::get(StringAttr name
) const {
89 auto it
= findAttr(*this, name
);
90 return it
.second
? it
.first
->getValue() : Attribute();
93 /// Return the specified named attribute if present, std::nullopt otherwise.
94 std::optional
<NamedAttribute
> NamedAttrList::getNamed(StringRef name
) const {
95 auto it
= findAttr(*this, name
);
96 return it
.second
? *it
.first
: std::optional
<NamedAttribute
>();
98 std::optional
<NamedAttribute
> NamedAttrList::getNamed(StringAttr name
) const {
99 auto it
= findAttr(*this, name
);
100 return it
.second
? *it
.first
: std::optional
<NamedAttribute
>();
103 /// If the an attribute exists with the specified name, change it to the new
104 /// value. Otherwise, add a new attribute with the specified name/value.
105 Attribute
NamedAttrList::set(StringAttr name
, Attribute value
) {
106 assert(value
&& "attributes may never be null");
108 // Look for an existing attribute with the given name, and set its value
109 // in-place. Return the previous value of the attribute, if there was one.
110 auto it
= findAttr(*this, name
);
112 // Update the existing attribute by swapping out the old value for the new
113 // value. Return the old value.
114 Attribute oldValue
= it
.first
->getValue();
115 if (it
.first
->getValue() != value
) {
116 it
.first
->setValue(value
);
118 // If the attributes have changed, the dictionary is invalidated.
119 dictionarySorted
.setPointer(nullptr);
123 // Perform a string lookup to insert the new attribute into its sorted
126 it
= findAttr(*this, name
.strref());
127 attrs
.insert(it
.first
, {name
, value
});
128 // Invalidate the dictionary. Return null as there was no previous value.
129 dictionarySorted
.setPointer(nullptr);
133 Attribute
NamedAttrList::set(StringRef name
, Attribute value
) {
134 assert(value
&& "attributes may never be null");
135 return set(mlir::StringAttr::get(value
.getContext(), name
), value
);
139 NamedAttrList::eraseImpl(SmallVectorImpl
<NamedAttribute
>::iterator it
) {
140 // Erasing does not affect the sorted property.
141 Attribute attr
= it
->getValue();
143 dictionarySorted
.setPointer(nullptr);
147 Attribute
NamedAttrList::erase(StringAttr name
) {
148 auto it
= findAttr(*this, name
);
149 return it
.second
? eraseImpl(it
.first
) : Attribute();
152 Attribute
NamedAttrList::erase(StringRef name
) {
153 auto it
= findAttr(*this, name
);
154 return it
.second
? eraseImpl(it
.first
) : Attribute();
158 NamedAttrList::operator=(const SmallVectorImpl
<NamedAttribute
> &rhs
) {
159 assign(rhs
.begin(), rhs
.end());
163 NamedAttrList::operator ArrayRef
<NamedAttribute
>() const { return attrs
; }
165 //===----------------------------------------------------------------------===//
167 //===----------------------------------------------------------------------===//
169 OperationState::OperationState(Location location
, StringRef name
)
170 : location(location
), name(name
, location
->getContext()) {}
172 OperationState::OperationState(Location location
, OperationName name
)
173 : location(location
), name(name
) {}
175 OperationState::OperationState(Location location
, OperationName name
,
176 ValueRange operands
, TypeRange types
,
177 ArrayRef
<NamedAttribute
> attributes
,
178 BlockRange successors
,
179 MutableArrayRef
<std::unique_ptr
<Region
>> regions
)
180 : location(location
), name(name
),
181 operands(operands
.begin(), operands
.end()),
182 types(types
.begin(), types
.end()),
183 attributes(attributes
.begin(), attributes
.end()),
184 successors(successors
.begin(), successors
.end()) {
185 for (std::unique_ptr
<Region
> &r
: regions
)
186 this->regions
.push_back(std::move(r
));
188 OperationState::OperationState(Location location
, StringRef name
,
189 ValueRange operands
, TypeRange types
,
190 ArrayRef
<NamedAttribute
> attributes
,
191 BlockRange successors
,
192 MutableArrayRef
<std::unique_ptr
<Region
>> regions
)
193 : OperationState(location
, OperationName(name
, location
.getContext()),
194 operands
, types
, attributes
, successors
, regions
) {}
196 OperationState::~OperationState() {
198 propertiesDeleter(properties
);
201 LogicalResult
OperationState::setProperties(
202 Operation
*op
, function_ref
<InFlightDiagnostic()> emitError
) const {
203 if (LLVM_UNLIKELY(propertiesAttr
)) {
205 return op
->setPropertiesFromAttribute(propertiesAttr
, emitError
);
208 propertiesSetter(op
->getPropertiesStorage(), properties
);
212 void OperationState::addOperands(ValueRange newOperands
) {
213 operands
.append(newOperands
.begin(), newOperands
.end());
216 void OperationState::addSuccessors(BlockRange newSuccessors
) {
217 successors
.append(newSuccessors
.begin(), newSuccessors
.end());
220 Region
*OperationState::addRegion() {
221 regions
.emplace_back(new Region
);
222 return regions
.back().get();
225 void OperationState::addRegion(std::unique_ptr
<Region
> &®ion
) {
226 regions
.push_back(std::move(region
));
229 void OperationState::addRegions(
230 MutableArrayRef
<std::unique_ptr
<Region
>> regions
) {
231 for (std::unique_ptr
<Region
> ®ion
: regions
)
232 addRegion(std::move(region
));
235 //===----------------------------------------------------------------------===//
237 //===----------------------------------------------------------------------===//
239 detail::OperandStorage::OperandStorage(Operation
*owner
,
240 OpOperand
*trailingOperands
,
242 : isStorageDynamic(false), operandStorage(trailingOperands
) {
243 numOperands
= capacity
= values
.size();
244 for (unsigned i
= 0; i
< numOperands
; ++i
)
245 new (&operandStorage
[i
]) OpOperand(owner
, values
[i
]);
248 detail::OperandStorage::~OperandStorage() {
249 for (auto &operand
: getOperands())
250 operand
.~OpOperand();
252 // If the storage is dynamic, deallocate it.
253 if (isStorageDynamic
)
254 free(operandStorage
);
257 /// Replace the operands contained in the storage with the ones provided in
259 void detail::OperandStorage::setOperands(Operation
*owner
, ValueRange values
) {
260 MutableArrayRef
<OpOperand
> storageOperands
= resize(owner
, values
.size());
261 for (unsigned i
= 0, e
= values
.size(); i
!= e
; ++i
)
262 storageOperands
[i
].set(values
[i
]);
265 /// Replace the operands beginning at 'start' and ending at 'start' + 'length'
266 /// with the ones provided in 'operands'. 'operands' may be smaller or larger
267 /// than the range pointed to by 'start'+'length'.
268 void detail::OperandStorage::setOperands(Operation
*owner
, unsigned start
,
269 unsigned length
, ValueRange operands
) {
270 // If the new size is the same, we can update inplace.
271 unsigned newSize
= operands
.size();
272 if (newSize
== length
) {
273 MutableArrayRef
<OpOperand
> storageOperands
= getOperands();
274 for (unsigned i
= 0, e
= length
; i
!= e
; ++i
)
275 storageOperands
[start
+ i
].set(operands
[i
]);
278 // If the new size is greater, remove the extra operands and set the rest
280 if (newSize
< length
) {
281 eraseOperands(start
+ operands
.size(), length
- newSize
);
282 setOperands(owner
, start
, newSize
, operands
);
285 // Otherwise, the new size is greater so we need to grow the storage.
286 auto storageOperands
= resize(owner
, size() + (newSize
- length
));
288 // Shift operands to the right to make space for the new operands.
289 unsigned rotateSize
= storageOperands
.size() - (start
+ length
);
290 auto rbegin
= storageOperands
.rbegin();
291 std::rotate(rbegin
, std::next(rbegin
, newSize
- length
), rbegin
+ rotateSize
);
293 // Update the operands inplace.
294 for (unsigned i
= 0, e
= operands
.size(); i
!= e
; ++i
)
295 storageOperands
[start
+ i
].set(operands
[i
]);
298 /// Erase an operand held by the storage.
299 void detail::OperandStorage::eraseOperands(unsigned start
, unsigned length
) {
300 MutableArrayRef
<OpOperand
> operands
= getOperands();
301 assert((start
+ length
) <= operands
.size());
302 numOperands
-= length
;
304 // Shift all operands down if the operand to remove is not at the end.
305 if (start
!= numOperands
) {
306 auto *indexIt
= std::next(operands
.begin(), start
);
307 std::rotate(indexIt
, std::next(indexIt
, length
), operands
.end());
309 for (unsigned i
= 0; i
!= length
; ++i
)
310 operands
[numOperands
+ i
].~OpOperand();
313 void detail::OperandStorage::eraseOperands(const BitVector
&eraseIndices
) {
314 MutableArrayRef
<OpOperand
> operands
= getOperands();
315 assert(eraseIndices
.size() == operands
.size());
317 // Check that at least one operand is erased.
318 int firstErasedIndice
= eraseIndices
.find_first();
319 if (firstErasedIndice
== -1)
322 // Shift all of the removed operands to the end, and destroy them.
323 numOperands
= firstErasedIndice
;
324 for (unsigned i
= firstErasedIndice
+ 1, e
= operands
.size(); i
< e
; ++i
)
325 if (!eraseIndices
.test(i
))
326 operands
[numOperands
++] = std::move(operands
[i
]);
327 for (OpOperand
&operand
: operands
.drop_front(numOperands
))
328 operand
.~OpOperand();
331 /// Resize the storage to the given size. Returns the array containing the new
333 MutableArrayRef
<OpOperand
> detail::OperandStorage::resize(Operation
*owner
,
335 // If the number of operands is less than or equal to the current amount, we
336 // can just update in place.
337 MutableArrayRef
<OpOperand
> origOperands
= getOperands();
338 if (newSize
<= numOperands
) {
339 // If the number of new size is less than the current, remove any extra
341 for (unsigned i
= newSize
; i
!= numOperands
; ++i
)
342 origOperands
[i
].~OpOperand();
343 numOperands
= newSize
;
344 return origOperands
.take_front(newSize
);
347 // If the new size is within the original inline capacity, grow inplace.
348 if (newSize
<= capacity
) {
349 OpOperand
*opBegin
= origOperands
.data();
350 for (unsigned e
= newSize
; numOperands
!= e
; ++numOperands
)
351 new (&opBegin
[numOperands
]) OpOperand(owner
);
352 return MutableArrayRef
<OpOperand
>(opBegin
, newSize
);
355 // Otherwise, we need to allocate a new storage.
356 unsigned newCapacity
=
357 std::max(unsigned(llvm::NextPowerOf2(capacity
+ 2)), newSize
);
358 OpOperand
*newOperandStorage
=
359 reinterpret_cast<OpOperand
*>(malloc(sizeof(OpOperand
) * newCapacity
));
361 // Move the current operands to the new storage.
362 MutableArrayRef
<OpOperand
> newOperands(newOperandStorage
, newSize
);
363 std::uninitialized_move(origOperands
.begin(), origOperands
.end(),
364 newOperands
.begin());
366 // Destroy the original operands.
367 for (auto &operand
: origOperands
)
368 operand
.~OpOperand();
370 // Initialize any new operands.
371 for (unsigned e
= newSize
; numOperands
!= e
; ++numOperands
)
372 new (&newOperands
[numOperands
]) OpOperand(owner
);
374 // If the current storage is dynamic, free it.
375 if (isStorageDynamic
)
376 free(operandStorage
);
378 // Update the storage representation to use the new dynamic storage.
379 operandStorage
= newOperandStorage
;
380 capacity
= newCapacity
;
381 isStorageDynamic
= true;
385 //===----------------------------------------------------------------------===//
386 // Operation Value-Iterators
387 //===----------------------------------------------------------------------===//
389 //===----------------------------------------------------------------------===//
392 unsigned OperandRange::getBeginOperandIndex() const {
393 assert(!empty() && "range must not be empty");
394 return base
->getOperandNumber();
397 OperandRangeRange
OperandRange::split(DenseI32ArrayAttr segmentSizes
) const {
398 return OperandRangeRange(*this, segmentSizes
);
401 //===----------------------------------------------------------------------===//
404 OperandRangeRange::OperandRangeRange(OperandRange operands
,
405 Attribute operandSegments
)
406 : OperandRangeRange(OwnerT(operands
.getBase(), operandSegments
), 0,
407 llvm::cast
<DenseI32ArrayAttr
>(operandSegments
).size()) {
410 OperandRange
OperandRangeRange::join() const {
411 const OwnerT
&owner
= getBase();
412 ArrayRef
<int32_t> sizeData
= llvm::cast
<DenseI32ArrayAttr
>(owner
.second
);
413 return OperandRange(owner
.first
,
414 std::accumulate(sizeData
.begin(), sizeData
.end(), 0));
417 OperandRange
OperandRangeRange::dereference(const OwnerT
&object
,
419 ArrayRef
<int32_t> sizeData
= llvm::cast
<DenseI32ArrayAttr
>(object
.second
);
420 uint32_t startIndex
=
421 std::accumulate(sizeData
.begin(), sizeData
.begin() + index
, 0);
422 return OperandRange(object
.first
+ startIndex
, *(sizeData
.begin() + index
));
425 //===----------------------------------------------------------------------===//
426 // MutableOperandRange
428 /// Construct a new mutable range from the given operand, operand start index,
429 /// and range length.
430 MutableOperandRange::MutableOperandRange(
431 Operation
*owner
, unsigned start
, unsigned length
,
432 ArrayRef
<OperandSegment
> operandSegments
)
433 : owner(owner
), start(start
), length(length
),
434 operandSegments(operandSegments
) {
435 assert((start
+ length
) <= owner
->getNumOperands() && "invalid range");
437 MutableOperandRange::MutableOperandRange(Operation
*owner
)
438 : MutableOperandRange(owner
, /*start=*/0, owner
->getNumOperands()) {}
440 /// Construct a new mutable range for the given OpOperand.
441 MutableOperandRange::MutableOperandRange(OpOperand
&opOperand
)
442 : MutableOperandRange(opOperand
.getOwner(),
443 /*start=*/opOperand
.getOperandNumber(),
446 /// Slice this range into a sub range, with the additional operand segment.
448 MutableOperandRange::slice(unsigned subStart
, unsigned subLen
,
449 std::optional
<OperandSegment
> segment
) const {
450 assert((subStart
+ subLen
) <= length
&& "invalid sub-range");
451 MutableOperandRange
subSlice(owner
, start
+ subStart
, subLen
,
454 subSlice
.operandSegments
.push_back(*segment
);
458 /// Append the given values to the range.
459 void MutableOperandRange::append(ValueRange values
) {
462 owner
->insertOperands(start
+ length
, values
);
463 updateLength(length
+ values
.size());
466 /// Assign this range to the given values.
467 void MutableOperandRange::assign(ValueRange values
) {
468 owner
->setOperands(start
, length
, values
);
469 if (length
!= values
.size())
470 updateLength(/*newLength=*/values
.size());
473 /// Assign the range to the given value.
474 void MutableOperandRange::assign(Value value
) {
476 owner
->setOperand(start
, value
);
478 owner
->setOperands(start
, length
, value
);
479 updateLength(/*newLength=*/1);
483 /// Erase the operands within the given sub-range.
484 void MutableOperandRange::erase(unsigned subStart
, unsigned subLen
) {
485 assert((subStart
+ subLen
) <= length
&& "invalid sub-range");
488 owner
->eraseOperands(start
+ subStart
, subLen
);
489 updateLength(length
- subLen
);
492 /// Clear this range and erase all of the operands.
493 void MutableOperandRange::clear() {
495 owner
->eraseOperands(start
, length
);
496 updateLength(/*newLength=*/0);
500 /// Explicit conversion to an OperandRange.
501 OperandRange
MutableOperandRange::getAsOperandRange() const {
502 return owner
->getOperands().slice(start
, length
);
505 /// Allow implicit conversion to an OperandRange.
506 MutableOperandRange::operator OperandRange() const {
507 return getAsOperandRange();
510 MutableOperandRange::operator MutableArrayRef
<OpOperand
>() const {
511 return owner
->getOpOperands().slice(start
, length
);
514 MutableOperandRangeRange
515 MutableOperandRange::split(NamedAttribute segmentSizes
) const {
516 return MutableOperandRangeRange(*this, segmentSizes
);
519 /// Update the length of this range to the one provided.
520 void MutableOperandRange::updateLength(unsigned newLength
) {
521 int32_t diff
= int32_t(newLength
) - int32_t(length
);
524 // Update any of the provided segment attributes.
525 for (OperandSegment
&segment
: operandSegments
) {
526 auto attr
= llvm::cast
<DenseI32ArrayAttr
>(segment
.second
.getValue());
527 SmallVector
<int32_t, 8> segments(attr
.asArrayRef());
528 segments
[segment
.first
] += diff
;
529 segment
.second
.setValue(
530 DenseI32ArrayAttr::get(attr
.getContext(), segments
));
531 owner
->setAttr(segment
.second
.getName(), segment
.second
.getValue());
535 OpOperand
&MutableOperandRange::operator[](unsigned index
) const {
536 assert(index
< length
&& "index is out of bounds");
537 return owner
->getOpOperand(start
+ index
);
540 MutableArrayRef
<OpOperand
>::iterator
MutableOperandRange::begin() const {
541 return owner
->getOpOperands().slice(start
, length
).begin();
544 MutableArrayRef
<OpOperand
>::iterator
MutableOperandRange::end() const {
545 return owner
->getOpOperands().slice(start
, length
).end();
548 //===----------------------------------------------------------------------===//
549 // MutableOperandRangeRange
551 MutableOperandRangeRange::MutableOperandRangeRange(
552 const MutableOperandRange
&operands
, NamedAttribute operandSegmentAttr
)
553 : MutableOperandRangeRange(
554 OwnerT(operands
, operandSegmentAttr
), 0,
555 llvm::cast
<DenseI32ArrayAttr
>(operandSegmentAttr
.getValue()).size()) {
558 MutableOperandRange
MutableOperandRangeRange::join() const {
559 return getBase().first
;
562 MutableOperandRangeRange::operator OperandRangeRange() const {
563 return OperandRangeRange(getBase().first
, getBase().second
.getValue());
566 MutableOperandRange
MutableOperandRangeRange::dereference(const OwnerT
&object
,
568 ArrayRef
<int32_t> sizeData
=
569 llvm::cast
<DenseI32ArrayAttr
>(object
.second
.getValue());
570 uint32_t startIndex
=
571 std::accumulate(sizeData
.begin(), sizeData
.begin() + index
, 0);
572 return object
.first
.slice(
573 startIndex
, *(sizeData
.begin() + index
),
574 MutableOperandRange::OperandSegment(index
, object
.second
));
577 //===----------------------------------------------------------------------===//
580 ResultRange::ResultRange(OpResult result
)
581 : ResultRange(static_cast<detail::OpResultImpl
*>(Value(result
).getImpl()),
584 ResultRange::use_range
ResultRange::getUses() const {
585 return {use_begin(), use_end()};
587 ResultRange::use_iterator
ResultRange::use_begin() const {
588 return use_iterator(*this);
590 ResultRange::use_iterator
ResultRange::use_end() const {
591 return use_iterator(*this, /*end=*/true);
593 ResultRange::user_range
ResultRange::getUsers() {
594 return {user_begin(), user_end()};
596 ResultRange::user_iterator
ResultRange::user_begin() {
597 return user_iterator(use_begin());
599 ResultRange::user_iterator
ResultRange::user_end() {
600 return user_iterator(use_end());
603 ResultRange::UseIterator::UseIterator(ResultRange results
, bool end
)
604 : it(end
? results
.end() : results
.begin()), endIt(results
.end()) {
605 // Only initialize current use if there are results/can be uses.
607 skipOverResultsWithNoUsers();
610 ResultRange::UseIterator
&ResultRange::UseIterator::operator++() {
611 // We increment over uses, if we reach the last use then move to next
613 if (use
!= (*it
).use_end())
615 if (use
== (*it
).use_end()) {
617 skipOverResultsWithNoUsers();
622 void ResultRange::UseIterator::skipOverResultsWithNoUsers() {
623 while (it
!= endIt
&& (*it
).use_empty())
626 // If we are at the last result, then set use to first use of
627 // first result (sentinel value used for end).
631 use
= (*it
).use_begin();
634 void ResultRange::replaceAllUsesWith(Operation
*op
) {
635 replaceAllUsesWith(op
->getResults());
638 void ResultRange::replaceUsesWithIf(
639 Operation
*op
, function_ref
<bool(OpOperand
&)> shouldReplace
) {
640 replaceUsesWithIf(op
->getResults(), shouldReplace
);
643 //===----------------------------------------------------------------------===//
646 ValueRange::ValueRange(ArrayRef
<Value
> values
)
647 : ValueRange(values
.data(), values
.size()) {}
648 ValueRange::ValueRange(OperandRange values
)
649 : ValueRange(values
.begin().getBase(), values
.size()) {}
650 ValueRange::ValueRange(ResultRange values
)
651 : ValueRange(values
.getBase(), values
.size()) {}
653 /// See `llvm::detail::indexed_accessor_range_base` for details.
654 ValueRange::OwnerT
ValueRange::offset_base(const OwnerT
&owner
,
656 if (const auto *value
= llvm::dyn_cast_if_present
<const Value
*>(owner
))
657 return {value
+ index
};
658 if (auto *operand
= llvm::dyn_cast_if_present
<OpOperand
*>(owner
))
659 return {operand
+ index
};
660 return cast
<detail::OpResultImpl
*>(owner
)->getNextResultAtOffset(index
);
662 /// See `llvm::detail::indexed_accessor_range_base` for details.
663 Value
ValueRange::dereference_iterator(const OwnerT
&owner
, ptrdiff_t index
) {
664 if (const auto *value
= llvm::dyn_cast_if_present
<const Value
*>(owner
))
666 if (auto *operand
= llvm::dyn_cast_if_present
<OpOperand
*>(owner
))
667 return operand
[index
].get();
668 return cast
<detail::OpResultImpl
*>(owner
)->getNextResultAtOffset(index
);
671 //===----------------------------------------------------------------------===//
672 // Operation Equivalency
673 //===----------------------------------------------------------------------===//
675 llvm::hash_code
OperationEquivalence::computeHash(
676 Operation
*op
, function_ref
<llvm::hash_code(Value
)> hashOperands
,
677 function_ref
<llvm::hash_code(Value
)> hashResults
, Flags flags
) {
678 // Hash operations based upon their:
682 llvm::hash_code hash
=
683 llvm::hash_combine(op
->getName(), op
->getRawDictionaryAttrs(),
684 op
->getResultTypes(), op
->hashProperties());
686 // - Location if required
687 if (!(flags
& Flags::IgnoreLocations
))
688 hash
= llvm::hash_combine(hash
, op
->getLoc());
691 if (op
->hasTrait
<mlir::OpTrait::IsCommutative
>() &&
692 op
->getNumOperands() > 0) {
693 size_t operandHash
= hashOperands(op
->getOperand(0));
694 for (auto operand
: op
->getOperands().drop_front())
695 operandHash
+= hashOperands(operand
);
696 hash
= llvm::hash_combine(hash
, operandHash
);
698 for (Value operand
: op
->getOperands())
699 hash
= llvm::hash_combine(hash
, hashOperands(operand
));
703 for (Value result
: op
->getResults())
704 hash
= llvm::hash_combine(hash
, hashResults(result
));
708 /*static*/ bool OperationEquivalence::isRegionEquivalentTo(
709 Region
*lhs
, Region
*rhs
,
710 function_ref
<LogicalResult(Value
, Value
)> checkEquivalent
,
711 function_ref
<void(Value
, Value
)> markEquivalent
,
712 OperationEquivalence::Flags flags
,
713 function_ref
<LogicalResult(ValueRange
, ValueRange
)>
714 checkCommutativeEquivalent
) {
715 DenseMap
<Block
*, Block
*> blocksMap
;
716 auto blocksEquivalent
= [&](Block
&lBlock
, Block
&rBlock
) {
717 // Check block arguments.
718 if (lBlock
.getNumArguments() != rBlock
.getNumArguments())
721 // Map the two blocks.
722 auto insertion
= blocksMap
.insert({&lBlock
, &rBlock
});
723 if (insertion
.first
->getSecond() != &rBlock
)
727 llvm::zip(lBlock
.getArguments(), rBlock
.getArguments())) {
728 Value curArg
= std::get
<0>(argPair
);
729 Value otherArg
= std::get
<1>(argPair
);
730 if (curArg
.getType() != otherArg
.getType())
732 if (!(flags
& OperationEquivalence::IgnoreLocations
) &&
733 curArg
.getLoc() != otherArg
.getLoc())
735 // Corresponding bbArgs are equivalent.
737 markEquivalent(curArg
, otherArg
);
740 auto opsEquivalent
= [&](Operation
&lOp
, Operation
&rOp
) {
741 // Check for op equality (recursively).
742 if (!OperationEquivalence::isEquivalentTo(&lOp
, &rOp
, checkEquivalent
,
743 markEquivalent
, flags
,
744 checkCommutativeEquivalent
))
746 // Check successor mapping.
747 for (auto successorsPair
:
748 llvm::zip(lOp
.getSuccessors(), rOp
.getSuccessors())) {
749 Block
*curSuccessor
= std::get
<0>(successorsPair
);
750 Block
*otherSuccessor
= std::get
<1>(successorsPair
);
751 auto insertion
= blocksMap
.insert({curSuccessor
, otherSuccessor
});
752 if (insertion
.first
->getSecond() != otherSuccessor
)
757 return llvm::all_of_zip(lBlock
, rBlock
, opsEquivalent
);
759 return llvm::all_of_zip(*lhs
, *rhs
, blocksEquivalent
);
762 // Value equivalence cache to be used with `isRegionEquivalentTo` and
764 struct ValueEquivalenceCache
{
765 DenseMap
<Value
, Value
> equivalentValues
;
766 LogicalResult
checkEquivalent(Value lhsValue
, Value rhsValue
) {
767 return success(lhsValue
== rhsValue
||
768 equivalentValues
.lookup(lhsValue
) == rhsValue
);
770 LogicalResult
checkCommutativeEquivalent(ValueRange lhsRange
,
771 ValueRange rhsRange
) {
772 // Handle simple case where sizes mismatch.
773 if (lhsRange
.size() != rhsRange
.size())
776 // Handle where operands in order are equivalent.
777 auto lhsIt
= lhsRange
.begin();
778 auto rhsIt
= rhsRange
.begin();
779 for (; lhsIt
!= lhsRange
.end(); ++lhsIt
, ++rhsIt
) {
780 if (failed(checkEquivalent(*lhsIt
, *rhsIt
)))
783 if (lhsIt
== lhsRange
.end())
786 // Handle another simple case where operands are just a permutation.
787 // Note: This is not sufficient, this handles simple cases relatively
789 auto sortValues
= [](ValueRange values
) {
790 SmallVector
<Value
> sortedValues
= llvm::to_vector(values
);
791 llvm::sort(sortedValues
, [](Value a
, Value b
) {
792 return a
.getAsOpaquePointer() < b
.getAsOpaquePointer();
796 auto lhsSorted
= sortValues({lhsIt
, lhsRange
.end()});
797 auto rhsSorted
= sortValues({rhsIt
, rhsRange
.end()});
798 return success(lhsSorted
== rhsSorted
);
800 void markEquivalent(Value lhsResult
, Value rhsResult
) {
801 auto insertion
= equivalentValues
.insert({lhsResult
, rhsResult
});
802 // Make sure that the value was not already marked equivalent to some other
805 assert(insertion
.first
->second
== rhsResult
&&
806 "inconsistent OperationEquivalence state");
811 OperationEquivalence::isRegionEquivalentTo(Region
*lhs
, Region
*rhs
,
812 OperationEquivalence::Flags flags
) {
813 ValueEquivalenceCache cache
;
814 return isRegionEquivalentTo(
816 [&](Value lhsValue
, Value rhsValue
) -> LogicalResult
{
817 return cache
.checkEquivalent(lhsValue
, rhsValue
);
819 [&](Value lhsResult
, Value rhsResult
) {
820 cache
.markEquivalent(lhsResult
, rhsResult
);
823 [&](ValueRange lhs
, ValueRange rhs
) -> LogicalResult
{
824 return cache
.checkCommutativeEquivalent(lhs
, rhs
);
828 /*static*/ bool OperationEquivalence::isEquivalentTo(
829 Operation
*lhs
, Operation
*rhs
,
830 function_ref
<LogicalResult(Value
, Value
)> checkEquivalent
,
831 function_ref
<void(Value
, Value
)> markEquivalent
, Flags flags
,
832 function_ref
<LogicalResult(ValueRange
, ValueRange
)>
833 checkCommutativeEquivalent
) {
837 // 1. Compare the operation properties.
838 if (lhs
->getName() != rhs
->getName() ||
839 lhs
->getRawDictionaryAttrs() != rhs
->getRawDictionaryAttrs() ||
840 lhs
->getNumRegions() != rhs
->getNumRegions() ||
841 lhs
->getNumSuccessors() != rhs
->getNumSuccessors() ||
842 lhs
->getNumOperands() != rhs
->getNumOperands() ||
843 lhs
->getNumResults() != rhs
->getNumResults() ||
844 !lhs
->getName().compareOpProperties(lhs
->getPropertiesStorage(),
845 rhs
->getPropertiesStorage()))
847 if (!(flags
& IgnoreLocations
) && lhs
->getLoc() != rhs
->getLoc())
850 // 2. Compare operands.
851 if (checkCommutativeEquivalent
&&
852 lhs
->hasTrait
<mlir::OpTrait::IsCommutative
>()) {
853 auto lhsRange
= lhs
->getOperands();
854 auto rhsRange
= rhs
->getOperands();
855 if (failed(checkCommutativeEquivalent(lhsRange
, rhsRange
)))
858 // Check pair wise for equivalence.
859 for (auto operandPair
: llvm::zip(lhs
->getOperands(), rhs
->getOperands())) {
860 Value curArg
= std::get
<0>(operandPair
);
861 Value otherArg
= std::get
<1>(operandPair
);
862 if (curArg
== otherArg
)
864 if (curArg
.getType() != otherArg
.getType())
866 if (failed(checkEquivalent(curArg
, otherArg
)))
871 // 3. Compare result types and mark results as equivalent.
872 for (auto resultPair
: llvm::zip(lhs
->getResults(), rhs
->getResults())) {
873 Value curArg
= std::get
<0>(resultPair
);
874 Value otherArg
= std::get
<1>(resultPair
);
875 if (curArg
.getType() != otherArg
.getType())
878 markEquivalent(curArg
, otherArg
);
881 // 4. Compare regions.
882 for (auto regionPair
: llvm::zip(lhs
->getRegions(), rhs
->getRegions()))
883 if (!isRegionEquivalentTo(&std::get
<0>(regionPair
),
884 &std::get
<1>(regionPair
), checkEquivalent
,
885 markEquivalent
, flags
))
891 /*static*/ bool OperationEquivalence::isEquivalentTo(Operation
*lhs
,
894 ValueEquivalenceCache cache
;
895 return OperationEquivalence::isEquivalentTo(
897 [&](Value lhsValue
, Value rhsValue
) -> LogicalResult
{
898 return cache
.checkEquivalent(lhsValue
, rhsValue
);
900 [&](Value lhsResult
, Value rhsResult
) {
901 cache
.markEquivalent(lhsResult
, rhsResult
);
904 [&](ValueRange lhs
, ValueRange rhs
) -> LogicalResult
{
905 return cache
.checkCommutativeEquivalent(lhs
, rhs
);
909 //===----------------------------------------------------------------------===//
910 // OperationFingerPrint
911 //===----------------------------------------------------------------------===//
913 template <typename T
>
914 static void addDataToHash(llvm::SHA1
&hasher
, const T
&data
) {
916 ArrayRef
<uint8_t>(reinterpret_cast<const uint8_t *>(&data
), sizeof(T
)));
919 OperationFingerPrint::OperationFingerPrint(Operation
*topOp
,
920 bool includeNested
) {
923 // Helper function that hashes an operation based on its mutable bits:
924 auto addOperationToHash
= [&](Operation
*op
) {
925 // - Operation pointer
926 addDataToHash(hasher
, op
);
927 // - Parent operation pointer (to take into account the nesting structure)
929 addDataToHash(hasher
, op
->getParentOp());
931 addDataToHash(hasher
, op
->getRawDictionaryAttrs());
933 addDataToHash(hasher
, op
->hashProperties());
934 // - Blocks in Regions
935 for (Region
®ion
: op
->getRegions()) {
936 for (Block
&block
: region
) {
937 addDataToHash(hasher
, &block
);
938 for (BlockArgument arg
: block
.getArguments())
939 addDataToHash(hasher
, arg
);
943 addDataToHash(hasher
, op
->getLoc().getAsOpaquePointer());
945 for (Value operand
: op
->getOperands())
946 addDataToHash(hasher
, operand
);
948 for (unsigned i
= 0, e
= op
->getNumSuccessors(); i
!= e
; ++i
)
949 addDataToHash(hasher
, op
->getSuccessor(i
));
951 for (Type t
: op
->getResultTypes())
952 addDataToHash(hasher
, t
);
956 topOp
->walk(addOperationToHash
);
958 addOperationToHash(topOp
);
960 hash
= hasher
.result();