[X86] Better handling of impossibly large stack frames (#124217)
[llvm-project.git] / mlir / test / Interfaces / TilingInterface / tile-and-fuse-using-scfforall.mlir
blob0bd2546e082b5afcc383c2c8f965436d3d73a5ec
1 // RUN: mlir-opt --transform-interpreter --cse --split-input-file %s | FileCheck %s
3 func.func @gemm_fill_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
4   %c0 = arith.constant 0 : index
5   %c1 = arith.constant 1 : index
6   %cst = arith.constant 0.0 : f32
7   %d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
8   %d1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
9   %init = tensor.empty(%d0, %d1) : tensor<?x?xf32>
10   %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<?x?xf32>) -> tensor<?x?xf32>
11   %gemm = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
12       outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
13   return %gemm : tensor<?x?xf32>
16 module attributes {transform.with_named_sequence} {
17   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
18     %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1
19       : (!transform.any_op) -> !transform.any_op
20     %a, %b = transform.test.fuse_using_forall %matmul [10, 20]
21       : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
22     transform.yield
23   }
25 //      CHECK: func.func @gemm_fill_fusion(
26 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
27 // CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
28 //      CHECK:   %[[INIT:.+]] = tensor.empty
29 //      CHECK:   scf.forall (%[[IV0:[a-zA-Z0-9]+]], %[[IV1:[a-zA-Z0-9]+]]) =
30 // CHECK-SAME:       shared_outs(%[[ITERARG0:.+]] = %[[INIT]])
31 //  CHECK-DAG:     %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0]
32 //  CHECK-DAG:     %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]]
33 //  CHECK-DAG:     %[[INIT_TILE:.+]] = tensor.extract_slice %[[ITERARG0]][%[[IV0]], %[[IV1]]]
34 //      CHECK:     %[[FILL_TILE:.+]] = linalg.fill
35 // CHECK-SAME:         outs(%[[INIT_TILE]] :
36 //      CHECK:     %[[GEMM_TILE:.+]] = linalg.matmul
37 // CHECK-SAME:         ins(%[[LHS_TILE]], %[[RHS_TILE]] :
38 // CHECK-SAME:         outs(%[[FILL_TILE]] :
39 //      CHECK:     scf.forall.in_parallel {
40 //      CHECK:       tensor.parallel_insert_slice %[[GEMM_TILE]] into %[[ITERARG0]][%[[IV0]], %[[IV1]]]
41 //      CHECK:     }
43 // -----
45 func.func @gemm_generic_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
46     %arg2 : tensor<?xf32>) -> tensor<?x?xf32> {
47   %c0 = arith.constant 0 : index
48   %c1 = arith.constant 1 : index
49   %cst = arith.constant 0.0 : f32
50   %d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
51   %d1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
52   %init = tensor.empty(%d0, %d1) : tensor<?x?xf32>
53   %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<?x?xf32>) -> tensor<?x?xf32>
54   %gemm = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
55       outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
56   %generic = linalg.generic {
57       indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>],
58       iterator_types = ["parallel", "parallel"]}
59       ins(%gemm, %arg2 : tensor<?x?xf32>, tensor<?xf32>) outs(%init : tensor<?x?xf32>) {
60     ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
61       %add = arith.addf %b0, %b1 : f32
62       linalg.yield %add : f32
63   } -> tensor<?x?xf32>
64   return %generic : tensor<?x?xf32>
67 module attributes {transform.with_named_sequence} {
68   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
69     %generic = transform.structured.match ops{["linalg.generic"]} in %arg1
70       : (!transform.any_op) -> !transform.any_op
71     %a, %b = transform.test.fuse_using_forall %generic [10, 20]
72       : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
73     transform.yield
74   }
76 //      CHECK: func.func @gemm_generic_fusion(
77 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
78 // CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>,
79 // CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<?xf32>)
80 //      CHECK:   %[[INIT:.+]] = tensor.empty
81 //      CHECK:   scf.forall (%[[IV0:[a-zA-Z0-9]+]], %[[IV1:[a-zA-Z0-9]+]]) =
82 // CHECK-SAME:       shared_outs(%[[ITERARG0:.+]] = %[[INIT]])
83 //  CHECK-DAG:     %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0]
84 //  CHECK-DAG:     %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]]
85 //  CHECK-DAG:     %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT]][%[[IV0]], %[[IV1]]]
86 //      CHECK:     %[[FILL_TILE:.+]] = linalg.fill
87 // CHECK-SAME:         outs(%[[INIT_TILE]] :
88 //      CHECK:     %[[GEMM_TILE:.+]] = linalg.matmul
89 // CHECK-SAME:         ins(%[[LHS_TILE]], %[[RHS_TILE]] :
90 // CHECK-SAME:         outs(%[[FILL_TILE]] :
91 //  CHECK-DAG:     %[[BIAS_TILE:.+]] = tensor.extract_slice %[[ARG2]][%[[IV1]]]
92 //  CHECK-DAG:     %[[OUTS_TILE:.+]] = tensor.extract_slice %[[ITERARG0]][%[[IV0]], %[[IV1]]]
93 //      CHECK:     %[[GENERIC_TILE:.+]] = linalg.generic
94 // CHECK-SAME:         ins(%[[GEMM_TILE]], %[[BIAS_TILE]] :
95 // CHECK-SAME:         outs(%[[OUTS_TILE]] :
96 //      CHECK:     scf.forall.in_parallel {
97 //      CHECK:       tensor.parallel_insert_slice %[[GENERIC_TILE]] into %[[ITERARG0]][%[[IV0]], %[[IV1]]]
98 //      CHECK:     }
100 // -----
102 func.func @reduction_sequence(%arg0: tensor<30x3xf32>) -> tensor<30x3xf32> {
103   %cst = arith.constant 0.000000e+00 : f32
104   %cst_0 = arith.constant 0xFF800000 : f32
105   %0 = tensor.empty() : tensor<30xf32>
106   %1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<30xf32>) -> tensor<30xf32>
107   %2 = linalg.generic {
108       indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
109       iterator_types = ["parallel", "reduction"]}
110       ins(%arg0 : tensor<30x3xf32>) outs(%1 : tensor<30xf32>) {
111     ^bb0(%arg1: f32, %arg2: f32):
112       %8 = arith.maximumf %arg2, %arg1 : f32
113       linalg.yield %8 : f32
114     } -> tensor<30xf32>
115   %3 = tensor.empty() : tensor<30x3xf32>
116   %4 = linalg.fill ins(%cst : f32) outs(%0 : tensor<30xf32>) -> tensor<30xf32>
117   %5:2 = linalg.generic {
118       indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>,
119                        affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>],
120       iterator_types = ["parallel", "reduction"]}
121       ins(%arg0, %2 : tensor<30x3xf32>, tensor<30xf32>) outs(%4, %3 : tensor<30xf32>, tensor<30x3xf32>) {
122     ^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32):
123       %8 = arith.subf %arg1, %arg2 : f32
124       %9 = math.exp %8 : f32
125       %10 = arith.addf %arg3, %9 : f32
126       linalg.yield %10, %9 : f32, f32
127     } -> (tensor<30xf32>, tensor<30x3xf32>)
128   %6 = linalg.generic {
129       indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>,
130                        affine_map<(d0, d1) -> (d0, d1)>],
131       iterator_types = ["parallel", "parallel"]}
132       ins(%5#1, %5#0 : tensor<30x3xf32>, tensor<30xf32>) outs(%3 : tensor<30x3xf32>) {
133     ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
134       %8 = arith.divf %arg1, %arg2 : f32
135       linalg.yield %8 : f32
136     } -> tensor<30x3xf32>
137   return %6 : tensor<30x3xf32>
140 module attributes {transform.with_named_sequence} {
141   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
142     %generics = transform.structured.match ops{["linalg.generic"]} in %arg1
143       : (!transform.any_op) -> !transform.any_op
144     %generic1, %generic2, %generic3 = transform.split_handle %generics
145       : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
146     %a, %b = transform.test.fuse_using_forall %generic3 [10]
147       : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
148     transform.yield
149   }
151 //       CHECK: func @reduction_sequence(%[[ARG0:.+]]: tensor<30x3xf32>)
152 //   CHECK-DAG:   %[[INIT0:.+]] = tensor.empty() : tensor<30xf32>
153 //   CHECK-DAG:   %[[INIT1:.+]] = tensor.empty() : tensor<30x3xf32>
154 //       CHECK:   %[[RESULT:[a-zA-Z0-9]+]] = scf.forall (%[[IV:[a-zA-Z0-9]+]])
155 //  CHECK-SAME:       shared_outs(%[[ITERARG0:[a-zA-Z0-9]+]] = %[[INIT1]])
156 //   CHECK-DAG:     %[[ARG0_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV]], 0]
157 //   CHECK-DAG:     %[[INIT0_SLICE:.+]] = tensor.extract_slice %[[INIT0]][%[[IV]]]
158 //       CHECK:     %[[FILL0:.+]] = linalg.fill
159 //  CHECK-SAME:         outs(%[[INIT0_SLICE]] :
160 //       CHECK:     %[[GENERIC0:.+]] = linalg.generic
161 //  CHECK-SAME:         ins(%[[ARG0_SLICE]] :
162 //  CHECK-SAME:         outs(%[[FILL0]] :
163 //       CHECK:     %[[FILL1:.+]] = linalg.fill
164 //  CHECK-SAME:         outs(%[[INIT0_SLICE]] :
165 //       CHECK:     %[[INIT1_SLICE:.+]] = tensor.extract_slice %[[INIT1]][%[[IV]], 0]
166 //       CHECK:     %[[GENERIC1:.+]]:2 = linalg.generic
167 //  CHECK-SAME:         ins(%[[ARG0_SLICE]], %[[GENERIC0]] :
168 //  CHECK-SAME:         outs(%[[FILL1]], %[[INIT1_SLICE]] :
169 //       CHECK:     %[[ITERARG0_SLICE:.+]] = tensor.extract_slice %[[ITERARG0]][%[[IV]], 0]
170 //       CHECK:     %[[GENERIC2:.+]] = linalg.generic
171 //  CHECK-SAME:         ins(%[[GENERIC1]]#1, %[[GENERIC1]]#0 :
172 //  CHECK-SAME:         outs(%[[ITERARG0_SLICE]] :
173 //       CHECK:     scf.forall.in_parallel {
174 //       CHECK:       tensor.parallel_insert_slice %[[GENERIC2]] into %[[ITERARG0]][%[[IV]], 0]
175 //       CHECK:     }
176 //       CHECK:   return %[[RESULT]]