[Workflow] Roll back some settings since they caused more issues
[llvm-project.git] / mlir / lib / Interfaces / DestinationStyleOpInterface.cpp
blob4e5ef66887cadf8405248feb586cb6de58505bae
1 //===- DestinationStyleOpInterface.cpp -- Destination style ops -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
11 using namespace mlir;
13 namespace mlir {
14 #include "mlir/Interfaces/DestinationStyleOpInterface.cpp.inc"
15 } // namespace mlir
17 namespace {
18 size_t getNumTensorResults(Operation *op) {
19 size_t numTensorResults = 0;
20 for (auto t : op->getResultTypes()) {
21 if (isa<TensorType>(t)) {
22 ++numTensorResults;
25 return numTensorResults;
27 } // namespace
29 LogicalResult detail::verifyDestinationStyleOpInterface(Operation *op) {
30 DestinationStyleOpInterface dstStyleOp =
31 cast<DestinationStyleOpInterface>(op);
33 SmallVector<OpOperand *> outputTensorOperands;
34 for (OpOperand &operand : dstStyleOp.getDpsInitsMutable()) {
35 Type type = operand.get().getType();
36 if (isa<RankedTensorType>(type)) {
37 outputTensorOperands.push_back(&operand);
38 } else if (!isa<MemRefType>(type)) {
39 return op->emitOpError("expected that operand #")
40 << operand.getOperandNumber()
41 << " is a ranked tensor or a ranked memref";
45 // Verify the number of tensor results matches the number of output tensors.
46 if (getNumTensorResults(op) != outputTensorOperands.size())
47 return op->emitOpError("expected the number of tensor results (")
48 << getNumTensorResults(op)
49 << ") to be equal to the number of output tensors ("
50 << outputTensorOperands.size() << ")";
52 for (OpOperand *opOperand : outputTensorOperands) {
53 OpResult result = dstStyleOp.getTiedOpResult(opOperand);
54 if (result.getType() != opOperand->get().getType())
55 return op->emitOpError("expected type of operand #")
56 << opOperand->getOperandNumber() << " ("
57 << opOperand->get().getType() << ")"
58 << " to match type of corresponding result (" << result.getType()
59 << ")";
61 return success();