[clang-tidy][NFC]remove deps of clang in clang tidy test (#116588)
[llvm-project.git] / mlir / test / Dialect / Tensor / decompose-concat.mlir
blob2b1cb138ecda5bd4a45801455fa7008fa4bdf6ed
1 // RUN: mlir-opt -split-input-file -transform-interpreter -cse --mlir-print-local-scope %s | FileCheck %s
3 func.func @decompose_dynamic_concat(%arg0 : tensor<8x4xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
4   %0 = tensor.concat dim(1) %arg0, %arg1 : (tensor<8x4xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
5   return %0 : tensor<?x?xf32>
7 // CHECK-LABEL: func @decompose_dynamic_concat(
8 //  CHECK-SAME:     %[[ARG0:.+]]: tensor<8x4xf32>
9 //  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
11 //   CHECK-DAG:     %[[C1:.+]] = arith.constant 1 : index
12 //   CHECK-DAG:     %[[C0:.+]] = arith.constant 0 : index
13 //       CHECK:     %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
14 //       CHECK:     %[[DIM0:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
15 //       CHECK:     %[[CONCAT_SIZE:.+]] = affine.apply affine_map<()[s0] -> (s0 + 4)>()[%[[DIM0]]]
16 //       CHECK:     %[[EMPTY:.+]] = tensor.empty(%[[CONCAT_SIZE]]) : tensor<8x?xf32>
17 //       CHECK:     %[[SLICE0:.+]] = tensor.insert_slice %[[ARG0]] into %[[EMPTY]][0, 0] [8, 4] [1, 1] : tensor<8x4xf32> into tensor<8x?xf32>
18 //       CHECK:     %[[CONCAT:.+]] = tensor.insert_slice %[[ARG1]] into %[[SLICE0]][0, 4] [%[[DIM]], %[[DIM0]]] [1, 1] : tensor<?x?xf32> into tensor<8x?xf32>
19 //       CHECK:     %[[CAST:.+]] = tensor.cast %[[CONCAT]] : tensor<8x?xf32> to tensor<?x?xf32>
20 //       CHECK:     return %[[CAST]] : tensor<?x?xf32>
22 func.func @decompose_1d_concat(%arg0 : tensor<1xf32>,
23                             %arg1 : tensor<2xf32>,
24                             %arg2 : tensor<3xf32>,
25                             %arg3: tensor<4xf32>) -> tensor<10xf32> {
26   %0 = tensor.concat dim(0) %arg0, %arg1, %arg2, %arg3
27              : (tensor<1xf32>, tensor<2xf32>, tensor<3xf32>, tensor<4xf32>) -> tensor<10xf32>
28   return %0 : tensor<10xf32>
30 // CHECK-LABEL: func @decompose_1d_concat
31 //       CHECK:    tensor.empty() : tensor<10xf32>
32 //       CHECK:    tensor.insert_slice %{{.*}}[0] [1] [1] : tensor<1xf32> into tensor<10xf32>
33 //       CHECK:    tensor.insert_slice %{{.*}}[1] [2] [1] : tensor<2xf32> into tensor<10xf32>
34 //       CHECK:    tensor.insert_slice %{{.*}}[3] [3] [1] : tensor<3xf32> into tensor<10xf32>
35 //       CHECK:    %[[CONCAT:.+]] = tensor.insert_slice %{{.*}}[6] [4] [1] : tensor<4xf32> into tensor<10xf32>
36 //       CHECK:    return %[[CONCAT]] : tensor<10xf32>
38 func.func @decompose_static_concat_dim(%arg0 : tensor<1x?x64xf32>,
39                                %arg1: tensor<1x?x64xf32>) -> tensor<1x?x128xf32> {
40   %0 = tensor.concat dim(2) %arg0, %arg1
41              : (tensor<1x?x64xf32>, tensor<1x?x64xf32>) -> tensor<1x?x128xf32>
42   return %0 : tensor<1x?x128xf32>
44 // CHECK-LABEL: func @decompose_static_concat_dim(
45 //  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<1x?x64xf32>,
46 //  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<1x?x64xf32>)
47 //   CHECK-DAG:     %[[C1:.+]] = arith.constant 1 : index
48 //       CHECK:     %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<1x?x64xf32>
49 //       CHECK:     %[[DIM1:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<1x?x64xf32>
50 //       CHECK:    tensor.empty(%[[DIM]]) : tensor<1x?x128xf32>
51 //       CHECK:    tensor.insert_slice %{{.*}}[0, 0, 0] [1, %[[DIM]], 64] [1, 1, 1] : tensor<1x?x64xf32> into tensor<1x?x128xf32>
52 //       CHECK:    %[[CONCAT:.+]] = tensor.insert_slice %{{.*}}[0, 0, 64] [1, %[[DIM1]], 64] [1, 1, 1] : tensor<1x?x64xf32> into tensor<1x?x128xf32>
53 //       CHECK:    return %[[CONCAT]] : tensor<1x?x128xf32>
56 func.func @decompose_dynamic_into_static_concat_dim(%arg0 : tensor<1x?x?xf32>,
57                                %arg1: tensor<1x?x?xf32>) -> tensor<1x?x128xf32> {
58   %0 = tensor.concat dim(2) %arg0, %arg1
59              : (tensor<1x?x?xf32>, tensor<1x?x?xf32>) -> tensor<1x?x128xf32>
60   return %0 : tensor<1x?x128xf32>
62 // CHECK-LABEL: func @decompose_dynamic_into_static_concat_dim(
63 //  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<1x?x?xf32>,
64 //  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<1x?x?xf32>)
65 //   CHECK-DAG:     %[[C1:.+]] = arith.constant 1 : index
66 //   CHECK-DAG:     %[[C2:.+]] = arith.constant 2 : index
67 //       CHECK:     %[[T0_DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<1x?x?xf32>
68 //       CHECK:     %[[T0_DIM2:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<1x?x?xf32>
69 //       CHECK:     %[[T1_DIM1:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<1x?x?xf32>
70 //       CHECK:     %[[T1_DIM2:.+]] = tensor.dim %[[ARG1]], %[[C2]] : tensor<1x?x?xf32>
71 //       CHECK:     %[[CONCAT_DIM:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%[[T0_DIM2]], %[[T1_DIM2]]]
72 //       CHECK:     tensor.empty(%[[T0_DIM1]], %[[CONCAT_DIM]]) : tensor<1x?x?xf32>
73 //       CHECK:     tensor.insert_slice %{{.*}}[0, 0, 0] [1, %[[T0_DIM1]], %[[T0_DIM2]]] [1, 1, 1]
74 //  CHECK-SAME:       tensor<1x?x?xf32> into tensor<1x?x?xf32>
75 //       CHECK:     %[[CONCAT:.+]] = tensor.insert_slice %{{.*}}[0, 0, %[[T0_DIM2]]] [1, %[[T1_DIM1]], %[[T1_DIM2]]] [1, 1, 1]
76 //  CHECK-SAME:        tensor<1x?x?xf32> into tensor<1x?x?xf32>
77 //       CHECK:     %[[CAST:.+]] = tensor.cast %[[CONCAT]] : tensor<1x?x?xf32> to tensor<1x?x128xf32>
78 //       CHECK:     return %[[CAST]] : tensor<1x?x128xf32>
80 module attributes {transform.with_named_sequence} {
81   transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) {
82     %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
83     transform.apply_patterns to %func_op {
84       transform.apply_patterns.tensor.decompose_concat
85     } : !transform.op<"func.func">
86     transform.yield
87   }