1 //===- InferTypeOpInterface.cpp - Infer Type Interfaces ---------*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file contains the definitions of the infer op interfaces defined in
10 // `InferTypeOpInterface.td`.
12 //===----------------------------------------------------------------------===//
14 #include "mlir/Interfaces/InferTypeOpInterface.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/Matchers.h"
17 #include "llvm/Support/FormatVariadic.h"
22 #include "mlir/Interfaces/InferTypeOpInterface.cpp.inc"
26 mlir::reifyResultShapes(OpBuilder
&b
, Operation
*op
,
27 ReifiedRankedShapedTypeDims
&reifiedReturnShapes
) {
28 auto reifiableOp
= dyn_cast
<ReifyRankedShapedTypeOpInterface
>(op
);
31 LogicalResult status
= reifiableOp
.reifyResultShapes(b
, reifiedReturnShapes
);
35 // Assert that ReifyRankedShapedTypeOpInterface::reifyResultShapes produced
37 int64_t resultIdx
= 0;
38 for (OpResult result
: op
->getResults()) {
39 auto shapedType
= dyn_cast
<ShapedType
>(result
.getType());
42 if (!shapedType
.hasRank()) {
43 // Nothing to check for unranked shaped values.
47 // Assert one OpFoldResult per dimension.
48 assert(shapedType
.getRank() ==
49 static_cast<int64_t>(reifiedReturnShapes
[resultIdx
].size()) &&
50 "incorrect implementation of ReifyRankedShapedTypeOpInterface");
51 for (int64_t dim
= 0; dim
< shapedType
.getRank(); ++dim
) {
52 // reifyResultShapes must return:
53 // * Attribute for static dimensions
54 // * Value for dynamic dimensions
55 assert(shapedType
.isDynamicDim(dim
) ==
56 reifiedReturnShapes
[resultIdx
][dim
].is
<Value
>() &&
57 "incorrect implementation of ReifyRankedShapedTypeOpInterface");
61 // Assert that every shaped value result was reified.
62 assert(resultIdx
== static_cast<int64_t>(reifiedReturnShapes
.size()) &&
63 "incorrect implementation of ReifyRankedShapedTypeOpInterface");
68 bool ShapeAdaptor::hasRank() const {
71 if (auto t
= llvm::dyn_cast_if_present
<Type
>(val
))
72 return cast
<ShapedType
>(t
).hasRank();
73 if (val
.is
<Attribute
>())
75 return val
.get
<ShapedTypeComponents
*>()->hasRank();
78 Type
ShapeAdaptor::getElementType() const {
81 if (auto t
= llvm::dyn_cast_if_present
<Type
>(val
))
82 return cast
<ShapedType
>(t
).getElementType();
83 if (val
.is
<Attribute
>())
85 return val
.get
<ShapedTypeComponents
*>()->getElementType();
88 void ShapeAdaptor::getDims(SmallVectorImpl
<int64_t> &res
) const {
90 if (auto t
= llvm::dyn_cast_if_present
<Type
>(val
)) {
91 ArrayRef
<int64_t> vals
= cast
<ShapedType
>(t
).getShape();
92 res
.assign(vals
.begin(), vals
.end());
93 } else if (auto attr
= llvm::dyn_cast_if_present
<Attribute
>(val
)) {
94 auto dattr
= cast
<DenseIntElementsAttr
>(attr
);
96 res
.reserve(dattr
.size());
97 for (auto it
: dattr
.getValues
<APInt
>())
98 res
.push_back(it
.getSExtValue());
100 auto vals
= val
.get
<ShapedTypeComponents
*>()->getDims();
101 res
.assign(vals
.begin(), vals
.end());
105 void ShapeAdaptor::getDims(ShapedTypeComponents
&res
) const {
111 int64_t ShapeAdaptor::getDimSize(int index
) const {
113 if (auto t
= llvm::dyn_cast_if_present
<Type
>(val
))
114 return cast
<ShapedType
>(t
).getDimSize(index
);
115 if (auto attr
= llvm::dyn_cast_if_present
<Attribute
>(val
))
116 return cast
<DenseIntElementsAttr
>(attr
)
117 .getValues
<APInt
>()[index
]
119 auto *stc
= val
.get
<ShapedTypeComponents
*>();
120 return stc
->getDims()[index
];
123 int64_t ShapeAdaptor::getRank() const {
125 if (auto t
= llvm::dyn_cast_if_present
<Type
>(val
))
126 return cast
<ShapedType
>(t
).getRank();
127 if (auto attr
= llvm::dyn_cast_if_present
<Attribute
>(val
))
128 return cast
<DenseIntElementsAttr
>(attr
).size();
129 return val
.get
<ShapedTypeComponents
*>()->getDims().size();
132 bool ShapeAdaptor::hasStaticShape() const {
136 if (auto t
= llvm::dyn_cast_if_present
<Type
>(val
))
137 return cast
<ShapedType
>(t
).hasStaticShape();
138 if (auto attr
= llvm::dyn_cast_if_present
<Attribute
>(val
)) {
139 auto dattr
= cast
<DenseIntElementsAttr
>(attr
);
140 for (auto index
: dattr
.getValues
<APInt
>())
141 if (ShapedType::isDynamic(index
.getSExtValue()))
145 auto *stc
= val
.get
<ShapedTypeComponents
*>();
146 return llvm::none_of(stc
->getDims(), ShapedType::isDynamic
);
149 int64_t ShapeAdaptor::getNumElements() const {
150 assert(hasStaticShape() && "cannot get element count of dynamic shaped type");
152 if (auto t
= llvm::dyn_cast_if_present
<Type
>(val
))
153 return cast
<ShapedType
>(t
).getNumElements();
155 if (auto attr
= llvm::dyn_cast_if_present
<Attribute
>(val
)) {
156 auto dattr
= cast
<DenseIntElementsAttr
>(attr
);
158 for (auto index
: dattr
.getValues
<APInt
>()) {
159 num
*= index
.getZExtValue();
160 assert(num
>= 0 && "integer overflow in element count computation");
165 auto *stc
= val
.get
<ShapedTypeComponents
*>();
167 for (int64_t dim
: stc
->getDims()) {
169 assert(num
>= 0 && "integer overflow in element count computation");
174 void ShapeAdaptor::dump() const {
176 llvm::errs() << "<<unranked>>\n";
180 SmallVector
<int64_t> dims
;
182 auto mapped
= llvm::map_range(dims
, [](int64_t dim
) -> std::string
{
183 if (ShapedType::isDynamic(dim
))
185 return llvm::formatv("{0}", dim
).str();
187 llvm::errs() << "rank = " << getRank() << " dims = [";
188 llvm::interleave(mapped
, llvm::errs(), "x");
189 llvm::errs() << "]\n";
192 ShapeAdaptor
ValueShapeRange::getValueAsShape(int index
) {
193 Value val
= operator[](index
);
195 if (ShapeAdaptor ret
= valueToShape(val
))
198 DenseIntElementsAttr attr
;
199 if (!matchPattern(val
, m_Constant(&attr
)))
201 if (attr
.getType().getRank() != 1)
206 ShapeAdaptor
ValueShapeRange::getShape(Value val
) const {
208 if (ShapeAdaptor ret
= operandShape(val
))
210 return val
.getType();
213 ShapeAdaptor
ValueShapeRange::getShape(int index
) const {
214 if (index
< 0 || static_cast<size_t>(index
) >= size())
216 return getShape(operator[](index
));
219 LogicalResult
mlir::detail::inferReturnTensorTypes(
220 ArrayRef
<ShapedTypeComponents
> retComponents
,
221 SmallVectorImpl
<Type
> &inferredReturnTypes
) {
222 for (const auto &shapeAndType
: retComponents
) {
223 Type elementTy
= shapeAndType
.getElementType();
224 assert(elementTy
&& "element type required to construct tensor");
226 Attribute attr
= shapeAndType
.getAttribute();
227 if (shapeAndType
.hasRank()) {
228 inferredReturnTypes
.push_back(
229 RankedTensorType::get(shapeAndType
.getDims(), elementTy
, attr
));
231 assert(attr
== nullptr && "attribute not supported");
232 inferredReturnTypes
.push_back(UnrankedTensorType::get(elementTy
));
238 LogicalResult
mlir::detail::verifyInferredResultTypes(Operation
*op
) {
239 SmallVector
<Type
, 4> inferredReturnTypes(op
->getResultTypes());
240 auto retTypeFn
= cast
<InferTypeOpInterface
>(op
);
241 auto result
= retTypeFn
.refineReturnTypes(
242 op
->getContext(), op
->getLoc(), op
->getOperands(),
243 op
->getDiscardableAttrDictionary(), op
->getPropertiesStorage(),
244 op
->getRegions(), inferredReturnTypes
);
246 op
->emitOpError() << "failed to infer returned types";