[clang][modules] Don't prevent translation of FW_Private includes when explicitly...
[llvm-project.git] / mlir / lib / IR / OperationSupport.cpp
blobfc5ccd23b5108d8374b737e833529e5c52727ca0
1 //===- OperationSupport.cpp -----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains out-of-line implementations of the support types that
10 // Operation and related classes build on top of.
12 //===----------------------------------------------------------------------===//
14 #include "mlir/IR/OperationSupport.h"
15 #include "mlir/IR/BuiltinAttributes.h"
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/OpDefinition.h"
18 #include "llvm/ADT/BitVector.h"
19 #include "llvm/Support/SHA1.h"
20 #include <numeric>
21 #include <optional>
23 using namespace mlir;
25 //===----------------------------------------------------------------------===//
26 // NamedAttrList
27 //===----------------------------------------------------------------------===//
29 NamedAttrList::NamedAttrList(ArrayRef<NamedAttribute> attributes) {
30 assign(attributes.begin(), attributes.end());
33 NamedAttrList::NamedAttrList(DictionaryAttr attributes)
34 : NamedAttrList(attributes ? attributes.getValue()
35 : ArrayRef<NamedAttribute>()) {
36 dictionarySorted.setPointerAndInt(attributes, true);
39 NamedAttrList::NamedAttrList(const_iterator inStart, const_iterator inEnd) {
40 assign(inStart, inEnd);
43 ArrayRef<NamedAttribute> NamedAttrList::getAttrs() const { return attrs; }
45 std::optional<NamedAttribute> NamedAttrList::findDuplicate() const {
46 std::optional<NamedAttribute> duplicate =
47 DictionaryAttr::findDuplicate(attrs, isSorted());
48 // DictionaryAttr::findDuplicate will sort the list, so reset the sorted
49 // state.
50 if (!isSorted())
51 dictionarySorted.setPointerAndInt(nullptr, true);
52 return duplicate;
55 DictionaryAttr NamedAttrList::getDictionary(MLIRContext *context) const {
56 if (!isSorted()) {
57 DictionaryAttr::sortInPlace(attrs);
58 dictionarySorted.setPointerAndInt(nullptr, true);
60 if (!dictionarySorted.getPointer())
61 dictionarySorted.setPointer(DictionaryAttr::getWithSorted(context, attrs));
62 return llvm::cast<DictionaryAttr>(dictionarySorted.getPointer());
65 /// Add an attribute with the specified name.
66 void NamedAttrList::append(StringRef name, Attribute attr) {
67 append(StringAttr::get(attr.getContext(), name), attr);
70 /// Replaces the attributes with new list of attributes.
71 void NamedAttrList::assign(const_iterator inStart, const_iterator inEnd) {
72 DictionaryAttr::sort(ArrayRef<NamedAttribute>{inStart, inEnd}, attrs);
73 dictionarySorted.setPointerAndInt(nullptr, true);
76 void NamedAttrList::push_back(NamedAttribute newAttribute) {
77 if (isSorted())
78 dictionarySorted.setInt(attrs.empty() || attrs.back() < newAttribute);
79 dictionarySorted.setPointer(nullptr);
80 attrs.push_back(newAttribute);
83 /// Return the specified attribute if present, null otherwise.
84 Attribute NamedAttrList::get(StringRef name) const {
85 auto it = findAttr(*this, name);
86 return it.second ? it.first->getValue() : Attribute();
88 Attribute NamedAttrList::get(StringAttr name) const {
89 auto it = findAttr(*this, name);
90 return it.second ? it.first->getValue() : Attribute();
93 /// Return the specified named attribute if present, std::nullopt otherwise.
94 std::optional<NamedAttribute> NamedAttrList::getNamed(StringRef name) const {
95 auto it = findAttr(*this, name);
96 return it.second ? *it.first : std::optional<NamedAttribute>();
98 std::optional<NamedAttribute> NamedAttrList::getNamed(StringAttr name) const {
99 auto it = findAttr(*this, name);
100 return it.second ? *it.first : std::optional<NamedAttribute>();
103 /// If the an attribute exists with the specified name, change it to the new
104 /// value. Otherwise, add a new attribute with the specified name/value.
105 Attribute NamedAttrList::set(StringAttr name, Attribute value) {
106 assert(value && "attributes may never be null");
108 // Look for an existing attribute with the given name, and set its value
109 // in-place. Return the previous value of the attribute, if there was one.
110 auto it = findAttr(*this, name);
111 if (it.second) {
112 // Update the existing attribute by swapping out the old value for the new
113 // value. Return the old value.
114 Attribute oldValue = it.first->getValue();
115 if (it.first->getValue() != value) {
116 it.first->setValue(value);
118 // If the attributes have changed, the dictionary is invalidated.
119 dictionarySorted.setPointer(nullptr);
121 return oldValue;
123 // Perform a string lookup to insert the new attribute into its sorted
124 // position.
125 if (isSorted())
126 it = findAttr(*this, name.strref());
127 attrs.insert(it.first, {name, value});
128 // Invalidate the dictionary. Return null as there was no previous value.
129 dictionarySorted.setPointer(nullptr);
130 return Attribute();
133 Attribute NamedAttrList::set(StringRef name, Attribute value) {
134 assert(value && "attributes may never be null");
135 return set(mlir::StringAttr::get(value.getContext(), name), value);
138 Attribute
139 NamedAttrList::eraseImpl(SmallVectorImpl<NamedAttribute>::iterator it) {
140 // Erasing does not affect the sorted property.
141 Attribute attr = it->getValue();
142 attrs.erase(it);
143 dictionarySorted.setPointer(nullptr);
144 return attr;
147 Attribute NamedAttrList::erase(StringAttr name) {
148 auto it = findAttr(*this, name);
149 return it.second ? eraseImpl(it.first) : Attribute();
152 Attribute NamedAttrList::erase(StringRef name) {
153 auto it = findAttr(*this, name);
154 return it.second ? eraseImpl(it.first) : Attribute();
157 NamedAttrList &
158 NamedAttrList::operator=(const SmallVectorImpl<NamedAttribute> &rhs) {
159 assign(rhs.begin(), rhs.end());
160 return *this;
163 NamedAttrList::operator ArrayRef<NamedAttribute>() const { return attrs; }
165 //===----------------------------------------------------------------------===//
166 // OperationState
167 //===----------------------------------------------------------------------===//
169 OperationState::OperationState(Location location, StringRef name)
170 : location(location), name(name, location->getContext()) {}
172 OperationState::OperationState(Location location, OperationName name)
173 : location(location), name(name) {}
175 OperationState::OperationState(Location location, OperationName name,
176 ValueRange operands, TypeRange types,
177 ArrayRef<NamedAttribute> attributes,
178 BlockRange successors,
179 MutableArrayRef<std::unique_ptr<Region>> regions)
180 : location(location), name(name),
181 operands(operands.begin(), operands.end()),
182 types(types.begin(), types.end()),
183 attributes(attributes.begin(), attributes.end()),
184 successors(successors.begin(), successors.end()) {
185 for (std::unique_ptr<Region> &r : regions)
186 this->regions.push_back(std::move(r));
188 OperationState::OperationState(Location location, StringRef name,
189 ValueRange operands, TypeRange types,
190 ArrayRef<NamedAttribute> attributes,
191 BlockRange successors,
192 MutableArrayRef<std::unique_ptr<Region>> regions)
193 : OperationState(location, OperationName(name, location.getContext()),
194 operands, types, attributes, successors, regions) {}
196 OperationState::~OperationState() {
197 if (properties)
198 propertiesDeleter(properties);
201 LogicalResult OperationState::setProperties(
202 Operation *op, function_ref<InFlightDiagnostic()> emitError) const {
203 if (LLVM_UNLIKELY(propertiesAttr)) {
204 assert(!properties);
205 return op->setPropertiesFromAttribute(propertiesAttr, emitError);
207 if (properties)
208 propertiesSetter(op->getPropertiesStorage(), properties);
209 return success();
212 void OperationState::addOperands(ValueRange newOperands) {
213 operands.append(newOperands.begin(), newOperands.end());
216 void OperationState::addSuccessors(BlockRange newSuccessors) {
217 successors.append(newSuccessors.begin(), newSuccessors.end());
220 Region *OperationState::addRegion() {
221 regions.emplace_back(new Region);
222 return regions.back().get();
225 void OperationState::addRegion(std::unique_ptr<Region> &&region) {
226 regions.push_back(std::move(region));
229 void OperationState::addRegions(
230 MutableArrayRef<std::unique_ptr<Region>> regions) {
231 for (std::unique_ptr<Region> &region : regions)
232 addRegion(std::move(region));
235 //===----------------------------------------------------------------------===//
236 // OperandStorage
237 //===----------------------------------------------------------------------===//
239 detail::OperandStorage::OperandStorage(Operation *owner,
240 OpOperand *trailingOperands,
241 ValueRange values)
242 : isStorageDynamic(false), operandStorage(trailingOperands) {
243 numOperands = capacity = values.size();
244 for (unsigned i = 0; i < numOperands; ++i)
245 new (&operandStorage[i]) OpOperand(owner, values[i]);
248 detail::OperandStorage::~OperandStorage() {
249 for (auto &operand : getOperands())
250 operand.~OpOperand();
252 // If the storage is dynamic, deallocate it.
253 if (isStorageDynamic)
254 free(operandStorage);
257 /// Replace the operands contained in the storage with the ones provided in
258 /// 'values'.
259 void detail::OperandStorage::setOperands(Operation *owner, ValueRange values) {
260 MutableArrayRef<OpOperand> storageOperands = resize(owner, values.size());
261 for (unsigned i = 0, e = values.size(); i != e; ++i)
262 storageOperands[i].set(values[i]);
265 /// Replace the operands beginning at 'start' and ending at 'start' + 'length'
266 /// with the ones provided in 'operands'. 'operands' may be smaller or larger
267 /// than the range pointed to by 'start'+'length'.
268 void detail::OperandStorage::setOperands(Operation *owner, unsigned start,
269 unsigned length, ValueRange operands) {
270 // If the new size is the same, we can update inplace.
271 unsigned newSize = operands.size();
272 if (newSize == length) {
273 MutableArrayRef<OpOperand> storageOperands = getOperands();
274 for (unsigned i = 0, e = length; i != e; ++i)
275 storageOperands[start + i].set(operands[i]);
276 return;
278 // If the new size is greater, remove the extra operands and set the rest
279 // inplace.
280 if (newSize < length) {
281 eraseOperands(start + operands.size(), length - newSize);
282 setOperands(owner, start, newSize, operands);
283 return;
285 // Otherwise, the new size is greater so we need to grow the storage.
286 auto storageOperands = resize(owner, size() + (newSize - length));
288 // Shift operands to the right to make space for the new operands.
289 unsigned rotateSize = storageOperands.size() - (start + length);
290 auto rbegin = storageOperands.rbegin();
291 std::rotate(rbegin, std::next(rbegin, newSize - length), rbegin + rotateSize);
293 // Update the operands inplace.
294 for (unsigned i = 0, e = operands.size(); i != e; ++i)
295 storageOperands[start + i].set(operands[i]);
298 /// Erase an operand held by the storage.
299 void detail::OperandStorage::eraseOperands(unsigned start, unsigned length) {
300 MutableArrayRef<OpOperand> operands = getOperands();
301 assert((start + length) <= operands.size());
302 numOperands -= length;
304 // Shift all operands down if the operand to remove is not at the end.
305 if (start != numOperands) {
306 auto *indexIt = std::next(operands.begin(), start);
307 std::rotate(indexIt, std::next(indexIt, length), operands.end());
309 for (unsigned i = 0; i != length; ++i)
310 operands[numOperands + i].~OpOperand();
313 void detail::OperandStorage::eraseOperands(const BitVector &eraseIndices) {
314 MutableArrayRef<OpOperand> operands = getOperands();
315 assert(eraseIndices.size() == operands.size());
317 // Check that at least one operand is erased.
318 int firstErasedIndice = eraseIndices.find_first();
319 if (firstErasedIndice == -1)
320 return;
322 // Shift all of the removed operands to the end, and destroy them.
323 numOperands = firstErasedIndice;
324 for (unsigned i = firstErasedIndice + 1, e = operands.size(); i < e; ++i)
325 if (!eraseIndices.test(i))
326 operands[numOperands++] = std::move(operands[i]);
327 for (OpOperand &operand : operands.drop_front(numOperands))
328 operand.~OpOperand();
331 /// Resize the storage to the given size. Returns the array containing the new
332 /// operands.
333 MutableArrayRef<OpOperand> detail::OperandStorage::resize(Operation *owner,
334 unsigned newSize) {
335 // If the number of operands is less than or equal to the current amount, we
336 // can just update in place.
337 MutableArrayRef<OpOperand> origOperands = getOperands();
338 if (newSize <= numOperands) {
339 // If the number of new size is less than the current, remove any extra
340 // operands.
341 for (unsigned i = newSize; i != numOperands; ++i)
342 origOperands[i].~OpOperand();
343 numOperands = newSize;
344 return origOperands.take_front(newSize);
347 // If the new size is within the original inline capacity, grow inplace.
348 if (newSize <= capacity) {
349 OpOperand *opBegin = origOperands.data();
350 for (unsigned e = newSize; numOperands != e; ++numOperands)
351 new (&opBegin[numOperands]) OpOperand(owner);
352 return MutableArrayRef<OpOperand>(opBegin, newSize);
355 // Otherwise, we need to allocate a new storage.
356 unsigned newCapacity =
357 std::max(unsigned(llvm::NextPowerOf2(capacity + 2)), newSize);
358 OpOperand *newOperandStorage =
359 reinterpret_cast<OpOperand *>(malloc(sizeof(OpOperand) * newCapacity));
361 // Move the current operands to the new storage.
362 MutableArrayRef<OpOperand> newOperands(newOperandStorage, newSize);
363 std::uninitialized_move(origOperands.begin(), origOperands.end(),
364 newOperands.begin());
366 // Destroy the original operands.
367 for (auto &operand : origOperands)
368 operand.~OpOperand();
370 // Initialize any new operands.
371 for (unsigned e = newSize; numOperands != e; ++numOperands)
372 new (&newOperands[numOperands]) OpOperand(owner);
374 // If the current storage is dynamic, free it.
375 if (isStorageDynamic)
376 free(operandStorage);
378 // Update the storage representation to use the new dynamic storage.
379 operandStorage = newOperandStorage;
380 capacity = newCapacity;
381 isStorageDynamic = true;
382 return newOperands;
385 //===----------------------------------------------------------------------===//
386 // Operation Value-Iterators
387 //===----------------------------------------------------------------------===//
389 //===----------------------------------------------------------------------===//
390 // OperandRange
392 unsigned OperandRange::getBeginOperandIndex() const {
393 assert(!empty() && "range must not be empty");
394 return base->getOperandNumber();
397 OperandRangeRange OperandRange::split(DenseI32ArrayAttr segmentSizes) const {
398 return OperandRangeRange(*this, segmentSizes);
401 //===----------------------------------------------------------------------===//
402 // OperandRangeRange
404 OperandRangeRange::OperandRangeRange(OperandRange operands,
405 Attribute operandSegments)
406 : OperandRangeRange(OwnerT(operands.getBase(), operandSegments), 0,
407 llvm::cast<DenseI32ArrayAttr>(operandSegments).size()) {
410 OperandRange OperandRangeRange::join() const {
411 const OwnerT &owner = getBase();
412 ArrayRef<int32_t> sizeData = llvm::cast<DenseI32ArrayAttr>(owner.second);
413 return OperandRange(owner.first,
414 std::accumulate(sizeData.begin(), sizeData.end(), 0));
417 OperandRange OperandRangeRange::dereference(const OwnerT &object,
418 ptrdiff_t index) {
419 ArrayRef<int32_t> sizeData = llvm::cast<DenseI32ArrayAttr>(object.second);
420 uint32_t startIndex =
421 std::accumulate(sizeData.begin(), sizeData.begin() + index, 0);
422 return OperandRange(object.first + startIndex, *(sizeData.begin() + index));
425 //===----------------------------------------------------------------------===//
426 // MutableOperandRange
428 /// Construct a new mutable range from the given operand, operand start index,
429 /// and range length.
430 MutableOperandRange::MutableOperandRange(
431 Operation *owner, unsigned start, unsigned length,
432 ArrayRef<OperandSegment> operandSegments)
433 : owner(owner), start(start), length(length),
434 operandSegments(operandSegments.begin(), operandSegments.end()) {
435 assert((start + length) <= owner->getNumOperands() && "invalid range");
437 MutableOperandRange::MutableOperandRange(Operation *owner)
438 : MutableOperandRange(owner, /*start=*/0, owner->getNumOperands()) {}
440 /// Construct a new mutable range for the given OpOperand.
441 MutableOperandRange::MutableOperandRange(OpOperand &opOperand)
442 : MutableOperandRange(opOperand.getOwner(),
443 /*start=*/opOperand.getOperandNumber(),
444 /*length=*/1) {}
446 /// Slice this range into a sub range, with the additional operand segment.
447 MutableOperandRange
448 MutableOperandRange::slice(unsigned subStart, unsigned subLen,
449 std::optional<OperandSegment> segment) const {
450 assert((subStart + subLen) <= length && "invalid sub-range");
451 MutableOperandRange subSlice(owner, start + subStart, subLen,
452 operandSegments);
453 if (segment)
454 subSlice.operandSegments.push_back(*segment);
455 return subSlice;
458 /// Append the given values to the range.
459 void MutableOperandRange::append(ValueRange values) {
460 if (values.empty())
461 return;
462 owner->insertOperands(start + length, values);
463 updateLength(length + values.size());
466 /// Assign this range to the given values.
467 void MutableOperandRange::assign(ValueRange values) {
468 owner->setOperands(start, length, values);
469 if (length != values.size())
470 updateLength(/*newLength=*/values.size());
473 /// Assign the range to the given value.
474 void MutableOperandRange::assign(Value value) {
475 if (length == 1) {
476 owner->setOperand(start, value);
477 } else {
478 owner->setOperands(start, length, value);
479 updateLength(/*newLength=*/1);
483 /// Erase the operands within the given sub-range.
484 void MutableOperandRange::erase(unsigned subStart, unsigned subLen) {
485 assert((subStart + subLen) <= length && "invalid sub-range");
486 if (length == 0)
487 return;
488 owner->eraseOperands(start + subStart, subLen);
489 updateLength(length - subLen);
492 /// Clear this range and erase all of the operands.
493 void MutableOperandRange::clear() {
494 if (length != 0) {
495 owner->eraseOperands(start, length);
496 updateLength(/*newLength=*/0);
500 /// Allow implicit conversion to an OperandRange.
501 MutableOperandRange::operator OperandRange() const {
502 return owner->getOperands().slice(start, length);
505 MutableOperandRange::operator MutableArrayRef<OpOperand>() const {
506 return owner->getOpOperands().slice(start, length);
509 MutableOperandRangeRange
510 MutableOperandRange::split(NamedAttribute segmentSizes) const {
511 return MutableOperandRangeRange(*this, segmentSizes);
514 /// Update the length of this range to the one provided.
515 void MutableOperandRange::updateLength(unsigned newLength) {
516 int32_t diff = int32_t(newLength) - int32_t(length);
517 length = newLength;
519 // Update any of the provided segment attributes.
520 for (OperandSegment &segment : operandSegments) {
521 auto attr = llvm::cast<DenseI32ArrayAttr>(segment.second.getValue());
522 SmallVector<int32_t, 8> segments(attr.asArrayRef());
523 segments[segment.first] += diff;
524 segment.second.setValue(
525 DenseI32ArrayAttr::get(attr.getContext(), segments));
526 owner->setAttr(segment.second.getName(), segment.second.getValue());
530 OpOperand &MutableOperandRange::operator[](unsigned index) const {
531 assert(index < length && "index is out of bounds");
532 return owner->getOpOperand(start + index);
535 MutableArrayRef<OpOperand>::iterator MutableOperandRange::begin() const {
536 return owner->getOpOperands().slice(start, length).begin();
539 MutableArrayRef<OpOperand>::iterator MutableOperandRange::end() const {
540 return owner->getOpOperands().slice(start, length).end();
543 //===----------------------------------------------------------------------===//
544 // MutableOperandRangeRange
546 MutableOperandRangeRange::MutableOperandRangeRange(
547 const MutableOperandRange &operands, NamedAttribute operandSegmentAttr)
548 : MutableOperandRangeRange(
549 OwnerT(operands, operandSegmentAttr), 0,
550 llvm::cast<DenseI32ArrayAttr>(operandSegmentAttr.getValue()).size()) {
553 MutableOperandRange MutableOperandRangeRange::join() const {
554 return getBase().first;
557 MutableOperandRangeRange::operator OperandRangeRange() const {
558 return OperandRangeRange(getBase().first, getBase().second.getValue());
561 MutableOperandRange MutableOperandRangeRange::dereference(const OwnerT &object,
562 ptrdiff_t index) {
563 ArrayRef<int32_t> sizeData =
564 llvm::cast<DenseI32ArrayAttr>(object.second.getValue());
565 uint32_t startIndex =
566 std::accumulate(sizeData.begin(), sizeData.begin() + index, 0);
567 return object.first.slice(
568 startIndex, *(sizeData.begin() + index),
569 MutableOperandRange::OperandSegment(index, object.second));
572 //===----------------------------------------------------------------------===//
573 // ResultRange
575 ResultRange::ResultRange(OpResult result)
576 : ResultRange(static_cast<detail::OpResultImpl *>(Value(result).getImpl()),
577 1) {}
579 ResultRange::use_range ResultRange::getUses() const {
580 return {use_begin(), use_end()};
582 ResultRange::use_iterator ResultRange::use_begin() const {
583 return use_iterator(*this);
585 ResultRange::use_iterator ResultRange::use_end() const {
586 return use_iterator(*this, /*end=*/true);
588 ResultRange::user_range ResultRange::getUsers() {
589 return {user_begin(), user_end()};
591 ResultRange::user_iterator ResultRange::user_begin() {
592 return user_iterator(use_begin());
594 ResultRange::user_iterator ResultRange::user_end() {
595 return user_iterator(use_end());
598 ResultRange::UseIterator::UseIterator(ResultRange results, bool end)
599 : it(end ? results.end() : results.begin()), endIt(results.end()) {
600 // Only initialize current use if there are results/can be uses.
601 if (it != endIt)
602 skipOverResultsWithNoUsers();
605 ResultRange::UseIterator &ResultRange::UseIterator::operator++() {
606 // We increment over uses, if we reach the last use then move to next
607 // result.
608 if (use != (*it).use_end())
609 ++use;
610 if (use == (*it).use_end()) {
611 ++it;
612 skipOverResultsWithNoUsers();
614 return *this;
617 void ResultRange::UseIterator::skipOverResultsWithNoUsers() {
618 while (it != endIt && (*it).use_empty())
619 ++it;
621 // If we are at the last result, then set use to first use of
622 // first result (sentinel value used for end).
623 if (it == endIt)
624 use = {};
625 else
626 use = (*it).use_begin();
629 void ResultRange::replaceAllUsesWith(Operation *op) {
630 replaceAllUsesWith(op->getResults());
633 void ResultRange::replaceUsesWithIf(
634 Operation *op, function_ref<bool(OpOperand &)> shouldReplace) {
635 replaceUsesWithIf(op->getResults(), shouldReplace);
638 //===----------------------------------------------------------------------===//
639 // ValueRange
641 ValueRange::ValueRange(ArrayRef<Value> values)
642 : ValueRange(values.data(), values.size()) {}
643 ValueRange::ValueRange(OperandRange values)
644 : ValueRange(values.begin().getBase(), values.size()) {}
645 ValueRange::ValueRange(ResultRange values)
646 : ValueRange(values.getBase(), values.size()) {}
648 /// See `llvm::detail::indexed_accessor_range_base` for details.
649 ValueRange::OwnerT ValueRange::offset_base(const OwnerT &owner,
650 ptrdiff_t index) {
651 if (const auto *value = llvm::dyn_cast_if_present<const Value *>(owner))
652 return {value + index};
653 if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(owner))
654 return {operand + index};
655 return owner.get<detail::OpResultImpl *>()->getNextResultAtOffset(index);
657 /// See `llvm::detail::indexed_accessor_range_base` for details.
658 Value ValueRange::dereference_iterator(const OwnerT &owner, ptrdiff_t index) {
659 if (const auto *value = llvm::dyn_cast_if_present<const Value *>(owner))
660 return value[index];
661 if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(owner))
662 return operand[index].get();
663 return owner.get<detail::OpResultImpl *>()->getNextResultAtOffset(index);
666 //===----------------------------------------------------------------------===//
667 // Operation Equivalency
668 //===----------------------------------------------------------------------===//
670 llvm::hash_code OperationEquivalence::computeHash(
671 Operation *op, function_ref<llvm::hash_code(Value)> hashOperands,
672 function_ref<llvm::hash_code(Value)> hashResults, Flags flags) {
673 // Hash operations based upon their:
674 // - Operation Name
675 // - Attributes
676 // - Result Types
677 llvm::hash_code hash =
678 llvm::hash_combine(op->getName(), op->getDiscardableAttrDictionary(),
679 op->getResultTypes(), op->hashProperties());
681 // - Location if required
682 if (!(flags & Flags::IgnoreLocations))
683 hash = llvm::hash_combine(hash, op->getLoc());
685 // - Operands
686 for (Value operand : op->getOperands())
687 hash = llvm::hash_combine(hash, hashOperands(operand));
689 // - Results
690 for (Value result : op->getResults())
691 hash = llvm::hash_combine(hash, hashResults(result));
692 return hash;
695 /*static*/ bool OperationEquivalence::isRegionEquivalentTo(
696 Region *lhs, Region *rhs,
697 function_ref<LogicalResult(Value, Value)> checkEquivalent,
698 function_ref<void(Value, Value)> markEquivalent,
699 OperationEquivalence::Flags flags) {
700 DenseMap<Block *, Block *> blocksMap;
701 auto blocksEquivalent = [&](Block &lBlock, Block &rBlock) {
702 // Check block arguments.
703 if (lBlock.getNumArguments() != rBlock.getNumArguments())
704 return false;
706 // Map the two blocks.
707 auto insertion = blocksMap.insert({&lBlock, &rBlock});
708 if (insertion.first->getSecond() != &rBlock)
709 return false;
711 for (auto argPair :
712 llvm::zip(lBlock.getArguments(), rBlock.getArguments())) {
713 Value curArg = std::get<0>(argPair);
714 Value otherArg = std::get<1>(argPair);
715 if (curArg.getType() != otherArg.getType())
716 return false;
717 if (!(flags & OperationEquivalence::IgnoreLocations) &&
718 curArg.getLoc() != otherArg.getLoc())
719 return false;
720 // Corresponding bbArgs are equivalent.
721 if (markEquivalent)
722 markEquivalent(curArg, otherArg);
725 auto opsEquivalent = [&](Operation &lOp, Operation &rOp) {
726 // Check for op equality (recursively).
727 if (!OperationEquivalence::isEquivalentTo(&lOp, &rOp, checkEquivalent,
728 markEquivalent, flags))
729 return false;
730 // Check successor mapping.
731 for (auto successorsPair :
732 llvm::zip(lOp.getSuccessors(), rOp.getSuccessors())) {
733 Block *curSuccessor = std::get<0>(successorsPair);
734 Block *otherSuccessor = std::get<1>(successorsPair);
735 auto insertion = blocksMap.insert({curSuccessor, otherSuccessor});
736 if (insertion.first->getSecond() != otherSuccessor)
737 return false;
739 return true;
741 return llvm::all_of_zip(lBlock, rBlock, opsEquivalent);
743 return llvm::all_of_zip(*lhs, *rhs, blocksEquivalent);
746 // Value equivalence cache to be used with `isRegionEquivalentTo` and
747 // `isEquivalentTo`.
748 struct ValueEquivalenceCache {
749 DenseMap<Value, Value> equivalentValues;
750 LogicalResult checkEquivalent(Value lhsValue, Value rhsValue) {
751 return success(lhsValue == rhsValue ||
752 equivalentValues.lookup(lhsValue) == rhsValue);
754 void markEquivalent(Value lhsResult, Value rhsResult) {
755 auto insertion = equivalentValues.insert({lhsResult, rhsResult});
756 // Make sure that the value was not already marked equivalent to some other
757 // value.
758 (void)insertion;
759 assert(insertion.first->second == rhsResult &&
760 "inconsistent OperationEquivalence state");
764 /*static*/ bool
765 OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs,
766 OperationEquivalence::Flags flags) {
767 ValueEquivalenceCache cache;
768 return isRegionEquivalentTo(
769 lhs, rhs,
770 [&](Value lhsValue, Value rhsValue) -> LogicalResult {
771 return cache.checkEquivalent(lhsValue, rhsValue);
773 [&](Value lhsResult, Value rhsResult) {
774 cache.markEquivalent(lhsResult, rhsResult);
776 flags);
779 /*static*/ bool OperationEquivalence::isEquivalentTo(
780 Operation *lhs, Operation *rhs,
781 function_ref<LogicalResult(Value, Value)> checkEquivalent,
782 function_ref<void(Value, Value)> markEquivalent, Flags flags) {
783 if (lhs == rhs)
784 return true;
786 // 1. Compare the operation properties.
787 if (lhs->getName() != rhs->getName() ||
788 lhs->getDiscardableAttrDictionary() !=
789 rhs->getDiscardableAttrDictionary() ||
790 lhs->getNumRegions() != rhs->getNumRegions() ||
791 lhs->getNumSuccessors() != rhs->getNumSuccessors() ||
792 lhs->getNumOperands() != rhs->getNumOperands() ||
793 lhs->getNumResults() != rhs->getNumResults() ||
794 !lhs->getName().compareOpProperties(lhs->getPropertiesStorage(),
795 rhs->getPropertiesStorage()))
796 return false;
797 if (!(flags & IgnoreLocations) && lhs->getLoc() != rhs->getLoc())
798 return false;
800 // 2. Compare operands.
801 for (auto operandPair : llvm::zip(lhs->getOperands(), rhs->getOperands())) {
802 Value curArg = std::get<0>(operandPair);
803 Value otherArg = std::get<1>(operandPair);
804 if (curArg == otherArg)
805 continue;
806 if (curArg.getType() != otherArg.getType())
807 return false;
808 if (failed(checkEquivalent(curArg, otherArg)))
809 return false;
812 // 3. Compare result types and mark results as equivalent.
813 for (auto resultPair : llvm::zip(lhs->getResults(), rhs->getResults())) {
814 Value curArg = std::get<0>(resultPair);
815 Value otherArg = std::get<1>(resultPair);
816 if (curArg.getType() != otherArg.getType())
817 return false;
818 if (markEquivalent)
819 markEquivalent(curArg, otherArg);
822 // 4. Compare regions.
823 for (auto regionPair : llvm::zip(lhs->getRegions(), rhs->getRegions()))
824 if (!isRegionEquivalentTo(&std::get<0>(regionPair),
825 &std::get<1>(regionPair), checkEquivalent,
826 markEquivalent, flags))
827 return false;
829 return true;
832 /*static*/ bool OperationEquivalence::isEquivalentTo(Operation *lhs,
833 Operation *rhs,
834 Flags flags) {
835 ValueEquivalenceCache cache;
836 return OperationEquivalence::isEquivalentTo(
837 lhs, rhs,
838 [&](Value lhsValue, Value rhsValue) -> LogicalResult {
839 return cache.checkEquivalent(lhsValue, rhsValue);
841 [&](Value lhsResult, Value rhsResult) {
842 cache.markEquivalent(lhsResult, rhsResult);
844 flags);
847 //===----------------------------------------------------------------------===//
848 // OperationFingerPrint
849 //===----------------------------------------------------------------------===//
851 template <typename T>
852 static void addDataToHash(llvm::SHA1 &hasher, const T &data) {
853 hasher.update(
854 ArrayRef<uint8_t>(reinterpret_cast<const uint8_t *>(&data), sizeof(T)));
857 OperationFingerPrint::OperationFingerPrint(Operation *topOp) {
858 llvm::SHA1 hasher;
860 // Hash each of the operations based upon their mutable bits:
861 topOp->walk([&](Operation *op) {
862 // - Operation pointer
863 addDataToHash(hasher, op);
864 // - Parent operation pointer (to take into account the nesting structure)
865 if (op != topOp)
866 addDataToHash(hasher, op->getParentOp());
867 // - Attributes
868 addDataToHash(hasher, op->getDiscardableAttrDictionary());
869 // - Properties
870 addDataToHash(hasher, op->hashProperties());
871 // - Blocks in Regions
872 for (Region &region : op->getRegions()) {
873 for (Block &block : region) {
874 addDataToHash(hasher, &block);
875 for (BlockArgument arg : block.getArguments())
876 addDataToHash(hasher, arg);
879 // - Location
880 addDataToHash(hasher, op->getLoc().getAsOpaquePointer());
881 // - Operands
882 for (Value operand : op->getOperands())
883 addDataToHash(hasher, operand);
884 // - Successors
885 for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i)
886 addDataToHash(hasher, op->getSuccessor(i));
887 // - Result types
888 for (Type t : op->getResultTypes())
889 addDataToHash(hasher, t);
891 hash = hasher.result();