Fix GCC build problem with 288f05f related to SmallVector. (#116958)
[llvm-project.git] / mlir / test / Interfaces / InferShapedTypeOpInterface / resolve-shaped-type-result-dims.mlir
blob4fa7406f21042ed97c5f8b3c0efe78cbcda3f8aa
1 // RUN: mlir-opt %s -resolve-shaped-type-result-dims -split-input-file | FileCheck %s
3 func.func @result_shape(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>)
4     -> (index, index, index, index, index) {
5   %c0 = arith.constant 0 : index
6   %c1 = arith.constant 1 : index
7   %c2 = arith.constant 2 : index
8   %0:2 = "test.op_with_result_shape_interface"(%arg0, %arg1)
9       : (tensor<2x3x?xf32>, tensor<?x5xf32>) -> (tensor<?x5xf32>, tensor<2x3x?xf32>)
10   %1 = tensor.dim %0#0, %c0 : tensor<?x5xf32>
11   %2 = tensor.dim %0#0, %c1 : tensor<?x5xf32>
12   %3 = tensor.dim %0#1, %c0 : tensor<2x3x?xf32>
13   %4 = tensor.dim %0#1, %c1 : tensor<2x3x?xf32>
14   %5 = tensor.dim %0#1, %c2 : tensor<2x3x?xf32>
15   return %1, %2, %3, %4, %5 : index, index, index, index, index
17 // CHECK-LABEL: func @result_shape(
18 //  CHECK-SAME:   %[[ARG_0:[a-z0-9]*]]: tensor<2x3x?xf32>
19 //  CHECK-SAME:   %[[ARG_1:[a-z0-9]*]]: tensor<?x5xf32>)
20 //   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
21 //   CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
22 //   CHECK-DAG:   %[[C3:.+]] = arith.constant 3 : index
23 //   CHECK-DAG:   %[[C5:.+]] = arith.constant 5 : index
24 //   CHECK-DAG:   %[[D0:.+]] = tensor.dim %[[ARG_1]], %[[C0]]
25 //   CHECK-DAG:   %[[D1:.+]] = tensor.dim %[[ARG_0]], %[[C2]]
26 //       CHECK:   return %[[D0]], %[[C5]], %[[C2]], %[[C3]], %[[D1]]
28 // -----
30 func.func @result_shape_per_dim(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>)
31     -> (index, index, index, index, index) {
32   %c0 = arith.constant 0 : index
33   %c1 = arith.constant 1 : index
34   %c2 = arith.constant 2 : index
35   %0:2 = "test.op_with_result_shape_per_dim_interface"(%arg0, %arg1)
36       : (tensor<2x3x?xf32>, tensor<?x5xf32>) -> (tensor<?x5xf32>, tensor<2x3x?xf32>)
37   %1 = tensor.dim %0#0, %c0 : tensor<?x5xf32>
38   %2 = tensor.dim %0#0, %c1 : tensor<?x5xf32>
39   %3 = tensor.dim %0#1, %c0 : tensor<2x3x?xf32>
40   %4 = tensor.dim %0#1, %c1 : tensor<2x3x?xf32>
41   %5 = tensor.dim %0#1, %c2 : tensor<2x3x?xf32>
42   return %1, %2, %3, %4, %5 : index, index, index, index, index
44 // CHECK-LABEL: func @result_shape_per_dim(
45 //  CHECK-SAME:   %[[ARG_0:[a-z0-9]*]]: tensor<2x3x?xf32>
46 //  CHECK-SAME:   %[[ARG_1:[a-z0-9]*]]: tensor<?x5xf32>)
47 //   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
48 //   CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
49 //   CHECK-DAG:   %[[C3:.+]] = arith.constant 3 : index
50 //   CHECK-DAG:   %[[C5:.+]] = arith.constant 5 : index
51 //   CHECK-DAG:   %[[D0:.+]] = tensor.dim %[[ARG_1]], %[[C0]]
52 //   CHECK-DAG:   %[[D1:.+]] = tensor.dim %[[ARG_0]], %[[C2]]
53 //       CHECK:   return %[[D0]], %[[C5]], %[[C2]], %[[C3]], %[[D1]]