1 //===- OperationSupport.cpp -----------------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file contains out-of-line implementations of the support types that
10 // Operation and related classes build on top of.
12 //===----------------------------------------------------------------------===//
14 #include "mlir/IR/OperationSupport.h"
15 #include "mlir/IR/BuiltinAttributes.h"
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/OpDefinition.h"
18 #include "llvm/ADT/BitVector.h"
19 #include "llvm/Support/SHA1.h"
25 //===----------------------------------------------------------------------===//
27 //===----------------------------------------------------------------------===//
29 NamedAttrList::NamedAttrList(ArrayRef
<NamedAttribute
> attributes
) {
30 assign(attributes
.begin(), attributes
.end());
33 NamedAttrList::NamedAttrList(DictionaryAttr attributes
)
34 : NamedAttrList(attributes
? attributes
.getValue()
35 : ArrayRef
<NamedAttribute
>()) {
36 dictionarySorted
.setPointerAndInt(attributes
, true);
39 NamedAttrList::NamedAttrList(const_iterator inStart
, const_iterator inEnd
) {
40 assign(inStart
, inEnd
);
43 ArrayRef
<NamedAttribute
> NamedAttrList::getAttrs() const { return attrs
; }
45 std::optional
<NamedAttribute
> NamedAttrList::findDuplicate() const {
46 std::optional
<NamedAttribute
> duplicate
=
47 DictionaryAttr::findDuplicate(attrs
, isSorted());
48 // DictionaryAttr::findDuplicate will sort the list, so reset the sorted
51 dictionarySorted
.setPointerAndInt(nullptr, true);
55 DictionaryAttr
NamedAttrList::getDictionary(MLIRContext
*context
) const {
57 DictionaryAttr::sortInPlace(attrs
);
58 dictionarySorted
.setPointerAndInt(nullptr, true);
60 if (!dictionarySorted
.getPointer())
61 dictionarySorted
.setPointer(DictionaryAttr::getWithSorted(context
, attrs
));
62 return llvm::cast
<DictionaryAttr
>(dictionarySorted
.getPointer());
65 /// Add an attribute with the specified name.
66 void NamedAttrList::append(StringRef name
, Attribute attr
) {
67 append(StringAttr::get(attr
.getContext(), name
), attr
);
70 /// Replaces the attributes with new list of attributes.
71 void NamedAttrList::assign(const_iterator inStart
, const_iterator inEnd
) {
72 DictionaryAttr::sort(ArrayRef
<NamedAttribute
>{inStart
, inEnd
}, attrs
);
73 dictionarySorted
.setPointerAndInt(nullptr, true);
76 void NamedAttrList::push_back(NamedAttribute newAttribute
) {
78 dictionarySorted
.setInt(attrs
.empty() || attrs
.back() < newAttribute
);
79 dictionarySorted
.setPointer(nullptr);
80 attrs
.push_back(newAttribute
);
83 /// Return the specified attribute if present, null otherwise.
84 Attribute
NamedAttrList::get(StringRef name
) const {
85 auto it
= findAttr(*this, name
);
86 return it
.second
? it
.first
->getValue() : Attribute();
88 Attribute
NamedAttrList::get(StringAttr name
) const {
89 auto it
= findAttr(*this, name
);
90 return it
.second
? it
.first
->getValue() : Attribute();
93 /// Return the specified named attribute if present, std::nullopt otherwise.
94 std::optional
<NamedAttribute
> NamedAttrList::getNamed(StringRef name
) const {
95 auto it
= findAttr(*this, name
);
96 return it
.second
? *it
.first
: std::optional
<NamedAttribute
>();
98 std::optional
<NamedAttribute
> NamedAttrList::getNamed(StringAttr name
) const {
99 auto it
= findAttr(*this, name
);
100 return it
.second
? *it
.first
: std::optional
<NamedAttribute
>();
103 /// If the an attribute exists with the specified name, change it to the new
104 /// value. Otherwise, add a new attribute with the specified name/value.
105 Attribute
NamedAttrList::set(StringAttr name
, Attribute value
) {
106 assert(value
&& "attributes may never be null");
108 // Look for an existing attribute with the given name, and set its value
109 // in-place. Return the previous value of the attribute, if there was one.
110 auto it
= findAttr(*this, name
);
112 // Update the existing attribute by swapping out the old value for the new
113 // value. Return the old value.
114 Attribute oldValue
= it
.first
->getValue();
115 if (it
.first
->getValue() != value
) {
116 it
.first
->setValue(value
);
118 // If the attributes have changed, the dictionary is invalidated.
119 dictionarySorted
.setPointer(nullptr);
123 // Perform a string lookup to insert the new attribute into its sorted
126 it
= findAttr(*this, name
.strref());
127 attrs
.insert(it
.first
, {name
, value
});
128 // Invalidate the dictionary. Return null as there was no previous value.
129 dictionarySorted
.setPointer(nullptr);
133 Attribute
NamedAttrList::set(StringRef name
, Attribute value
) {
134 assert(value
&& "attributes may never be null");
135 return set(mlir::StringAttr::get(value
.getContext(), name
), value
);
139 NamedAttrList::eraseImpl(SmallVectorImpl
<NamedAttribute
>::iterator it
) {
140 // Erasing does not affect the sorted property.
141 Attribute attr
= it
->getValue();
143 dictionarySorted
.setPointer(nullptr);
147 Attribute
NamedAttrList::erase(StringAttr name
) {
148 auto it
= findAttr(*this, name
);
149 return it
.second
? eraseImpl(it
.first
) : Attribute();
152 Attribute
NamedAttrList::erase(StringRef name
) {
153 auto it
= findAttr(*this, name
);
154 return it
.second
? eraseImpl(it
.first
) : Attribute();
158 NamedAttrList::operator=(const SmallVectorImpl
<NamedAttribute
> &rhs
) {
159 assign(rhs
.begin(), rhs
.end());
163 NamedAttrList::operator ArrayRef
<NamedAttribute
>() const { return attrs
; }
165 //===----------------------------------------------------------------------===//
167 //===----------------------------------------------------------------------===//
169 OperationState::OperationState(Location location
, StringRef name
)
170 : location(location
), name(name
, location
->getContext()) {}
172 OperationState::OperationState(Location location
, OperationName name
)
173 : location(location
), name(name
) {}
175 OperationState::OperationState(Location location
, OperationName name
,
176 ValueRange operands
, TypeRange types
,
177 ArrayRef
<NamedAttribute
> attributes
,
178 BlockRange successors
,
179 MutableArrayRef
<std::unique_ptr
<Region
>> regions
)
180 : location(location
), name(name
),
181 operands(operands
.begin(), operands
.end()),
182 types(types
.begin(), types
.end()),
183 attributes(attributes
.begin(), attributes
.end()),
184 successors(successors
.begin(), successors
.end()) {
185 for (std::unique_ptr
<Region
> &r
: regions
)
186 this->regions
.push_back(std::move(r
));
188 OperationState::OperationState(Location location
, StringRef name
,
189 ValueRange operands
, TypeRange types
,
190 ArrayRef
<NamedAttribute
> attributes
,
191 BlockRange successors
,
192 MutableArrayRef
<std::unique_ptr
<Region
>> regions
)
193 : OperationState(location
, OperationName(name
, location
.getContext()),
194 operands
, types
, attributes
, successors
, regions
) {}
196 OperationState::~OperationState() {
198 propertiesDeleter(properties
);
201 LogicalResult
OperationState::setProperties(
202 Operation
*op
, function_ref
<InFlightDiagnostic()> emitError
) const {
203 if (LLVM_UNLIKELY(propertiesAttr
)) {
205 return op
->setPropertiesFromAttribute(propertiesAttr
, emitError
);
208 propertiesSetter(op
->getPropertiesStorage(), properties
);
212 void OperationState::addOperands(ValueRange newOperands
) {
213 operands
.append(newOperands
.begin(), newOperands
.end());
216 void OperationState::addSuccessors(BlockRange newSuccessors
) {
217 successors
.append(newSuccessors
.begin(), newSuccessors
.end());
220 Region
*OperationState::addRegion() {
221 regions
.emplace_back(new Region
);
222 return regions
.back().get();
225 void OperationState::addRegion(std::unique_ptr
<Region
> &®ion
) {
226 regions
.push_back(std::move(region
));
229 void OperationState::addRegions(
230 MutableArrayRef
<std::unique_ptr
<Region
>> regions
) {
231 for (std::unique_ptr
<Region
> ®ion
: regions
)
232 addRegion(std::move(region
));
235 //===----------------------------------------------------------------------===//
237 //===----------------------------------------------------------------------===//
239 detail::OperandStorage::OperandStorage(Operation
*owner
,
240 OpOperand
*trailingOperands
,
242 : isStorageDynamic(false), operandStorage(trailingOperands
) {
243 numOperands
= capacity
= values
.size();
244 for (unsigned i
= 0; i
< numOperands
; ++i
)
245 new (&operandStorage
[i
]) OpOperand(owner
, values
[i
]);
248 detail::OperandStorage::~OperandStorage() {
249 for (auto &operand
: getOperands())
250 operand
.~OpOperand();
252 // If the storage is dynamic, deallocate it.
253 if (isStorageDynamic
)
254 free(operandStorage
);
257 /// Replace the operands contained in the storage with the ones provided in
259 void detail::OperandStorage::setOperands(Operation
*owner
, ValueRange values
) {
260 MutableArrayRef
<OpOperand
> storageOperands
= resize(owner
, values
.size());
261 for (unsigned i
= 0, e
= values
.size(); i
!= e
; ++i
)
262 storageOperands
[i
].set(values
[i
]);
265 /// Replace the operands beginning at 'start' and ending at 'start' + 'length'
266 /// with the ones provided in 'operands'. 'operands' may be smaller or larger
267 /// than the range pointed to by 'start'+'length'.
268 void detail::OperandStorage::setOperands(Operation
*owner
, unsigned start
,
269 unsigned length
, ValueRange operands
) {
270 // If the new size is the same, we can update inplace.
271 unsigned newSize
= operands
.size();
272 if (newSize
== length
) {
273 MutableArrayRef
<OpOperand
> storageOperands
= getOperands();
274 for (unsigned i
= 0, e
= length
; i
!= e
; ++i
)
275 storageOperands
[start
+ i
].set(operands
[i
]);
278 // If the new size is greater, remove the extra operands and set the rest
280 if (newSize
< length
) {
281 eraseOperands(start
+ operands
.size(), length
- newSize
);
282 setOperands(owner
, start
, newSize
, operands
);
285 // Otherwise, the new size is greater so we need to grow the storage.
286 auto storageOperands
= resize(owner
, size() + (newSize
- length
));
288 // Shift operands to the right to make space for the new operands.
289 unsigned rotateSize
= storageOperands
.size() - (start
+ length
);
290 auto rbegin
= storageOperands
.rbegin();
291 std::rotate(rbegin
, std::next(rbegin
, newSize
- length
), rbegin
+ rotateSize
);
293 // Update the operands inplace.
294 for (unsigned i
= 0, e
= operands
.size(); i
!= e
; ++i
)
295 storageOperands
[start
+ i
].set(operands
[i
]);
298 /// Erase an operand held by the storage.
299 void detail::OperandStorage::eraseOperands(unsigned start
, unsigned length
) {
300 MutableArrayRef
<OpOperand
> operands
= getOperands();
301 assert((start
+ length
) <= operands
.size());
302 numOperands
-= length
;
304 // Shift all operands down if the operand to remove is not at the end.
305 if (start
!= numOperands
) {
306 auto *indexIt
= std::next(operands
.begin(), start
);
307 std::rotate(indexIt
, std::next(indexIt
, length
), operands
.end());
309 for (unsigned i
= 0; i
!= length
; ++i
)
310 operands
[numOperands
+ i
].~OpOperand();
313 void detail::OperandStorage::eraseOperands(const BitVector
&eraseIndices
) {
314 MutableArrayRef
<OpOperand
> operands
= getOperands();
315 assert(eraseIndices
.size() == operands
.size());
317 // Check that at least one operand is erased.
318 int firstErasedIndice
= eraseIndices
.find_first();
319 if (firstErasedIndice
== -1)
322 // Shift all of the removed operands to the end, and destroy them.
323 numOperands
= firstErasedIndice
;
324 for (unsigned i
= firstErasedIndice
+ 1, e
= operands
.size(); i
< e
; ++i
)
325 if (!eraseIndices
.test(i
))
326 operands
[numOperands
++] = std::move(operands
[i
]);
327 for (OpOperand
&operand
: operands
.drop_front(numOperands
))
328 operand
.~OpOperand();
331 /// Resize the storage to the given size. Returns the array containing the new
333 MutableArrayRef
<OpOperand
> detail::OperandStorage::resize(Operation
*owner
,
335 // If the number of operands is less than or equal to the current amount, we
336 // can just update in place.
337 MutableArrayRef
<OpOperand
> origOperands
= getOperands();
338 if (newSize
<= numOperands
) {
339 // If the number of new size is less than the current, remove any extra
341 for (unsigned i
= newSize
; i
!= numOperands
; ++i
)
342 origOperands
[i
].~OpOperand();
343 numOperands
= newSize
;
344 return origOperands
.take_front(newSize
);
347 // If the new size is within the original inline capacity, grow inplace.
348 if (newSize
<= capacity
) {
349 OpOperand
*opBegin
= origOperands
.data();
350 for (unsigned e
= newSize
; numOperands
!= e
; ++numOperands
)
351 new (&opBegin
[numOperands
]) OpOperand(owner
);
352 return MutableArrayRef
<OpOperand
>(opBegin
, newSize
);
355 // Otherwise, we need to allocate a new storage.
356 unsigned newCapacity
=
357 std::max(unsigned(llvm::NextPowerOf2(capacity
+ 2)), newSize
);
358 OpOperand
*newOperandStorage
=
359 reinterpret_cast<OpOperand
*>(malloc(sizeof(OpOperand
) * newCapacity
));
361 // Move the current operands to the new storage.
362 MutableArrayRef
<OpOperand
> newOperands(newOperandStorage
, newSize
);
363 std::uninitialized_move(origOperands
.begin(), origOperands
.end(),
364 newOperands
.begin());
366 // Destroy the original operands.
367 for (auto &operand
: origOperands
)
368 operand
.~OpOperand();
370 // Initialize any new operands.
371 for (unsigned e
= newSize
; numOperands
!= e
; ++numOperands
)
372 new (&newOperands
[numOperands
]) OpOperand(owner
);
374 // If the current storage is dynamic, free it.
375 if (isStorageDynamic
)
376 free(operandStorage
);
378 // Update the storage representation to use the new dynamic storage.
379 operandStorage
= newOperandStorage
;
380 capacity
= newCapacity
;
381 isStorageDynamic
= true;
385 //===----------------------------------------------------------------------===//
386 // Operation Value-Iterators
387 //===----------------------------------------------------------------------===//
389 //===----------------------------------------------------------------------===//
392 unsigned OperandRange::getBeginOperandIndex() const {
393 assert(!empty() && "range must not be empty");
394 return base
->getOperandNumber();
397 OperandRangeRange
OperandRange::split(DenseI32ArrayAttr segmentSizes
) const {
398 return OperandRangeRange(*this, segmentSizes
);
401 //===----------------------------------------------------------------------===//
404 OperandRangeRange::OperandRangeRange(OperandRange operands
,
405 Attribute operandSegments
)
406 : OperandRangeRange(OwnerT(operands
.getBase(), operandSegments
), 0,
407 llvm::cast
<DenseI32ArrayAttr
>(operandSegments
).size()) {
410 OperandRange
OperandRangeRange::join() const {
411 const OwnerT
&owner
= getBase();
412 ArrayRef
<int32_t> sizeData
= llvm::cast
<DenseI32ArrayAttr
>(owner
.second
);
413 return OperandRange(owner
.first
,
414 std::accumulate(sizeData
.begin(), sizeData
.end(), 0));
417 OperandRange
OperandRangeRange::dereference(const OwnerT
&object
,
419 ArrayRef
<int32_t> sizeData
= llvm::cast
<DenseI32ArrayAttr
>(object
.second
);
420 uint32_t startIndex
=
421 std::accumulate(sizeData
.begin(), sizeData
.begin() + index
, 0);
422 return OperandRange(object
.first
+ startIndex
, *(sizeData
.begin() + index
));
425 //===----------------------------------------------------------------------===//
426 // MutableOperandRange
428 /// Construct a new mutable range from the given operand, operand start index,
429 /// and range length.
430 MutableOperandRange::MutableOperandRange(
431 Operation
*owner
, unsigned start
, unsigned length
,
432 ArrayRef
<OperandSegment
> operandSegments
)
433 : owner(owner
), start(start
), length(length
),
434 operandSegments(operandSegments
.begin(), operandSegments
.end()) {
435 assert((start
+ length
) <= owner
->getNumOperands() && "invalid range");
437 MutableOperandRange::MutableOperandRange(Operation
*owner
)
438 : MutableOperandRange(owner
, /*start=*/0, owner
->getNumOperands()) {}
440 /// Construct a new mutable range for the given OpOperand.
441 MutableOperandRange::MutableOperandRange(OpOperand
&opOperand
)
442 : MutableOperandRange(opOperand
.getOwner(),
443 /*start=*/opOperand
.getOperandNumber(),
446 /// Slice this range into a sub range, with the additional operand segment.
448 MutableOperandRange::slice(unsigned subStart
, unsigned subLen
,
449 std::optional
<OperandSegment
> segment
) const {
450 assert((subStart
+ subLen
) <= length
&& "invalid sub-range");
451 MutableOperandRange
subSlice(owner
, start
+ subStart
, subLen
,
454 subSlice
.operandSegments
.push_back(*segment
);
458 /// Append the given values to the range.
459 void MutableOperandRange::append(ValueRange values
) {
462 owner
->insertOperands(start
+ length
, values
);
463 updateLength(length
+ values
.size());
466 /// Assign this range to the given values.
467 void MutableOperandRange::assign(ValueRange values
) {
468 owner
->setOperands(start
, length
, values
);
469 if (length
!= values
.size())
470 updateLength(/*newLength=*/values
.size());
473 /// Assign the range to the given value.
474 void MutableOperandRange::assign(Value value
) {
476 owner
->setOperand(start
, value
);
478 owner
->setOperands(start
, length
, value
);
479 updateLength(/*newLength=*/1);
483 /// Erase the operands within the given sub-range.
484 void MutableOperandRange::erase(unsigned subStart
, unsigned subLen
) {
485 assert((subStart
+ subLen
) <= length
&& "invalid sub-range");
488 owner
->eraseOperands(start
+ subStart
, subLen
);
489 updateLength(length
- subLen
);
492 /// Clear this range and erase all of the operands.
493 void MutableOperandRange::clear() {
495 owner
->eraseOperands(start
, length
);
496 updateLength(/*newLength=*/0);
500 /// Allow implicit conversion to an OperandRange.
501 MutableOperandRange::operator OperandRange() const {
502 return owner
->getOperands().slice(start
, length
);
505 MutableOperandRange::operator MutableArrayRef
<OpOperand
>() const {
506 return owner
->getOpOperands().slice(start
, length
);
509 MutableOperandRangeRange
510 MutableOperandRange::split(NamedAttribute segmentSizes
) const {
511 return MutableOperandRangeRange(*this, segmentSizes
);
514 /// Update the length of this range to the one provided.
515 void MutableOperandRange::updateLength(unsigned newLength
) {
516 int32_t diff
= int32_t(newLength
) - int32_t(length
);
519 // Update any of the provided segment attributes.
520 for (OperandSegment
&segment
: operandSegments
) {
521 auto attr
= llvm::cast
<DenseI32ArrayAttr
>(segment
.second
.getValue());
522 SmallVector
<int32_t, 8> segments(attr
.asArrayRef());
523 segments
[segment
.first
] += diff
;
524 segment
.second
.setValue(
525 DenseI32ArrayAttr::get(attr
.getContext(), segments
));
526 owner
->setAttr(segment
.second
.getName(), segment
.second
.getValue());
530 OpOperand
&MutableOperandRange::operator[](unsigned index
) const {
531 assert(index
< length
&& "index is out of bounds");
532 return owner
->getOpOperand(start
+ index
);
535 MutableArrayRef
<OpOperand
>::iterator
MutableOperandRange::begin() const {
536 return owner
->getOpOperands().slice(start
, length
).begin();
539 MutableArrayRef
<OpOperand
>::iterator
MutableOperandRange::end() const {
540 return owner
->getOpOperands().slice(start
, length
).end();
543 //===----------------------------------------------------------------------===//
544 // MutableOperandRangeRange
546 MutableOperandRangeRange::MutableOperandRangeRange(
547 const MutableOperandRange
&operands
, NamedAttribute operandSegmentAttr
)
548 : MutableOperandRangeRange(
549 OwnerT(operands
, operandSegmentAttr
), 0,
550 llvm::cast
<DenseI32ArrayAttr
>(operandSegmentAttr
.getValue()).size()) {
553 MutableOperandRange
MutableOperandRangeRange::join() const {
554 return getBase().first
;
557 MutableOperandRangeRange::operator OperandRangeRange() const {
558 return OperandRangeRange(getBase().first
, getBase().second
.getValue());
561 MutableOperandRange
MutableOperandRangeRange::dereference(const OwnerT
&object
,
563 ArrayRef
<int32_t> sizeData
=
564 llvm::cast
<DenseI32ArrayAttr
>(object
.second
.getValue());
565 uint32_t startIndex
=
566 std::accumulate(sizeData
.begin(), sizeData
.begin() + index
, 0);
567 return object
.first
.slice(
568 startIndex
, *(sizeData
.begin() + index
),
569 MutableOperandRange::OperandSegment(index
, object
.second
));
572 //===----------------------------------------------------------------------===//
575 ResultRange::ResultRange(OpResult result
)
576 : ResultRange(static_cast<detail::OpResultImpl
*>(Value(result
).getImpl()),
579 ResultRange::use_range
ResultRange::getUses() const {
580 return {use_begin(), use_end()};
582 ResultRange::use_iterator
ResultRange::use_begin() const {
583 return use_iterator(*this);
585 ResultRange::use_iterator
ResultRange::use_end() const {
586 return use_iterator(*this, /*end=*/true);
588 ResultRange::user_range
ResultRange::getUsers() {
589 return {user_begin(), user_end()};
591 ResultRange::user_iterator
ResultRange::user_begin() {
592 return user_iterator(use_begin());
594 ResultRange::user_iterator
ResultRange::user_end() {
595 return user_iterator(use_end());
598 ResultRange::UseIterator::UseIterator(ResultRange results
, bool end
)
599 : it(end
? results
.end() : results
.begin()), endIt(results
.end()) {
600 // Only initialize current use if there are results/can be uses.
602 skipOverResultsWithNoUsers();
605 ResultRange::UseIterator
&ResultRange::UseIterator::operator++() {
606 // We increment over uses, if we reach the last use then move to next
608 if (use
!= (*it
).use_end())
610 if (use
== (*it
).use_end()) {
612 skipOverResultsWithNoUsers();
617 void ResultRange::UseIterator::skipOverResultsWithNoUsers() {
618 while (it
!= endIt
&& (*it
).use_empty())
621 // If we are at the last result, then set use to first use of
622 // first result (sentinel value used for end).
626 use
= (*it
).use_begin();
629 void ResultRange::replaceAllUsesWith(Operation
*op
) {
630 replaceAllUsesWith(op
->getResults());
633 void ResultRange::replaceUsesWithIf(
634 Operation
*op
, function_ref
<bool(OpOperand
&)> shouldReplace
) {
635 replaceUsesWithIf(op
->getResults(), shouldReplace
);
638 //===----------------------------------------------------------------------===//
641 ValueRange::ValueRange(ArrayRef
<Value
> values
)
642 : ValueRange(values
.data(), values
.size()) {}
643 ValueRange::ValueRange(OperandRange values
)
644 : ValueRange(values
.begin().getBase(), values
.size()) {}
645 ValueRange::ValueRange(ResultRange values
)
646 : ValueRange(values
.getBase(), values
.size()) {}
648 /// See `llvm::detail::indexed_accessor_range_base` for details.
649 ValueRange::OwnerT
ValueRange::offset_base(const OwnerT
&owner
,
651 if (const auto *value
= llvm::dyn_cast_if_present
<const Value
*>(owner
))
652 return {value
+ index
};
653 if (auto *operand
= llvm::dyn_cast_if_present
<OpOperand
*>(owner
))
654 return {operand
+ index
};
655 return owner
.get
<detail::OpResultImpl
*>()->getNextResultAtOffset(index
);
657 /// See `llvm::detail::indexed_accessor_range_base` for details.
658 Value
ValueRange::dereference_iterator(const OwnerT
&owner
, ptrdiff_t index
) {
659 if (const auto *value
= llvm::dyn_cast_if_present
<const Value
*>(owner
))
661 if (auto *operand
= llvm::dyn_cast_if_present
<OpOperand
*>(owner
))
662 return operand
[index
].get();
663 return owner
.get
<detail::OpResultImpl
*>()->getNextResultAtOffset(index
);
666 //===----------------------------------------------------------------------===//
667 // Operation Equivalency
668 //===----------------------------------------------------------------------===//
670 llvm::hash_code
OperationEquivalence::computeHash(
671 Operation
*op
, function_ref
<llvm::hash_code(Value
)> hashOperands
,
672 function_ref
<llvm::hash_code(Value
)> hashResults
, Flags flags
) {
673 // Hash operations based upon their:
677 llvm::hash_code hash
=
678 llvm::hash_combine(op
->getName(), op
->getDiscardableAttrDictionary(),
679 op
->getResultTypes(), op
->hashProperties());
681 // - Location if required
682 if (!(flags
& Flags::IgnoreLocations
))
683 hash
= llvm::hash_combine(hash
, op
->getLoc());
686 for (Value operand
: op
->getOperands())
687 hash
= llvm::hash_combine(hash
, hashOperands(operand
));
690 for (Value result
: op
->getResults())
691 hash
= llvm::hash_combine(hash
, hashResults(result
));
695 /*static*/ bool OperationEquivalence::isRegionEquivalentTo(
696 Region
*lhs
, Region
*rhs
,
697 function_ref
<LogicalResult(Value
, Value
)> checkEquivalent
,
698 function_ref
<void(Value
, Value
)> markEquivalent
,
699 OperationEquivalence::Flags flags
) {
700 DenseMap
<Block
*, Block
*> blocksMap
;
701 auto blocksEquivalent
= [&](Block
&lBlock
, Block
&rBlock
) {
702 // Check block arguments.
703 if (lBlock
.getNumArguments() != rBlock
.getNumArguments())
706 // Map the two blocks.
707 auto insertion
= blocksMap
.insert({&lBlock
, &rBlock
});
708 if (insertion
.first
->getSecond() != &rBlock
)
712 llvm::zip(lBlock
.getArguments(), rBlock
.getArguments())) {
713 Value curArg
= std::get
<0>(argPair
);
714 Value otherArg
= std::get
<1>(argPair
);
715 if (curArg
.getType() != otherArg
.getType())
717 if (!(flags
& OperationEquivalence::IgnoreLocations
) &&
718 curArg
.getLoc() != otherArg
.getLoc())
720 // Corresponding bbArgs are equivalent.
722 markEquivalent(curArg
, otherArg
);
725 auto opsEquivalent
= [&](Operation
&lOp
, Operation
&rOp
) {
726 // Check for op equality (recursively).
727 if (!OperationEquivalence::isEquivalentTo(&lOp
, &rOp
, checkEquivalent
,
728 markEquivalent
, flags
))
730 // Check successor mapping.
731 for (auto successorsPair
:
732 llvm::zip(lOp
.getSuccessors(), rOp
.getSuccessors())) {
733 Block
*curSuccessor
= std::get
<0>(successorsPair
);
734 Block
*otherSuccessor
= std::get
<1>(successorsPair
);
735 auto insertion
= blocksMap
.insert({curSuccessor
, otherSuccessor
});
736 if (insertion
.first
->getSecond() != otherSuccessor
)
741 return llvm::all_of_zip(lBlock
, rBlock
, opsEquivalent
);
743 return llvm::all_of_zip(*lhs
, *rhs
, blocksEquivalent
);
746 // Value equivalence cache to be used with `isRegionEquivalentTo` and
748 struct ValueEquivalenceCache
{
749 DenseMap
<Value
, Value
> equivalentValues
;
750 LogicalResult
checkEquivalent(Value lhsValue
, Value rhsValue
) {
751 return success(lhsValue
== rhsValue
||
752 equivalentValues
.lookup(lhsValue
) == rhsValue
);
754 void markEquivalent(Value lhsResult
, Value rhsResult
) {
755 auto insertion
= equivalentValues
.insert({lhsResult
, rhsResult
});
756 // Make sure that the value was not already marked equivalent to some other
759 assert(insertion
.first
->second
== rhsResult
&&
760 "inconsistent OperationEquivalence state");
765 OperationEquivalence::isRegionEquivalentTo(Region
*lhs
, Region
*rhs
,
766 OperationEquivalence::Flags flags
) {
767 ValueEquivalenceCache cache
;
768 return isRegionEquivalentTo(
770 [&](Value lhsValue
, Value rhsValue
) -> LogicalResult
{
771 return cache
.checkEquivalent(lhsValue
, rhsValue
);
773 [&](Value lhsResult
, Value rhsResult
) {
774 cache
.markEquivalent(lhsResult
, rhsResult
);
779 /*static*/ bool OperationEquivalence::isEquivalentTo(
780 Operation
*lhs
, Operation
*rhs
,
781 function_ref
<LogicalResult(Value
, Value
)> checkEquivalent
,
782 function_ref
<void(Value
, Value
)> markEquivalent
, Flags flags
) {
786 // 1. Compare the operation properties.
787 if (lhs
->getName() != rhs
->getName() ||
788 lhs
->getDiscardableAttrDictionary() !=
789 rhs
->getDiscardableAttrDictionary() ||
790 lhs
->getNumRegions() != rhs
->getNumRegions() ||
791 lhs
->getNumSuccessors() != rhs
->getNumSuccessors() ||
792 lhs
->getNumOperands() != rhs
->getNumOperands() ||
793 lhs
->getNumResults() != rhs
->getNumResults() ||
794 !lhs
->getName().compareOpProperties(lhs
->getPropertiesStorage(),
795 rhs
->getPropertiesStorage()))
797 if (!(flags
& IgnoreLocations
) && lhs
->getLoc() != rhs
->getLoc())
800 // 2. Compare operands.
801 for (auto operandPair
: llvm::zip(lhs
->getOperands(), rhs
->getOperands())) {
802 Value curArg
= std::get
<0>(operandPair
);
803 Value otherArg
= std::get
<1>(operandPair
);
804 if (curArg
== otherArg
)
806 if (curArg
.getType() != otherArg
.getType())
808 if (failed(checkEquivalent(curArg
, otherArg
)))
812 // 3. Compare result types and mark results as equivalent.
813 for (auto resultPair
: llvm::zip(lhs
->getResults(), rhs
->getResults())) {
814 Value curArg
= std::get
<0>(resultPair
);
815 Value otherArg
= std::get
<1>(resultPair
);
816 if (curArg
.getType() != otherArg
.getType())
819 markEquivalent(curArg
, otherArg
);
822 // 4. Compare regions.
823 for (auto regionPair
: llvm::zip(lhs
->getRegions(), rhs
->getRegions()))
824 if (!isRegionEquivalentTo(&std::get
<0>(regionPair
),
825 &std::get
<1>(regionPair
), checkEquivalent
,
826 markEquivalent
, flags
))
832 /*static*/ bool OperationEquivalence::isEquivalentTo(Operation
*lhs
,
835 ValueEquivalenceCache cache
;
836 return OperationEquivalence::isEquivalentTo(
838 [&](Value lhsValue
, Value rhsValue
) -> LogicalResult
{
839 return cache
.checkEquivalent(lhsValue
, rhsValue
);
841 [&](Value lhsResult
, Value rhsResult
) {
842 cache
.markEquivalent(lhsResult
, rhsResult
);
847 //===----------------------------------------------------------------------===//
848 // OperationFingerPrint
849 //===----------------------------------------------------------------------===//
851 template <typename T
>
852 static void addDataToHash(llvm::SHA1
&hasher
, const T
&data
) {
854 ArrayRef
<uint8_t>(reinterpret_cast<const uint8_t *>(&data
), sizeof(T
)));
857 OperationFingerPrint::OperationFingerPrint(Operation
*topOp
) {
860 // Hash each of the operations based upon their mutable bits:
861 topOp
->walk([&](Operation
*op
) {
862 // - Operation pointer
863 addDataToHash(hasher
, op
);
864 // - Parent operation pointer (to take into account the nesting structure)
866 addDataToHash(hasher
, op
->getParentOp());
868 addDataToHash(hasher
, op
->getDiscardableAttrDictionary());
870 addDataToHash(hasher
, op
->hashProperties());
871 // - Blocks in Regions
872 for (Region
®ion
: op
->getRegions()) {
873 for (Block
&block
: region
) {
874 addDataToHash(hasher
, &block
);
875 for (BlockArgument arg
: block
.getArguments())
876 addDataToHash(hasher
, arg
);
880 addDataToHash(hasher
, op
->getLoc().getAsOpaquePointer());
882 for (Value operand
: op
->getOperands())
883 addDataToHash(hasher
, operand
);
885 for (unsigned i
= 0, e
= op
->getNumSuccessors(); i
!= e
; ++i
)
886 addDataToHash(hasher
, op
->getSuccessor(i
));
888 for (Type t
: op
->getResultTypes())
889 addDataToHash(hasher
, t
);
891 hash
= hasher
.result();