1 //===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/IR/BuiltinTypes.h"
10 #include "TypeDetail.h"
11 #include "mlir/IR/AffineExpr.h"
12 #include "mlir/IR/AffineMap.h"
13 #include "mlir/IR/BuiltinAttributes.h"
14 #include "mlir/IR/BuiltinDialect.h"
15 #include "mlir/IR/Diagnostics.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/TensorEncoding.h"
18 #include "mlir/IR/TypeUtilities.h"
19 #include "llvm/ADT/APFloat.h"
20 #include "llvm/ADT/BitVector.h"
21 #include "llvm/ADT/Sequence.h"
22 #include "llvm/ADT/Twine.h"
23 #include "llvm/ADT/TypeSwitch.h"
26 using namespace mlir::detail
;
28 //===----------------------------------------------------------------------===//
29 /// Tablegen Type Definitions
30 //===----------------------------------------------------------------------===//
32 #define GET_TYPEDEF_CLASSES
33 #include "mlir/IR/BuiltinTypes.cpp.inc"
36 #include "mlir/IR/BuiltinTypeConstraints.cpp.inc"
39 //===----------------------------------------------------------------------===//
41 //===----------------------------------------------------------------------===//
43 void BuiltinDialect::registerTypes() {
45 #define GET_TYPEDEF_LIST
46 #include "mlir/IR/BuiltinTypes.cpp.inc"
50 //===----------------------------------------------------------------------===//
52 //===----------------------------------------------------------------------===//
54 /// Verify the construction of an integer type.
55 LogicalResult
ComplexType::verify(function_ref
<InFlightDiagnostic()> emitError
,
57 if (!elementType
.isIntOrFloat())
58 return emitError() << "invalid element type for complex";
62 //===----------------------------------------------------------------------===//
64 //===----------------------------------------------------------------------===//
66 /// Verify the construction of an integer type.
67 LogicalResult
IntegerType::verify(function_ref
<InFlightDiagnostic()> emitError
,
69 SignednessSemantics signedness
) {
70 if (width
> IntegerType::kMaxWidth
) {
71 return emitError() << "integer bitwidth is limited to "
72 << IntegerType::kMaxWidth
<< " bits";
77 unsigned IntegerType::getWidth() const { return getImpl()->width
; }
79 IntegerType::SignednessSemantics
IntegerType::getSignedness() const {
80 return getImpl()->signedness
;
83 IntegerType
IntegerType::scaleElementBitwidth(unsigned scale
) {
86 return IntegerType::get(getContext(), scale
* getWidth(), getSignedness());
89 //===----------------------------------------------------------------------===//
91 //===----------------------------------------------------------------------===//
93 unsigned FloatType::getWidth() {
94 return APFloat::semanticsSizeInBits(getFloatSemantics());
97 /// Returns the floating semantics for the given type.
98 const llvm::fltSemantics
&FloatType::getFloatSemantics() {
99 if (llvm::isa
<Float4E2M1FNType
>(*this))
100 return APFloat::Float4E2M1FN();
101 if (llvm::isa
<Float6E2M3FNType
>(*this))
102 return APFloat::Float6E2M3FN();
103 if (llvm::isa
<Float6E3M2FNType
>(*this))
104 return APFloat::Float6E3M2FN();
105 if (llvm::isa
<Float8E5M2Type
>(*this))
106 return APFloat::Float8E5M2();
107 if (llvm::isa
<Float8E4M3Type
>(*this))
108 return APFloat::Float8E4M3();
109 if (llvm::isa
<Float8E4M3FNType
>(*this))
110 return APFloat::Float8E4M3FN();
111 if (llvm::isa
<Float8E5M2FNUZType
>(*this))
112 return APFloat::Float8E5M2FNUZ();
113 if (llvm::isa
<Float8E4M3FNUZType
>(*this))
114 return APFloat::Float8E4M3FNUZ();
115 if (llvm::isa
<Float8E4M3B11FNUZType
>(*this))
116 return APFloat::Float8E4M3B11FNUZ();
117 if (llvm::isa
<Float8E3M4Type
>(*this))
118 return APFloat::Float8E3M4();
119 if (llvm::isa
<Float8E8M0FNUType
>(*this))
120 return APFloat::Float8E8M0FNU();
121 if (llvm::isa
<BFloat16Type
>(*this))
122 return APFloat::BFloat();
123 if (llvm::isa
<Float16Type
>(*this))
124 return APFloat::IEEEhalf();
125 if (llvm::isa
<FloatTF32Type
>(*this))
126 return APFloat::FloatTF32();
127 if (llvm::isa
<Float32Type
>(*this))
128 return APFloat::IEEEsingle();
129 if (llvm::isa
<Float64Type
>(*this))
130 return APFloat::IEEEdouble();
131 if (llvm::isa
<Float80Type
>(*this))
132 return APFloat::x87DoubleExtended();
133 if (llvm::isa
<Float128Type
>(*this))
134 return APFloat::IEEEquad();
135 llvm_unreachable("non-floating point type used");
138 FloatType
FloatType::scaleElementBitwidth(unsigned scale
) {
141 MLIRContext
*ctx
= getContext();
142 if (isF16() || isBF16()) {
144 return FloatType::getF32(ctx
);
146 return FloatType::getF64(ctx
);
150 return FloatType::getF64(ctx
);
154 unsigned FloatType::getFPMantissaWidth() {
155 return APFloat::semanticsPrecision(getFloatSemantics());
158 //===----------------------------------------------------------------------===//
160 //===----------------------------------------------------------------------===//
162 unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs
; }
164 ArrayRef
<Type
> FunctionType::getInputs() const {
165 return getImpl()->getInputs();
168 unsigned FunctionType::getNumResults() const { return getImpl()->numResults
; }
170 ArrayRef
<Type
> FunctionType::getResults() const {
171 return getImpl()->getResults();
174 FunctionType
FunctionType::clone(TypeRange inputs
, TypeRange results
) const {
175 return get(getContext(), inputs
, results
);
178 /// Returns a new function type with the specified arguments and results
180 FunctionType
FunctionType::getWithArgsAndResults(
181 ArrayRef
<unsigned> argIndices
, TypeRange argTypes
,
182 ArrayRef
<unsigned> resultIndices
, TypeRange resultTypes
) {
183 SmallVector
<Type
> argStorage
, resultStorage
;
184 TypeRange newArgTypes
=
185 insertTypesInto(getInputs(), argIndices
, argTypes
, argStorage
);
186 TypeRange newResultTypes
=
187 insertTypesInto(getResults(), resultIndices
, resultTypes
, resultStorage
);
188 return clone(newArgTypes
, newResultTypes
);
191 /// Returns a new function type without the specified arguments and results.
193 FunctionType::getWithoutArgsAndResults(const BitVector
&argIndices
,
194 const BitVector
&resultIndices
) {
195 SmallVector
<Type
> argStorage
, resultStorage
;
196 TypeRange newArgTypes
= filterTypesOut(getInputs(), argIndices
, argStorage
);
197 TypeRange newResultTypes
=
198 filterTypesOut(getResults(), resultIndices
, resultStorage
);
199 return clone(newArgTypes
, newResultTypes
);
202 //===----------------------------------------------------------------------===//
204 //===----------------------------------------------------------------------===//
206 /// Verify the construction of an opaque type.
207 LogicalResult
OpaqueType::verify(function_ref
<InFlightDiagnostic()> emitError
,
208 StringAttr dialect
, StringRef typeData
) {
209 if (!Dialect::isValidNamespace(dialect
.strref()))
210 return emitError() << "invalid dialect namespace '" << dialect
<< "'";
212 // Check that the dialect is actually registered.
213 MLIRContext
*context
= dialect
.getContext();
214 if (!context
->allowsUnregisteredDialects() &&
215 !context
->getLoadedDialect(dialect
.strref())) {
217 << "`!" << dialect
<< "<\"" << typeData
<< "\">"
218 << "` type created with unregistered dialect. If this is "
219 "intended, please call allowUnregisteredDialects() on the "
220 "MLIRContext, or use -allow-unregistered-dialect with "
221 "the MLIR opt tool used";
227 //===----------------------------------------------------------------------===//
229 //===----------------------------------------------------------------------===//
231 bool VectorType::isValidElementType(Type t
) {
232 return isValidVectorTypeElementType(t
);
235 LogicalResult
VectorType::verify(function_ref
<InFlightDiagnostic()> emitError
,
236 ArrayRef
<int64_t> shape
, Type elementType
,
237 ArrayRef
<bool> scalableDims
) {
238 if (!isValidElementType(elementType
))
240 << "vector elements must be int/index/float type but got "
243 if (any_of(shape
, [](int64_t i
) { return i
<= 0; }))
245 << "vector types must have positive constant sizes but got "
248 if (scalableDims
.size() != shape
.size())
249 return emitError() << "number of dims must match, got "
250 << scalableDims
.size() << " and " << shape
.size();
255 VectorType
VectorType::scaleElementBitwidth(unsigned scale
) {
258 if (auto et
= llvm::dyn_cast
<IntegerType
>(getElementType()))
259 if (auto scaledEt
= et
.scaleElementBitwidth(scale
))
260 return VectorType::get(getShape(), scaledEt
, getScalableDims());
261 if (auto et
= llvm::dyn_cast
<FloatType
>(getElementType()))
262 if (auto scaledEt
= et
.scaleElementBitwidth(scale
))
263 return VectorType::get(getShape(), scaledEt
, getScalableDims());
267 VectorType
VectorType::cloneWith(std::optional
<ArrayRef
<int64_t>> shape
,
268 Type elementType
) const {
269 return VectorType::get(shape
.value_or(getShape()), elementType
,
273 //===----------------------------------------------------------------------===//
275 //===----------------------------------------------------------------------===//
277 Type
TensorType::getElementType() const {
278 return llvm::TypeSwitch
<TensorType
, Type
>(*this)
279 .Case
<RankedTensorType
, UnrankedTensorType
>(
280 [](auto type
) { return type
.getElementType(); });
283 bool TensorType::hasRank() const {
284 return !llvm::isa
<UnrankedTensorType
>(*this);
287 ArrayRef
<int64_t> TensorType::getShape() const {
288 return llvm::cast
<RankedTensorType
>(*this).getShape();
291 TensorType
TensorType::cloneWith(std::optional
<ArrayRef
<int64_t>> shape
,
292 Type elementType
) const {
293 if (llvm::dyn_cast
<UnrankedTensorType
>(*this)) {
295 return RankedTensorType::get(*shape
, elementType
);
296 return UnrankedTensorType::get(elementType
);
299 auto rankedTy
= llvm::cast
<RankedTensorType
>(*this);
301 return RankedTensorType::get(rankedTy
.getShape(), elementType
,
302 rankedTy
.getEncoding());
303 return RankedTensorType::get(shape
.value_or(rankedTy
.getShape()), elementType
,
304 rankedTy
.getEncoding());
307 RankedTensorType
TensorType::clone(::llvm::ArrayRef
<int64_t> shape
,
308 Type elementType
) const {
309 return ::llvm::cast
<RankedTensorType
>(cloneWith(shape
, elementType
));
312 RankedTensorType
TensorType::clone(::llvm::ArrayRef
<int64_t> shape
) const {
313 return ::llvm::cast
<RankedTensorType
>(cloneWith(shape
, getElementType()));
316 // Check if "elementType" can be an element type of a tensor.
318 checkTensorElementType(function_ref
<InFlightDiagnostic()> emitError
,
320 if (!TensorType::isValidElementType(elementType
))
321 return emitError() << "invalid tensor element type: " << elementType
;
325 /// Return true if the specified element type is ok in a tensor.
326 bool TensorType::isValidElementType(Type type
) {
327 // Note: Non standard/builtin types are allowed to exist within tensor
328 // types. Dialects are expected to verify that tensor types have a valid
329 // element type within that dialect.
330 return llvm::isa
<ComplexType
, FloatType
, IntegerType
, OpaqueType
, VectorType
,
332 !llvm::isa
<BuiltinDialect
>(type
.getDialect());
335 //===----------------------------------------------------------------------===//
337 //===----------------------------------------------------------------------===//
340 RankedTensorType::verify(function_ref
<InFlightDiagnostic()> emitError
,
341 ArrayRef
<int64_t> shape
, Type elementType
,
342 Attribute encoding
) {
343 for (int64_t s
: shape
)
344 if (s
< 0 && !ShapedType::isDynamic(s
))
345 return emitError() << "invalid tensor dimension size";
346 if (auto v
= llvm::dyn_cast_or_null
<VerifiableTensorEncoding
>(encoding
))
347 if (failed(v
.verifyEncoding(shape
, elementType
, emitError
)))
349 return checkTensorElementType(emitError
, elementType
);
352 //===----------------------------------------------------------------------===//
353 // UnrankedTensorType
354 //===----------------------------------------------------------------------===//
357 UnrankedTensorType::verify(function_ref
<InFlightDiagnostic()> emitError
,
359 return checkTensorElementType(emitError
, elementType
);
362 //===----------------------------------------------------------------------===//
364 //===----------------------------------------------------------------------===//
366 Type
BaseMemRefType::getElementType() const {
367 return llvm::TypeSwitch
<BaseMemRefType
, Type
>(*this)
368 .Case
<MemRefType
, UnrankedMemRefType
>(
369 [](auto type
) { return type
.getElementType(); });
372 bool BaseMemRefType::hasRank() const {
373 return !llvm::isa
<UnrankedMemRefType
>(*this);
376 ArrayRef
<int64_t> BaseMemRefType::getShape() const {
377 return llvm::cast
<MemRefType
>(*this).getShape();
380 BaseMemRefType
BaseMemRefType::cloneWith(std::optional
<ArrayRef
<int64_t>> shape
,
381 Type elementType
) const {
382 if (llvm::dyn_cast
<UnrankedMemRefType
>(*this)) {
384 return UnrankedMemRefType::get(elementType
, getMemorySpace());
385 MemRefType::Builder
builder(*shape
, elementType
);
386 builder
.setMemorySpace(getMemorySpace());
390 MemRefType::Builder
builder(llvm::cast
<MemRefType
>(*this));
392 builder
.setShape(*shape
);
393 builder
.setElementType(elementType
);
397 MemRefType
BaseMemRefType::clone(::llvm::ArrayRef
<int64_t> shape
,
398 Type elementType
) const {
399 return ::llvm::cast
<MemRefType
>(cloneWith(shape
, elementType
));
402 MemRefType
BaseMemRefType::clone(::llvm::ArrayRef
<int64_t> shape
) const {
403 return ::llvm::cast
<MemRefType
>(cloneWith(shape
, getElementType()));
406 Attribute
BaseMemRefType::getMemorySpace() const {
407 if (auto rankedMemRefTy
= llvm::dyn_cast
<MemRefType
>(*this))
408 return rankedMemRefTy
.getMemorySpace();
409 return llvm::cast
<UnrankedMemRefType
>(*this).getMemorySpace();
412 unsigned BaseMemRefType::getMemorySpaceAsInt() const {
413 if (auto rankedMemRefTy
= llvm::dyn_cast
<MemRefType
>(*this))
414 return rankedMemRefTy
.getMemorySpaceAsInt();
415 return llvm::cast
<UnrankedMemRefType
>(*this).getMemorySpaceAsInt();
418 //===----------------------------------------------------------------------===//
420 //===----------------------------------------------------------------------===//
422 std::optional
<llvm::SmallDenseSet
<unsigned>>
423 mlir::computeRankReductionMask(ArrayRef
<int64_t> originalShape
,
424 ArrayRef
<int64_t> reducedShape
,
426 size_t originalRank
= originalShape
.size(), reducedRank
= reducedShape
.size();
427 llvm::SmallDenseSet
<unsigned> unusedDims
;
428 unsigned reducedIdx
= 0;
429 for (unsigned originalIdx
= 0; originalIdx
< originalRank
; ++originalIdx
) {
430 // Greedily insert `originalIdx` if match.
431 int64_t origSize
= originalShape
[originalIdx
];
432 // if `matchDynamic`, count dynamic dims as a match, unless `origSize` is 1.
433 if (matchDynamic
&& reducedIdx
< reducedRank
&& origSize
!= 1 &&
434 (ShapedType::isDynamic(reducedShape
[reducedIdx
]) ||
435 ShapedType::isDynamic(origSize
))) {
439 if (reducedIdx
< reducedRank
&& origSize
== reducedShape
[reducedIdx
]) {
444 unusedDims
.insert(originalIdx
);
445 // If no match on `originalIdx`, the `originalShape` at this dimension
446 // must be 1, otherwise we bail.
450 // The whole reducedShape must be scanned, otherwise we bail.
451 if (reducedIdx
!= reducedRank
)
456 SliceVerificationResult
457 mlir::isRankReducedType(ShapedType originalType
,
458 ShapedType candidateReducedType
) {
459 if (originalType
== candidateReducedType
)
460 return SliceVerificationResult::Success
;
462 ShapedType originalShapedType
= llvm::cast
<ShapedType
>(originalType
);
463 ShapedType candidateReducedShapedType
=
464 llvm::cast
<ShapedType
>(candidateReducedType
);
466 // Rank and size logic is valid for all ShapedTypes.
467 ArrayRef
<int64_t> originalShape
= originalShapedType
.getShape();
468 ArrayRef
<int64_t> candidateReducedShape
=
469 candidateReducedShapedType
.getShape();
470 unsigned originalRank
= originalShape
.size(),
471 candidateReducedRank
= candidateReducedShape
.size();
472 if (candidateReducedRank
> originalRank
)
473 return SliceVerificationResult::RankTooLarge
;
475 auto optionalUnusedDimsMask
=
476 computeRankReductionMask(originalShape
, candidateReducedShape
);
478 // Sizes cannot be matched in case empty vector is returned.
479 if (!optionalUnusedDimsMask
)
480 return SliceVerificationResult::SizeMismatch
;
482 if (originalShapedType
.getElementType() !=
483 candidateReducedShapedType
.getElementType())
484 return SliceVerificationResult::ElemTypeMismatch
;
486 return SliceVerificationResult::Success
;
489 bool mlir::detail::isSupportedMemorySpace(Attribute memorySpace
) {
490 // Empty attribute is allowed as default memory space.
494 // Supported built-in attributes.
495 if (llvm::isa
<IntegerAttr
, StringAttr
, DictionaryAttr
>(memorySpace
))
498 // Allow custom dialect attributes.
499 if (!isa
<BuiltinDialect
>(memorySpace
.getDialect()))
505 Attribute
mlir::detail::wrapIntegerMemorySpace(unsigned memorySpace
,
507 if (memorySpace
== 0)
510 return IntegerAttr::get(IntegerType::get(ctx
, 64), memorySpace
);
513 Attribute
mlir::detail::skipDefaultMemorySpace(Attribute memorySpace
) {
514 IntegerAttr intMemorySpace
= llvm::dyn_cast_or_null
<IntegerAttr
>(memorySpace
);
515 if (intMemorySpace
&& intMemorySpace
.getValue() == 0)
521 unsigned mlir::detail::getMemorySpaceAsInt(Attribute memorySpace
) {
525 assert(llvm::isa
<IntegerAttr
>(memorySpace
) &&
526 "Using `getMemorySpaceInteger` with non-Integer attribute");
528 return static_cast<unsigned>(llvm::cast
<IntegerAttr
>(memorySpace
).getInt());
531 unsigned MemRefType::getMemorySpaceAsInt() const {
532 return detail::getMemorySpaceAsInt(getMemorySpace());
535 MemRefType
MemRefType::get(ArrayRef
<int64_t> shape
, Type elementType
,
536 MemRefLayoutAttrInterface layout
,
537 Attribute memorySpace
) {
538 // Use default layout for empty attribute.
540 layout
= AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(
541 shape
.size(), elementType
.getContext()));
543 // Drop default memory space value and replace it with empty attribute.
544 memorySpace
= skipDefaultMemorySpace(memorySpace
);
546 return Base::get(elementType
.getContext(), shape
, elementType
, layout
,
550 MemRefType
MemRefType::getChecked(
551 function_ref
<InFlightDiagnostic()> emitErrorFn
, ArrayRef
<int64_t> shape
,
552 Type elementType
, MemRefLayoutAttrInterface layout
, Attribute memorySpace
) {
554 // Use default layout for empty attribute.
556 layout
= AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(
557 shape
.size(), elementType
.getContext()));
559 // Drop default memory space value and replace it with empty attribute.
560 memorySpace
= skipDefaultMemorySpace(memorySpace
);
562 return Base::getChecked(emitErrorFn
, elementType
.getContext(), shape
,
563 elementType
, layout
, memorySpace
);
566 MemRefType
MemRefType::get(ArrayRef
<int64_t> shape
, Type elementType
,
567 AffineMap map
, Attribute memorySpace
) {
569 // Use default layout for empty map.
571 map
= AffineMap::getMultiDimIdentityMap(shape
.size(),
572 elementType
.getContext());
574 // Wrap AffineMap into Attribute.
575 auto layout
= AffineMapAttr::get(map
);
577 // Drop default memory space value and replace it with empty attribute.
578 memorySpace
= skipDefaultMemorySpace(memorySpace
);
580 return Base::get(elementType
.getContext(), shape
, elementType
, layout
,
585 MemRefType::getChecked(function_ref
<InFlightDiagnostic()> emitErrorFn
,
586 ArrayRef
<int64_t> shape
, Type elementType
, AffineMap map
,
587 Attribute memorySpace
) {
589 // Use default layout for empty map.
591 map
= AffineMap::getMultiDimIdentityMap(shape
.size(),
592 elementType
.getContext());
594 // Wrap AffineMap into Attribute.
595 auto layout
= AffineMapAttr::get(map
);
597 // Drop default memory space value and replace it with empty attribute.
598 memorySpace
= skipDefaultMemorySpace(memorySpace
);
600 return Base::getChecked(emitErrorFn
, elementType
.getContext(), shape
,
601 elementType
, layout
, memorySpace
);
604 MemRefType
MemRefType::get(ArrayRef
<int64_t> shape
, Type elementType
,
605 AffineMap map
, unsigned memorySpaceInd
) {
607 // Use default layout for empty map.
609 map
= AffineMap::getMultiDimIdentityMap(shape
.size(),
610 elementType
.getContext());
612 // Wrap AffineMap into Attribute.
613 auto layout
= AffineMapAttr::get(map
);
615 // Convert deprecated integer-like memory space to Attribute.
616 Attribute memorySpace
=
617 wrapIntegerMemorySpace(memorySpaceInd
, elementType
.getContext());
619 return Base::get(elementType
.getContext(), shape
, elementType
, layout
,
624 MemRefType::getChecked(function_ref
<InFlightDiagnostic()> emitErrorFn
,
625 ArrayRef
<int64_t> shape
, Type elementType
, AffineMap map
,
626 unsigned memorySpaceInd
) {
628 // Use default layout for empty map.
630 map
= AffineMap::getMultiDimIdentityMap(shape
.size(),
631 elementType
.getContext());
633 // Wrap AffineMap into Attribute.
634 auto layout
= AffineMapAttr::get(map
);
636 // Convert deprecated integer-like memory space to Attribute.
637 Attribute memorySpace
=
638 wrapIntegerMemorySpace(memorySpaceInd
, elementType
.getContext());
640 return Base::getChecked(emitErrorFn
, elementType
.getContext(), shape
,
641 elementType
, layout
, memorySpace
);
644 LogicalResult
MemRefType::verify(function_ref
<InFlightDiagnostic()> emitError
,
645 ArrayRef
<int64_t> shape
, Type elementType
,
646 MemRefLayoutAttrInterface layout
,
647 Attribute memorySpace
) {
648 if (!BaseMemRefType::isValidElementType(elementType
))
649 return emitError() << "invalid memref element type";
651 // Negative sizes are not allowed except for `kDynamic`.
652 for (int64_t s
: shape
)
653 if (s
< 0 && !ShapedType::isDynamic(s
))
654 return emitError() << "invalid memref size";
656 assert(layout
&& "missing layout specification");
657 if (failed(layout
.verifyLayout(shape
, emitError
)))
660 if (!isSupportedMemorySpace(memorySpace
))
661 return emitError() << "unsupported memory space Attribute";
666 //===----------------------------------------------------------------------===//
667 // UnrankedMemRefType
668 //===----------------------------------------------------------------------===//
670 unsigned UnrankedMemRefType::getMemorySpaceAsInt() const {
671 return detail::getMemorySpaceAsInt(getMemorySpace());
675 UnrankedMemRefType::verify(function_ref
<InFlightDiagnostic()> emitError
,
676 Type elementType
, Attribute memorySpace
) {
677 if (!BaseMemRefType::isValidElementType(elementType
))
678 return emitError() << "invalid memref element type";
680 if (!isSupportedMemorySpace(memorySpace
))
681 return emitError() << "unsupported memory space Attribute";
686 // Fallback cases for terminal dim/sym/cst that are not part of a binary op (
687 // i.e. single term). Accumulate the AffineExpr into the existing one.
688 static void extractStridesFromTerm(AffineExpr e
,
689 AffineExpr multiplicativeFactor
,
690 MutableArrayRef
<AffineExpr
> strides
,
691 AffineExpr
&offset
) {
692 if (auto dim
= dyn_cast
<AffineDimExpr
>(e
))
693 strides
[dim
.getPosition()] =
694 strides
[dim
.getPosition()] + multiplicativeFactor
;
696 offset
= offset
+ e
* multiplicativeFactor
;
699 /// Takes a single AffineExpr `e` and populates the `strides` array with the
700 /// strides expressions for each dim position.
701 /// The convention is that the strides for dimensions d0, .. dn appear in
702 /// order to make indexing intuitive into the result.
703 static LogicalResult
extractStrides(AffineExpr e
,
704 AffineExpr multiplicativeFactor
,
705 MutableArrayRef
<AffineExpr
> strides
,
706 AffineExpr
&offset
) {
707 auto bin
= dyn_cast
<AffineBinaryOpExpr
>(e
);
709 extractStridesFromTerm(e
, multiplicativeFactor
, strides
, offset
);
713 if (bin
.getKind() == AffineExprKind::CeilDiv
||
714 bin
.getKind() == AffineExprKind::FloorDiv
||
715 bin
.getKind() == AffineExprKind::Mod
)
718 if (bin
.getKind() == AffineExprKind::Mul
) {
719 auto dim
= dyn_cast
<AffineDimExpr
>(bin
.getLHS());
721 strides
[dim
.getPosition()] =
722 strides
[dim
.getPosition()] + bin
.getRHS() * multiplicativeFactor
;
725 // LHS and RHS may both contain complex expressions of dims. Try one path
726 // and if it fails try the other. This is guaranteed to succeed because
727 // only one path may have a `dim`, otherwise this is not an AffineExpr in
729 if (bin
.getLHS().isSymbolicOrConstant())
730 return extractStrides(bin
.getRHS(), multiplicativeFactor
* bin
.getLHS(),
732 return extractStrides(bin
.getLHS(), multiplicativeFactor
* bin
.getRHS(),
736 if (bin
.getKind() == AffineExprKind::Add
) {
738 extractStrides(bin
.getLHS(), multiplicativeFactor
, strides
, offset
);
740 extractStrides(bin
.getRHS(), multiplicativeFactor
, strides
, offset
);
741 return success(succeeded(res1
) && succeeded(res2
));
744 llvm_unreachable("unexpected binary operation");
747 /// A stride specification is a list of integer values that are either static
748 /// or dynamic (encoded with ShapedType::kDynamic). Strides encode
749 /// the distance in the number of elements between successive entries along a
750 /// particular dimension.
752 /// For example, `memref<42x16xf32, (64 * d0 + d1)>` specifies a view into a
753 /// non-contiguous memory region of `42` by `16` `f32` elements in which the
754 /// distance between two consecutive elements along the outer dimension is `1`
755 /// and the distance between two consecutive elements along the inner dimension
758 /// The convention is that the strides for dimensions d0, .. dn appear in
759 /// order to make indexing intuitive into the result.
760 static LogicalResult
getStridesAndOffset(MemRefType t
,
761 SmallVectorImpl
<AffineExpr
> &strides
,
762 AffineExpr
&offset
) {
763 AffineMap m
= t
.getLayout().getAffineMap();
765 if (m
.getNumResults() != 1 && !m
.isIdentity())
768 auto zero
= getAffineConstantExpr(0, t
.getContext());
769 auto one
= getAffineConstantExpr(1, t
.getContext());
771 strides
.assign(t
.getRank(), zero
);
773 // Canonical case for empty map.
774 if (m
.isIdentity()) {
775 // 0-D corner case, offset is already 0.
776 if (t
.getRank() == 0)
779 makeCanonicalStridedLayoutExpr(t
.getShape(), t
.getContext());
780 if (succeeded(extractStrides(stridedExpr
, one
, strides
, offset
)))
782 assert(false && "unexpected failure: extract strides in canonical layout");
785 // Non-canonical case requires more work.
787 simplifyAffineExpr(m
.getResult(0), m
.getNumDims(), m
.getNumSymbols());
788 if (failed(extractStrides(stridedExpr
, one
, strides
, offset
))) {
789 offset
= AffineExpr();
794 // Simplify results to allow folding to constants and simple checks.
795 unsigned numDims
= m
.getNumDims();
796 unsigned numSymbols
= m
.getNumSymbols();
797 offset
= simplifyAffineExpr(offset
, numDims
, numSymbols
);
798 for (auto &stride
: strides
)
799 stride
= simplifyAffineExpr(stride
, numDims
, numSymbols
);
804 LogicalResult
mlir::getStridesAndOffset(MemRefType t
,
805 SmallVectorImpl
<int64_t> &strides
,
807 // Happy path: the type uses the strided layout directly.
808 if (auto strided
= llvm::dyn_cast
<StridedLayoutAttr
>(t
.getLayout())) {
809 llvm::append_range(strides
, strided
.getStrides());
810 offset
= strided
.getOffset();
814 // Otherwise, defer to the affine fallback as layouts are supposed to be
815 // convertible to affine maps.
816 AffineExpr offsetExpr
;
817 SmallVector
<AffineExpr
, 4> strideExprs
;
818 if (failed(::getStridesAndOffset(t
, strideExprs
, offsetExpr
)))
820 if (auto cst
= dyn_cast
<AffineConstantExpr
>(offsetExpr
))
821 offset
= cst
.getValue();
823 offset
= ShapedType::kDynamic
;
824 for (auto e
: strideExprs
) {
825 if (auto c
= dyn_cast
<AffineConstantExpr
>(e
))
826 strides
.push_back(c
.getValue());
828 strides
.push_back(ShapedType::kDynamic
);
833 std::pair
<SmallVector
<int64_t>, int64_t>
834 mlir::getStridesAndOffset(MemRefType t
) {
835 SmallVector
<int64_t> strides
;
837 LogicalResult status
= getStridesAndOffset(t
, strides
, offset
);
839 assert(succeeded(status
) && "Invalid use of check-free getStridesAndOffset");
840 return {strides
, offset
};
843 //===----------------------------------------------------------------------===//
845 //===----------------------------------------------------------------------===//
847 /// Return the elements types for this tuple.
848 ArrayRef
<Type
> TupleType::getTypes() const { return getImpl()->getTypes(); }
850 /// Accumulate the types contained in this tuple and tuples nested within it.
851 /// Note that this only flattens nested tuples, not any other container type,
852 /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to
853 /// (i32, tensor<i32>, f32, i64)
854 void TupleType::getFlattenedTypes(SmallVectorImpl
<Type
> &types
) {
855 for (Type type
: getTypes()) {
856 if (auto nestedTuple
= llvm::dyn_cast
<TupleType
>(type
))
857 nestedTuple
.getFlattenedTypes(types
);
859 types
.push_back(type
);
863 /// Return the number of element types.
864 size_t TupleType::size() const { return getImpl()->size(); }
866 //===----------------------------------------------------------------------===//
868 //===----------------------------------------------------------------------===//
870 /// Return a version of `t` with identity layout if it can be determined
871 /// statically that the layout is the canonical contiguous strided layout.
872 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
873 /// `t` with simplified layout.
874 /// If `t` has multiple layout maps or a multi-result layout, just return `t`.
875 MemRefType
mlir::canonicalizeStridedLayout(MemRefType t
) {
876 AffineMap m
= t
.getLayout().getAffineMap();
878 // Already in canonical form.
882 // Can't reduce to canonical identity form, return in canonical form.
883 if (m
.getNumResults() > 1)
886 // Corner-case for 0-D affine maps.
887 if (m
.getNumDims() == 0 && m
.getNumSymbols() == 0) {
888 if (auto cst
= dyn_cast
<AffineConstantExpr
>(m
.getResult(0)))
889 if (cst
.getValue() == 0)
890 return MemRefType::Builder(t
).setLayout({});
894 // 0-D corner case for empty shape that still have an affine map. Example:
895 // `memref<f32, affine_map<()[s0] -> (s0)>>`. This is a 1 element memref whose
896 // offset needs to remain, just return t.
897 if (t
.getShape().empty())
900 // If the canonical strided layout for the sizes of `t` is equal to the
901 // simplified layout of `t` we can just return an empty layout. Otherwise,
902 // just simplify the existing layout.
904 makeCanonicalStridedLayoutExpr(t
.getShape(), t
.getContext());
905 auto simplifiedLayoutExpr
=
906 simplifyAffineExpr(m
.getResult(0), m
.getNumDims(), m
.getNumSymbols());
907 if (expr
!= simplifiedLayoutExpr
)
908 return MemRefType::Builder(t
).setLayout(AffineMapAttr::get(AffineMap::get(
909 m
.getNumDims(), m
.getNumSymbols(), simplifiedLayoutExpr
)));
910 return MemRefType::Builder(t
).setLayout({});
913 AffineExpr
mlir::makeCanonicalStridedLayoutExpr(ArrayRef
<int64_t> sizes
,
914 ArrayRef
<AffineExpr
> exprs
,
915 MLIRContext
*context
) {
916 // Size 0 corner case is useful for canonicalizations.
918 return getAffineConstantExpr(0, context
);
920 assert(!exprs
.empty() && "expected exprs");
921 auto maps
= AffineMap::inferFromExprList(exprs
, context
);
922 assert(!maps
.empty() && "Expected one non-empty map");
923 unsigned numDims
= maps
[0].getNumDims(), nSymbols
= maps
[0].getNumSymbols();
926 bool dynamicPoisonBit
= false;
927 int64_t runningSize
= 1;
928 for (auto en
: llvm::zip(llvm::reverse(exprs
), llvm::reverse(sizes
))) {
929 int64_t size
= std::get
<1>(en
);
930 AffineExpr dimExpr
= std::get
<0>(en
);
931 AffineExpr stride
= dynamicPoisonBit
932 ? getAffineSymbolExpr(nSymbols
++, context
)
933 : getAffineConstantExpr(runningSize
, context
);
934 expr
= expr
? expr
+ dimExpr
* stride
: dimExpr
* stride
;
937 assert(runningSize
> 0 && "integer overflow in size computation");
939 dynamicPoisonBit
= true;
942 return simplifyAffineExpr(expr
, numDims
, nSymbols
);
945 AffineExpr
mlir::makeCanonicalStridedLayoutExpr(ArrayRef
<int64_t> sizes
,
946 MLIRContext
*context
) {
947 SmallVector
<AffineExpr
, 4> exprs
;
948 exprs
.reserve(sizes
.size());
949 for (auto dim
: llvm::seq
<unsigned>(0, sizes
.size()))
950 exprs
.push_back(getAffineDimExpr(dim
, context
));
951 return makeCanonicalStridedLayoutExpr(sizes
, exprs
, context
);
954 bool mlir::isStrided(MemRefType t
) {
956 SmallVector
<int64_t, 4> strides
;
957 auto res
= getStridesAndOffset(t
, strides
, offset
);
958 return succeeded(res
);
961 bool mlir::isLastMemrefDimUnitStride(MemRefType type
) {
963 SmallVector
<int64_t> strides
;
964 auto successStrides
= getStridesAndOffset(type
, strides
, offset
);
965 return succeeded(successStrides
) && (strides
.empty() || strides
.back() == 1);
968 bool mlir::trailingNDimsContiguous(MemRefType type
, int64_t n
) {
969 if (!isLastMemrefDimUnitStride(type
))
972 auto memrefShape
= type
.getShape().take_back(n
);
973 if (ShapedType::isDynamicShape(memrefShape
))
976 if (type
.getLayout().isIdentity())
980 SmallVector
<int64_t> stridesFull
;
981 if (!succeeded(getStridesAndOffset(type
, stridesFull
, offset
)))
983 auto strides
= ArrayRef
<int64_t>(stridesFull
).take_back(n
);
988 // Check whether strides match "flattened" dims.
989 SmallVector
<int64_t> flattenedDims
;
991 for (auto dim
: llvm::reverse(memrefShape
.drop_front(1))) {
993 flattenedDims
.push_back(dimProduct
);
996 strides
= strides
.drop_back(1);
997 return llvm::equal(strides
, llvm::reverse(flattenedDims
));