[OptTable] Fix typo VALUE => VALUES (NFCI) (#121523)
[llvm-project.git] / mlir / lib / IR / BuiltinTypes.cpp
blob6546234429c8cbec3b3a21881560b403a9ab9275
1 //===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/IR/BuiltinTypes.h"
10 #include "TypeDetail.h"
11 #include "mlir/IR/AffineExpr.h"
12 #include "mlir/IR/AffineMap.h"
13 #include "mlir/IR/BuiltinAttributes.h"
14 #include "mlir/IR/BuiltinDialect.h"
15 #include "mlir/IR/Diagnostics.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/TensorEncoding.h"
18 #include "mlir/IR/TypeUtilities.h"
19 #include "llvm/ADT/APFloat.h"
20 #include "llvm/ADT/BitVector.h"
21 #include "llvm/ADT/Sequence.h"
22 #include "llvm/ADT/Twine.h"
23 #include "llvm/ADT/TypeSwitch.h"
25 using namespace mlir;
26 using namespace mlir::detail;
28 //===----------------------------------------------------------------------===//
29 /// Tablegen Type Definitions
30 //===----------------------------------------------------------------------===//
32 #define GET_TYPEDEF_CLASSES
33 #include "mlir/IR/BuiltinTypes.cpp.inc"
35 namespace mlir {
36 #include "mlir/IR/BuiltinTypeConstraints.cpp.inc"
37 } // namespace mlir
39 //===----------------------------------------------------------------------===//
40 // BuiltinDialect
41 //===----------------------------------------------------------------------===//
43 void BuiltinDialect::registerTypes() {
44 addTypes<
45 #define GET_TYPEDEF_LIST
46 #include "mlir/IR/BuiltinTypes.cpp.inc"
47 >();
50 //===----------------------------------------------------------------------===//
51 /// ComplexType
52 //===----------------------------------------------------------------------===//
54 /// Verify the construction of an integer type.
55 LogicalResult ComplexType::verify(function_ref<InFlightDiagnostic()> emitError,
56 Type elementType) {
57 if (!elementType.isIntOrFloat())
58 return emitError() << "invalid element type for complex";
59 return success();
62 //===----------------------------------------------------------------------===//
63 // Integer Type
64 //===----------------------------------------------------------------------===//
66 /// Verify the construction of an integer type.
67 LogicalResult IntegerType::verify(function_ref<InFlightDiagnostic()> emitError,
68 unsigned width,
69 SignednessSemantics signedness) {
70 if (width > IntegerType::kMaxWidth) {
71 return emitError() << "integer bitwidth is limited to "
72 << IntegerType::kMaxWidth << " bits";
74 return success();
77 unsigned IntegerType::getWidth() const { return getImpl()->width; }
79 IntegerType::SignednessSemantics IntegerType::getSignedness() const {
80 return getImpl()->signedness;
83 IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
84 if (!scale)
85 return IntegerType();
86 return IntegerType::get(getContext(), scale * getWidth(), getSignedness());
89 //===----------------------------------------------------------------------===//
90 // Float Type
91 //===----------------------------------------------------------------------===//
93 unsigned FloatType::getWidth() {
94 return APFloat::semanticsSizeInBits(getFloatSemantics());
97 /// Returns the floating semantics for the given type.
98 const llvm::fltSemantics &FloatType::getFloatSemantics() {
99 if (llvm::isa<Float4E2M1FNType>(*this))
100 return APFloat::Float4E2M1FN();
101 if (llvm::isa<Float6E2M3FNType>(*this))
102 return APFloat::Float6E2M3FN();
103 if (llvm::isa<Float6E3M2FNType>(*this))
104 return APFloat::Float6E3M2FN();
105 if (llvm::isa<Float8E5M2Type>(*this))
106 return APFloat::Float8E5M2();
107 if (llvm::isa<Float8E4M3Type>(*this))
108 return APFloat::Float8E4M3();
109 if (llvm::isa<Float8E4M3FNType>(*this))
110 return APFloat::Float8E4M3FN();
111 if (llvm::isa<Float8E5M2FNUZType>(*this))
112 return APFloat::Float8E5M2FNUZ();
113 if (llvm::isa<Float8E4M3FNUZType>(*this))
114 return APFloat::Float8E4M3FNUZ();
115 if (llvm::isa<Float8E4M3B11FNUZType>(*this))
116 return APFloat::Float8E4M3B11FNUZ();
117 if (llvm::isa<Float8E3M4Type>(*this))
118 return APFloat::Float8E3M4();
119 if (llvm::isa<Float8E8M0FNUType>(*this))
120 return APFloat::Float8E8M0FNU();
121 if (llvm::isa<BFloat16Type>(*this))
122 return APFloat::BFloat();
123 if (llvm::isa<Float16Type>(*this))
124 return APFloat::IEEEhalf();
125 if (llvm::isa<FloatTF32Type>(*this))
126 return APFloat::FloatTF32();
127 if (llvm::isa<Float32Type>(*this))
128 return APFloat::IEEEsingle();
129 if (llvm::isa<Float64Type>(*this))
130 return APFloat::IEEEdouble();
131 if (llvm::isa<Float80Type>(*this))
132 return APFloat::x87DoubleExtended();
133 if (llvm::isa<Float128Type>(*this))
134 return APFloat::IEEEquad();
135 llvm_unreachable("non-floating point type used");
138 FloatType FloatType::scaleElementBitwidth(unsigned scale) {
139 if (!scale)
140 return FloatType();
141 MLIRContext *ctx = getContext();
142 if (isF16() || isBF16()) {
143 if (scale == 2)
144 return FloatType::getF32(ctx);
145 if (scale == 4)
146 return FloatType::getF64(ctx);
148 if (isF32())
149 if (scale == 2)
150 return FloatType::getF64(ctx);
151 return FloatType();
154 unsigned FloatType::getFPMantissaWidth() {
155 return APFloat::semanticsPrecision(getFloatSemantics());
158 //===----------------------------------------------------------------------===//
159 // FunctionType
160 //===----------------------------------------------------------------------===//
162 unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; }
164 ArrayRef<Type> FunctionType::getInputs() const {
165 return getImpl()->getInputs();
168 unsigned FunctionType::getNumResults() const { return getImpl()->numResults; }
170 ArrayRef<Type> FunctionType::getResults() const {
171 return getImpl()->getResults();
174 FunctionType FunctionType::clone(TypeRange inputs, TypeRange results) const {
175 return get(getContext(), inputs, results);
178 /// Returns a new function type with the specified arguments and results
179 /// inserted.
180 FunctionType FunctionType::getWithArgsAndResults(
181 ArrayRef<unsigned> argIndices, TypeRange argTypes,
182 ArrayRef<unsigned> resultIndices, TypeRange resultTypes) {
183 SmallVector<Type> argStorage, resultStorage;
184 TypeRange newArgTypes =
185 insertTypesInto(getInputs(), argIndices, argTypes, argStorage);
186 TypeRange newResultTypes =
187 insertTypesInto(getResults(), resultIndices, resultTypes, resultStorage);
188 return clone(newArgTypes, newResultTypes);
191 /// Returns a new function type without the specified arguments and results.
192 FunctionType
193 FunctionType::getWithoutArgsAndResults(const BitVector &argIndices,
194 const BitVector &resultIndices) {
195 SmallVector<Type> argStorage, resultStorage;
196 TypeRange newArgTypes = filterTypesOut(getInputs(), argIndices, argStorage);
197 TypeRange newResultTypes =
198 filterTypesOut(getResults(), resultIndices, resultStorage);
199 return clone(newArgTypes, newResultTypes);
202 //===----------------------------------------------------------------------===//
203 // OpaqueType
204 //===----------------------------------------------------------------------===//
206 /// Verify the construction of an opaque type.
207 LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError,
208 StringAttr dialect, StringRef typeData) {
209 if (!Dialect::isValidNamespace(dialect.strref()))
210 return emitError() << "invalid dialect namespace '" << dialect << "'";
212 // Check that the dialect is actually registered.
213 MLIRContext *context = dialect.getContext();
214 if (!context->allowsUnregisteredDialects() &&
215 !context->getLoadedDialect(dialect.strref())) {
216 return emitError()
217 << "`!" << dialect << "<\"" << typeData << "\">"
218 << "` type created with unregistered dialect. If this is "
219 "intended, please call allowUnregisteredDialects() on the "
220 "MLIRContext, or use -allow-unregistered-dialect with "
221 "the MLIR opt tool used";
224 return success();
227 //===----------------------------------------------------------------------===//
228 // VectorType
229 //===----------------------------------------------------------------------===//
231 bool VectorType::isValidElementType(Type t) {
232 return isValidVectorTypeElementType(t);
235 LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
236 ArrayRef<int64_t> shape, Type elementType,
237 ArrayRef<bool> scalableDims) {
238 if (!isValidElementType(elementType))
239 return emitError()
240 << "vector elements must be int/index/float type but got "
241 << elementType;
243 if (any_of(shape, [](int64_t i) { return i <= 0; }))
244 return emitError()
245 << "vector types must have positive constant sizes but got "
246 << shape;
248 if (scalableDims.size() != shape.size())
249 return emitError() << "number of dims must match, got "
250 << scalableDims.size() << " and " << shape.size();
252 return success();
255 VectorType VectorType::scaleElementBitwidth(unsigned scale) {
256 if (!scale)
257 return VectorType();
258 if (auto et = llvm::dyn_cast<IntegerType>(getElementType()))
259 if (auto scaledEt = et.scaleElementBitwidth(scale))
260 return VectorType::get(getShape(), scaledEt, getScalableDims());
261 if (auto et = llvm::dyn_cast<FloatType>(getElementType()))
262 if (auto scaledEt = et.scaleElementBitwidth(scale))
263 return VectorType::get(getShape(), scaledEt, getScalableDims());
264 return VectorType();
267 VectorType VectorType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
268 Type elementType) const {
269 return VectorType::get(shape.value_or(getShape()), elementType,
270 getScalableDims());
273 //===----------------------------------------------------------------------===//
274 // TensorType
275 //===----------------------------------------------------------------------===//
277 Type TensorType::getElementType() const {
278 return llvm::TypeSwitch<TensorType, Type>(*this)
279 .Case<RankedTensorType, UnrankedTensorType>(
280 [](auto type) { return type.getElementType(); });
283 bool TensorType::hasRank() const {
284 return !llvm::isa<UnrankedTensorType>(*this);
287 ArrayRef<int64_t> TensorType::getShape() const {
288 return llvm::cast<RankedTensorType>(*this).getShape();
291 TensorType TensorType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
292 Type elementType) const {
293 if (llvm::dyn_cast<UnrankedTensorType>(*this)) {
294 if (shape)
295 return RankedTensorType::get(*shape, elementType);
296 return UnrankedTensorType::get(elementType);
299 auto rankedTy = llvm::cast<RankedTensorType>(*this);
300 if (!shape)
301 return RankedTensorType::get(rankedTy.getShape(), elementType,
302 rankedTy.getEncoding());
303 return RankedTensorType::get(shape.value_or(rankedTy.getShape()), elementType,
304 rankedTy.getEncoding());
307 RankedTensorType TensorType::clone(::llvm::ArrayRef<int64_t> shape,
308 Type elementType) const {
309 return ::llvm::cast<RankedTensorType>(cloneWith(shape, elementType));
312 RankedTensorType TensorType::clone(::llvm::ArrayRef<int64_t> shape) const {
313 return ::llvm::cast<RankedTensorType>(cloneWith(shape, getElementType()));
316 // Check if "elementType" can be an element type of a tensor.
317 static LogicalResult
318 checkTensorElementType(function_ref<InFlightDiagnostic()> emitError,
319 Type elementType) {
320 if (!TensorType::isValidElementType(elementType))
321 return emitError() << "invalid tensor element type: " << elementType;
322 return success();
325 /// Return true if the specified element type is ok in a tensor.
326 bool TensorType::isValidElementType(Type type) {
327 // Note: Non standard/builtin types are allowed to exist within tensor
328 // types. Dialects are expected to verify that tensor types have a valid
329 // element type within that dialect.
330 return llvm::isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
331 IndexType>(type) ||
332 !llvm::isa<BuiltinDialect>(type.getDialect());
335 //===----------------------------------------------------------------------===//
336 // RankedTensorType
337 //===----------------------------------------------------------------------===//
339 LogicalResult
340 RankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
341 ArrayRef<int64_t> shape, Type elementType,
342 Attribute encoding) {
343 for (int64_t s : shape)
344 if (s < 0 && !ShapedType::isDynamic(s))
345 return emitError() << "invalid tensor dimension size";
346 if (auto v = llvm::dyn_cast_or_null<VerifiableTensorEncoding>(encoding))
347 if (failed(v.verifyEncoding(shape, elementType, emitError)))
348 return failure();
349 return checkTensorElementType(emitError, elementType);
352 //===----------------------------------------------------------------------===//
353 // UnrankedTensorType
354 //===----------------------------------------------------------------------===//
356 LogicalResult
357 UnrankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
358 Type elementType) {
359 return checkTensorElementType(emitError, elementType);
362 //===----------------------------------------------------------------------===//
363 // BaseMemRefType
364 //===----------------------------------------------------------------------===//
366 Type BaseMemRefType::getElementType() const {
367 return llvm::TypeSwitch<BaseMemRefType, Type>(*this)
368 .Case<MemRefType, UnrankedMemRefType>(
369 [](auto type) { return type.getElementType(); });
372 bool BaseMemRefType::hasRank() const {
373 return !llvm::isa<UnrankedMemRefType>(*this);
376 ArrayRef<int64_t> BaseMemRefType::getShape() const {
377 return llvm::cast<MemRefType>(*this).getShape();
380 BaseMemRefType BaseMemRefType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
381 Type elementType) const {
382 if (llvm::dyn_cast<UnrankedMemRefType>(*this)) {
383 if (!shape)
384 return UnrankedMemRefType::get(elementType, getMemorySpace());
385 MemRefType::Builder builder(*shape, elementType);
386 builder.setMemorySpace(getMemorySpace());
387 return builder;
390 MemRefType::Builder builder(llvm::cast<MemRefType>(*this));
391 if (shape)
392 builder.setShape(*shape);
393 builder.setElementType(elementType);
394 return builder;
397 MemRefType BaseMemRefType::clone(::llvm::ArrayRef<int64_t> shape,
398 Type elementType) const {
399 return ::llvm::cast<MemRefType>(cloneWith(shape, elementType));
402 MemRefType BaseMemRefType::clone(::llvm::ArrayRef<int64_t> shape) const {
403 return ::llvm::cast<MemRefType>(cloneWith(shape, getElementType()));
406 Attribute BaseMemRefType::getMemorySpace() const {
407 if (auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*this))
408 return rankedMemRefTy.getMemorySpace();
409 return llvm::cast<UnrankedMemRefType>(*this).getMemorySpace();
412 unsigned BaseMemRefType::getMemorySpaceAsInt() const {
413 if (auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*this))
414 return rankedMemRefTy.getMemorySpaceAsInt();
415 return llvm::cast<UnrankedMemRefType>(*this).getMemorySpaceAsInt();
418 //===----------------------------------------------------------------------===//
419 // MemRefType
420 //===----------------------------------------------------------------------===//
422 std::optional<llvm::SmallDenseSet<unsigned>>
423 mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape,
424 ArrayRef<int64_t> reducedShape,
425 bool matchDynamic) {
426 size_t originalRank = originalShape.size(), reducedRank = reducedShape.size();
427 llvm::SmallDenseSet<unsigned> unusedDims;
428 unsigned reducedIdx = 0;
429 for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
430 // Greedily insert `originalIdx` if match.
431 int64_t origSize = originalShape[originalIdx];
432 // if `matchDynamic`, count dynamic dims as a match, unless `origSize` is 1.
433 if (matchDynamic && reducedIdx < reducedRank && origSize != 1 &&
434 (ShapedType::isDynamic(reducedShape[reducedIdx]) ||
435 ShapedType::isDynamic(origSize))) {
436 reducedIdx++;
437 continue;
439 if (reducedIdx < reducedRank && origSize == reducedShape[reducedIdx]) {
440 reducedIdx++;
441 continue;
444 unusedDims.insert(originalIdx);
445 // If no match on `originalIdx`, the `originalShape` at this dimension
446 // must be 1, otherwise we bail.
447 if (origSize != 1)
448 return std::nullopt;
450 // The whole reducedShape must be scanned, otherwise we bail.
451 if (reducedIdx != reducedRank)
452 return std::nullopt;
453 return unusedDims;
456 SliceVerificationResult
457 mlir::isRankReducedType(ShapedType originalType,
458 ShapedType candidateReducedType) {
459 if (originalType == candidateReducedType)
460 return SliceVerificationResult::Success;
462 ShapedType originalShapedType = llvm::cast<ShapedType>(originalType);
463 ShapedType candidateReducedShapedType =
464 llvm::cast<ShapedType>(candidateReducedType);
466 // Rank and size logic is valid for all ShapedTypes.
467 ArrayRef<int64_t> originalShape = originalShapedType.getShape();
468 ArrayRef<int64_t> candidateReducedShape =
469 candidateReducedShapedType.getShape();
470 unsigned originalRank = originalShape.size(),
471 candidateReducedRank = candidateReducedShape.size();
472 if (candidateReducedRank > originalRank)
473 return SliceVerificationResult::RankTooLarge;
475 auto optionalUnusedDimsMask =
476 computeRankReductionMask(originalShape, candidateReducedShape);
478 // Sizes cannot be matched in case empty vector is returned.
479 if (!optionalUnusedDimsMask)
480 return SliceVerificationResult::SizeMismatch;
482 if (originalShapedType.getElementType() !=
483 candidateReducedShapedType.getElementType())
484 return SliceVerificationResult::ElemTypeMismatch;
486 return SliceVerificationResult::Success;
489 bool mlir::detail::isSupportedMemorySpace(Attribute memorySpace) {
490 // Empty attribute is allowed as default memory space.
491 if (!memorySpace)
492 return true;
494 // Supported built-in attributes.
495 if (llvm::isa<IntegerAttr, StringAttr, DictionaryAttr>(memorySpace))
496 return true;
498 // Allow custom dialect attributes.
499 if (!isa<BuiltinDialect>(memorySpace.getDialect()))
500 return true;
502 return false;
505 Attribute mlir::detail::wrapIntegerMemorySpace(unsigned memorySpace,
506 MLIRContext *ctx) {
507 if (memorySpace == 0)
508 return nullptr;
510 return IntegerAttr::get(IntegerType::get(ctx, 64), memorySpace);
513 Attribute mlir::detail::skipDefaultMemorySpace(Attribute memorySpace) {
514 IntegerAttr intMemorySpace = llvm::dyn_cast_or_null<IntegerAttr>(memorySpace);
515 if (intMemorySpace && intMemorySpace.getValue() == 0)
516 return nullptr;
518 return memorySpace;
521 unsigned mlir::detail::getMemorySpaceAsInt(Attribute memorySpace) {
522 if (!memorySpace)
523 return 0;
525 assert(llvm::isa<IntegerAttr>(memorySpace) &&
526 "Using `getMemorySpaceInteger` with non-Integer attribute");
528 return static_cast<unsigned>(llvm::cast<IntegerAttr>(memorySpace).getInt());
531 unsigned MemRefType::getMemorySpaceAsInt() const {
532 return detail::getMemorySpaceAsInt(getMemorySpace());
535 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
536 MemRefLayoutAttrInterface layout,
537 Attribute memorySpace) {
538 // Use default layout for empty attribute.
539 if (!layout)
540 layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(
541 shape.size(), elementType.getContext()));
543 // Drop default memory space value and replace it with empty attribute.
544 memorySpace = skipDefaultMemorySpace(memorySpace);
546 return Base::get(elementType.getContext(), shape, elementType, layout,
547 memorySpace);
550 MemRefType MemRefType::getChecked(
551 function_ref<InFlightDiagnostic()> emitErrorFn, ArrayRef<int64_t> shape,
552 Type elementType, MemRefLayoutAttrInterface layout, Attribute memorySpace) {
554 // Use default layout for empty attribute.
555 if (!layout)
556 layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(
557 shape.size(), elementType.getContext()));
559 // Drop default memory space value and replace it with empty attribute.
560 memorySpace = skipDefaultMemorySpace(memorySpace);
562 return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
563 elementType, layout, memorySpace);
566 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
567 AffineMap map, Attribute memorySpace) {
569 // Use default layout for empty map.
570 if (!map)
571 map = AffineMap::getMultiDimIdentityMap(shape.size(),
572 elementType.getContext());
574 // Wrap AffineMap into Attribute.
575 auto layout = AffineMapAttr::get(map);
577 // Drop default memory space value and replace it with empty attribute.
578 memorySpace = skipDefaultMemorySpace(memorySpace);
580 return Base::get(elementType.getContext(), shape, elementType, layout,
581 memorySpace);
584 MemRefType
585 MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
586 ArrayRef<int64_t> shape, Type elementType, AffineMap map,
587 Attribute memorySpace) {
589 // Use default layout for empty map.
590 if (!map)
591 map = AffineMap::getMultiDimIdentityMap(shape.size(),
592 elementType.getContext());
594 // Wrap AffineMap into Attribute.
595 auto layout = AffineMapAttr::get(map);
597 // Drop default memory space value and replace it with empty attribute.
598 memorySpace = skipDefaultMemorySpace(memorySpace);
600 return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
601 elementType, layout, memorySpace);
604 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
605 AffineMap map, unsigned memorySpaceInd) {
607 // Use default layout for empty map.
608 if (!map)
609 map = AffineMap::getMultiDimIdentityMap(shape.size(),
610 elementType.getContext());
612 // Wrap AffineMap into Attribute.
613 auto layout = AffineMapAttr::get(map);
615 // Convert deprecated integer-like memory space to Attribute.
616 Attribute memorySpace =
617 wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
619 return Base::get(elementType.getContext(), shape, elementType, layout,
620 memorySpace);
623 MemRefType
624 MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
625 ArrayRef<int64_t> shape, Type elementType, AffineMap map,
626 unsigned memorySpaceInd) {
628 // Use default layout for empty map.
629 if (!map)
630 map = AffineMap::getMultiDimIdentityMap(shape.size(),
631 elementType.getContext());
633 // Wrap AffineMap into Attribute.
634 auto layout = AffineMapAttr::get(map);
636 // Convert deprecated integer-like memory space to Attribute.
637 Attribute memorySpace =
638 wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
640 return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
641 elementType, layout, memorySpace);
644 LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
645 ArrayRef<int64_t> shape, Type elementType,
646 MemRefLayoutAttrInterface layout,
647 Attribute memorySpace) {
648 if (!BaseMemRefType::isValidElementType(elementType))
649 return emitError() << "invalid memref element type";
651 // Negative sizes are not allowed except for `kDynamic`.
652 for (int64_t s : shape)
653 if (s < 0 && !ShapedType::isDynamic(s))
654 return emitError() << "invalid memref size";
656 assert(layout && "missing layout specification");
657 if (failed(layout.verifyLayout(shape, emitError)))
658 return failure();
660 if (!isSupportedMemorySpace(memorySpace))
661 return emitError() << "unsupported memory space Attribute";
663 return success();
666 //===----------------------------------------------------------------------===//
667 // UnrankedMemRefType
668 //===----------------------------------------------------------------------===//
670 unsigned UnrankedMemRefType::getMemorySpaceAsInt() const {
671 return detail::getMemorySpaceAsInt(getMemorySpace());
674 LogicalResult
675 UnrankedMemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
676 Type elementType, Attribute memorySpace) {
677 if (!BaseMemRefType::isValidElementType(elementType))
678 return emitError() << "invalid memref element type";
680 if (!isSupportedMemorySpace(memorySpace))
681 return emitError() << "unsupported memory space Attribute";
683 return success();
686 // Fallback cases for terminal dim/sym/cst that are not part of a binary op (
687 // i.e. single term). Accumulate the AffineExpr into the existing one.
688 static void extractStridesFromTerm(AffineExpr e,
689 AffineExpr multiplicativeFactor,
690 MutableArrayRef<AffineExpr> strides,
691 AffineExpr &offset) {
692 if (auto dim = dyn_cast<AffineDimExpr>(e))
693 strides[dim.getPosition()] =
694 strides[dim.getPosition()] + multiplicativeFactor;
695 else
696 offset = offset + e * multiplicativeFactor;
699 /// Takes a single AffineExpr `e` and populates the `strides` array with the
700 /// strides expressions for each dim position.
701 /// The convention is that the strides for dimensions d0, .. dn appear in
702 /// order to make indexing intuitive into the result.
703 static LogicalResult extractStrides(AffineExpr e,
704 AffineExpr multiplicativeFactor,
705 MutableArrayRef<AffineExpr> strides,
706 AffineExpr &offset) {
707 auto bin = dyn_cast<AffineBinaryOpExpr>(e);
708 if (!bin) {
709 extractStridesFromTerm(e, multiplicativeFactor, strides, offset);
710 return success();
713 if (bin.getKind() == AffineExprKind::CeilDiv ||
714 bin.getKind() == AffineExprKind::FloorDiv ||
715 bin.getKind() == AffineExprKind::Mod)
716 return failure();
718 if (bin.getKind() == AffineExprKind::Mul) {
719 auto dim = dyn_cast<AffineDimExpr>(bin.getLHS());
720 if (dim) {
721 strides[dim.getPosition()] =
722 strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
723 return success();
725 // LHS and RHS may both contain complex expressions of dims. Try one path
726 // and if it fails try the other. This is guaranteed to succeed because
727 // only one path may have a `dim`, otherwise this is not an AffineExpr in
728 // the first place.
729 if (bin.getLHS().isSymbolicOrConstant())
730 return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(),
731 strides, offset);
732 return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(),
733 strides, offset);
736 if (bin.getKind() == AffineExprKind::Add) {
737 auto res1 =
738 extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset);
739 auto res2 =
740 extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset);
741 return success(succeeded(res1) && succeeded(res2));
744 llvm_unreachable("unexpected binary operation");
747 /// A stride specification is a list of integer values that are either static
748 /// or dynamic (encoded with ShapedType::kDynamic). Strides encode
749 /// the distance in the number of elements between successive entries along a
750 /// particular dimension.
752 /// For example, `memref<42x16xf32, (64 * d0 + d1)>` specifies a view into a
753 /// non-contiguous memory region of `42` by `16` `f32` elements in which the
754 /// distance between two consecutive elements along the outer dimension is `1`
755 /// and the distance between two consecutive elements along the inner dimension
756 /// is `64`.
758 /// The convention is that the strides for dimensions d0, .. dn appear in
759 /// order to make indexing intuitive into the result.
760 static LogicalResult getStridesAndOffset(MemRefType t,
761 SmallVectorImpl<AffineExpr> &strides,
762 AffineExpr &offset) {
763 AffineMap m = t.getLayout().getAffineMap();
765 if (m.getNumResults() != 1 && !m.isIdentity())
766 return failure();
768 auto zero = getAffineConstantExpr(0, t.getContext());
769 auto one = getAffineConstantExpr(1, t.getContext());
770 offset = zero;
771 strides.assign(t.getRank(), zero);
773 // Canonical case for empty map.
774 if (m.isIdentity()) {
775 // 0-D corner case, offset is already 0.
776 if (t.getRank() == 0)
777 return success();
778 auto stridedExpr =
779 makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
780 if (succeeded(extractStrides(stridedExpr, one, strides, offset)))
781 return success();
782 assert(false && "unexpected failure: extract strides in canonical layout");
785 // Non-canonical case requires more work.
786 auto stridedExpr =
787 simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
788 if (failed(extractStrides(stridedExpr, one, strides, offset))) {
789 offset = AffineExpr();
790 strides.clear();
791 return failure();
794 // Simplify results to allow folding to constants and simple checks.
795 unsigned numDims = m.getNumDims();
796 unsigned numSymbols = m.getNumSymbols();
797 offset = simplifyAffineExpr(offset, numDims, numSymbols);
798 for (auto &stride : strides)
799 stride = simplifyAffineExpr(stride, numDims, numSymbols);
801 return success();
804 LogicalResult mlir::getStridesAndOffset(MemRefType t,
805 SmallVectorImpl<int64_t> &strides,
806 int64_t &offset) {
807 // Happy path: the type uses the strided layout directly.
808 if (auto strided = llvm::dyn_cast<StridedLayoutAttr>(t.getLayout())) {
809 llvm::append_range(strides, strided.getStrides());
810 offset = strided.getOffset();
811 return success();
814 // Otherwise, defer to the affine fallback as layouts are supposed to be
815 // convertible to affine maps.
816 AffineExpr offsetExpr;
817 SmallVector<AffineExpr, 4> strideExprs;
818 if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr)))
819 return failure();
820 if (auto cst = dyn_cast<AffineConstantExpr>(offsetExpr))
821 offset = cst.getValue();
822 else
823 offset = ShapedType::kDynamic;
824 for (auto e : strideExprs) {
825 if (auto c = dyn_cast<AffineConstantExpr>(e))
826 strides.push_back(c.getValue());
827 else
828 strides.push_back(ShapedType::kDynamic);
830 return success();
833 std::pair<SmallVector<int64_t>, int64_t>
834 mlir::getStridesAndOffset(MemRefType t) {
835 SmallVector<int64_t> strides;
836 int64_t offset;
837 LogicalResult status = getStridesAndOffset(t, strides, offset);
838 (void)status;
839 assert(succeeded(status) && "Invalid use of check-free getStridesAndOffset");
840 return {strides, offset};
843 //===----------------------------------------------------------------------===//
844 /// TupleType
845 //===----------------------------------------------------------------------===//
847 /// Return the elements types for this tuple.
848 ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); }
850 /// Accumulate the types contained in this tuple and tuples nested within it.
851 /// Note that this only flattens nested tuples, not any other container type,
852 /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to
853 /// (i32, tensor<i32>, f32, i64)
854 void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) {
855 for (Type type : getTypes()) {
856 if (auto nestedTuple = llvm::dyn_cast<TupleType>(type))
857 nestedTuple.getFlattenedTypes(types);
858 else
859 types.push_back(type);
863 /// Return the number of element types.
864 size_t TupleType::size() const { return getImpl()->size(); }
866 //===----------------------------------------------------------------------===//
867 // Type Utilities
868 //===----------------------------------------------------------------------===//
870 /// Return a version of `t` with identity layout if it can be determined
871 /// statically that the layout is the canonical contiguous strided layout.
872 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
873 /// `t` with simplified layout.
874 /// If `t` has multiple layout maps or a multi-result layout, just return `t`.
875 MemRefType mlir::canonicalizeStridedLayout(MemRefType t) {
876 AffineMap m = t.getLayout().getAffineMap();
878 // Already in canonical form.
879 if (m.isIdentity())
880 return t;
882 // Can't reduce to canonical identity form, return in canonical form.
883 if (m.getNumResults() > 1)
884 return t;
886 // Corner-case for 0-D affine maps.
887 if (m.getNumDims() == 0 && m.getNumSymbols() == 0) {
888 if (auto cst = dyn_cast<AffineConstantExpr>(m.getResult(0)))
889 if (cst.getValue() == 0)
890 return MemRefType::Builder(t).setLayout({});
891 return t;
894 // 0-D corner case for empty shape that still have an affine map. Example:
895 // `memref<f32, affine_map<()[s0] -> (s0)>>`. This is a 1 element memref whose
896 // offset needs to remain, just return t.
897 if (t.getShape().empty())
898 return t;
900 // If the canonical strided layout for the sizes of `t` is equal to the
901 // simplified layout of `t` we can just return an empty layout. Otherwise,
902 // just simplify the existing layout.
903 AffineExpr expr =
904 makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
905 auto simplifiedLayoutExpr =
906 simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
907 if (expr != simplifiedLayoutExpr)
908 return MemRefType::Builder(t).setLayout(AffineMapAttr::get(AffineMap::get(
909 m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)));
910 return MemRefType::Builder(t).setLayout({});
913 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
914 ArrayRef<AffineExpr> exprs,
915 MLIRContext *context) {
916 // Size 0 corner case is useful for canonicalizations.
917 if (sizes.empty())
918 return getAffineConstantExpr(0, context);
920 assert(!exprs.empty() && "expected exprs");
921 auto maps = AffineMap::inferFromExprList(exprs, context);
922 assert(!maps.empty() && "Expected one non-empty map");
923 unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols();
925 AffineExpr expr;
926 bool dynamicPoisonBit = false;
927 int64_t runningSize = 1;
928 for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) {
929 int64_t size = std::get<1>(en);
930 AffineExpr dimExpr = std::get<0>(en);
931 AffineExpr stride = dynamicPoisonBit
932 ? getAffineSymbolExpr(nSymbols++, context)
933 : getAffineConstantExpr(runningSize, context);
934 expr = expr ? expr + dimExpr * stride : dimExpr * stride;
935 if (size > 0) {
936 runningSize *= size;
937 assert(runningSize > 0 && "integer overflow in size computation");
938 } else {
939 dynamicPoisonBit = true;
942 return simplifyAffineExpr(expr, numDims, nSymbols);
945 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
946 MLIRContext *context) {
947 SmallVector<AffineExpr, 4> exprs;
948 exprs.reserve(sizes.size());
949 for (auto dim : llvm::seq<unsigned>(0, sizes.size()))
950 exprs.push_back(getAffineDimExpr(dim, context));
951 return makeCanonicalStridedLayoutExpr(sizes, exprs, context);
954 bool mlir::isStrided(MemRefType t) {
955 int64_t offset;
956 SmallVector<int64_t, 4> strides;
957 auto res = getStridesAndOffset(t, strides, offset);
958 return succeeded(res);
961 bool mlir::isLastMemrefDimUnitStride(MemRefType type) {
962 int64_t offset;
963 SmallVector<int64_t> strides;
964 auto successStrides = getStridesAndOffset(type, strides, offset);
965 return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
968 bool mlir::trailingNDimsContiguous(MemRefType type, int64_t n) {
969 if (!isLastMemrefDimUnitStride(type))
970 return false;
972 auto memrefShape = type.getShape().take_back(n);
973 if (ShapedType::isDynamicShape(memrefShape))
974 return false;
976 if (type.getLayout().isIdentity())
977 return true;
979 int64_t offset;
980 SmallVector<int64_t> stridesFull;
981 if (!succeeded(getStridesAndOffset(type, stridesFull, offset)))
982 return false;
983 auto strides = ArrayRef<int64_t>(stridesFull).take_back(n);
985 if (strides.empty())
986 return true;
988 // Check whether strides match "flattened" dims.
989 SmallVector<int64_t> flattenedDims;
990 auto dimProduct = 1;
991 for (auto dim : llvm::reverse(memrefShape.drop_front(1))) {
992 dimProduct *= dim;
993 flattenedDims.push_back(dimProduct);
996 strides = strides.drop_back(1);
997 return llvm::equal(strides, llvm::reverse(flattenedDims));