1 //===- ViewLikeInterface.cpp - View-like operations in MLIR ---------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Interfaces/ViewLikeInterface.h"
13 //===----------------------------------------------------------------------===//
14 // ViewLike Interfaces
15 //===----------------------------------------------------------------------===//
17 /// Include the definitions of the loop-like interfaces.
18 #include "mlir/Interfaces/ViewLikeInterface.cpp.inc"
20 LogicalResult
mlir::verifyListOfOperandsOrIntegers(Operation
*op
,
23 ArrayRef
<int64_t> staticVals
,
25 // Check static and dynamic offsets/sizes/strides does not overflow type.
26 if (staticVals
.size() != numElements
)
27 return op
->emitError("expected ") << numElements
<< " " << name
28 << " values, got " << staticVals
.size();
29 unsigned expectedNumDynamicEntries
=
30 llvm::count_if(staticVals
, [](int64_t staticVal
) {
31 return ShapedType::isDynamic(staticVal
);
33 if (values
.size() != expectedNumDynamicEntries
)
34 return op
->emitError("expected ")
35 << expectedNumDynamicEntries
<< " dynamic " << name
<< " values";
40 mlir::detail::verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op
) {
41 std::array
<unsigned, 3> maxRanks
= op
.getArrayAttrMaxRanks();
42 // Offsets can come in 2 flavors:
43 // 1. Either single entry (when maxRanks == 1).
44 // 2. Or as an array whose rank must match that of the mixed sizes.
45 // So that the result type is well-formed.
46 if (!(op
.getMixedOffsets().size() == 1 && maxRanks
[0] == 1) && // NOLINT
47 op
.getMixedOffsets().size() != op
.getMixedSizes().size())
49 "expected mixed offsets rank to match mixed sizes rank (")
50 << op
.getMixedOffsets().size() << " vs " << op
.getMixedSizes().size()
51 << ") so the rank of the result type is well-formed.";
52 // Ranks of mixed sizes and strides must always match so the result type is
54 if (op
.getMixedSizes().size() != op
.getMixedStrides().size())
56 "expected mixed sizes rank to match mixed strides rank (")
57 << op
.getMixedSizes().size() << " vs " << op
.getMixedStrides().size()
58 << ") so the rank of the result type is well-formed.";
60 if (failed(verifyListOfOperandsOrIntegers(
61 op
, "offset", maxRanks
[0], op
.getStaticOffsets(), op
.getOffsets())))
63 if (failed(verifyListOfOperandsOrIntegers(
64 op
, "size", maxRanks
[1], op
.getStaticSizes(), op
.getSizes())))
66 if (failed(verifyListOfOperandsOrIntegers(
67 op
, "stride", maxRanks
[2], op
.getStaticStrides(), op
.getStrides())))
70 for (int64_t offset
: op
.getStaticOffsets()) {
71 if (offset
< 0 && !ShapedType::isDynamic(offset
))
72 return op
->emitError("expected offsets to be non-negative, but got ")
75 for (int64_t size
: op
.getStaticSizes()) {
76 if (size
< 0 && !ShapedType::isDynamic(size
))
77 return op
->emitError("expected sizes to be non-negative, but got ")
83 static char getLeftDelimiter(AsmParser::Delimiter delimiter
) {
85 case AsmParser::Delimiter::Paren
:
87 case AsmParser::Delimiter::LessGreater
:
89 case AsmParser::Delimiter::Square
:
91 case AsmParser::Delimiter::Braces
:
94 llvm_unreachable("unsupported delimiter");
98 static char getRightDelimiter(AsmParser::Delimiter delimiter
) {
100 case AsmParser::Delimiter::Paren
:
102 case AsmParser::Delimiter::LessGreater
:
104 case AsmParser::Delimiter::Square
:
106 case AsmParser::Delimiter::Braces
:
109 llvm_unreachable("unsupported delimiter");
113 void mlir::printDynamicIndexList(OpAsmPrinter
&printer
, Operation
*op
,
115 ArrayRef
<int64_t> integers
,
116 ArrayRef
<bool> scalables
, TypeRange valueTypes
,
117 AsmParser::Delimiter delimiter
) {
118 char leftDelimiter
= getLeftDelimiter(delimiter
);
119 char rightDelimiter
= getRightDelimiter(delimiter
);
120 printer
<< leftDelimiter
;
121 if (integers
.empty()) {
122 printer
<< rightDelimiter
;
126 unsigned dynamicValIdx
= 0;
127 unsigned scalableIndexIdx
= 0;
128 llvm::interleaveComma(integers
, printer
, [&](int64_t integer
) {
129 if (!scalables
.empty() && scalables
[scalableIndexIdx
])
131 if (ShapedType::isDynamic(integer
)) {
132 printer
<< values
[dynamicValIdx
];
133 if (!valueTypes
.empty())
134 printer
<< " : " << valueTypes
[dynamicValIdx
];
139 if (!scalables
.empty() && scalables
[scalableIndexIdx
])
145 printer
<< rightDelimiter
;
148 ParseResult
mlir::parseDynamicIndexList(
150 SmallVectorImpl
<OpAsmParser::UnresolvedOperand
> &values
,
151 DenseI64ArrayAttr
&integers
, DenseBoolArrayAttr
&scalables
,
152 SmallVectorImpl
<Type
> *valueTypes
, AsmParser::Delimiter delimiter
) {
154 SmallVector
<int64_t, 4> integerVals
;
155 SmallVector
<bool, 4> scalableVals
;
156 auto parseIntegerOrValue
= [&]() {
157 OpAsmParser::UnresolvedOperand operand
;
158 auto res
= parser
.parseOptionalOperand(operand
);
160 // When encountering `[`, assume that this is a scalable index.
161 scalableVals
.push_back(parser
.parseOptionalLSquare().succeeded());
163 if (res
.has_value() && succeeded(res
.value())) {
164 values
.push_back(operand
);
165 integerVals
.push_back(ShapedType::kDynamic
);
166 if (valueTypes
&& parser
.parseColonType(valueTypes
->emplace_back()))
170 if (failed(parser
.parseInteger(integer
)))
172 integerVals
.push_back(integer
);
175 // If this is assumed to be a scalable index, verify that there's a closing
177 if (scalableVals
.back() && parser
.parseOptionalRSquare().failed())
181 if (parser
.parseCommaSeparatedList(delimiter
, parseIntegerOrValue
,
182 " in dynamic index list"))
183 return parser
.emitError(parser
.getNameLoc())
184 << "expected SSA value or integer";
185 integers
= parser
.getBuilder().getDenseI64ArrayAttr(integerVals
);
186 scalables
= parser
.getBuilder().getDenseBoolArrayAttr(scalableVals
);
190 bool mlir::detail::sameOffsetsSizesAndStrides(
191 OffsetSizeAndStrideOpInterface a
, OffsetSizeAndStrideOpInterface b
,
192 llvm::function_ref
<bool(OpFoldResult
, OpFoldResult
)> cmp
) {
193 if (a
.getStaticOffsets().size() != b
.getStaticOffsets().size())
195 if (a
.getStaticSizes().size() != b
.getStaticSizes().size())
197 if (a
.getStaticStrides().size() != b
.getStaticStrides().size())
199 for (auto it
: llvm::zip(a
.getMixedOffsets(), b
.getMixedOffsets()))
200 if (!cmp(std::get
<0>(it
), std::get
<1>(it
)))
202 for (auto it
: llvm::zip(a
.getMixedSizes(), b
.getMixedSizes()))
203 if (!cmp(std::get
<0>(it
), std::get
<1>(it
)))
205 for (auto it
: llvm::zip(a
.getMixedStrides(), b
.getMixedStrides()))
206 if (!cmp(std::get
<0>(it
), std::get
<1>(it
)))
211 unsigned mlir::detail::getNumDynamicEntriesUpToIdx(ArrayRef
<int64_t> staticVals
,
213 return std::count_if(staticVals
.begin(), staticVals
.begin() + idx
,
214 [&](int64_t val
) { return ShapedType::isDynamic(val
); });