1 //===- SubsetOpInterface.cpp - Tensor Subsets -----------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Interfaces/SubsetOpInterface.h"
10 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
11 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
13 #include "mlir/Interfaces/SubsetOpInterface.cpp.inc"
17 OpOperand
&detail::defaultGetDestinationOperand(Operation
*op
) {
18 auto dstOp
= dyn_cast
<DestinationStyleOpInterface
>(op
);
19 assert(dstOp
&& "getDestination must be implemented for non-DPS ops");
21 dstOp
.getNumDpsInits() == 1 &&
22 "getDestination must be implemented for ops with 0 or more than 1 init");
23 return *dstOp
.getDpsInitOperand(0);
26 OpResult
detail::defaultGetUpdatedDestination(Operation
*op
) {
27 auto dstOp
= dyn_cast
<DestinationStyleOpInterface
>(op
);
28 assert(dstOp
&& "getUpdatedDestination must be implemented for non-DPS ops");
29 auto insertionOp
= cast
<SubsetInsertionOpInterface
>(op
);
30 return dstOp
.getTiedOpResult(&insertionOp
.getDestinationOperand());
33 bool detail::defaultIsEquivalentSubset(
34 Operation
*op
, Value candidate
,
35 function_ref
<bool(Value
, Value
)> equivalenceFn
) {
36 assert(isa
<SubsetInsertionOpInterface
>(op
) &&
37 "expected SubsetInsertionOpInterface");
38 if (!candidate
.getDefiningOp
<SubsetExtractionOpInterface
>())
40 return cast
<SubsetOpInterface
>(op
).operatesOnEquivalentSubset(
41 candidate
.getDefiningOp
<SubsetOpInterface
>(), equivalenceFn
);
44 bool detail::defaultOperatesOnEquivalentSubset(
45 Operation
*op
, SubsetOpInterface candidate
,
46 function_ref
<bool(Value
, Value
)> equivalenceFn
) {
47 auto subsetOp
= cast
<SubsetOpInterface
>(op
);
48 FailureOr
<HyperrectangularSlice
> slice
=
49 subsetOp
.getAccessedHyperrectangularSlice();
50 assert(succeeded(slice
) &&
51 "operatesOnEquivalentSubset must be implemented if "
52 "getAccessedHyperrectangularSlice is not implemented");
53 FailureOr
<HyperrectangularSlice
> otherSlice
=
54 candidate
.getAccessedHyperrectangularSlice();
55 if (failed(otherSlice
))
57 if (!equivalenceFn(subsetOp
.getTensorContainer(),
58 candidate
.getTensorContainer()))
60 FailureOr
<bool> equivalent
= ValueBoundsConstraintSet::areEquivalentSlices(
61 op
->getContext(), *slice
, *otherSlice
);
62 return succeeded(equivalent
) && *equivalent
;
65 bool detail::defaultOperatesOnDisjointSubset(
66 Operation
*op
, SubsetOpInterface candidate
,
67 function_ref
<bool(Value
, Value
)> equivalenceFn
) {
68 auto subsetOp
= cast
<SubsetOpInterface
>(op
);
69 FailureOr
<HyperrectangularSlice
> slice
=
70 subsetOp
.getAccessedHyperrectangularSlice();
71 assert(succeeded(slice
) &&
72 "defaultOperatesOnDisjointSubset must be implemented if "
73 "getAccessedHyperrectangularSlice is not implemented");
74 FailureOr
<HyperrectangularSlice
> otherSlice
=
75 candidate
.getAccessedHyperrectangularSlice();
76 if (failed(otherSlice
))
78 if (!equivalenceFn(subsetOp
.getTensorContainer(),
79 candidate
.getTensorContainer()))
81 FailureOr
<bool> overlapping
= ValueBoundsConstraintSet::areOverlappingSlices(
82 op
->getContext(), *slice
, *otherSlice
);
83 return succeeded(overlapping
) && !*overlapping
;
86 Value
detail::getTensorContainer(Operation
*op
) {
87 if (auto insertionOp
= dyn_cast
<::mlir::SubsetInsertionOpInterface
>(op
))
88 return insertionOp
.getDestinationOperand().get();
89 return cast
<::mlir::SubsetExtractionOpInterface
>(op
).getSourceOperand().get();
92 LogicalResult
detail::verifySubsetOpInterface(SubsetOpInterface op
) {
93 if (!(isa
<SubsetExtractionOpInterface
>(op
.getOperation()) ^
94 isa
<SubsetInsertionOpInterface
>(op
.getOperation())))
95 return op
->emitOpError(
96 "SubsetOpInterface ops must implement either "
97 "SubsetExtractionOpInterface or SubsetInsertionOpInterface");
102 detail::verifySubsetExtractionOpInterface(SubsetExtractionOpInterface op
) {
103 if (op
->getNumResults() != 1)
104 return op
->emitOpError(
105 "SubsetExtractionOpInterface ops must have one result");