[clang-tidy][NFC]remove deps of clang in clang tidy test (#116588)
[llvm-project.git] / mlir / test / Dialect / Tensor / rewrite-as-constant.mlir
blob35ee6f1caf0d9ba24aad1d5c4a75434a6e4badb5
1 // RUN: mlir-opt -split-input-file -transform-interpreter %s | FileCheck %s
3 module attributes {transform.with_named_sequence} {
4   transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
5     %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
6     transform.apply_patterns to %func_op {
7       transform.apply_patterns.tensor.rewrite_as_constant
8     } : !transform.op<"func.func">
9     transform.yield
10   }
13 // CHECK-LABEL: func @tensor_generate_constant(
14 //       CHECK:   %[[cst:.*]] = arith.constant dense<5.000000e+00> : tensor<2x3x5xf32>
15 //       CHECK:   return %[[cst]]
16 func.func @tensor_generate_constant() -> tensor<2x3x5xf32> {
17   %cst = arith.constant 5.0 : f32
18   %0 = tensor.generate {
19     ^bb0(%arg0: index, %arg1: index, %arg2: index):
20     tensor.yield %cst : f32
21   } : tensor<2x3x5xf32>
22   return %0 : tensor<2x3x5xf32>
25 //         CHECK-LABEL: func @pad_of_ints(
26 //               CHECK: %[[cst:.*]] = arith.constant dense<[
27 // CHECK-SAME{LITERAL}:     [0, 0, 0, 0],
28 // CHECK-SAME{LITERAL}:     [0, 6, 7, 0],
29 // CHECK-SAME{LITERAL}:     [0, 8, 9, 0],
30 // CHECK-SAME{LITERAL}:     [0, 0, 0, 0]
31 // CHECK-SAME{LITERAL}:     ]> : tensor<4x4xi32>
32 //               CHECK: %[[cast:.*]] = tensor.cast %[[cst]] : tensor<4x4xi32> to tensor<?x?xi32>
33 //               CHECK: return %[[cast]]
34 func.func @pad_of_ints() -> tensor<?x?xi32> {
35   %init = arith.constant dense<[[6, 7], [8, 9]]> : tensor<2x2xi32>
36   %pad_value = arith.constant 0 : i32
38   %c1 = arith.constant 1 : index
40   %0 = tensor.pad %init low[%c1, %c1] high[%c1, %c1] {
41     ^bb0(%arg1: index, %arg2: index):
42       tensor.yield %pad_value : i32
43   } : tensor<2x2xi32> to tensor<?x?xi32>
45   return %0 : tensor<?x?xi32>
48 //         CHECK-LABEL: func @pad_of_floats(
49 //               CHECK: %[[cst:.*]] = arith.constant dense<[
50 // CHECK-SAME{LITERAL}:     [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00],
51 // CHECK-SAME{LITERAL}:     [0.000000e+00, 6.000000e+00, 7.000000e+00, 0.000000e+00],
52 // CHECK-SAME{LITERAL}:     [0.000000e+00, 8.000000e+00, 9.000000e+00, 0.000000e+00],
53 // CHECK-SAME{LITERAL}:     [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00]
54 // CHECK-SAME{LITERAL}:     ]> : tensor<4x4xf32>
55 //               CHECK: return %[[cst]]
57 func.func @pad_of_floats() -> tensor<4x4xf32> {
58   %init = arith.constant dense<[[6.0, 7.0], [8.0, 9.0]]> : tensor<2x2xf32>
59   %pad_value = arith.constant 0.0 : f32
61   %0 = tensor.pad %init low[1, 1] high[1, 1] {
62     ^bb0(%arg1: index, %arg2: index):
63       tensor.yield %pad_value : f32
64   } : tensor<2x2xf32> to tensor<4x4xf32>
66   return %0 : tensor<4x4xf32>
69 //         CHECK-LABEL: func @pad_of_ints_no_low_dims(
70 //               CHECK: %[[cst:.*]] = arith.constant dense<[
71 // CHECK-SAME{LITERAL}:     [6, 7, 0],
72 // CHECK-SAME{LITERAL}:     [8, 9, 0],
73 // CHECK-SAME{LITERAL}:     [0, 0, 0]
74 // CHECK-SAME{LITERAL}:     ]> : tensor<3x3xi32>
75 //               CHECK: return %[[cst]]
76 func.func @pad_of_ints_no_low_dims() -> tensor<3x3xi32> {
77   %init = arith.constant dense<[[6, 7], [8, 9]]> : tensor<2x2xi32>
78   %pad_value = arith.constant 0 : i32
80   %0 = tensor.pad %init low[0, 0] high[1, 1] {
81     ^bb0(%arg1: index, %arg2: index):
82       tensor.yield %pad_value : i32
83   } : tensor<2x2xi32> to tensor<3x3xi32>
85   return %0 : tensor<3x3xi32>
88 //         CHECK-LABEL: func @pad_of_ints_no_high_dims(
89 //               CHECK: %[[cst:.*]] = arith.constant dense<[
90 // CHECK-SAME{LITERAL}:     [0, 0, 0],
91 // CHECK-SAME{LITERAL}:     [0, 6, 7],
92 // CHECK-SAME{LITERAL}:     [0, 8, 9]
93 // CHECK-SAME{LITERAL}:     ]> : tensor<3x3xi32>
94 //               CHECK: return %[[cst]]
95 func.func @pad_of_ints_no_high_dims() -> tensor<3x3xi32> {
96   %init = arith.constant dense<[[6, 7], [8, 9]]> : tensor<2x2xi32>
97   %pad_value = arith.constant 0 : i32
99   %0 = tensor.pad %init low[1, 1] high[0, 0] {
100     ^bb0(%arg1: index, %arg2: index):
101       tensor.yield %pad_value : i32
102   } : tensor<2x2xi32> to tensor<3x3xi32>
104   return %0 : tensor<3x3xi32>
107 //         CHECK-LABEL: func @pad_multi_use_do_not_fold(
108 //               CHECK: %[[pad:.+]] = tensor.pad
109 //               CHECK: return %[[pad]]
110 func.func @pad_multi_use_do_not_fold() -> (tensor<?x?xi32>, tensor<2x2xi32>) {
111   %init = arith.constant dense<[[6, 7], [8, 9]]> : tensor<2x2xi32>
112   %pad_value = arith.constant 0 : i32
114   %c1 = arith.constant 1 : index
116   %0 = tensor.pad %init low[%c1, %c1] high[%c1, %c1] {
117     ^bb0(%arg1: index, %arg2: index):
118       tensor.yield %pad_value : i32
119   } : tensor<2x2xi32> to tensor<?x?xi32>
121   return %0, %init : tensor<?x?xi32>, tensor<2x2xi32>
124 // -----
126 module attributes {transform.with_named_sequence} {
127   transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
128     %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
129     transform.apply_patterns to %func_op {
130       transform.apply_patterns.tensor.rewrite_as_constant aggressive
131     } : !transform.op<"func.func">
132     transform.yield
133   }
136 //         CHECK-LABEL: func @pad_aggressive_fold(
137 //               CHECK: %[[init:.*]] = arith.constant dense<7> : tensor<2x2xi32>
138 //               CHECK: %[[cst:.*]] = arith.constant dense<[
139 // CHECK-SAME{LITERAL}:     [0, 0, 0, 0],
140 // CHECK-SAME{LITERAL}:     [0, 7, 7, 0],
141 // CHECK-SAME{LITERAL}:     [0, 7, 7, 0],
142 // CHECK-SAME{LITERAL}:     [0, 0, 0, 0]
143 // CHECK-SAME{LITERAL}:     ]> : tensor<4x4xi32>
144 //               CHECK: %[[cast:.*]] = tensor.cast %[[cst]] : tensor<4x4xi32> to tensor<?x?xi32>
145 //               CHECK: return %[[cast]]
146 func.func @pad_aggressive_fold() -> (tensor<?x?xi32>, tensor<2x2xi32>) {
147   %init = arith.constant dense<7> : tensor<2x2xi32>
148   %pad_value = arith.constant 0 : i32
150   %c1 = arith.constant 1 : index
152   %0 = tensor.pad %init low[%c1, %c1] high[%c1, %c1] {
153     ^bb0(%arg1: index, %arg2: index):
154       tensor.yield %pad_value : i32
155   } : tensor<2x2xi32> to tensor<?x?xi32>
157   return %0, %init : tensor<?x?xi32>, tensor<2x2xi32>