1 //===- DestinationStyleOpInterface.cpp -- Destination style ops -----------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
14 #include "mlir/Interfaces/DestinationStyleOpInterface.cpp.inc"
18 size_t getNumTensorResults(Operation
*op
) {
19 size_t numTensorResults
= 0;
20 for (auto t
: op
->getResultTypes()) {
21 if (isa
<TensorType
>(t
)) {
25 return numTensorResults
;
29 LogicalResult
detail::verifyDestinationStyleOpInterface(Operation
*op
) {
30 DestinationStyleOpInterface dstStyleOp
=
31 cast
<DestinationStyleOpInterface
>(op
);
33 SmallVector
<OpOperand
*> outputTensorOperands
;
34 for (OpOperand
&operand
: dstStyleOp
.getDpsInitsMutable()) {
35 Type type
= operand
.get().getType();
36 if (isa
<TensorType
>(type
)) {
37 outputTensorOperands
.push_back(&operand
);
38 } else if (!isa
<BaseMemRefType
>(type
)) {
39 return op
->emitOpError("expected that operand #")
40 << operand
.getOperandNumber() << " is a tensor or a memref";
44 // Verify the number of tensor results matches the number of output tensors.
45 if (getNumTensorResults(op
) != outputTensorOperands
.size())
46 return op
->emitOpError("expected the number of tensor results (")
47 << getNumTensorResults(op
)
48 << ") to be equal to the number of output tensors ("
49 << outputTensorOperands
.size() << ")";
51 for (OpOperand
*opOperand
: outputTensorOperands
) {
52 OpResult result
= dstStyleOp
.getTiedOpResult(opOperand
);
53 if (result
.getType() != opOperand
->get().getType())
54 return op
->emitOpError("expected type of operand #")
55 << opOperand
->getOperandNumber() << " ("
56 << opOperand
->get().getType() << ")"
57 << " to match type of corresponding result (" << result
.getType()