[Workflow] Roll back some settings since they caused more issues
[llvm-project.git] / mlir / lib / Interfaces / InferTypeOpInterface.cpp
blob3c50c4c37c6f593b4166182cdc6259032967fa0f
1 //===- InferTypeOpInterface.cpp - Infer Type Interfaces ---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the definitions of the infer op interfaces defined in
10 // `InferTypeOpInterface.td`.
12 //===----------------------------------------------------------------------===//
14 #include "mlir/Interfaces/InferTypeOpInterface.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/Matchers.h"
17 #include "llvm/Support/FormatVariadic.h"
19 using namespace mlir;
21 namespace mlir {
22 #include "mlir/Interfaces/InferTypeOpInterface.cpp.inc"
23 } // namespace mlir
25 LogicalResult
26 mlir::reifyResultShapes(OpBuilder &b, Operation *op,
27 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
28 auto reifiableOp = dyn_cast<ReifyRankedShapedTypeOpInterface>(op);
29 if (!reifiableOp)
30 return failure();
31 LogicalResult status = reifiableOp.reifyResultShapes(b, reifiedReturnShapes);
32 #ifndef NDEBUG
33 if (failed(status))
34 return failure();
35 // Assert that ReifyRankedShapedTypeOpInterface::reifyResultShapes produced
36 // a correct result.
37 int64_t resultIdx = 0;
38 for (OpResult result : op->getResults()) {
39 auto shapedType = dyn_cast<ShapedType>(result.getType());
40 if (!shapedType)
41 continue;
42 if (!shapedType.hasRank()) {
43 // Nothing to check for unranked shaped values.
44 ++resultIdx;
45 continue;
47 // Assert one OpFoldResult per dimension.
48 assert(shapedType.getRank() ==
49 static_cast<int64_t>(reifiedReturnShapes[resultIdx].size()) &&
50 "incorrect implementation of ReifyRankedShapedTypeOpInterface");
51 for (int64_t dim = 0; dim < shapedType.getRank(); ++dim) {
52 // reifyResultShapes must return:
53 // * Attribute for static dimensions
54 // * Value for dynamic dimensions
55 assert(shapedType.isDynamicDim(dim) ==
56 reifiedReturnShapes[resultIdx][dim].is<Value>() &&
57 "incorrect implementation of ReifyRankedShapedTypeOpInterface");
59 ++resultIdx;
61 // Assert that every shaped value result was reified.
62 assert(resultIdx == static_cast<int64_t>(reifiedReturnShapes.size()) &&
63 "incorrect implementation of ReifyRankedShapedTypeOpInterface");
64 #endif // NDEBUG
65 return status;
68 bool ShapeAdaptor::hasRank() const {
69 if (val.isNull())
70 return false;
71 if (auto t = llvm::dyn_cast_if_present<Type>(val))
72 return cast<ShapedType>(t).hasRank();
73 if (val.is<Attribute>())
74 return true;
75 return val.get<ShapedTypeComponents *>()->hasRank();
78 Type ShapeAdaptor::getElementType() const {
79 if (val.isNull())
80 return nullptr;
81 if (auto t = llvm::dyn_cast_if_present<Type>(val))
82 return cast<ShapedType>(t).getElementType();
83 if (val.is<Attribute>())
84 return nullptr;
85 return val.get<ShapedTypeComponents *>()->getElementType();
88 void ShapeAdaptor::getDims(SmallVectorImpl<int64_t> &res) const {
89 assert(hasRank());
90 if (auto t = llvm::dyn_cast_if_present<Type>(val)) {
91 ArrayRef<int64_t> vals = cast<ShapedType>(t).getShape();
92 res.assign(vals.begin(), vals.end());
93 } else if (auto attr = llvm::dyn_cast_if_present<Attribute>(val)) {
94 auto dattr = cast<DenseIntElementsAttr>(attr);
95 res.clear();
96 res.reserve(dattr.size());
97 for (auto it : dattr.getValues<APInt>())
98 res.push_back(it.getSExtValue());
99 } else {
100 auto vals = val.get<ShapedTypeComponents *>()->getDims();
101 res.assign(vals.begin(), vals.end());
105 void ShapeAdaptor::getDims(ShapedTypeComponents &res) const {
106 assert(hasRank());
107 res.ranked = true;
108 getDims(res.dims);
111 int64_t ShapeAdaptor::getDimSize(int index) const {
112 assert(hasRank());
113 if (auto t = llvm::dyn_cast_if_present<Type>(val))
114 return cast<ShapedType>(t).getDimSize(index);
115 if (auto attr = llvm::dyn_cast_if_present<Attribute>(val))
116 return cast<DenseIntElementsAttr>(attr)
117 .getValues<APInt>()[index]
118 .getSExtValue();
119 auto *stc = val.get<ShapedTypeComponents *>();
120 return stc->getDims()[index];
123 int64_t ShapeAdaptor::getRank() const {
124 assert(hasRank());
125 if (auto t = llvm::dyn_cast_if_present<Type>(val))
126 return cast<ShapedType>(t).getRank();
127 if (auto attr = llvm::dyn_cast_if_present<Attribute>(val))
128 return cast<DenseIntElementsAttr>(attr).size();
129 return val.get<ShapedTypeComponents *>()->getDims().size();
132 bool ShapeAdaptor::hasStaticShape() const {
133 if (!hasRank())
134 return false;
136 if (auto t = llvm::dyn_cast_if_present<Type>(val))
137 return cast<ShapedType>(t).hasStaticShape();
138 if (auto attr = llvm::dyn_cast_if_present<Attribute>(val)) {
139 auto dattr = cast<DenseIntElementsAttr>(attr);
140 for (auto index : dattr.getValues<APInt>())
141 if (ShapedType::isDynamic(index.getSExtValue()))
142 return false;
143 return true;
145 auto *stc = val.get<ShapedTypeComponents *>();
146 return llvm::none_of(stc->getDims(), ShapedType::isDynamic);
149 int64_t ShapeAdaptor::getNumElements() const {
150 assert(hasStaticShape() && "cannot get element count of dynamic shaped type");
152 if (auto t = llvm::dyn_cast_if_present<Type>(val))
153 return cast<ShapedType>(t).getNumElements();
155 if (auto attr = llvm::dyn_cast_if_present<Attribute>(val)) {
156 auto dattr = cast<DenseIntElementsAttr>(attr);
157 int64_t num = 1;
158 for (auto index : dattr.getValues<APInt>()) {
159 num *= index.getZExtValue();
160 assert(num >= 0 && "integer overflow in element count computation");
162 return num;
165 auto *stc = val.get<ShapedTypeComponents *>();
166 int64_t num = 1;
167 for (int64_t dim : stc->getDims()) {
168 num *= dim;
169 assert(num >= 0 && "integer overflow in element count computation");
171 return num;
174 void ShapeAdaptor::dump() const {
175 if (!hasRank()) {
176 llvm::errs() << "<<unranked>>\n";
177 return;
180 SmallVector<int64_t> dims;
181 getDims(dims);
182 auto mapped = llvm::map_range(dims, [](int64_t dim) -> std::string {
183 if (ShapedType::isDynamic(dim))
184 return "?";
185 return llvm::formatv("{0}", dim).str();
187 llvm::errs() << "rank = " << getRank() << " dims = [";
188 llvm::interleave(mapped, llvm::errs(), "x");
189 llvm::errs() << "]\n";
192 ShapeAdaptor ValueShapeRange::getValueAsShape(int index) {
193 Value val = operator[](index);
194 if (valueToShape)
195 if (ShapeAdaptor ret = valueToShape(val))
196 return ret;
198 DenseIntElementsAttr attr;
199 if (!matchPattern(val, m_Constant(&attr)))
200 return nullptr;
201 if (attr.getType().getRank() != 1)
202 return nullptr;
203 return attr;
206 ShapeAdaptor ValueShapeRange::getShape(Value val) const {
207 if (operandShape)
208 if (ShapeAdaptor ret = operandShape(val))
209 return ret;
210 return val.getType();
213 ShapeAdaptor ValueShapeRange::getShape(int index) const {
214 if (index < 0 || static_cast<size_t>(index) >= size())
215 return nullptr;
216 return getShape(operator[](index));
219 LogicalResult mlir::detail::inferReturnTensorTypes(
220 ArrayRef<ShapedTypeComponents> retComponents,
221 SmallVectorImpl<Type> &inferredReturnTypes) {
222 for (const auto &shapeAndType : retComponents) {
223 Type elementTy = shapeAndType.getElementType();
224 assert(elementTy && "element type required to construct tensor");
226 Attribute attr = shapeAndType.getAttribute();
227 if (shapeAndType.hasRank()) {
228 inferredReturnTypes.push_back(
229 RankedTensorType::get(shapeAndType.getDims(), elementTy, attr));
230 } else {
231 assert(attr == nullptr && "attribute not supported");
232 inferredReturnTypes.push_back(UnrankedTensorType::get(elementTy));
235 return success();
238 LogicalResult mlir::detail::verifyInferredResultTypes(Operation *op) {
239 SmallVector<Type, 4> inferredReturnTypes(op->getResultTypes());
240 auto retTypeFn = cast<InferTypeOpInterface>(op);
241 auto result = retTypeFn.refineReturnTypes(
242 op->getContext(), op->getLoc(), op->getOperands(),
243 op->getDiscardableAttrDictionary(), op->getPropertiesStorage(),
244 op->getRegions(), inferredReturnTypes);
245 if (failed(result))
246 op->emitOpError() << "failed to infer returned types";
248 return result;