1 //===- Traits.cpp - Common op traits shared by dialects -------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Dialect/Traits.h"
10 #include "mlir/IR/BuiltinTypes.h"
11 #include "mlir/IR/TypeUtilities.h"
12 #include "llvm/Support/FormatVariadic.h"
17 bool OpTrait::util::staticallyKnownBroadcastable(ArrayRef
<int64_t> shape1
,
18 ArrayRef
<int64_t> shape2
) {
19 SmallVector
<SmallVector
<int64_t, 6>, 2> extents
;
20 extents
.emplace_back(shape1
.begin(), shape1
.end());
21 extents
.emplace_back(shape2
.begin(), shape2
.end());
22 return staticallyKnownBroadcastable(extents
);
25 bool OpTrait::util::staticallyKnownBroadcastable(
26 ArrayRef
<SmallVector
<int64_t, 6>> shapes
) {
27 assert(!shapes
.empty() && "Expected at least one shape");
28 size_t maxRank
= shapes
[0].size();
29 for (size_t i
= 1; i
!= shapes
.size(); ++i
)
30 maxRank
= std::max(maxRank
, shapes
[i
].size());
32 // We look backwards through every column of `shapes`.
33 for (size_t i
= 0; i
!= maxRank
; ++i
) {
34 bool seenDynamic
= false;
35 std::optional
<int64_t> nonOneDim
;
36 for (ArrayRef
<int64_t> extent
: shapes
) {
37 int64_t dim
= i
>= extent
.size() ? 1 : extent
[extent
.size() - i
- 1];
42 // Dimensions are compatible when
43 //. 1. One is dynamic, the rest are 1
44 if (ShapedType::isDynamic(dim
)) {
45 if (seenDynamic
|| nonOneDim
)
50 // 2. All are 1 or a specific constant.
51 if (nonOneDim
&& dim
!= *nonOneDim
)
60 bool OpTrait::util::getBroadcastedShape(ArrayRef
<int64_t> shape1
,
61 ArrayRef
<int64_t> shape2
,
62 SmallVectorImpl
<int64_t> &resultShape
) {
63 // To compute the result broadcasted shape, we compare operand shapes
64 // element-wise: starting with the trailing dimensions, and working the
65 // way backward. Two dimensions are compatible when
66 // 1. they are equal, or
67 // 2. one of them is 1
68 // The result shape has the maximum among the two inputs at every
72 if (shape1
.size() > shape2
.size()) {
73 std::copy(shape1
.begin(), shape1
.end(), std::back_inserter(resultShape
));
75 std::copy(shape2
.begin(), shape2
.end(), std::back_inserter(resultShape
));
78 auto i1
= shape1
.rbegin(), e1
= shape1
.rend();
79 auto i2
= shape2
.rbegin(), e2
= shape2
.rend();
80 auto iR
= resultShape
.rbegin();
82 // Check each dimension is consistent.
83 for (; i1
!= e1
&& i2
!= e2
; ++i1
, ++i2
, ++iR
) {
84 if (ShapedType::isDynamic(*i1
) || ShapedType::isDynamic(*i2
)) {
85 // One or both dimensions is unknown. Follow TensorFlow behavior:
86 // - If either dimension is greater than 1, we assume that the program is
87 // correct, and the other dimension will be broadcast to match it.
88 // - If either dimension is 1, the other dimension is the output.
93 } else if (*i1
== 1) {
95 } else if (*i2
== 1) {
98 *iR
= ShapedType::kDynamic
;
101 if (*i1
== *i2
|| *i2
== 1) {
103 } else if (*i1
== 1) {
106 // This dimension of the two operand types is incompatible.
116 /// Returns the shape of the given type. Scalars will be considered as having a
117 /// shape with zero dimensions.
118 static ArrayRef
<int64_t> getShape(Type type
) {
119 if (auto sType
= dyn_cast
<ShapedType
>(type
))
120 return sType
.getShape();
124 /// Returns the result broadcast composition type from the two given types by
125 /// following NumPy broadcast semantics. Returned type may have dynamic shape if
126 /// either of the input types has dynamic shape. Returns null type if the two
127 /// given types are not broadcast-compatible.
129 /// elementType, if specified, will be used as the element type of the
130 /// broadcasted result type. Otherwise it is required that the element type of
131 /// type1 and type2 is the same and this element type will be used as the
132 /// resultant element type.
133 Type
OpTrait::util::getBroadcastedType(Type type1
, Type type2
,
135 // If the elementType is not specified, then the use the common element type
136 // of the inputs or fail if there is no common element type.
138 elementType
= getElementTypeOrSelf(type1
);
139 if (elementType
!= getElementTypeOrSelf(type2
))
143 // If one of the types is unranked tensor, then the other type shouldn't be
144 // vector and the result should have unranked tensor type.
145 if (isa
<UnrankedTensorType
>(type1
) || isa
<UnrankedTensorType
>(type2
)) {
146 if (isa
<VectorType
>(type1
) || isa
<VectorType
>(type2
))
148 return UnrankedTensorType::get(elementType
);
151 // Returns the type kind if the given type is a vector or ranked tensor type.
152 // Returns std::nullopt otherwise.
153 auto getCompositeTypeKind
= [](Type type
) -> std::optional
<TypeID
> {
154 if (isa
<VectorType
, RankedTensorType
>(type
))
155 return type
.getTypeID();
159 // Make sure the composite type, if has, is consistent.
160 std::optional
<TypeID
> compositeKind1
= getCompositeTypeKind(type1
);
161 std::optional
<TypeID
> compositeKind2
= getCompositeTypeKind(type2
);
162 std::optional
<TypeID
> resultCompositeKind
;
164 if (compositeKind1
&& compositeKind2
) {
165 // Disallow mixing vector and tensor.
166 if (compositeKind1
!= compositeKind2
)
168 resultCompositeKind
= compositeKind1
;
169 } else if (compositeKind1
) {
170 resultCompositeKind
= compositeKind1
;
171 } else if (compositeKind2
) {
172 resultCompositeKind
= compositeKind2
;
175 // Get the shape of each type.
176 SmallVector
<int64_t, 4> resultShape
;
177 if (!getBroadcastedShape(getShape(type1
), getShape(type2
), resultShape
))
180 // Compose the final broadcasted type
181 if (resultCompositeKind
== VectorType::getTypeID())
182 return VectorType::get(resultShape
, elementType
);
183 if (resultCompositeKind
== RankedTensorType::getTypeID())
184 return RankedTensorType::get(resultShape
, elementType
);
188 /// Returns a tuple corresponding to whether range has tensor or vector type.
189 template <typename iterator_range
>
190 static std::tuple
<bool, bool> hasTensorOrVectorType(iterator_range types
) {
191 return std::make_tuple(
192 llvm::any_of(types
, [](Type t
) { return isa
<TensorType
>(t
); }),
193 llvm::any_of(types
, [](Type t
) { return isa
<VectorType
>(t
); }));
196 static bool isCompatibleInferredReturnShape(ArrayRef
<int64_t> inferred
,
197 ArrayRef
<int64_t> existing
) {
198 // If both interred and existing dimensions are static, they must be equal.
199 auto isCompatible
= [](int64_t inferredDim
, int64_t existingDim
) {
200 return ShapedType::isDynamic(existingDim
) ||
201 ShapedType::isDynamic(inferredDim
) || inferredDim
== existingDim
;
203 if (inferred
.size() != existing
.size())
205 for (auto [inferredDim
, existingDim
] : llvm::zip(inferred
, existing
))
206 if (!isCompatible(inferredDim
, existingDim
))
211 static std::string
getShapeString(ArrayRef
<int64_t> shape
) {
212 // TODO: should replace with printing shape more uniformly across here and
215 llvm::raw_string_ostream
ss(ret
);
220 if (ShapedType::isDynamic(dim
))
230 LogicalResult
OpTrait::impl::verifyCompatibleOperandBroadcast(Operation
*op
) {
231 // Ensure broadcasting only tensor or only vector types.
232 auto operandsHasTensorVectorType
=
233 hasTensorOrVectorType(op
->getOperandTypes());
234 auto resultsHasTensorVectorType
= hasTensorOrVectorType(op
->getResultTypes());
235 if ((std::get
<0>(operandsHasTensorVectorType
) ||
236 std::get
<0>(resultsHasTensorVectorType
)) &&
237 (std::get
<1>(operandsHasTensorVectorType
) ||
238 std::get
<1>(resultsHasTensorVectorType
)))
239 return op
->emitError("cannot broadcast vector with tensor");
241 auto rankedOperands
= make_filter_range(
242 op
->getOperandTypes(), [](Type t
) { return isa
<RankedTensorType
>(t
); });
244 // If all operands are unranked, then all result shapes are possible.
245 if (rankedOperands
.empty())
248 // Compute broadcasted shape of operands (which requires that operands are
249 // broadcast compatible). The results need to be broadcast compatible with
250 // this result shape.
251 SmallVector
<int64_t, 4> resultShape
;
252 (void)util::getBroadcastedShape(getShape(*rankedOperands
.begin()), {},
254 for (auto other
: make_early_inc_range(rankedOperands
)) {
255 SmallVector
<int64_t, 4> temp
= resultShape
;
256 if (!util::getBroadcastedShape(temp
, getShape(other
), resultShape
))
257 return op
->emitOpError("operands don't have broadcast-compatible shapes");
260 auto rankedResults
= make_filter_range(
261 op
->getResultTypes(), [](Type t
) { return isa
<RankedTensorType
>(t
); });
263 // If all of the results are unranked then no further verification.
264 if (rankedResults
.empty())
267 for (auto type
: rankedResults
) {
268 ArrayRef
<int64_t> actualSuffix
=
269 getShape(type
).take_back(resultShape
.size());
270 if (!isCompatibleInferredReturnShape(resultShape
, actualSuffix
))
271 return op
->emitOpError()
272 << "result type " << getShapeString(getShape(type
))
273 << " not broadcast compatible with broadcasted operands's shapes "
274 << getShapeString(resultShape
);