[clang-tidy][NFC]remove deps of clang in clang tidy test (#116588)
[llvm-project.git] / mlir / test / Dialect / Tensor / mesh-spmdization.mlir
blob5443eea83aa2d89272440c5eedab49ffc66b65c1
1 // RUN: mlir-opt \
2 // RUN:   --pass-pipeline="builtin.module(func.func(mesh-spmdization,test-constant-fold))" \
3 // RUN:   %s | FileCheck %s
5 mesh.mesh @mesh_1d_4(shape = 4)
7 // CHECK-LABEL: func @tensor_empty_static_sharded_dims_offsets
8 func.func @tensor_empty_static_sharded_dims_offsets() -> () {
9   %b = tensor.empty() : tensor<8x16xf32>
10   %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !mesh.sharding
11   %sharded= mesh.shard %b to %sharding : tensor<8x16xf32>
12   // CHECK:  %[[sharding:.*]] = mesh.sharding @mesh_1d_4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !mesh.sharding
13   // CHECK:  %[[proc_linear_idx:.*]] = mesh.process_linear_index on @mesh_1d_4 : index
14   // CHECK:  %[[V0:.*]]:2 = mesh.shard_shape 8x16 %[[sharding]] %[[proc_linear_idx]] : index, index
15   // CHECK:  tensor.empty(%[[V0]]#0) : tensor<?x16xf32>
17   return
20 // CHECK-LABEL: func @tensor_empty_dynamic_sharded_dims_offsets
21 // CHECK-SAME: %[[A0:.*]]: index
22 func.func @tensor_empty_dynamic_sharded_dims_offsets(%arg0 : index) -> () {
23   %b = tensor.empty(%arg0) : tensor<8x?xf32>
24   %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !mesh.sharding
25   %sharded= mesh.shard %b to %sharding : tensor<8x?xf32>
26   // CHECK:  %[[sharding:.*]] = mesh.sharding @mesh_1d_4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !mesh.sharding
27   // CHECK:  %[[proc_linear_idx:.*]] = mesh.process_linear_index on @mesh_1d_4 : index
28   // CHECK:  %[[V0:.*]]:2 = mesh.shard_shape 8x? %[[sharding]] %[[proc_linear_idx]] : index, index
29   // CHECK:  tensor.empty(%[[V0]]#0, %[[A0]]) : tensor<?x?xf32>
31   return
34 // CHECK-LABEL: func @tensor_empty_same_static_dims_sizes
35 func.func @tensor_empty_same_static_dims_sizes() -> () {
36   %b = tensor.empty() : tensor<16x16xf32>
37   %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_offsets = [0, 4, 8, 12, 16] : !mesh.sharding
38   %sharded= mesh.shard %b to %sharding : tensor<16x16xf32>
39   // CHECK-NEXT:  tensor.empty() : tensor<4x16xf32>
41   return