1 //===- InferShapeTest.cpp - unit tests for shape inference ----------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Dialect/MemRef/IR/MemRef.h"
10 #include "mlir/IR/AffineMap.h"
11 #include "mlir/IR/Builders.h"
12 #include "mlir/IR/BuiltinTypes.h"
13 #include "gtest/gtest.h"
16 using namespace mlir::memref
;
18 // Source memref has identity layout.
19 TEST(InferShapeTest
, inferRankReducedShapeIdentity
) {
22 auto sourceMemref
= MemRefType::get({10, 5}, b
.getIndexType());
23 auto reducedType
= SubViewOp::inferRankReducedResultType(
24 /*resultShape=*/{2}, sourceMemref
, {2, 3}, {1, 2}, {1, 1});
25 auto expectedType
= MemRefType::get(
26 {2}, b
.getIndexType(),
27 StridedLayoutAttr::get(&ctx
, /*offset=*/13, /*strides=*/{1}));
28 EXPECT_EQ(reducedType
, expectedType
);
31 // Source memref has non-identity layout.
32 TEST(InferShapeTest
, inferRankReducedShapeNonIdentity
) {
35 AffineExpr dim0
, dim1
;
36 bindDims(&ctx
, dim0
, dim1
);
37 auto sourceMemref
= MemRefType::get({10, 5}, b
.getIndexType(),
38 AffineMap::get(2, 0, 1000 * dim0
+ dim1
));
39 auto reducedType
= SubViewOp::inferRankReducedResultType(
40 /*resultShape=*/{2}, sourceMemref
, {2, 3}, {1, 2}, {1, 1});
41 auto expectedType
= MemRefType::get(
42 {2}, b
.getIndexType(),
43 StridedLayoutAttr::get(&ctx
, /*offset=*/2003, /*strides=*/{1}));
44 EXPECT_EQ(reducedType
, expectedType
);
47 TEST(InferShapeTest
, inferRankReducedShapeToScalar
) {
50 AffineExpr dim0
, dim1
;
51 bindDims(&ctx
, dim0
, dim1
);
52 auto sourceMemref
= MemRefType::get({10, 5}, b
.getIndexType(),
53 AffineMap::get(2, 0, 1000 * dim0
+ dim1
));
54 auto reducedType
= SubViewOp::inferRankReducedResultType(
55 /*resultShape=*/{}, sourceMemref
, {2, 3}, {1, 1}, {1, 1});
56 auto expectedType
= MemRefType::get(
58 StridedLayoutAttr::get(&ctx
, /*offset=*/2003, /*strides=*/{}));
59 EXPECT_EQ(reducedType
, expectedType
);