[X86] Better handling of impossibly large stack frames (#124217)
[llvm-project.git] / mlir / test / Interfaces / TilingInterface / tile-fuse-and-yield-using-scfforall.mlir
blob8fc8f3245be159adef690160c4546b554822d6dd
1 // RUN: mlir-opt -transform-interpreter -cse -split-input-file %s | FileCheck %s
3 func.func @gemm_gemm_fusion_yield_both(%lhs0 : tensor<?x?xf32>, %rhs0 : tensor<?x?xf32>, %rhs1 : tensor<?x?xf32>,
4     %init0 : tensor<?x?xf32>, %init1 : tensor<?x?xf32>)
5     -> (tensor<?x?xf32>, tensor<?x?xf32>) {
6   %c0 = arith.constant 0 : index
7   %c1 = arith.constant 1 : index
8   %cst = arith.constant 0.0 : f32
9   %d0 = tensor.dim %lhs0, %c0 : tensor<?x?xf32>
10   %d1 = tensor.dim %rhs0, %c1 : tensor<?x?xf32>
11   %fill0 = linalg.fill ins(%cst : f32) outs(%init0 : tensor<?x?xf32>) -> tensor<?x?xf32>
12   %gemm0 = linalg.matmul
13       ins(%lhs0, %rhs0 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%fill0 : tensor<?x?xf32>) -> tensor<?x?xf32>
14   %d2 = tensor.dim %rhs1, %c1 : tensor<?x?xf32>
15   %fill1 = linalg.fill ins(%cst : f32) outs(%init1 : tensor<?x?xf32>) -> tensor<?x?xf32>
16   %gemm1 = linalg.matmul
17       ins(%gemm0, %rhs1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%fill1 : tensor<?x?xf32>) -> tensor<?x?xf32>
18   return %gemm0, %gemm1 : tensor<?x?xf32>, tensor<?x?xf32>
21 module attributes {transform.with_named_sequence} {
22   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
23     %matmuls = transform.structured.match ops{["linalg.matmul"]} in %arg1
24       : (!transform.any_op) -> !transform.any_op
25     %mm1, %mm2 = transform.split_handle %matmuls
26       : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
27     %a, %b = transform.test.fuse_and_yield %mm2 [10] use_forall true
28       : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
29     transform.yield
30   }
32 //      CHECK: func.func @gemm_gemm_fusion_yield_both(
33 // CHECK-SAME:     %[[LHS0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
34 // CHECK-SAME:     %[[RHS0:[a-zA-Z0-9]+]]: tensor<?x?xf32>,
35 // CHECK-SAME:     %[[RHS1:[a-zA-Z0-9]+]]: tensor<?x?xf32>,
36 // CHECK-SAME:     %[[INIT0:[a-zA-Z0-9]+]]: tensor<?x?xf32>,
37 // CHECK-SAME:     %[[INIT1:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
38 //  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
39 //  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
40 //      CHECK:   %[[RESULT:.+]]:2 = scf.forall (%[[IV:[a-zA-Z0-9]+]]) =
41 // CHECK-SAME:       shared_outs(%[[ITERARG0:[a-zA-Z0-9]+]] = %[[INIT1]], %[[ITERARG1:[a-zA-Z0-9]+]] = %[[INIT0]])
42 //  CHECK-DAG:     %[[LHS0_TILE:.+]] = tensor.extract_slice %[[LHS0]][%[[IV]], 0]
43 //  CHECK-DAG:     %[[RHS0_TILE:.+]] = tensor.extract_slice %[[RHS0]][0, 0]
44 //  CHECK-DAG:     %[[INIT0_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV]], 0]
45 //      CHECK:     %[[FILL0_TILE:.+]] = linalg.fill
46 // CHECK-SAME:         outs(%[[INIT0_TILE]] :
47 //      CHECK:     %[[GEMM0_TILE:.+]] = linalg.matmul
48 // CHECK-SAME:         ins(%[[LHS0_TILE]], %[[RHS0_TILE]] :
49 // CHECK-SAME:         outs(%[[FILL0_TILE]] :
50 //  CHECK-DAG:     %[[RHS1_TILE:.+]] = tensor.extract_slice %[[RHS1]][0, 0]
51 //  CHECK-DAG:     %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ITERARG0]][%[[IV]], 0]
52 //      CHECK:     %[[FILL1_TILE:.+]] = linalg.fill
53 // CHECK-SAME:         outs(%[[INIT1_TILE]] :
54 //      CHECK:     %[[GEMM1_TILE:.+]] = linalg.matmul
55 // CHECK-SAME:         ins(%[[GEMM0_TILE]], %[[RHS1_TILE]] :
56 // CHECK-SAME:         outs(%[[FILL1_TILE]] :
57 //      CHECK:     scf.forall.in_parallel {
58 //      CHECK:       tensor.parallel_insert_slice %[[GEMM1_TILE]] into %[[ITERARG0]][%[[IV]], 0]
59 //      CHECK:       tensor.parallel_insert_slice %[[GEMM0_TILE]] into %[[ITERARG1]][%[[IV]], 0]
60 //      CHECK:   return %[[RESULT]]#1, %[[RESULT]]#0