1 // RUN: mlir-opt %s -resolve-shaped-type-result-dims -split-input-file | FileCheck %s
3 func.func @result_shape(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>)
4 -> (index, index, index, index, index) {
5 %c0 = arith.constant 0 : index
6 %c1 = arith.constant 1 : index
7 %c2 = arith.constant 2 : index
8 %0:2 = "test.op_with_result_shape_interface"(%arg0, %arg1)
9 : (tensor<2x3x?xf32>, tensor<?x5xf32>) -> (tensor<?x5xf32>, tensor<2x3x?xf32>)
10 %1 = tensor.dim %0#0, %c0 : tensor<?x5xf32>
11 %2 = tensor.dim %0#0, %c1 : tensor<?x5xf32>
12 %3 = tensor.dim %0#1, %c0 : tensor<2x3x?xf32>
13 %4 = tensor.dim %0#1, %c1 : tensor<2x3x?xf32>
14 %5 = tensor.dim %0#1, %c2 : tensor<2x3x?xf32>
15 return %1, %2, %3, %4, %5 : index, index, index, index, index
17 // CHECK-LABEL: func @result_shape(
18 // CHECK-SAME: %[[ARG_0:[a-z0-9]*]]: tensor<2x3x?xf32>
19 // CHECK-SAME: %[[ARG_1:[a-z0-9]*]]: tensor<?x5xf32>)
20 // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
21 // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
22 // CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
23 // CHECK-DAG: %[[C5:.+]] = arith.constant 5 : index
24 // CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG_1]], %[[C0]]
25 // CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG_0]], %[[C2]]
26 // CHECK: return %[[D0]], %[[C5]], %[[C2]], %[[C3]], %[[D1]]
30 func.func @result_shape_per_dim(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>)
31 -> (index, index, index, index, index) {
32 %c0 = arith.constant 0 : index
33 %c1 = arith.constant 1 : index
34 %c2 = arith.constant 2 : index
35 %0:2 = "test.op_with_result_shape_per_dim_interface"(%arg0, %arg1)
36 : (tensor<2x3x?xf32>, tensor<?x5xf32>) -> (tensor<?x5xf32>, tensor<2x3x?xf32>)
37 %1 = tensor.dim %0#0, %c0 : tensor<?x5xf32>
38 %2 = tensor.dim %0#0, %c1 : tensor<?x5xf32>
39 %3 = tensor.dim %0#1, %c0 : tensor<2x3x?xf32>
40 %4 = tensor.dim %0#1, %c1 : tensor<2x3x?xf32>
41 %5 = tensor.dim %0#1, %c2 : tensor<2x3x?xf32>
42 return %1, %2, %3, %4, %5 : index, index, index, index, index
44 // CHECK-LABEL: func @result_shape_per_dim(
45 // CHECK-SAME: %[[ARG_0:[a-z0-9]*]]: tensor<2x3x?xf32>
46 // CHECK-SAME: %[[ARG_1:[a-z0-9]*]]: tensor<?x5xf32>)
47 // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
48 // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
49 // CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
50 // CHECK-DAG: %[[C5:.+]] = arith.constant 5 : index
51 // CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG_1]], %[[C0]]
52 // CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG_0]], %[[C2]]
53 // CHECK: return %[[D0]], %[[C5]], %[[C2]], %[[C3]], %[[D1]]