1 //===- InferTypeOpInterfaceTest.cpp - Unit Test for type interface --------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Interfaces/InferTypeOpInterface.h"
10 #include "mlir/Dialect/Arith/IR/Arith.h"
11 #include "mlir/Dialect/Func/IR/FuncOps.h"
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/BuiltinOps.h"
14 #include "mlir/IR/Dialect.h"
15 #include "mlir/IR/DialectImplementation.h"
16 #include "mlir/IR/ImplicitLocOpBuilder.h"
17 #include "mlir/IR/OpDefinition.h"
18 #include "mlir/IR/OpImplementation.h"
19 #include "mlir/Parser/Parser.h"
21 #include <gtest/gtest.h>
25 class ValueShapeRangeTest
: public testing::Test
{
27 void SetUp() override
{
28 const char *ir
= R
"MLIR(
29 func.func @map(%arg : tensor<1xi64>) {
30 %0 = arith.constant dense<[10]> : tensor<1xi64>
31 %1 = arith.addi %arg, %0 : tensor<1xi64>
36 registry
.insert
<func::FuncDialect
, arith::ArithDialect
>();
37 ctx
.appendDialectRegistry(registry
);
38 module
= parseSourceString
<ModuleOp
>(ir
, &ctx
);
40 mapFn
= cast
<func::FuncOp
>(module
->front());
43 // Create ValueShapeRange on the arith.addi operation.
44 ValueShapeRange
addiRange() {
45 auto &fnBody
= mapFn
.getBody();
46 return std::next(fnBody
.front().begin())->getOperands();
49 DialectRegistry registry
;
51 OwningOpRef
<ModuleOp
> module
;
55 TEST_F(ValueShapeRangeTest
, ShapesFromValues
) {
56 ValueShapeRange range
= addiRange();
58 EXPECT_FALSE(range
.getValueAsShape(0));
59 ASSERT_TRUE(range
.getValueAsShape(1));
60 EXPECT_TRUE(range
.getValueAsShape(1).hasRank());
61 EXPECT_EQ(range
.getValueAsShape(1).getRank(), 1);
62 EXPECT_EQ(range
.getValueAsShape(1).getDimSize(0), 10);
63 EXPECT_EQ(range
.getShape(1).getRank(), 1);
64 EXPECT_EQ(range
.getShape(1).getDimSize(0), 1);
67 TEST_F(ValueShapeRangeTest
, MapValuesToShapes
) {
68 ValueShapeRange range
= addiRange();
69 ShapedTypeComponents
fixed(SmallVector
<int64_t>{30});
70 auto mapping
= [&](Value val
) -> ShapeAdaptor
{
71 if (val
== mapFn
.getArgument(0))
75 range
.setValueToShapeMapping(mapping
);
77 ASSERT_TRUE(range
.getValueAsShape(0));
78 EXPECT_TRUE(range
.getValueAsShape(0).hasRank());
79 EXPECT_EQ(range
.getValueAsShape(0).getRank(), 1);
80 EXPECT_EQ(range
.getValueAsShape(0).getDimSize(0), 30);
81 ASSERT_TRUE(range
.getValueAsShape(1));
82 EXPECT_TRUE(range
.getValueAsShape(1).hasRank());
83 EXPECT_EQ(range
.getValueAsShape(1).getRank(), 1);
84 EXPECT_EQ(range
.getValueAsShape(1).getDimSize(0), 10);
87 TEST_F(ValueShapeRangeTest
, SettingShapes
) {
88 ShapedTypeComponents
shape(SmallVector
<int64_t>{10, 20});
89 ValueShapeRange range
= addiRange();
90 auto mapping
= [&](Value val
) -> ShapeAdaptor
{
91 if (val
== mapFn
.getArgument(0))
95 range
.setOperandShapeMapping(mapping
);
97 ASSERT_TRUE(range
.getShape(0));
98 EXPECT_EQ(range
.getShape(0).getRank(), 2);
99 EXPECT_EQ(range
.getShape(0).getDimSize(0), 10);
100 EXPECT_EQ(range
.getShape(0).getDimSize(1), 20);
101 ASSERT_TRUE(range
.getShape(1));
102 EXPECT_EQ(range
.getShape(1).getRank(), 1);
103 EXPECT_EQ(range
.getShape(1).getDimSize(0), 1);
104 EXPECT_FALSE(range
.getShape(2));