[clang-tidy][NFC]remove deps of clang in clang tidy test (#116588)
[llvm-project.git] / mlir / test / Conversion / SCFToSPIRV / for.mlir
blob81661ec7a3a0603a446ac3bc234d216e9c33a137
1 // RUN: mlir-opt -convert-scf-to-spirv %s -o - | FileCheck %s
3 module attributes {
4   spirv.target_env = #spirv.target_env<
5     #spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
6 } {
8 func.func @loop_kernel(%arg2 : memref<10xf32, #spirv.storage_class<StorageBuffer>>, %arg3 : memref<10xf32, #spirv.storage_class<StorageBuffer>>) {
9   // CHECK: %[[LB:.*]] = spirv.Constant 4 : i32
10   %lb = arith.constant 4 : index
11   // CHECK: %[[UB:.*]] = spirv.Constant 42 : i32
12   %ub = arith.constant 42 : index
13   // CHECK: %[[STEP:.*]] = spirv.Constant 2 : i32
14   %step = arith.constant 2 : index
15   // CHECK:      spirv.mlir.loop {
16   // CHECK-NEXT:   spirv.Branch ^[[HEADER:.*]](%[[LB]] : i32)
17   // CHECK:      ^[[HEADER]](%[[INDVAR:.*]]: i32):
18   // CHECK:        %[[CMP:.*]] = spirv.SLessThan %[[INDVAR]], %[[UB]] : i32
19   // CHECK:        spirv.BranchConditional %[[CMP]], ^[[BODY:.*]], ^[[MERGE:.*]]
20   // CHECK:      ^[[BODY]]:
21   // CHECK:        %[[ZERO1:.*]] = spirv.Constant 0 : i32
22   // CHECK:        spirv.AccessChain {{%.*}}{{\[}}%[[ZERO1]], %[[INDVAR]]{{\]}}
23   // CHECK:        %[[ZERO2:.*]] = spirv.Constant 0 : i32
24   // CHECK:        spirv.AccessChain {{%.*}}[%[[ZERO2]], %[[INDVAR]]]
25   // CHECK:        %[[INCREMENT:.*]] = spirv.IAdd %[[INDVAR]], %[[STEP]] : i32
26   // CHECK:        spirv.Branch ^[[HEADER]](%[[INCREMENT]] : i32)
27   // CHECK:      ^[[MERGE]]
28   // CHECK:        spirv.mlir.merge
29   // CHECK:      }
30   scf.for %arg4 = %lb to %ub step %step {
31     %1 = memref.load %arg2[%arg4] : memref<10xf32, #spirv.storage_class<StorageBuffer>>
32     memref.store %1, %arg3[%arg4] : memref<10xf32, #spirv.storage_class<StorageBuffer>>
33   }
34   return
37 // CHECK-LABEL: @loop_yield
38 func.func @loop_yield(%arg2 : memref<10xf32, #spirv.storage_class<StorageBuffer>>, %arg3 : memref<10xf32, #spirv.storage_class<StorageBuffer>>) {
39   // CHECK: %[[LB:.*]] = spirv.Constant 4 : i32
40   %lb = arith.constant 4 : index
41   // CHECK: %[[UB:.*]] = spirv.Constant 42 : i32
42   %ub = arith.constant 42 : index
43   // CHECK: %[[STEP:.*]] = spirv.Constant 2 : i32
44   %step = arith.constant 2 : index
45   // CHECK: %[[INITVAR1:.*]] = spirv.Constant 0.000000e+00 : f32
46   %s0 = arith.constant 0.0 : f32
47   // CHECK: %[[INITVAR2:.*]] = spirv.Constant 1.000000e+00 : f32
48   %s1 = arith.constant 1.0 : f32
49   // CHECK: %[[VAR1:.*]] = spirv.Variable : !spirv.ptr<f32, Function>
50   // CHECK: %[[VAR2:.*]] = spirv.Variable : !spirv.ptr<f32, Function>
51   // CHECK: spirv.mlir.loop {
52   // CHECK:   spirv.Branch ^[[HEADER:.*]](%[[LB]], %[[INITVAR1]], %[[INITVAR2]] : i32, f32, f32)
53   // CHECK: ^[[HEADER]](%[[INDVAR:.*]]: i32, %[[CARRIED1:.*]]: f32, %[[CARRIED2:.*]]: f32):
54   // CHECK:   %[[CMP:.*]] = spirv.SLessThan %[[INDVAR]], %[[UB]] : i32
55   // CHECK:   spirv.BranchConditional %[[CMP]], ^[[BODY:.*]], ^[[MERGE:.*]]
56   // CHECK: ^[[BODY]]:
57   // CHECK:   %[[UPDATED:.*]] = spirv.FAdd %[[CARRIED1]], %[[CARRIED1]] : f32
58   // CHECK-DAG:   %[[INCREMENT:.*]] = spirv.IAdd %[[INDVAR]], %[[STEP]] : i32
59   // CHECK-DAG:   spirv.Store "Function" %[[VAR1]], %[[UPDATED]] : f32
60   // CHECK-DAG:   spirv.Store "Function" %[[VAR2]], %[[UPDATED]] : f32
61   // CHECK: spirv.Branch ^[[HEADER]](%[[INCREMENT]], %[[UPDATED]], %[[UPDATED]] : i32, f32, f32)
62   // CHECK: ^[[MERGE]]:
63   // CHECK:   spirv.mlir.merge
64   // CHECK: }
65   %result:2 = scf.for %i0 = %lb to %ub step %step iter_args(%si = %s0, %sj = %s1) -> (f32, f32) {
66     %sn = arith.addf %si, %si : f32
67     scf.yield %sn, %sn : f32, f32
68   }
69   // CHECK-DAG: %[[OUT1:.*]] = spirv.Load "Function" %[[VAR1]] : f32
70   // CHECK-DAG: %[[OUT2:.*]] = spirv.Load "Function" %[[VAR2]] : f32
71   // CHECK: spirv.Store "StorageBuffer" {{%.*}}, %[[OUT1]] : f32
72   // CHECK: spirv.Store "StorageBuffer" {{%.*}}, %[[OUT2]] : f32
73   memref.store %result#0, %arg3[%lb] : memref<10xf32, #spirv.storage_class<StorageBuffer>>
74   memref.store %result#1, %arg3[%ub] : memref<10xf32, #spirv.storage_class<StorageBuffer>>
75   return
78 } // end module