1 //===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/IR/BuiltinTypes.h"
10 #include "TypeDetail.h"
11 #include "mlir/IR/AffineExpr.h"
12 #include "mlir/IR/AffineMap.h"
13 #include "mlir/IR/BuiltinAttributes.h"
14 #include "mlir/IR/BuiltinDialect.h"
15 #include "mlir/IR/Diagnostics.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/TensorEncoding.h"
18 #include "mlir/IR/TypeUtilities.h"
19 #include "llvm/ADT/APFloat.h"
20 #include "llvm/ADT/BitVector.h"
21 #include "llvm/ADT/Sequence.h"
22 #include "llvm/ADT/Twine.h"
23 #include "llvm/ADT/TypeSwitch.h"
26 using namespace mlir::detail
;
28 //===----------------------------------------------------------------------===//
29 /// Tablegen Type Definitions
30 //===----------------------------------------------------------------------===//
32 #define GET_TYPEDEF_CLASSES
33 #include "mlir/IR/BuiltinTypes.cpp.inc"
35 //===----------------------------------------------------------------------===//
37 //===----------------------------------------------------------------------===//
39 void BuiltinDialect::registerTypes() {
41 #define GET_TYPEDEF_LIST
42 #include "mlir/IR/BuiltinTypes.cpp.inc"
46 //===----------------------------------------------------------------------===//
48 //===----------------------------------------------------------------------===//
50 /// Verify the construction of an integer type.
51 LogicalResult
ComplexType::verify(function_ref
<InFlightDiagnostic()> emitError
,
53 if (!elementType
.isIntOrFloat())
54 return emitError() << "invalid element type for complex";
58 //===----------------------------------------------------------------------===//
60 //===----------------------------------------------------------------------===//
62 /// Verify the construction of an integer type.
63 LogicalResult
IntegerType::verify(function_ref
<InFlightDiagnostic()> emitError
,
65 SignednessSemantics signedness
) {
66 if (width
> IntegerType::kMaxWidth
) {
67 return emitError() << "integer bitwidth is limited to "
68 << IntegerType::kMaxWidth
<< " bits";
73 unsigned IntegerType::getWidth() const { return getImpl()->width
; }
75 IntegerType::SignednessSemantics
IntegerType::getSignedness() const {
76 return getImpl()->signedness
;
79 IntegerType
IntegerType::scaleElementBitwidth(unsigned scale
) {
82 return IntegerType::get(getContext(), scale
* getWidth(), getSignedness());
85 //===----------------------------------------------------------------------===//
87 //===----------------------------------------------------------------------===//
89 unsigned FloatType::getWidth() {
90 if (llvm::isa
<Float8E5M2Type
, Float8E4M3FNType
, Float8E5M2FNUZType
,
91 Float8E4M3FNUZType
, Float8E4M3B11FNUZType
>(*this))
93 if (llvm::isa
<Float16Type
, BFloat16Type
>(*this))
95 if (llvm::isa
<Float32Type
, FloatTF32Type
>(*this))
97 if (llvm::isa
<Float64Type
>(*this))
99 if (llvm::isa
<Float80Type
>(*this))
101 if (llvm::isa
<Float128Type
>(*this))
103 llvm_unreachable("unexpected float type");
106 /// Returns the floating semantics for the given type.
107 const llvm::fltSemantics
&FloatType::getFloatSemantics() {
108 if (llvm::isa
<Float8E5M2Type
>(*this))
109 return APFloat::Float8E5M2();
110 if (llvm::isa
<Float8E4M3FNType
>(*this))
111 return APFloat::Float8E4M3FN();
112 if (llvm::isa
<Float8E5M2FNUZType
>(*this))
113 return APFloat::Float8E5M2FNUZ();
114 if (llvm::isa
<Float8E4M3FNUZType
>(*this))
115 return APFloat::Float8E4M3FNUZ();
116 if (llvm::isa
<Float8E4M3B11FNUZType
>(*this))
117 return APFloat::Float8E4M3B11FNUZ();
118 if (llvm::isa
<BFloat16Type
>(*this))
119 return APFloat::BFloat();
120 if (llvm::isa
<Float16Type
>(*this))
121 return APFloat::IEEEhalf();
122 if (llvm::isa
<FloatTF32Type
>(*this))
123 return APFloat::FloatTF32();
124 if (llvm::isa
<Float32Type
>(*this))
125 return APFloat::IEEEsingle();
126 if (llvm::isa
<Float64Type
>(*this))
127 return APFloat::IEEEdouble();
128 if (llvm::isa
<Float80Type
>(*this))
129 return APFloat::x87DoubleExtended();
130 if (llvm::isa
<Float128Type
>(*this))
131 return APFloat::IEEEquad();
132 llvm_unreachable("non-floating point type used");
135 FloatType
FloatType::scaleElementBitwidth(unsigned scale
) {
138 MLIRContext
*ctx
= getContext();
139 if (isF16() || isBF16()) {
141 return FloatType::getF32(ctx
);
143 return FloatType::getF64(ctx
);
147 return FloatType::getF64(ctx
);
151 unsigned FloatType::getFPMantissaWidth() {
152 return APFloat::semanticsPrecision(getFloatSemantics());
155 //===----------------------------------------------------------------------===//
157 //===----------------------------------------------------------------------===//
159 unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs
; }
161 ArrayRef
<Type
> FunctionType::getInputs() const {
162 return getImpl()->getInputs();
165 unsigned FunctionType::getNumResults() const { return getImpl()->numResults
; }
167 ArrayRef
<Type
> FunctionType::getResults() const {
168 return getImpl()->getResults();
171 FunctionType
FunctionType::clone(TypeRange inputs
, TypeRange results
) const {
172 return get(getContext(), inputs
, results
);
175 /// Returns a new function type with the specified arguments and results
177 FunctionType
FunctionType::getWithArgsAndResults(
178 ArrayRef
<unsigned> argIndices
, TypeRange argTypes
,
179 ArrayRef
<unsigned> resultIndices
, TypeRange resultTypes
) {
180 SmallVector
<Type
> argStorage
, resultStorage
;
181 TypeRange newArgTypes
=
182 insertTypesInto(getInputs(), argIndices
, argTypes
, argStorage
);
183 TypeRange newResultTypes
=
184 insertTypesInto(getResults(), resultIndices
, resultTypes
, resultStorage
);
185 return clone(newArgTypes
, newResultTypes
);
188 /// Returns a new function type without the specified arguments and results.
190 FunctionType::getWithoutArgsAndResults(const BitVector
&argIndices
,
191 const BitVector
&resultIndices
) {
192 SmallVector
<Type
> argStorage
, resultStorage
;
193 TypeRange newArgTypes
= filterTypesOut(getInputs(), argIndices
, argStorage
);
194 TypeRange newResultTypes
=
195 filterTypesOut(getResults(), resultIndices
, resultStorage
);
196 return clone(newArgTypes
, newResultTypes
);
199 //===----------------------------------------------------------------------===//
201 //===----------------------------------------------------------------------===//
203 /// Verify the construction of an opaque type.
204 LogicalResult
OpaqueType::verify(function_ref
<InFlightDiagnostic()> emitError
,
205 StringAttr dialect
, StringRef typeData
) {
206 if (!Dialect::isValidNamespace(dialect
.strref()))
207 return emitError() << "invalid dialect namespace '" << dialect
<< "'";
209 // Check that the dialect is actually registered.
210 MLIRContext
*context
= dialect
.getContext();
211 if (!context
->allowsUnregisteredDialects() &&
212 !context
->getLoadedDialect(dialect
.strref())) {
214 << "`!" << dialect
<< "<\"" << typeData
<< "\">"
215 << "` type created with unregistered dialect. If this is "
216 "intended, please call allowUnregisteredDialects() on the "
217 "MLIRContext, or use -allow-unregistered-dialect with "
218 "the MLIR opt tool used";
224 //===----------------------------------------------------------------------===//
226 //===----------------------------------------------------------------------===//
228 LogicalResult
VectorType::verify(function_ref
<InFlightDiagnostic()> emitError
,
229 ArrayRef
<int64_t> shape
, Type elementType
,
230 ArrayRef
<bool> scalableDims
) {
231 if (!isValidElementType(elementType
))
233 << "vector elements must be int/index/float type but got "
236 if (any_of(shape
, [](int64_t i
) { return i
<= 0; }))
238 << "vector types must have positive constant sizes but got "
241 if (scalableDims
.size() != shape
.size())
242 return emitError() << "number of dims must match, got "
243 << scalableDims
.size() << " and " << shape
.size();
248 VectorType
VectorType::scaleElementBitwidth(unsigned scale
) {
251 if (auto et
= llvm::dyn_cast
<IntegerType
>(getElementType()))
252 if (auto scaledEt
= et
.scaleElementBitwidth(scale
))
253 return VectorType::get(getShape(), scaledEt
, getScalableDims());
254 if (auto et
= llvm::dyn_cast
<FloatType
>(getElementType()))
255 if (auto scaledEt
= et
.scaleElementBitwidth(scale
))
256 return VectorType::get(getShape(), scaledEt
, getScalableDims());
260 VectorType
VectorType::cloneWith(std::optional
<ArrayRef
<int64_t>> shape
,
261 Type elementType
) const {
262 return VectorType::get(shape
.value_or(getShape()), elementType
,
266 //===----------------------------------------------------------------------===//
268 //===----------------------------------------------------------------------===//
270 Type
TensorType::getElementType() const {
271 return llvm::TypeSwitch
<TensorType
, Type
>(*this)
272 .Case
<RankedTensorType
, UnrankedTensorType
>(
273 [](auto type
) { return type
.getElementType(); });
276 bool TensorType::hasRank() const { return !llvm::isa
<UnrankedTensorType
>(*this); }
278 ArrayRef
<int64_t> TensorType::getShape() const {
279 return llvm::cast
<RankedTensorType
>(*this).getShape();
282 TensorType
TensorType::cloneWith(std::optional
<ArrayRef
<int64_t>> shape
,
283 Type elementType
) const {
284 if (llvm::dyn_cast
<UnrankedTensorType
>(*this)) {
286 return RankedTensorType::get(*shape
, elementType
);
287 return UnrankedTensorType::get(elementType
);
290 auto rankedTy
= llvm::cast
<RankedTensorType
>(*this);
292 return RankedTensorType::get(rankedTy
.getShape(), elementType
,
293 rankedTy
.getEncoding());
294 return RankedTensorType::get(shape
.value_or(rankedTy
.getShape()), elementType
,
295 rankedTy
.getEncoding());
298 RankedTensorType
TensorType::clone(::llvm::ArrayRef
<int64_t> shape
,
299 Type elementType
) const {
300 return ::llvm::cast
<RankedTensorType
>(cloneWith(shape
, elementType
));
303 RankedTensorType
TensorType::clone(::llvm::ArrayRef
<int64_t> shape
) const {
304 return ::llvm::cast
<RankedTensorType
>(cloneWith(shape
, getElementType()));
307 // Check if "elementType" can be an element type of a tensor.
309 checkTensorElementType(function_ref
<InFlightDiagnostic()> emitError
,
311 if (!TensorType::isValidElementType(elementType
))
312 return emitError() << "invalid tensor element type: " << elementType
;
316 /// Return true if the specified element type is ok in a tensor.
317 bool TensorType::isValidElementType(Type type
) {
318 // Note: Non standard/builtin types are allowed to exist within tensor
319 // types. Dialects are expected to verify that tensor types have a valid
320 // element type within that dialect.
321 return llvm::isa
<ComplexType
, FloatType
, IntegerType
, OpaqueType
, VectorType
,
323 !llvm::isa
<BuiltinDialect
>(type
.getDialect());
326 //===----------------------------------------------------------------------===//
328 //===----------------------------------------------------------------------===//
331 RankedTensorType::verify(function_ref
<InFlightDiagnostic()> emitError
,
332 ArrayRef
<int64_t> shape
, Type elementType
,
333 Attribute encoding
) {
334 for (int64_t s
: shape
)
335 if (s
< 0 && !ShapedType::isDynamic(s
))
336 return emitError() << "invalid tensor dimension size";
337 if (auto v
= llvm::dyn_cast_or_null
<VerifiableTensorEncoding
>(encoding
))
338 if (failed(v
.verifyEncoding(shape
, elementType
, emitError
)))
340 return checkTensorElementType(emitError
, elementType
);
343 //===----------------------------------------------------------------------===//
344 // UnrankedTensorType
345 //===----------------------------------------------------------------------===//
348 UnrankedTensorType::verify(function_ref
<InFlightDiagnostic()> emitError
,
350 return checkTensorElementType(emitError
, elementType
);
353 //===----------------------------------------------------------------------===//
355 //===----------------------------------------------------------------------===//
357 Type
BaseMemRefType::getElementType() const {
358 return llvm::TypeSwitch
<BaseMemRefType
, Type
>(*this)
359 .Case
<MemRefType
, UnrankedMemRefType
>(
360 [](auto type
) { return type
.getElementType(); });
363 bool BaseMemRefType::hasRank() const { return !llvm::isa
<UnrankedMemRefType
>(*this); }
365 ArrayRef
<int64_t> BaseMemRefType::getShape() const {
366 return llvm::cast
<MemRefType
>(*this).getShape();
369 BaseMemRefType
BaseMemRefType::cloneWith(std::optional
<ArrayRef
<int64_t>> shape
,
370 Type elementType
) const {
371 if (llvm::dyn_cast
<UnrankedMemRefType
>(*this)) {
373 return UnrankedMemRefType::get(elementType
, getMemorySpace());
374 MemRefType::Builder
builder(*shape
, elementType
);
375 builder
.setMemorySpace(getMemorySpace());
379 MemRefType::Builder
builder(llvm::cast
<MemRefType
>(*this));
381 builder
.setShape(*shape
);
382 builder
.setElementType(elementType
);
386 MemRefType
BaseMemRefType::clone(::llvm::ArrayRef
<int64_t> shape
,
387 Type elementType
) const {
388 return ::llvm::cast
<MemRefType
>(cloneWith(shape
, elementType
));
391 MemRefType
BaseMemRefType::clone(::llvm::ArrayRef
<int64_t> shape
) const {
392 return ::llvm::cast
<MemRefType
>(cloneWith(shape
, getElementType()));
395 Attribute
BaseMemRefType::getMemorySpace() const {
396 if (auto rankedMemRefTy
= llvm::dyn_cast
<MemRefType
>(*this))
397 return rankedMemRefTy
.getMemorySpace();
398 return llvm::cast
<UnrankedMemRefType
>(*this).getMemorySpace();
401 unsigned BaseMemRefType::getMemorySpaceAsInt() const {
402 if (auto rankedMemRefTy
= llvm::dyn_cast
<MemRefType
>(*this))
403 return rankedMemRefTy
.getMemorySpaceAsInt();
404 return llvm::cast
<UnrankedMemRefType
>(*this).getMemorySpaceAsInt();
407 //===----------------------------------------------------------------------===//
409 //===----------------------------------------------------------------------===//
411 /// Given an `originalShape` and a `reducedShape` assumed to be a subset of
412 /// `originalShape` with some `1` entries erased, return the set of indices
413 /// that specifies which of the entries of `originalShape` are dropped to obtain
414 /// `reducedShape`. The returned mask can be applied as a projection to
415 /// `originalShape` to obtain the `reducedShape`. This mask is useful to track
416 /// which dimensions must be kept when e.g. compute MemRef strides under
417 /// rank-reducing operations. Return std::nullopt if reducedShape cannot be
418 /// obtained by dropping only `1` entries in `originalShape`.
419 std::optional
<llvm::SmallDenseSet
<unsigned>>
420 mlir::computeRankReductionMask(ArrayRef
<int64_t> originalShape
,
421 ArrayRef
<int64_t> reducedShape
) {
422 size_t originalRank
= originalShape
.size(), reducedRank
= reducedShape
.size();
423 llvm::SmallDenseSet
<unsigned> unusedDims
;
424 unsigned reducedIdx
= 0;
425 for (unsigned originalIdx
= 0; originalIdx
< originalRank
; ++originalIdx
) {
426 // Greedily insert `originalIdx` if match.
427 if (reducedIdx
< reducedRank
&&
428 originalShape
[originalIdx
] == reducedShape
[reducedIdx
]) {
433 unusedDims
.insert(originalIdx
);
434 // If no match on `originalIdx`, the `originalShape` at this dimension
435 // must be 1, otherwise we bail.
436 if (originalShape
[originalIdx
] != 1)
439 // The whole reducedShape must be scanned, otherwise we bail.
440 if (reducedIdx
!= reducedRank
)
445 SliceVerificationResult
446 mlir::isRankReducedType(ShapedType originalType
,
447 ShapedType candidateReducedType
) {
448 if (originalType
== candidateReducedType
)
449 return SliceVerificationResult::Success
;
451 ShapedType originalShapedType
= llvm::cast
<ShapedType
>(originalType
);
452 ShapedType candidateReducedShapedType
=
453 llvm::cast
<ShapedType
>(candidateReducedType
);
455 // Rank and size logic is valid for all ShapedTypes.
456 ArrayRef
<int64_t> originalShape
= originalShapedType
.getShape();
457 ArrayRef
<int64_t> candidateReducedShape
=
458 candidateReducedShapedType
.getShape();
459 unsigned originalRank
= originalShape
.size(),
460 candidateReducedRank
= candidateReducedShape
.size();
461 if (candidateReducedRank
> originalRank
)
462 return SliceVerificationResult::RankTooLarge
;
464 auto optionalUnusedDimsMask
=
465 computeRankReductionMask(originalShape
, candidateReducedShape
);
467 // Sizes cannot be matched in case empty vector is returned.
468 if (!optionalUnusedDimsMask
)
469 return SliceVerificationResult::SizeMismatch
;
471 if (originalShapedType
.getElementType() !=
472 candidateReducedShapedType
.getElementType())
473 return SliceVerificationResult::ElemTypeMismatch
;
475 return SliceVerificationResult::Success
;
478 bool mlir::detail::isSupportedMemorySpace(Attribute memorySpace
) {
479 // Empty attribute is allowed as default memory space.
483 // Supported built-in attributes.
484 if (llvm::isa
<IntegerAttr
, StringAttr
, DictionaryAttr
>(memorySpace
))
487 // Allow custom dialect attributes.
488 if (!isa
<BuiltinDialect
>(memorySpace
.getDialect()))
494 Attribute
mlir::detail::wrapIntegerMemorySpace(unsigned memorySpace
,
496 if (memorySpace
== 0)
499 return IntegerAttr::get(IntegerType::get(ctx
, 64), memorySpace
);
502 Attribute
mlir::detail::skipDefaultMemorySpace(Attribute memorySpace
) {
503 IntegerAttr intMemorySpace
= llvm::dyn_cast_or_null
<IntegerAttr
>(memorySpace
);
504 if (intMemorySpace
&& intMemorySpace
.getValue() == 0)
510 unsigned mlir::detail::getMemorySpaceAsInt(Attribute memorySpace
) {
514 assert(llvm::isa
<IntegerAttr
>(memorySpace
) &&
515 "Using `getMemorySpaceInteger` with non-Integer attribute");
517 return static_cast<unsigned>(llvm::cast
<IntegerAttr
>(memorySpace
).getInt());
520 unsigned MemRefType::getMemorySpaceAsInt() const {
521 return detail::getMemorySpaceAsInt(getMemorySpace());
524 MemRefType
MemRefType::get(ArrayRef
<int64_t> shape
, Type elementType
,
525 MemRefLayoutAttrInterface layout
,
526 Attribute memorySpace
) {
527 // Use default layout for empty attribute.
529 layout
= AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(
530 shape
.size(), elementType
.getContext()));
532 // Drop default memory space value and replace it with empty attribute.
533 memorySpace
= skipDefaultMemorySpace(memorySpace
);
535 return Base::get(elementType
.getContext(), shape
, elementType
, layout
,
539 MemRefType
MemRefType::getChecked(
540 function_ref
<InFlightDiagnostic()> emitErrorFn
, ArrayRef
<int64_t> shape
,
541 Type elementType
, MemRefLayoutAttrInterface layout
, Attribute memorySpace
) {
543 // Use default layout for empty attribute.
545 layout
= AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(
546 shape
.size(), elementType
.getContext()));
548 // Drop default memory space value and replace it with empty attribute.
549 memorySpace
= skipDefaultMemorySpace(memorySpace
);
551 return Base::getChecked(emitErrorFn
, elementType
.getContext(), shape
,
552 elementType
, layout
, memorySpace
);
555 MemRefType
MemRefType::get(ArrayRef
<int64_t> shape
, Type elementType
,
556 AffineMap map
, Attribute memorySpace
) {
558 // Use default layout for empty map.
560 map
= AffineMap::getMultiDimIdentityMap(shape
.size(),
561 elementType
.getContext());
563 // Wrap AffineMap into Attribute.
564 auto layout
= AffineMapAttr::get(map
);
566 // Drop default memory space value and replace it with empty attribute.
567 memorySpace
= skipDefaultMemorySpace(memorySpace
);
569 return Base::get(elementType
.getContext(), shape
, elementType
, layout
,
574 MemRefType::getChecked(function_ref
<InFlightDiagnostic()> emitErrorFn
,
575 ArrayRef
<int64_t> shape
, Type elementType
, AffineMap map
,
576 Attribute memorySpace
) {
578 // Use default layout for empty map.
580 map
= AffineMap::getMultiDimIdentityMap(shape
.size(),
581 elementType
.getContext());
583 // Wrap AffineMap into Attribute.
584 auto layout
= AffineMapAttr::get(map
);
586 // Drop default memory space value and replace it with empty attribute.
587 memorySpace
= skipDefaultMemorySpace(memorySpace
);
589 return Base::getChecked(emitErrorFn
, elementType
.getContext(), shape
,
590 elementType
, layout
, memorySpace
);
593 MemRefType
MemRefType::get(ArrayRef
<int64_t> shape
, Type elementType
,
594 AffineMap map
, unsigned memorySpaceInd
) {
596 // Use default layout for empty map.
598 map
= AffineMap::getMultiDimIdentityMap(shape
.size(),
599 elementType
.getContext());
601 // Wrap AffineMap into Attribute.
602 auto layout
= AffineMapAttr::get(map
);
604 // Convert deprecated integer-like memory space to Attribute.
605 Attribute memorySpace
=
606 wrapIntegerMemorySpace(memorySpaceInd
, elementType
.getContext());
608 return Base::get(elementType
.getContext(), shape
, elementType
, layout
,
613 MemRefType::getChecked(function_ref
<InFlightDiagnostic()> emitErrorFn
,
614 ArrayRef
<int64_t> shape
, Type elementType
, AffineMap map
,
615 unsigned memorySpaceInd
) {
617 // Use default layout for empty map.
619 map
= AffineMap::getMultiDimIdentityMap(shape
.size(),
620 elementType
.getContext());
622 // Wrap AffineMap into Attribute.
623 auto layout
= AffineMapAttr::get(map
);
625 // Convert deprecated integer-like memory space to Attribute.
626 Attribute memorySpace
=
627 wrapIntegerMemorySpace(memorySpaceInd
, elementType
.getContext());
629 return Base::getChecked(emitErrorFn
, elementType
.getContext(), shape
,
630 elementType
, layout
, memorySpace
);
633 LogicalResult
MemRefType::verify(function_ref
<InFlightDiagnostic()> emitError
,
634 ArrayRef
<int64_t> shape
, Type elementType
,
635 MemRefLayoutAttrInterface layout
,
636 Attribute memorySpace
) {
637 if (!BaseMemRefType::isValidElementType(elementType
))
638 return emitError() << "invalid memref element type";
640 // Negative sizes are not allowed except for `kDynamic`.
641 for (int64_t s
: shape
)
642 if (s
< 0 && !ShapedType::isDynamic(s
))
643 return emitError() << "invalid memref size";
645 assert(layout
&& "missing layout specification");
646 if (failed(layout
.verifyLayout(shape
, emitError
)))
649 if (!isSupportedMemorySpace(memorySpace
))
650 return emitError() << "unsupported memory space Attribute";
655 //===----------------------------------------------------------------------===//
656 // UnrankedMemRefType
657 //===----------------------------------------------------------------------===//
659 unsigned UnrankedMemRefType::getMemorySpaceAsInt() const {
660 return detail::getMemorySpaceAsInt(getMemorySpace());
664 UnrankedMemRefType::verify(function_ref
<InFlightDiagnostic()> emitError
,
665 Type elementType
, Attribute memorySpace
) {
666 if (!BaseMemRefType::isValidElementType(elementType
))
667 return emitError() << "invalid memref element type";
669 if (!isSupportedMemorySpace(memorySpace
))
670 return emitError() << "unsupported memory space Attribute";
675 // Fallback cases for terminal dim/sym/cst that are not part of a binary op (
676 // i.e. single term). Accumulate the AffineExpr into the existing one.
677 static void extractStridesFromTerm(AffineExpr e
,
678 AffineExpr multiplicativeFactor
,
679 MutableArrayRef
<AffineExpr
> strides
,
680 AffineExpr
&offset
) {
681 if (auto dim
= e
.dyn_cast
<AffineDimExpr
>())
682 strides
[dim
.getPosition()] =
683 strides
[dim
.getPosition()] + multiplicativeFactor
;
685 offset
= offset
+ e
* multiplicativeFactor
;
688 /// Takes a single AffineExpr `e` and populates the `strides` array with the
689 /// strides expressions for each dim position.
690 /// The convention is that the strides for dimensions d0, .. dn appear in
691 /// order to make indexing intuitive into the result.
692 static LogicalResult
extractStrides(AffineExpr e
,
693 AffineExpr multiplicativeFactor
,
694 MutableArrayRef
<AffineExpr
> strides
,
695 AffineExpr
&offset
) {
696 auto bin
= e
.dyn_cast
<AffineBinaryOpExpr
>();
698 extractStridesFromTerm(e
, multiplicativeFactor
, strides
, offset
);
702 if (bin
.getKind() == AffineExprKind::CeilDiv
||
703 bin
.getKind() == AffineExprKind::FloorDiv
||
704 bin
.getKind() == AffineExprKind::Mod
)
707 if (bin
.getKind() == AffineExprKind::Mul
) {
708 auto dim
= bin
.getLHS().dyn_cast
<AffineDimExpr
>();
710 strides
[dim
.getPosition()] =
711 strides
[dim
.getPosition()] + bin
.getRHS() * multiplicativeFactor
;
714 // LHS and RHS may both contain complex expressions of dims. Try one path
715 // and if it fails try the other. This is guaranteed to succeed because
716 // only one path may have a `dim`, otherwise this is not an AffineExpr in
718 if (bin
.getLHS().isSymbolicOrConstant())
719 return extractStrides(bin
.getRHS(), multiplicativeFactor
* bin
.getLHS(),
721 return extractStrides(bin
.getLHS(), multiplicativeFactor
* bin
.getRHS(),
725 if (bin
.getKind() == AffineExprKind::Add
) {
727 extractStrides(bin
.getLHS(), multiplicativeFactor
, strides
, offset
);
729 extractStrides(bin
.getRHS(), multiplicativeFactor
, strides
, offset
);
730 return success(succeeded(res1
) && succeeded(res2
));
733 llvm_unreachable("unexpected binary operation");
736 /// A stride specification is a list of integer values that are either static
737 /// or dynamic (encoded with ShapedType::kDynamic). Strides encode
738 /// the distance in the number of elements between successive entries along a
739 /// particular dimension.
741 /// For example, `memref<42x16xf32, (64 * d0 + d1)>` specifies a view into a
742 /// non-contiguous memory region of `42` by `16` `f32` elements in which the
743 /// distance between two consecutive elements along the outer dimension is `1`
744 /// and the distance between two consecutive elements along the inner dimension
747 /// The convention is that the strides for dimensions d0, .. dn appear in
748 /// order to make indexing intuitive into the result.
749 static LogicalResult
getStridesAndOffset(MemRefType t
,
750 SmallVectorImpl
<AffineExpr
> &strides
,
751 AffineExpr
&offset
) {
752 AffineMap m
= t
.getLayout().getAffineMap();
754 if (m
.getNumResults() != 1 && !m
.isIdentity())
757 auto zero
= getAffineConstantExpr(0, t
.getContext());
758 auto one
= getAffineConstantExpr(1, t
.getContext());
760 strides
.assign(t
.getRank(), zero
);
762 // Canonical case for empty map.
763 if (m
.isIdentity()) {
764 // 0-D corner case, offset is already 0.
765 if (t
.getRank() == 0)
768 makeCanonicalStridedLayoutExpr(t
.getShape(), t
.getContext());
769 if (succeeded(extractStrides(stridedExpr
, one
, strides
, offset
)))
771 assert(false && "unexpected failure: extract strides in canonical layout");
774 // Non-canonical case requires more work.
776 simplifyAffineExpr(m
.getResult(0), m
.getNumDims(), m
.getNumSymbols());
777 if (failed(extractStrides(stridedExpr
, one
, strides
, offset
))) {
778 offset
= AffineExpr();
783 // Simplify results to allow folding to constants and simple checks.
784 unsigned numDims
= m
.getNumDims();
785 unsigned numSymbols
= m
.getNumSymbols();
786 offset
= simplifyAffineExpr(offset
, numDims
, numSymbols
);
787 for (auto &stride
: strides
)
788 stride
= simplifyAffineExpr(stride
, numDims
, numSymbols
);
790 // In practice, a strided memref must be internally non-aliasing. Test
791 // against 0 as a proxy.
792 // TODO: static cases can have more advanced checks.
793 // TODO: dynamic cases would require a way to compare symbolic
794 // expressions and would probably need an affine set context propagated
796 if (llvm::any_of(strides
, [](AffineExpr e
) {
797 return e
== getAffineConstantExpr(0, e
.getContext());
799 offset
= AffineExpr();
807 LogicalResult
mlir::getStridesAndOffset(MemRefType t
,
808 SmallVectorImpl
<int64_t> &strides
,
810 // Happy path: the type uses the strided layout directly.
811 if (auto strided
= llvm::dyn_cast
<StridedLayoutAttr
>(t
.getLayout())) {
812 llvm::append_range(strides
, strided
.getStrides());
813 offset
= strided
.getOffset();
817 // Otherwise, defer to the affine fallback as layouts are supposed to be
818 // convertible to affine maps.
819 AffineExpr offsetExpr
;
820 SmallVector
<AffineExpr
, 4> strideExprs
;
821 if (failed(::getStridesAndOffset(t
, strideExprs
, offsetExpr
)))
823 if (auto cst
= offsetExpr
.dyn_cast
<AffineConstantExpr
>())
824 offset
= cst
.getValue();
826 offset
= ShapedType::kDynamic
;
827 for (auto e
: strideExprs
) {
828 if (auto c
= e
.dyn_cast
<AffineConstantExpr
>())
829 strides
.push_back(c
.getValue());
831 strides
.push_back(ShapedType::kDynamic
);
836 std::pair
<SmallVector
<int64_t>, int64_t>
837 mlir::getStridesAndOffset(MemRefType t
) {
838 SmallVector
<int64_t> strides
;
840 LogicalResult status
= getStridesAndOffset(t
, strides
, offset
);
842 assert(succeeded(status
) && "Invalid use of check-free getStridesAndOffset");
843 return {strides
, offset
};
846 //===----------------------------------------------------------------------===//
848 //===----------------------------------------------------------------------===//
850 /// Return the elements types for this tuple.
851 ArrayRef
<Type
> TupleType::getTypes() const { return getImpl()->getTypes(); }
853 /// Accumulate the types contained in this tuple and tuples nested within it.
854 /// Note that this only flattens nested tuples, not any other container type,
855 /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to
856 /// (i32, tensor<i32>, f32, i64)
857 void TupleType::getFlattenedTypes(SmallVectorImpl
<Type
> &types
) {
858 for (Type type
: getTypes()) {
859 if (auto nestedTuple
= llvm::dyn_cast
<TupleType
>(type
))
860 nestedTuple
.getFlattenedTypes(types
);
862 types
.push_back(type
);
866 /// Return the number of element types.
867 size_t TupleType::size() const { return getImpl()->size(); }
869 //===----------------------------------------------------------------------===//
871 //===----------------------------------------------------------------------===//
873 /// Return a version of `t` with identity layout if it can be determined
874 /// statically that the layout is the canonical contiguous strided layout.
875 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
876 /// `t` with simplified layout.
877 /// If `t` has multiple layout maps or a multi-result layout, just return `t`.
878 MemRefType
mlir::canonicalizeStridedLayout(MemRefType t
) {
879 AffineMap m
= t
.getLayout().getAffineMap();
881 // Already in canonical form.
885 // Can't reduce to canonical identity form, return in canonical form.
886 if (m
.getNumResults() > 1)
889 // Corner-case for 0-D affine maps.
890 if (m
.getNumDims() == 0 && m
.getNumSymbols() == 0) {
891 if (auto cst
= m
.getResult(0).dyn_cast
<AffineConstantExpr
>())
892 if (cst
.getValue() == 0)
893 return MemRefType::Builder(t
).setLayout({});
897 // 0-D corner case for empty shape that still have an affine map. Example:
898 // `memref<f32, affine_map<()[s0] -> (s0)>>`. This is a 1 element memref whose
899 // offset needs to remain, just return t.
900 if (t
.getShape().empty())
903 // If the canonical strided layout for the sizes of `t` is equal to the
904 // simplified layout of `t` we can just return an empty layout. Otherwise,
905 // just simplify the existing layout.
907 makeCanonicalStridedLayoutExpr(t
.getShape(), t
.getContext());
908 auto simplifiedLayoutExpr
=
909 simplifyAffineExpr(m
.getResult(0), m
.getNumDims(), m
.getNumSymbols());
910 if (expr
!= simplifiedLayoutExpr
)
911 return MemRefType::Builder(t
).setLayout(AffineMapAttr::get(AffineMap::get(
912 m
.getNumDims(), m
.getNumSymbols(), simplifiedLayoutExpr
)));
913 return MemRefType::Builder(t
).setLayout({});
916 AffineExpr
mlir::makeCanonicalStridedLayoutExpr(ArrayRef
<int64_t> sizes
,
917 ArrayRef
<AffineExpr
> exprs
,
918 MLIRContext
*context
) {
919 // Size 0 corner case is useful for canonicalizations.
921 return getAffineConstantExpr(0, context
);
923 assert(!exprs
.empty() && "expected exprs");
924 auto maps
= AffineMap::inferFromExprList(exprs
);
925 assert(!maps
.empty() && "Expected one non-empty map");
926 unsigned numDims
= maps
[0].getNumDims(), nSymbols
= maps
[0].getNumSymbols();
929 bool dynamicPoisonBit
= false;
930 int64_t runningSize
= 1;
931 for (auto en
: llvm::zip(llvm::reverse(exprs
), llvm::reverse(sizes
))) {
932 int64_t size
= std::get
<1>(en
);
933 AffineExpr dimExpr
= std::get
<0>(en
);
934 AffineExpr stride
= dynamicPoisonBit
935 ? getAffineSymbolExpr(nSymbols
++, context
)
936 : getAffineConstantExpr(runningSize
, context
);
937 expr
= expr
? expr
+ dimExpr
* stride
: dimExpr
* stride
;
940 assert(runningSize
> 0 && "integer overflow in size computation");
942 dynamicPoisonBit
= true;
945 return simplifyAffineExpr(expr
, numDims
, nSymbols
);
948 AffineExpr
mlir::makeCanonicalStridedLayoutExpr(ArrayRef
<int64_t> sizes
,
949 MLIRContext
*context
) {
950 SmallVector
<AffineExpr
, 4> exprs
;
951 exprs
.reserve(sizes
.size());
952 for (auto dim
: llvm::seq
<unsigned>(0, sizes
.size()))
953 exprs
.push_back(getAffineDimExpr(dim
, context
));
954 return makeCanonicalStridedLayoutExpr(sizes
, exprs
, context
);
957 bool mlir::isStrided(MemRefType t
) {
959 SmallVector
<int64_t, 4> strides
;
960 auto res
= getStridesAndOffset(t
, strides
, offset
);
961 return succeeded(res
);
964 bool mlir::isLastMemrefDimUnitStride(MemRefType type
) {
966 SmallVector
<int64_t> strides
;
967 auto successStrides
= getStridesAndOffset(type
, strides
, offset
);
968 return succeeded(successStrides
) && (strides
.empty() || strides
.back() == 1);