1 //===- DestinationStyleOpInterface.cpp -- Destination style ops -----------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
14 #include "mlir/Interfaces/DestinationStyleOpInterface.cpp.inc"
18 size_t getNumTensorResults(Operation
*op
) {
19 size_t numTensorResults
= 0;
20 for (auto t
: op
->getResultTypes()) {
21 if (isa
<TensorType
>(t
)) {
25 return numTensorResults
;
29 LogicalResult
detail::verifyDestinationStyleOpInterface(Operation
*op
) {
30 DestinationStyleOpInterface dstStyleOp
=
31 cast
<DestinationStyleOpInterface
>(op
);
33 SmallVector
<OpOperand
*> outputTensorOperands
;
34 for (OpOperand
&operand
: dstStyleOp
.getDpsInitsMutable()) {
35 Type type
= operand
.get().getType();
36 if (isa
<RankedTensorType
>(type
)) {
37 outputTensorOperands
.push_back(&operand
);
38 } else if (!isa
<MemRefType
>(type
)) {
39 return op
->emitOpError("expected that operand #")
40 << operand
.getOperandNumber()
41 << " is a ranked tensor or a ranked memref";
45 // Verify the number of tensor results matches the number of output tensors.
46 if (getNumTensorResults(op
) != outputTensorOperands
.size())
47 return op
->emitOpError("expected the number of tensor results (")
48 << getNumTensorResults(op
)
49 << ") to be equal to the number of output tensors ("
50 << outputTensorOperands
.size() << ")";
52 for (OpOperand
*opOperand
: outputTensorOperands
) {
53 OpResult result
= dstStyleOp
.getTiedOpResult(opOperand
);
54 if (result
.getType() != opOperand
->get().getType())
55 return op
->emitOpError("expected type of operand #")
56 << opOperand
->getOperandNumber() << " ("
57 << opOperand
->get().getType() << ")"
58 << " to match type of corresponding result (" << result
.getType()