1 // RUN: mlir-opt -split-input-file -convert-math-to-spirv -verify-diagnostics %s -o - | FileCheck %s
3 module attributes { spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Kernel], []>, #spirv.resource_limits<>> } {
5 // CHECK-LABEL: @float32_unary_scalar
6 func.func @float32_unary_scalar(%arg0: f32) {
7 // CHECK: spirv.CL.atan %{{.*}}: f32
8 %0 = math.atan %arg0 : f32
9 // CHECK: spirv.CL.cos %{{.*}}: f32
10 %1 = math.cos %arg0 : f32
11 // CHECK: spirv.CL.exp %{{.*}}: f32
12 %2 = math.exp %arg0 : f32
13 // CHECK: %[[EXP:.+]] = spirv.CL.exp %arg0
14 // CHECK: %[[ONE:.+]] = spirv.Constant 1.000000e+00 : f32
15 // CHECK: spirv.FSub %[[EXP]], %[[ONE]]
16 %3 = math.expm1 %arg0 : f32
17 // CHECK: spirv.CL.log %{{.*}}: f32
18 %4 = math.log %arg0 : f32
19 // CHECK: %[[ONE:.+]] = spirv.Constant 1.000000e+00 : f32
20 // CHECK: %[[ADDONE:.+]] = spirv.FAdd %[[ONE]], %{{.+}}
21 // CHECK: spirv.CL.log %[[ADDONE]]
22 %5 = math.log1p %arg0 : f32
23 // CHECK: %[[LOG2_RECIPROCAL:.+]] = spirv.Constant 1.44269502 : f32
24 // CHECK: %[[LOG0:.+]] = spirv.CL.log {{.+}}
25 // CHECK: spirv.FMul %[[LOG0]], %[[LOG2_RECIPROCAL]]
26 %6 = math.log2 %arg0 : f32
27 // CHECK: %[[LOG10_RECIPROCAL:.+]] = spirv.Constant 0.434294492 : f32
28 // CHECK: %[[LOG1:.+]] = spirv.CL.log {{.+}}
29 // CHECK: spirv.FMul %[[LOG1]], %[[LOG10_RECIPROCAL]]
30 %7 = math.log10 %arg0 : f32
31 // CHECK: spirv.CL.rint %{{.*}}: f32
32 %8 = math.roundeven %arg0 : f32
33 // CHECK: spirv.CL.rsqrt %{{.*}}: f32
34 %9 = math.rsqrt %arg0 : f32
35 // CHECK: spirv.CL.sqrt %{{.*}}: f32
36 %10 = math.sqrt %arg0 : f32
37 // CHECK: spirv.CL.tanh %{{.*}}: f32
38 %11 = math.tanh %arg0 : f32
39 // CHECK: spirv.CL.sin %{{.*}}: f32
40 %12 = math.sin %arg0 : f32
41 // CHECK: spirv.CL.fabs %{{.*}}: f32
42 %13 = math.absf %arg0 : f32
43 // CHECK: spirv.CL.ceil %{{.*}}: f32
44 %14 = math.ceil %arg0 : f32
45 // CHECK: spirv.CL.floor %{{.*}}: f32
46 %15 = math.floor %arg0 : f32
47 // CHECK: spirv.CL.erf %{{.*}}: f32
48 %16 = math.erf %arg0 : f32
49 // CHECK: spirv.CL.round %{{.*}}: f32
50 %17 = math.round %arg0 : f32
54 // CHECK-LABEL: @float32_unary_vector
55 func.func @float32_unary_vector(%arg0: vector<3xf32>) {
56 // CHECK: spirv.CL.atan %{{.*}}: vector<3xf32>
57 %0 = math.atan %arg0 : vector<3xf32>
58 // CHECK: spirv.CL.cos %{{.*}}: vector<3xf32>
59 %1 = math.cos %arg0 : vector<3xf32>
60 // CHECK: spirv.CL.exp %{{.*}}: vector<3xf32>
61 %2 = math.exp %arg0 : vector<3xf32>
62 // CHECK: %[[EXP:.+]] = spirv.CL.exp %arg0
63 // CHECK: %[[ONE:.+]] = spirv.Constant dense<1.000000e+00> : vector<3xf32>
64 // CHECK: spirv.FSub %[[EXP]], %[[ONE]]
65 %3 = math.expm1 %arg0 : vector<3xf32>
66 // CHECK: spirv.CL.log %{{.*}}: vector<3xf32>
67 %4 = math.log %arg0 : vector<3xf32>
68 // CHECK: %[[ONE:.+]] = spirv.Constant dense<1.000000e+00> : vector<3xf32>
69 // CHECK: %[[ADDONE:.+]] = spirv.FAdd %[[ONE]], %{{.+}}
70 // CHECK: spirv.CL.log %[[ADDONE]]
71 %5 = math.log1p %arg0 : vector<3xf32>
72 // CHECK: %[[LOG2_RECIPROCAL:.+]] = spirv.Constant dense<1.44269502> : vector<3xf32>
73 // CHECK: %[[LOG0:.+]] = spirv.CL.log {{.+}}
74 // CHECK: spirv.FMul %[[LOG0]], %[[LOG2_RECIPROCAL]]
75 %6 = math.log2 %arg0 : vector<3xf32>
76 // CHECK: %[[LOG10_RECIPROCAL:.+]] = spirv.Constant dense<0.434294492> : vector<3xf32>
77 // CHECK: %[[LOG1:.+]] = spirv.CL.log {{.+}}
78 // CHECK: spirv.FMul %[[LOG1]], %[[LOG10_RECIPROCAL]]
79 %7 = math.log10 %arg0 : vector<3xf32>
80 // CHECK: spirv.CL.rint %{{.*}}: vector<3xf32>
81 %8 = math.roundeven %arg0 : vector<3xf32>
82 // CHECK: spirv.CL.rsqrt %{{.*}}: vector<3xf32>
83 %9 = math.rsqrt %arg0 : vector<3xf32>
84 // CHECK: spirv.CL.sqrt %{{.*}}: vector<3xf32>
85 %10 = math.sqrt %arg0 : vector<3xf32>
86 // CHECK: spirv.CL.tanh %{{.*}}: vector<3xf32>
87 %11 = math.tanh %arg0 : vector<3xf32>
88 // CHECK: spirv.CL.sin %{{.*}}: vector<3xf32>
89 %12 = math.sin %arg0 : vector<3xf32>
93 // CHECK-LABEL: @float32_binary_scalar
94 func.func @float32_binary_scalar(%lhs: f32, %rhs: f32) {
95 // CHECK: spirv.CL.atan2 %{{.*}}: f32
96 %0 = math.atan2 %lhs, %rhs : f32
97 // CHECK: spirv.CL.pow %{{.*}}: f32
98 %1 = math.powf %lhs, %rhs : f32
102 // CHECK-LABEL: @float32_binary_vector
103 func.func @float32_binary_vector(%lhs: vector<4xf32>, %rhs: vector<4xf32>) {
104 // CHECK: spirv.CL.atan2 %{{.*}}: vector<4xf32>
105 %0 = math.atan2 %lhs, %rhs : vector<4xf32>
106 // CHECK: spirv.CL.pow %{{.*}}: vector<4xf32>
107 %1 = math.powf %lhs, %rhs : vector<4xf32>
111 // CHECK-LABEL: @float32_ternary_scalar
112 func.func @float32_ternary_scalar(%a: f32, %b: f32, %c: f32) {
113 // CHECK: spirv.CL.fma %{{.*}}: f32
114 %0 = math.fma %a, %b, %c : f32
118 // CHECK-LABEL: @float32_ternary_vector
119 func.func @float32_ternary_vector(%a: vector<4xf32>, %b: vector<4xf32>,
121 // CHECK: spirv.CL.fma %{{.*}}: vector<4xf32>
122 %0 = math.fma %a, %b, %c : vector<4xf32>
126 // CHECK-LABEL: @int_unary
127 func.func @int_unary(%arg0: i32) {
128 // CHECK: spirv.CL.s_abs %{{.*}}
129 %0 = math.absi %arg0 : i32
138 spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader], []>, #spirv.resource_limits<>>
141 // 2-D vectors are not supported.
143 // CHECK-LABEL: @vector_2d
144 func.func @vector_2d(%arg0: vector<2x2xf32>) {
145 // CHECK-NEXT: math.atan {{.+}} : vector<2x2xf32>
146 %0 = math.atan %arg0 : vector<2x2xf32>
147 // CHECK-NEXT: math.cos {{.+}} : vector<2x2xf32>
148 %1 = math.cos %arg0 : vector<2x2xf32>
149 // CHECK-NEXT: math.exp {{.+}} : vector<2x2xf32>
150 %2 = math.exp %arg0 : vector<2x2xf32>
151 // CHECK-NEXT: math.absf {{.+}} : vector<2x2xf32>
152 %3 = math.absf %arg0 : vector<2x2xf32>
153 // CHECK-NEXT: math.ceil {{.+}} : vector<2x2xf32>
154 %4 = math.ceil %arg0 : vector<2x2xf32>
155 // CHECK-NEXT: math.floor {{.+}} : vector<2x2xf32>
156 %5 = math.floor %arg0 : vector<2x2xf32>
157 // CHECK-NEXT: math.powf {{.+}}, {{%.+}} : vector<2x2xf32>
158 %6 = math.powf %arg0, %arg0 : vector<2x2xf32>
159 // CHECK-NEXT: return
163 // Tensors are not supported.
165 // CHECK-LABEL: @tensor_1d
166 func.func @tensor_1d(%arg0: tensor<2xf32>) {
167 // CHECK-NEXT: math.atan {{.+}} : tensor<2xf32>
168 %0 = math.atan %arg0 : tensor<2xf32>
169 // CHECK-NEXT: math.cos {{.+}} : tensor<2xf32>
170 %1 = math.cos %arg0 : tensor<2xf32>
171 // CHECK-NEXT: math.exp {{.+}} : tensor<2xf32>
172 %2 = math.exp %arg0 : tensor<2xf32>
173 // CHECK-NEXT: math.absf {{.+}} : tensor<2xf32>
174 %3 = math.absf %arg0 : tensor<2xf32>
175 // CHECK-NEXT: math.ceil {{.+}} : tensor<2xf32>
176 %4 = math.ceil %arg0 : tensor<2xf32>
177 // CHECK-NEXT: math.floor {{.+}} : tensor<2xf32>
178 %5 = math.floor %arg0 : tensor<2xf32>
179 // CHECK-NEXT: math.powf {{.+}}, {{%.+}} : tensor<2xf32>
180 %6 = math.powf %arg0, %arg0 : tensor<2xf32>
181 // CHECK-NEXT: return