Fix GCC build problem with 288f05f related to SmallVector. (#116958)
[llvm-project.git] / mlir / lib / Interfaces / DestinationStyleOpInterface.cpp
blob496238fcaa3ff164aed4df9b4ac589a3340565ed
1 //===- DestinationStyleOpInterface.cpp -- Destination style ops -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
11 using namespace mlir;
13 namespace mlir {
14 #include "mlir/Interfaces/DestinationStyleOpInterface.cpp.inc"
15 } // namespace mlir
17 namespace {
18 size_t getNumTensorResults(Operation *op) {
19 size_t numTensorResults = 0;
20 for (auto t : op->getResultTypes()) {
21 if (isa<TensorType>(t)) {
22 ++numTensorResults;
25 return numTensorResults;
27 } // namespace
29 LogicalResult detail::verifyDestinationStyleOpInterface(Operation *op) {
30 DestinationStyleOpInterface dstStyleOp =
31 cast<DestinationStyleOpInterface>(op);
33 SmallVector<OpOperand *> outputTensorOperands;
34 for (OpOperand &operand : dstStyleOp.getDpsInitsMutable()) {
35 Type type = operand.get().getType();
36 if (isa<TensorType>(type)) {
37 outputTensorOperands.push_back(&operand);
38 } else if (!isa<BaseMemRefType>(type)) {
39 return op->emitOpError("expected that operand #")
40 << operand.getOperandNumber() << " is a tensor or a memref";
44 // Verify the number of tensor results matches the number of output tensors.
45 if (getNumTensorResults(op) != outputTensorOperands.size())
46 return op->emitOpError("expected the number of tensor results (")
47 << getNumTensorResults(op)
48 << ") to be equal to the number of output tensors ("
49 << outputTensorOperands.size() << ")";
51 for (OpOperand *opOperand : outputTensorOperands) {
52 OpResult result = dstStyleOp.getTiedOpResult(opOperand);
53 if (result.getType() != opOperand->get().getType())
54 return op->emitOpError("expected type of operand #")
55 << opOperand->getOperandNumber() << " ("
56 << opOperand->get().getType() << ")"
57 << " to match type of corresponding result (" << result.getType()
58 << ")";
61 return success();