Fix GCC build problem with 288f05f related to SmallVector. (#116958)
[llvm-project.git] / mlir / test / Conversion / MathToSPIRV / math-to-opencl-spirv.mlir
blob393a910c1fb1d740bcbf30dd999b9b042b78e0bd
1 // RUN: mlir-opt -split-input-file -convert-math-to-spirv -verify-diagnostics %s -o - | FileCheck %s
3 module attributes { spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Kernel], []>, #spirv.resource_limits<>> } {
5 // CHECK-LABEL: @float32_unary_scalar
6 func.func @float32_unary_scalar(%arg0: f32) {
7   // CHECK: spirv.CL.atan %{{.*}}: f32
8   %0 = math.atan %arg0 : f32
9   // CHECK: spirv.CL.cos %{{.*}}: f32
10   %1 = math.cos %arg0 : f32
11   // CHECK: spirv.CL.exp %{{.*}}: f32
12   %2 = math.exp %arg0 : f32
13   // CHECK: %[[EXP:.+]] = spirv.CL.exp %arg0
14   // CHECK: %[[ONE:.+]] = spirv.Constant 1.000000e+00 : f32
15   // CHECK: spirv.FSub %[[EXP]], %[[ONE]]
16   %3 = math.expm1 %arg0 : f32
17   // CHECK: spirv.CL.log %{{.*}}: f32
18   %4 = math.log %arg0 : f32
19   // CHECK: %[[ONE:.+]] = spirv.Constant 1.000000e+00 : f32
20   // CHECK: %[[ADDONE:.+]] = spirv.FAdd %[[ONE]], %{{.+}}
21   // CHECK: spirv.CL.log %[[ADDONE]]
22   %5 = math.log1p %arg0 : f32
23   // CHECK: %[[LOG2_RECIPROCAL:.+]] = spirv.Constant 1.44269502 : f32
24   // CHECK: %[[LOG0:.+]] = spirv.CL.log {{.+}}
25   // CHECK: spirv.FMul %[[LOG0]], %[[LOG2_RECIPROCAL]]
26   %6 = math.log2 %arg0 : f32
27   // CHECK: %[[LOG10_RECIPROCAL:.+]] = spirv.Constant 0.434294492 : f32
28   // CHECK: %[[LOG1:.+]] = spirv.CL.log {{.+}}
29   // CHECK: spirv.FMul %[[LOG1]], %[[LOG10_RECIPROCAL]]
30   %7 = math.log10 %arg0 : f32
31   // CHECK: spirv.CL.rint %{{.*}}: f32
32   %8 = math.roundeven %arg0 : f32
33   // CHECK: spirv.CL.rsqrt %{{.*}}: f32
34   %9 = math.rsqrt %arg0 : f32
35   // CHECK: spirv.CL.sqrt %{{.*}}: f32
36   %10 = math.sqrt %arg0 : f32
37   // CHECK: spirv.CL.tanh %{{.*}}: f32
38   %11 = math.tanh %arg0 : f32
39   // CHECK: spirv.CL.sin %{{.*}}: f32
40   %12 = math.sin %arg0 : f32
41   // CHECK: spirv.CL.fabs %{{.*}}: f32
42   %13 = math.absf %arg0 : f32
43   // CHECK: spirv.CL.ceil %{{.*}}: f32
44   %14 = math.ceil %arg0 : f32
45   // CHECK: spirv.CL.floor %{{.*}}: f32
46   %15 = math.floor %arg0 : f32
47   // CHECK: spirv.CL.erf %{{.*}}: f32
48   %16 = math.erf %arg0 : f32
49   // CHECK: spirv.CL.round %{{.*}}: f32
50   %17 = math.round %arg0 : f32
51   return
54 // CHECK-LABEL: @float32_unary_vector
55 func.func @float32_unary_vector(%arg0: vector<3xf32>) {
56   // CHECK: spirv.CL.atan %{{.*}}: vector<3xf32>
57   %0 = math.atan %arg0 : vector<3xf32>
58   // CHECK: spirv.CL.cos %{{.*}}: vector<3xf32>
59   %1 = math.cos %arg0 : vector<3xf32>
60   // CHECK: spirv.CL.exp %{{.*}}: vector<3xf32>
61   %2 = math.exp %arg0 : vector<3xf32>
62   // CHECK: %[[EXP:.+]] = spirv.CL.exp %arg0
63   // CHECK: %[[ONE:.+]] = spirv.Constant dense<1.000000e+00> : vector<3xf32>
64   // CHECK: spirv.FSub %[[EXP]], %[[ONE]]
65   %3 = math.expm1 %arg0 : vector<3xf32>
66   // CHECK: spirv.CL.log %{{.*}}: vector<3xf32>
67   %4 = math.log %arg0 : vector<3xf32>
68   // CHECK: %[[ONE:.+]] = spirv.Constant dense<1.000000e+00> : vector<3xf32>
69   // CHECK: %[[ADDONE:.+]] = spirv.FAdd %[[ONE]], %{{.+}}
70   // CHECK: spirv.CL.log %[[ADDONE]]
71   %5 = math.log1p %arg0 : vector<3xf32>
72   // CHECK: %[[LOG2_RECIPROCAL:.+]] = spirv.Constant dense<1.44269502> : vector<3xf32>
73   // CHECK: %[[LOG0:.+]] = spirv.CL.log {{.+}}
74   // CHECK: spirv.FMul %[[LOG0]], %[[LOG2_RECIPROCAL]]
75   %6 = math.log2 %arg0 : vector<3xf32>
76   // CHECK: %[[LOG10_RECIPROCAL:.+]] = spirv.Constant dense<0.434294492> : vector<3xf32>
77   // CHECK: %[[LOG1:.+]] = spirv.CL.log {{.+}}
78   // CHECK: spirv.FMul %[[LOG1]], %[[LOG10_RECIPROCAL]]
79   %7 = math.log10 %arg0 : vector<3xf32>
80   // CHECK: spirv.CL.rint %{{.*}}: vector<3xf32>
81   %8 = math.roundeven %arg0 : vector<3xf32>
82   // CHECK: spirv.CL.rsqrt %{{.*}}: vector<3xf32>
83   %9 = math.rsqrt %arg0 : vector<3xf32>
84   // CHECK: spirv.CL.sqrt %{{.*}}: vector<3xf32>
85   %10 = math.sqrt %arg0 : vector<3xf32>
86   // CHECK: spirv.CL.tanh %{{.*}}: vector<3xf32>
87   %11 = math.tanh %arg0 : vector<3xf32>
88   // CHECK: spirv.CL.sin %{{.*}}: vector<3xf32>
89   %12 = math.sin %arg0 : vector<3xf32>
90   return
93 // CHECK-LABEL: @float32_binary_scalar
94 func.func @float32_binary_scalar(%lhs: f32, %rhs: f32) {
95   // CHECK: spirv.CL.atan2 %{{.*}}: f32
96   %0 = math.atan2 %lhs, %rhs : f32
97   // CHECK: spirv.CL.pow %{{.*}}: f32
98   %1 = math.powf %lhs, %rhs : f32
99   return
102 // CHECK-LABEL: @float32_binary_vector
103 func.func @float32_binary_vector(%lhs: vector<4xf32>, %rhs: vector<4xf32>) {
104   // CHECK: spirv.CL.atan2 %{{.*}}: vector<4xf32>
105   %0 = math.atan2 %lhs, %rhs : vector<4xf32>
106   // CHECK: spirv.CL.pow %{{.*}}: vector<4xf32>
107   %1 = math.powf %lhs, %rhs : vector<4xf32>
108   return
111 // CHECK-LABEL: @float32_ternary_scalar
112 func.func @float32_ternary_scalar(%a: f32, %b: f32, %c: f32) {
113   // CHECK: spirv.CL.fma %{{.*}}: f32
114   %0 = math.fma %a, %b, %c : f32
115   return
118 // CHECK-LABEL: @float32_ternary_vector
119 func.func @float32_ternary_vector(%a: vector<4xf32>, %b: vector<4xf32>,
120                             %c: vector<4xf32>) {
121   // CHECK: spirv.CL.fma %{{.*}}: vector<4xf32>
122   %0 = math.fma %a, %b, %c : vector<4xf32>
123   return
126 // CHECK-LABEL: @int_unary
127 func.func @int_unary(%arg0: i32) {
128   // CHECK: spirv.CL.s_abs %{{.*}}
129   %0 = math.absi %arg0 : i32
130   return
133 } // end module
135 // -----
137 module attributes {
138   spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader], []>, #spirv.resource_limits<>>
139 } {
141 // 2-D vectors are not supported.
143 // CHECK-LABEL: @vector_2d
144 func.func @vector_2d(%arg0: vector<2x2xf32>) {
145   // CHECK-NEXT: math.atan {{.+}} : vector<2x2xf32>
146   %0 = math.atan %arg0 : vector<2x2xf32>
147   // CHECK-NEXT: math.cos {{.+}} : vector<2x2xf32>
148   %1 = math.cos %arg0 : vector<2x2xf32>
149   // CHECK-NEXT: math.exp {{.+}} : vector<2x2xf32>
150   %2 = math.exp %arg0 : vector<2x2xf32>
151   // CHECK-NEXT: math.absf {{.+}} : vector<2x2xf32>
152   %3 = math.absf %arg0 : vector<2x2xf32>
153   // CHECK-NEXT: math.ceil {{.+}} : vector<2x2xf32>
154   %4 = math.ceil %arg0 : vector<2x2xf32>
155   // CHECK-NEXT: math.floor {{.+}} : vector<2x2xf32>
156   %5 = math.floor %arg0 : vector<2x2xf32>
157   // CHECK-NEXT: math.powf {{.+}}, {{%.+}} : vector<2x2xf32>
158   %6 = math.powf %arg0, %arg0 : vector<2x2xf32>
159   // CHECK-NEXT: return
160   return
163 // Tensors are not supported.
165 // CHECK-LABEL: @tensor_1d
166 func.func @tensor_1d(%arg0: tensor<2xf32>) {
167   // CHECK-NEXT: math.atan {{.+}} : tensor<2xf32>
168   %0 = math.atan %arg0 : tensor<2xf32>
169   // CHECK-NEXT: math.cos {{.+}} : tensor<2xf32>
170   %1 = math.cos %arg0 : tensor<2xf32>
171   // CHECK-NEXT: math.exp {{.+}} : tensor<2xf32>
172   %2 = math.exp %arg0 : tensor<2xf32>
173   // CHECK-NEXT: math.absf {{.+}} : tensor<2xf32>
174   %3 = math.absf %arg0 : tensor<2xf32>
175   // CHECK-NEXT: math.ceil {{.+}} : tensor<2xf32>
176   %4 = math.ceil %arg0 : tensor<2xf32>
177   // CHECK-NEXT: math.floor {{.+}} : tensor<2xf32>
178   %5 = math.floor %arg0 : tensor<2xf32>
179   // CHECK-NEXT: math.powf {{.+}}, {{%.+}} : tensor<2xf32>
180   %6 = math.powf %arg0, %arg0 : tensor<2xf32>
181   // CHECK-NEXT: return
182   return
185 } // end module