[clang][modules] Don't prevent translation of FW_Private includes when explicitly...
[llvm-project.git] / mlir / lib / IR / BuiltinTypes.cpp
bloba9284d5714637bc14b034e12ddf7dbceaf73ced3
1 //===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/IR/BuiltinTypes.h"
10 #include "TypeDetail.h"
11 #include "mlir/IR/AffineExpr.h"
12 #include "mlir/IR/AffineMap.h"
13 #include "mlir/IR/BuiltinAttributes.h"
14 #include "mlir/IR/BuiltinDialect.h"
15 #include "mlir/IR/Diagnostics.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/TensorEncoding.h"
18 #include "mlir/IR/TypeUtilities.h"
19 #include "llvm/ADT/APFloat.h"
20 #include "llvm/ADT/BitVector.h"
21 #include "llvm/ADT/Sequence.h"
22 #include "llvm/ADT/Twine.h"
23 #include "llvm/ADT/TypeSwitch.h"
25 using namespace mlir;
26 using namespace mlir::detail;
28 //===----------------------------------------------------------------------===//
29 /// Tablegen Type Definitions
30 //===----------------------------------------------------------------------===//
32 #define GET_TYPEDEF_CLASSES
33 #include "mlir/IR/BuiltinTypes.cpp.inc"
35 //===----------------------------------------------------------------------===//
36 // BuiltinDialect
37 //===----------------------------------------------------------------------===//
39 void BuiltinDialect::registerTypes() {
40 addTypes<
41 #define GET_TYPEDEF_LIST
42 #include "mlir/IR/BuiltinTypes.cpp.inc"
43 >();
46 //===----------------------------------------------------------------------===//
47 /// ComplexType
48 //===----------------------------------------------------------------------===//
50 /// Verify the construction of an integer type.
51 LogicalResult ComplexType::verify(function_ref<InFlightDiagnostic()> emitError,
52 Type elementType) {
53 if (!elementType.isIntOrFloat())
54 return emitError() << "invalid element type for complex";
55 return success();
58 //===----------------------------------------------------------------------===//
59 // Integer Type
60 //===----------------------------------------------------------------------===//
62 /// Verify the construction of an integer type.
63 LogicalResult IntegerType::verify(function_ref<InFlightDiagnostic()> emitError,
64 unsigned width,
65 SignednessSemantics signedness) {
66 if (width > IntegerType::kMaxWidth) {
67 return emitError() << "integer bitwidth is limited to "
68 << IntegerType::kMaxWidth << " bits";
70 return success();
73 unsigned IntegerType::getWidth() const { return getImpl()->width; }
75 IntegerType::SignednessSemantics IntegerType::getSignedness() const {
76 return getImpl()->signedness;
79 IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
80 if (!scale)
81 return IntegerType();
82 return IntegerType::get(getContext(), scale * getWidth(), getSignedness());
85 //===----------------------------------------------------------------------===//
86 // Float Type
87 //===----------------------------------------------------------------------===//
89 unsigned FloatType::getWidth() {
90 if (llvm::isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
91 Float8E4M3FNUZType, Float8E4M3B11FNUZType>(*this))
92 return 8;
93 if (llvm::isa<Float16Type, BFloat16Type>(*this))
94 return 16;
95 if (llvm::isa<Float32Type, FloatTF32Type>(*this))
96 return 32;
97 if (llvm::isa<Float64Type>(*this))
98 return 64;
99 if (llvm::isa<Float80Type>(*this))
100 return 80;
101 if (llvm::isa<Float128Type>(*this))
102 return 128;
103 llvm_unreachable("unexpected float type");
106 /// Returns the floating semantics for the given type.
107 const llvm::fltSemantics &FloatType::getFloatSemantics() {
108 if (llvm::isa<Float8E5M2Type>(*this))
109 return APFloat::Float8E5M2();
110 if (llvm::isa<Float8E4M3FNType>(*this))
111 return APFloat::Float8E4M3FN();
112 if (llvm::isa<Float8E5M2FNUZType>(*this))
113 return APFloat::Float8E5M2FNUZ();
114 if (llvm::isa<Float8E4M3FNUZType>(*this))
115 return APFloat::Float8E4M3FNUZ();
116 if (llvm::isa<Float8E4M3B11FNUZType>(*this))
117 return APFloat::Float8E4M3B11FNUZ();
118 if (llvm::isa<BFloat16Type>(*this))
119 return APFloat::BFloat();
120 if (llvm::isa<Float16Type>(*this))
121 return APFloat::IEEEhalf();
122 if (llvm::isa<FloatTF32Type>(*this))
123 return APFloat::FloatTF32();
124 if (llvm::isa<Float32Type>(*this))
125 return APFloat::IEEEsingle();
126 if (llvm::isa<Float64Type>(*this))
127 return APFloat::IEEEdouble();
128 if (llvm::isa<Float80Type>(*this))
129 return APFloat::x87DoubleExtended();
130 if (llvm::isa<Float128Type>(*this))
131 return APFloat::IEEEquad();
132 llvm_unreachable("non-floating point type used");
135 FloatType FloatType::scaleElementBitwidth(unsigned scale) {
136 if (!scale)
137 return FloatType();
138 MLIRContext *ctx = getContext();
139 if (isF16() || isBF16()) {
140 if (scale == 2)
141 return FloatType::getF32(ctx);
142 if (scale == 4)
143 return FloatType::getF64(ctx);
145 if (isF32())
146 if (scale == 2)
147 return FloatType::getF64(ctx);
148 return FloatType();
151 unsigned FloatType::getFPMantissaWidth() {
152 return APFloat::semanticsPrecision(getFloatSemantics());
155 //===----------------------------------------------------------------------===//
156 // FunctionType
157 //===----------------------------------------------------------------------===//
159 unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; }
161 ArrayRef<Type> FunctionType::getInputs() const {
162 return getImpl()->getInputs();
165 unsigned FunctionType::getNumResults() const { return getImpl()->numResults; }
167 ArrayRef<Type> FunctionType::getResults() const {
168 return getImpl()->getResults();
171 FunctionType FunctionType::clone(TypeRange inputs, TypeRange results) const {
172 return get(getContext(), inputs, results);
175 /// Returns a new function type with the specified arguments and results
176 /// inserted.
177 FunctionType FunctionType::getWithArgsAndResults(
178 ArrayRef<unsigned> argIndices, TypeRange argTypes,
179 ArrayRef<unsigned> resultIndices, TypeRange resultTypes) {
180 SmallVector<Type> argStorage, resultStorage;
181 TypeRange newArgTypes =
182 insertTypesInto(getInputs(), argIndices, argTypes, argStorage);
183 TypeRange newResultTypes =
184 insertTypesInto(getResults(), resultIndices, resultTypes, resultStorage);
185 return clone(newArgTypes, newResultTypes);
188 /// Returns a new function type without the specified arguments and results.
189 FunctionType
190 FunctionType::getWithoutArgsAndResults(const BitVector &argIndices,
191 const BitVector &resultIndices) {
192 SmallVector<Type> argStorage, resultStorage;
193 TypeRange newArgTypes = filterTypesOut(getInputs(), argIndices, argStorage);
194 TypeRange newResultTypes =
195 filterTypesOut(getResults(), resultIndices, resultStorage);
196 return clone(newArgTypes, newResultTypes);
199 //===----------------------------------------------------------------------===//
200 // OpaqueType
201 //===----------------------------------------------------------------------===//
203 /// Verify the construction of an opaque type.
204 LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError,
205 StringAttr dialect, StringRef typeData) {
206 if (!Dialect::isValidNamespace(dialect.strref()))
207 return emitError() << "invalid dialect namespace '" << dialect << "'";
209 // Check that the dialect is actually registered.
210 MLIRContext *context = dialect.getContext();
211 if (!context->allowsUnregisteredDialects() &&
212 !context->getLoadedDialect(dialect.strref())) {
213 return emitError()
214 << "`!" << dialect << "<\"" << typeData << "\">"
215 << "` type created with unregistered dialect. If this is "
216 "intended, please call allowUnregisteredDialects() on the "
217 "MLIRContext, or use -allow-unregistered-dialect with "
218 "the MLIR opt tool used";
221 return success();
224 //===----------------------------------------------------------------------===//
225 // VectorType
226 //===----------------------------------------------------------------------===//
228 LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
229 ArrayRef<int64_t> shape, Type elementType,
230 ArrayRef<bool> scalableDims) {
231 if (!isValidElementType(elementType))
232 return emitError()
233 << "vector elements must be int/index/float type but got "
234 << elementType;
236 if (any_of(shape, [](int64_t i) { return i <= 0; }))
237 return emitError()
238 << "vector types must have positive constant sizes but got "
239 << shape;
241 if (scalableDims.size() != shape.size())
242 return emitError() << "number of dims must match, got "
243 << scalableDims.size() << " and " << shape.size();
245 return success();
248 VectorType VectorType::scaleElementBitwidth(unsigned scale) {
249 if (!scale)
250 return VectorType();
251 if (auto et = llvm::dyn_cast<IntegerType>(getElementType()))
252 if (auto scaledEt = et.scaleElementBitwidth(scale))
253 return VectorType::get(getShape(), scaledEt, getScalableDims());
254 if (auto et = llvm::dyn_cast<FloatType>(getElementType()))
255 if (auto scaledEt = et.scaleElementBitwidth(scale))
256 return VectorType::get(getShape(), scaledEt, getScalableDims());
257 return VectorType();
260 VectorType VectorType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
261 Type elementType) const {
262 return VectorType::get(shape.value_or(getShape()), elementType,
263 getScalableDims());
266 //===----------------------------------------------------------------------===//
267 // TensorType
268 //===----------------------------------------------------------------------===//
270 Type TensorType::getElementType() const {
271 return llvm::TypeSwitch<TensorType, Type>(*this)
272 .Case<RankedTensorType, UnrankedTensorType>(
273 [](auto type) { return type.getElementType(); });
276 bool TensorType::hasRank() const { return !llvm::isa<UnrankedTensorType>(*this); }
278 ArrayRef<int64_t> TensorType::getShape() const {
279 return llvm::cast<RankedTensorType>(*this).getShape();
282 TensorType TensorType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
283 Type elementType) const {
284 if (llvm::dyn_cast<UnrankedTensorType>(*this)) {
285 if (shape)
286 return RankedTensorType::get(*shape, elementType);
287 return UnrankedTensorType::get(elementType);
290 auto rankedTy = llvm::cast<RankedTensorType>(*this);
291 if (!shape)
292 return RankedTensorType::get(rankedTy.getShape(), elementType,
293 rankedTy.getEncoding());
294 return RankedTensorType::get(shape.value_or(rankedTy.getShape()), elementType,
295 rankedTy.getEncoding());
298 RankedTensorType TensorType::clone(::llvm::ArrayRef<int64_t> shape,
299 Type elementType) const {
300 return ::llvm::cast<RankedTensorType>(cloneWith(shape, elementType));
303 RankedTensorType TensorType::clone(::llvm::ArrayRef<int64_t> shape) const {
304 return ::llvm::cast<RankedTensorType>(cloneWith(shape, getElementType()));
307 // Check if "elementType" can be an element type of a tensor.
308 static LogicalResult
309 checkTensorElementType(function_ref<InFlightDiagnostic()> emitError,
310 Type elementType) {
311 if (!TensorType::isValidElementType(elementType))
312 return emitError() << "invalid tensor element type: " << elementType;
313 return success();
316 /// Return true if the specified element type is ok in a tensor.
317 bool TensorType::isValidElementType(Type type) {
318 // Note: Non standard/builtin types are allowed to exist within tensor
319 // types. Dialects are expected to verify that tensor types have a valid
320 // element type within that dialect.
321 return llvm::isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
322 IndexType>(type) ||
323 !llvm::isa<BuiltinDialect>(type.getDialect());
326 //===----------------------------------------------------------------------===//
327 // RankedTensorType
328 //===----------------------------------------------------------------------===//
330 LogicalResult
331 RankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
332 ArrayRef<int64_t> shape, Type elementType,
333 Attribute encoding) {
334 for (int64_t s : shape)
335 if (s < 0 && !ShapedType::isDynamic(s))
336 return emitError() << "invalid tensor dimension size";
337 if (auto v = llvm::dyn_cast_or_null<VerifiableTensorEncoding>(encoding))
338 if (failed(v.verifyEncoding(shape, elementType, emitError)))
339 return failure();
340 return checkTensorElementType(emitError, elementType);
343 //===----------------------------------------------------------------------===//
344 // UnrankedTensorType
345 //===----------------------------------------------------------------------===//
347 LogicalResult
348 UnrankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
349 Type elementType) {
350 return checkTensorElementType(emitError, elementType);
353 //===----------------------------------------------------------------------===//
354 // BaseMemRefType
355 //===----------------------------------------------------------------------===//
357 Type BaseMemRefType::getElementType() const {
358 return llvm::TypeSwitch<BaseMemRefType, Type>(*this)
359 .Case<MemRefType, UnrankedMemRefType>(
360 [](auto type) { return type.getElementType(); });
363 bool BaseMemRefType::hasRank() const { return !llvm::isa<UnrankedMemRefType>(*this); }
365 ArrayRef<int64_t> BaseMemRefType::getShape() const {
366 return llvm::cast<MemRefType>(*this).getShape();
369 BaseMemRefType BaseMemRefType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
370 Type elementType) const {
371 if (llvm::dyn_cast<UnrankedMemRefType>(*this)) {
372 if (!shape)
373 return UnrankedMemRefType::get(elementType, getMemorySpace());
374 MemRefType::Builder builder(*shape, elementType);
375 builder.setMemorySpace(getMemorySpace());
376 return builder;
379 MemRefType::Builder builder(llvm::cast<MemRefType>(*this));
380 if (shape)
381 builder.setShape(*shape);
382 builder.setElementType(elementType);
383 return builder;
386 MemRefType BaseMemRefType::clone(::llvm::ArrayRef<int64_t> shape,
387 Type elementType) const {
388 return ::llvm::cast<MemRefType>(cloneWith(shape, elementType));
391 MemRefType BaseMemRefType::clone(::llvm::ArrayRef<int64_t> shape) const {
392 return ::llvm::cast<MemRefType>(cloneWith(shape, getElementType()));
395 Attribute BaseMemRefType::getMemorySpace() const {
396 if (auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*this))
397 return rankedMemRefTy.getMemorySpace();
398 return llvm::cast<UnrankedMemRefType>(*this).getMemorySpace();
401 unsigned BaseMemRefType::getMemorySpaceAsInt() const {
402 if (auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*this))
403 return rankedMemRefTy.getMemorySpaceAsInt();
404 return llvm::cast<UnrankedMemRefType>(*this).getMemorySpaceAsInt();
407 //===----------------------------------------------------------------------===//
408 // MemRefType
409 //===----------------------------------------------------------------------===//
411 /// Given an `originalShape` and a `reducedShape` assumed to be a subset of
412 /// `originalShape` with some `1` entries erased, return the set of indices
413 /// that specifies which of the entries of `originalShape` are dropped to obtain
414 /// `reducedShape`. The returned mask can be applied as a projection to
415 /// `originalShape` to obtain the `reducedShape`. This mask is useful to track
416 /// which dimensions must be kept when e.g. compute MemRef strides under
417 /// rank-reducing operations. Return std::nullopt if reducedShape cannot be
418 /// obtained by dropping only `1` entries in `originalShape`.
419 std::optional<llvm::SmallDenseSet<unsigned>>
420 mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape,
421 ArrayRef<int64_t> reducedShape) {
422 size_t originalRank = originalShape.size(), reducedRank = reducedShape.size();
423 llvm::SmallDenseSet<unsigned> unusedDims;
424 unsigned reducedIdx = 0;
425 for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
426 // Greedily insert `originalIdx` if match.
427 if (reducedIdx < reducedRank &&
428 originalShape[originalIdx] == reducedShape[reducedIdx]) {
429 reducedIdx++;
430 continue;
433 unusedDims.insert(originalIdx);
434 // If no match on `originalIdx`, the `originalShape` at this dimension
435 // must be 1, otherwise we bail.
436 if (originalShape[originalIdx] != 1)
437 return std::nullopt;
439 // The whole reducedShape must be scanned, otherwise we bail.
440 if (reducedIdx != reducedRank)
441 return std::nullopt;
442 return unusedDims;
445 SliceVerificationResult
446 mlir::isRankReducedType(ShapedType originalType,
447 ShapedType candidateReducedType) {
448 if (originalType == candidateReducedType)
449 return SliceVerificationResult::Success;
451 ShapedType originalShapedType = llvm::cast<ShapedType>(originalType);
452 ShapedType candidateReducedShapedType =
453 llvm::cast<ShapedType>(candidateReducedType);
455 // Rank and size logic is valid for all ShapedTypes.
456 ArrayRef<int64_t> originalShape = originalShapedType.getShape();
457 ArrayRef<int64_t> candidateReducedShape =
458 candidateReducedShapedType.getShape();
459 unsigned originalRank = originalShape.size(),
460 candidateReducedRank = candidateReducedShape.size();
461 if (candidateReducedRank > originalRank)
462 return SliceVerificationResult::RankTooLarge;
464 auto optionalUnusedDimsMask =
465 computeRankReductionMask(originalShape, candidateReducedShape);
467 // Sizes cannot be matched in case empty vector is returned.
468 if (!optionalUnusedDimsMask)
469 return SliceVerificationResult::SizeMismatch;
471 if (originalShapedType.getElementType() !=
472 candidateReducedShapedType.getElementType())
473 return SliceVerificationResult::ElemTypeMismatch;
475 return SliceVerificationResult::Success;
478 bool mlir::detail::isSupportedMemorySpace(Attribute memorySpace) {
479 // Empty attribute is allowed as default memory space.
480 if (!memorySpace)
481 return true;
483 // Supported built-in attributes.
484 if (llvm::isa<IntegerAttr, StringAttr, DictionaryAttr>(memorySpace))
485 return true;
487 // Allow custom dialect attributes.
488 if (!isa<BuiltinDialect>(memorySpace.getDialect()))
489 return true;
491 return false;
494 Attribute mlir::detail::wrapIntegerMemorySpace(unsigned memorySpace,
495 MLIRContext *ctx) {
496 if (memorySpace == 0)
497 return nullptr;
499 return IntegerAttr::get(IntegerType::get(ctx, 64), memorySpace);
502 Attribute mlir::detail::skipDefaultMemorySpace(Attribute memorySpace) {
503 IntegerAttr intMemorySpace = llvm::dyn_cast_or_null<IntegerAttr>(memorySpace);
504 if (intMemorySpace && intMemorySpace.getValue() == 0)
505 return nullptr;
507 return memorySpace;
510 unsigned mlir::detail::getMemorySpaceAsInt(Attribute memorySpace) {
511 if (!memorySpace)
512 return 0;
514 assert(llvm::isa<IntegerAttr>(memorySpace) &&
515 "Using `getMemorySpaceInteger` with non-Integer attribute");
517 return static_cast<unsigned>(llvm::cast<IntegerAttr>(memorySpace).getInt());
520 unsigned MemRefType::getMemorySpaceAsInt() const {
521 return detail::getMemorySpaceAsInt(getMemorySpace());
524 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
525 MemRefLayoutAttrInterface layout,
526 Attribute memorySpace) {
527 // Use default layout for empty attribute.
528 if (!layout)
529 layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(
530 shape.size(), elementType.getContext()));
532 // Drop default memory space value and replace it with empty attribute.
533 memorySpace = skipDefaultMemorySpace(memorySpace);
535 return Base::get(elementType.getContext(), shape, elementType, layout,
536 memorySpace);
539 MemRefType MemRefType::getChecked(
540 function_ref<InFlightDiagnostic()> emitErrorFn, ArrayRef<int64_t> shape,
541 Type elementType, MemRefLayoutAttrInterface layout, Attribute memorySpace) {
543 // Use default layout for empty attribute.
544 if (!layout)
545 layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(
546 shape.size(), elementType.getContext()));
548 // Drop default memory space value and replace it with empty attribute.
549 memorySpace = skipDefaultMemorySpace(memorySpace);
551 return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
552 elementType, layout, memorySpace);
555 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
556 AffineMap map, Attribute memorySpace) {
558 // Use default layout for empty map.
559 if (!map)
560 map = AffineMap::getMultiDimIdentityMap(shape.size(),
561 elementType.getContext());
563 // Wrap AffineMap into Attribute.
564 auto layout = AffineMapAttr::get(map);
566 // Drop default memory space value and replace it with empty attribute.
567 memorySpace = skipDefaultMemorySpace(memorySpace);
569 return Base::get(elementType.getContext(), shape, elementType, layout,
570 memorySpace);
573 MemRefType
574 MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
575 ArrayRef<int64_t> shape, Type elementType, AffineMap map,
576 Attribute memorySpace) {
578 // Use default layout for empty map.
579 if (!map)
580 map = AffineMap::getMultiDimIdentityMap(shape.size(),
581 elementType.getContext());
583 // Wrap AffineMap into Attribute.
584 auto layout = AffineMapAttr::get(map);
586 // Drop default memory space value and replace it with empty attribute.
587 memorySpace = skipDefaultMemorySpace(memorySpace);
589 return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
590 elementType, layout, memorySpace);
593 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
594 AffineMap map, unsigned memorySpaceInd) {
596 // Use default layout for empty map.
597 if (!map)
598 map = AffineMap::getMultiDimIdentityMap(shape.size(),
599 elementType.getContext());
601 // Wrap AffineMap into Attribute.
602 auto layout = AffineMapAttr::get(map);
604 // Convert deprecated integer-like memory space to Attribute.
605 Attribute memorySpace =
606 wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
608 return Base::get(elementType.getContext(), shape, elementType, layout,
609 memorySpace);
612 MemRefType
613 MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
614 ArrayRef<int64_t> shape, Type elementType, AffineMap map,
615 unsigned memorySpaceInd) {
617 // Use default layout for empty map.
618 if (!map)
619 map = AffineMap::getMultiDimIdentityMap(shape.size(),
620 elementType.getContext());
622 // Wrap AffineMap into Attribute.
623 auto layout = AffineMapAttr::get(map);
625 // Convert deprecated integer-like memory space to Attribute.
626 Attribute memorySpace =
627 wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
629 return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
630 elementType, layout, memorySpace);
633 LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
634 ArrayRef<int64_t> shape, Type elementType,
635 MemRefLayoutAttrInterface layout,
636 Attribute memorySpace) {
637 if (!BaseMemRefType::isValidElementType(elementType))
638 return emitError() << "invalid memref element type";
640 // Negative sizes are not allowed except for `kDynamic`.
641 for (int64_t s : shape)
642 if (s < 0 && !ShapedType::isDynamic(s))
643 return emitError() << "invalid memref size";
645 assert(layout && "missing layout specification");
646 if (failed(layout.verifyLayout(shape, emitError)))
647 return failure();
649 if (!isSupportedMemorySpace(memorySpace))
650 return emitError() << "unsupported memory space Attribute";
652 return success();
655 //===----------------------------------------------------------------------===//
656 // UnrankedMemRefType
657 //===----------------------------------------------------------------------===//
659 unsigned UnrankedMemRefType::getMemorySpaceAsInt() const {
660 return detail::getMemorySpaceAsInt(getMemorySpace());
663 LogicalResult
664 UnrankedMemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
665 Type elementType, Attribute memorySpace) {
666 if (!BaseMemRefType::isValidElementType(elementType))
667 return emitError() << "invalid memref element type";
669 if (!isSupportedMemorySpace(memorySpace))
670 return emitError() << "unsupported memory space Attribute";
672 return success();
675 // Fallback cases for terminal dim/sym/cst that are not part of a binary op (
676 // i.e. single term). Accumulate the AffineExpr into the existing one.
677 static void extractStridesFromTerm(AffineExpr e,
678 AffineExpr multiplicativeFactor,
679 MutableArrayRef<AffineExpr> strides,
680 AffineExpr &offset) {
681 if (auto dim = e.dyn_cast<AffineDimExpr>())
682 strides[dim.getPosition()] =
683 strides[dim.getPosition()] + multiplicativeFactor;
684 else
685 offset = offset + e * multiplicativeFactor;
688 /// Takes a single AffineExpr `e` and populates the `strides` array with the
689 /// strides expressions for each dim position.
690 /// The convention is that the strides for dimensions d0, .. dn appear in
691 /// order to make indexing intuitive into the result.
692 static LogicalResult extractStrides(AffineExpr e,
693 AffineExpr multiplicativeFactor,
694 MutableArrayRef<AffineExpr> strides,
695 AffineExpr &offset) {
696 auto bin = e.dyn_cast<AffineBinaryOpExpr>();
697 if (!bin) {
698 extractStridesFromTerm(e, multiplicativeFactor, strides, offset);
699 return success();
702 if (bin.getKind() == AffineExprKind::CeilDiv ||
703 bin.getKind() == AffineExprKind::FloorDiv ||
704 bin.getKind() == AffineExprKind::Mod)
705 return failure();
707 if (bin.getKind() == AffineExprKind::Mul) {
708 auto dim = bin.getLHS().dyn_cast<AffineDimExpr>();
709 if (dim) {
710 strides[dim.getPosition()] =
711 strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
712 return success();
714 // LHS and RHS may both contain complex expressions of dims. Try one path
715 // and if it fails try the other. This is guaranteed to succeed because
716 // only one path may have a `dim`, otherwise this is not an AffineExpr in
717 // the first place.
718 if (bin.getLHS().isSymbolicOrConstant())
719 return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(),
720 strides, offset);
721 return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(),
722 strides, offset);
725 if (bin.getKind() == AffineExprKind::Add) {
726 auto res1 =
727 extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset);
728 auto res2 =
729 extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset);
730 return success(succeeded(res1) && succeeded(res2));
733 llvm_unreachable("unexpected binary operation");
736 /// A stride specification is a list of integer values that are either static
737 /// or dynamic (encoded with ShapedType::kDynamic). Strides encode
738 /// the distance in the number of elements between successive entries along a
739 /// particular dimension.
741 /// For example, `memref<42x16xf32, (64 * d0 + d1)>` specifies a view into a
742 /// non-contiguous memory region of `42` by `16` `f32` elements in which the
743 /// distance between two consecutive elements along the outer dimension is `1`
744 /// and the distance between two consecutive elements along the inner dimension
745 /// is `64`.
747 /// The convention is that the strides for dimensions d0, .. dn appear in
748 /// order to make indexing intuitive into the result.
749 static LogicalResult getStridesAndOffset(MemRefType t,
750 SmallVectorImpl<AffineExpr> &strides,
751 AffineExpr &offset) {
752 AffineMap m = t.getLayout().getAffineMap();
754 if (m.getNumResults() != 1 && !m.isIdentity())
755 return failure();
757 auto zero = getAffineConstantExpr(0, t.getContext());
758 auto one = getAffineConstantExpr(1, t.getContext());
759 offset = zero;
760 strides.assign(t.getRank(), zero);
762 // Canonical case for empty map.
763 if (m.isIdentity()) {
764 // 0-D corner case, offset is already 0.
765 if (t.getRank() == 0)
766 return success();
767 auto stridedExpr =
768 makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
769 if (succeeded(extractStrides(stridedExpr, one, strides, offset)))
770 return success();
771 assert(false && "unexpected failure: extract strides in canonical layout");
774 // Non-canonical case requires more work.
775 auto stridedExpr =
776 simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
777 if (failed(extractStrides(stridedExpr, one, strides, offset))) {
778 offset = AffineExpr();
779 strides.clear();
780 return failure();
783 // Simplify results to allow folding to constants and simple checks.
784 unsigned numDims = m.getNumDims();
785 unsigned numSymbols = m.getNumSymbols();
786 offset = simplifyAffineExpr(offset, numDims, numSymbols);
787 for (auto &stride : strides)
788 stride = simplifyAffineExpr(stride, numDims, numSymbols);
790 // In practice, a strided memref must be internally non-aliasing. Test
791 // against 0 as a proxy.
792 // TODO: static cases can have more advanced checks.
793 // TODO: dynamic cases would require a way to compare symbolic
794 // expressions and would probably need an affine set context propagated
795 // everywhere.
796 if (llvm::any_of(strides, [](AffineExpr e) {
797 return e == getAffineConstantExpr(0, e.getContext());
798 })) {
799 offset = AffineExpr();
800 strides.clear();
801 return failure();
804 return success();
807 LogicalResult mlir::getStridesAndOffset(MemRefType t,
808 SmallVectorImpl<int64_t> &strides,
809 int64_t &offset) {
810 // Happy path: the type uses the strided layout directly.
811 if (auto strided = llvm::dyn_cast<StridedLayoutAttr>(t.getLayout())) {
812 llvm::append_range(strides, strided.getStrides());
813 offset = strided.getOffset();
814 return success();
817 // Otherwise, defer to the affine fallback as layouts are supposed to be
818 // convertible to affine maps.
819 AffineExpr offsetExpr;
820 SmallVector<AffineExpr, 4> strideExprs;
821 if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr)))
822 return failure();
823 if (auto cst = offsetExpr.dyn_cast<AffineConstantExpr>())
824 offset = cst.getValue();
825 else
826 offset = ShapedType::kDynamic;
827 for (auto e : strideExprs) {
828 if (auto c = e.dyn_cast<AffineConstantExpr>())
829 strides.push_back(c.getValue());
830 else
831 strides.push_back(ShapedType::kDynamic);
833 return success();
836 std::pair<SmallVector<int64_t>, int64_t>
837 mlir::getStridesAndOffset(MemRefType t) {
838 SmallVector<int64_t> strides;
839 int64_t offset;
840 LogicalResult status = getStridesAndOffset(t, strides, offset);
841 (void)status;
842 assert(succeeded(status) && "Invalid use of check-free getStridesAndOffset");
843 return {strides, offset};
846 //===----------------------------------------------------------------------===//
847 /// TupleType
848 //===----------------------------------------------------------------------===//
850 /// Return the elements types for this tuple.
851 ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); }
853 /// Accumulate the types contained in this tuple and tuples nested within it.
854 /// Note that this only flattens nested tuples, not any other container type,
855 /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to
856 /// (i32, tensor<i32>, f32, i64)
857 void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) {
858 for (Type type : getTypes()) {
859 if (auto nestedTuple = llvm::dyn_cast<TupleType>(type))
860 nestedTuple.getFlattenedTypes(types);
861 else
862 types.push_back(type);
866 /// Return the number of element types.
867 size_t TupleType::size() const { return getImpl()->size(); }
869 //===----------------------------------------------------------------------===//
870 // Type Utilities
871 //===----------------------------------------------------------------------===//
873 /// Return a version of `t` with identity layout if it can be determined
874 /// statically that the layout is the canonical contiguous strided layout.
875 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
876 /// `t` with simplified layout.
877 /// If `t` has multiple layout maps or a multi-result layout, just return `t`.
878 MemRefType mlir::canonicalizeStridedLayout(MemRefType t) {
879 AffineMap m = t.getLayout().getAffineMap();
881 // Already in canonical form.
882 if (m.isIdentity())
883 return t;
885 // Can't reduce to canonical identity form, return in canonical form.
886 if (m.getNumResults() > 1)
887 return t;
889 // Corner-case for 0-D affine maps.
890 if (m.getNumDims() == 0 && m.getNumSymbols() == 0) {
891 if (auto cst = m.getResult(0).dyn_cast<AffineConstantExpr>())
892 if (cst.getValue() == 0)
893 return MemRefType::Builder(t).setLayout({});
894 return t;
897 // 0-D corner case for empty shape that still have an affine map. Example:
898 // `memref<f32, affine_map<()[s0] -> (s0)>>`. This is a 1 element memref whose
899 // offset needs to remain, just return t.
900 if (t.getShape().empty())
901 return t;
903 // If the canonical strided layout for the sizes of `t` is equal to the
904 // simplified layout of `t` we can just return an empty layout. Otherwise,
905 // just simplify the existing layout.
906 AffineExpr expr =
907 makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
908 auto simplifiedLayoutExpr =
909 simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
910 if (expr != simplifiedLayoutExpr)
911 return MemRefType::Builder(t).setLayout(AffineMapAttr::get(AffineMap::get(
912 m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)));
913 return MemRefType::Builder(t).setLayout({});
916 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
917 ArrayRef<AffineExpr> exprs,
918 MLIRContext *context) {
919 // Size 0 corner case is useful for canonicalizations.
920 if (sizes.empty())
921 return getAffineConstantExpr(0, context);
923 assert(!exprs.empty() && "expected exprs");
924 auto maps = AffineMap::inferFromExprList(exprs);
925 assert(!maps.empty() && "Expected one non-empty map");
926 unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols();
928 AffineExpr expr;
929 bool dynamicPoisonBit = false;
930 int64_t runningSize = 1;
931 for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) {
932 int64_t size = std::get<1>(en);
933 AffineExpr dimExpr = std::get<0>(en);
934 AffineExpr stride = dynamicPoisonBit
935 ? getAffineSymbolExpr(nSymbols++, context)
936 : getAffineConstantExpr(runningSize, context);
937 expr = expr ? expr + dimExpr * stride : dimExpr * stride;
938 if (size > 0) {
939 runningSize *= size;
940 assert(runningSize > 0 && "integer overflow in size computation");
941 } else {
942 dynamicPoisonBit = true;
945 return simplifyAffineExpr(expr, numDims, nSymbols);
948 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
949 MLIRContext *context) {
950 SmallVector<AffineExpr, 4> exprs;
951 exprs.reserve(sizes.size());
952 for (auto dim : llvm::seq<unsigned>(0, sizes.size()))
953 exprs.push_back(getAffineDimExpr(dim, context));
954 return makeCanonicalStridedLayoutExpr(sizes, exprs, context);
957 bool mlir::isStrided(MemRefType t) {
958 int64_t offset;
959 SmallVector<int64_t, 4> strides;
960 auto res = getStridesAndOffset(t, strides, offset);
961 return succeeded(res);
964 bool mlir::isLastMemrefDimUnitStride(MemRefType type) {
965 int64_t offset;
966 SmallVector<int64_t> strides;
967 auto successStrides = getStridesAndOffset(type, strides, offset);
968 return succeeded(successStrides) && (strides.empty() || strides.back() == 1);