[rtsan] Remove mkfifoat interceptor (#116997)
[llvm-project.git] / mlir / lib / Conversion / MathToSPIRV / MathToSPIRV.cpp
blob1b83794b5f45021159793c456717eb7e5b185e63
1 //===- MathToSPIRV.cpp - Math to SPIR-V Patterns --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert Math dialect to SPIR-V dialect.
11 //===----------------------------------------------------------------------===//
13 #include "../SPIRVCommon/Pattern.h"
14 #include "mlir/Dialect/Math/IR/Math.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/TypeUtilities.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/FormatVariadic.h"
25 #define DEBUG_TYPE "math-to-spirv-pattern"
27 using namespace mlir;
29 //===----------------------------------------------------------------------===//
30 // Utility functions
31 //===----------------------------------------------------------------------===//
33 /// Creates a 32-bit scalar/vector integer constant. Returns nullptr if the
34 /// given type is not a 32-bit scalar/vector type.
35 static Value getScalarOrVectorI32Constant(Type type, int value,
36 OpBuilder &builder, Location loc) {
37 if (auto vectorType = dyn_cast<VectorType>(type)) {
38 if (!vectorType.getElementType().isInteger(32))
39 return nullptr;
40 SmallVector<int> values(vectorType.getNumElements(), value);
41 return builder.create<spirv::ConstantOp>(loc, type,
42 builder.getI32VectorAttr(values));
44 if (type.isInteger(32))
45 return builder.create<spirv::ConstantOp>(loc, type,
46 builder.getI32IntegerAttr(value));
48 return nullptr;
51 /// Check if the type is supported by math-to-spirv conversion. We expect to
52 /// only see scalars and vectors at this point, with higher-level types already
53 /// lowered.
54 static bool isSupportedSourceType(Type originalType) {
55 if (originalType.isIntOrIndexOrFloat())
56 return true;
58 if (auto vecTy = dyn_cast<VectorType>(originalType)) {
59 if (!vecTy.getElementType().isIntOrIndexOrFloat())
60 return false;
61 if (vecTy.isScalable())
62 return false;
63 if (vecTy.getRank() > 1)
64 return false;
66 return true;
69 return false;
72 /// Check if all `sourceOp` types are supported by math-to-spirv conversion.
73 /// Notify of a match failure othwerise and return a `failure` result.
74 /// This is intended to simplify type checks in `OpConversionPattern`s.
75 static LogicalResult checkSourceOpTypes(ConversionPatternRewriter &rewriter,
76 Operation *sourceOp) {
77 auto allTypes = llvm::to_vector(sourceOp->getOperandTypes());
78 llvm::append_range(allTypes, sourceOp->getResultTypes());
80 for (Type ty : allTypes) {
81 if (!isSupportedSourceType(ty)) {
82 return rewriter.notifyMatchFailure(
83 sourceOp,
84 llvm::formatv(
85 "unsupported source type for Math to SPIR-V conversion: {0}",
86 ty));
90 return success();
93 //===----------------------------------------------------------------------===//
94 // Operation conversion
95 //===----------------------------------------------------------------------===//
97 // Note that DRR cannot be used for the patterns in this file: we may need to
98 // convert type along the way, which requires ConversionPattern. DRR generates
99 // normal RewritePattern.
101 namespace {
102 /// Converts elementwise unary, binary, and ternary standard operations to
103 /// SPIR-V operations. Checks that source `Op` types are supported.
104 template <typename Op, typename SPIRVOp>
105 struct CheckedElementwiseOpPattern final
106 : public spirv::ElementwiseOpPattern<Op, SPIRVOp> {
107 using BasePattern = typename spirv::ElementwiseOpPattern<Op, SPIRVOp>;
108 using BasePattern::BasePattern;
110 LogicalResult
111 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
112 ConversionPatternRewriter &rewriter) const override {
113 if (LogicalResult res = checkSourceOpTypes(rewriter, op); failed(res))
114 return res;
116 return BasePattern::matchAndRewrite(op, adaptor, rewriter);
120 /// Converts math.copysign to SPIR-V ops.
121 struct CopySignPattern final : public OpConversionPattern<math::CopySignOp> {
122 using OpConversionPattern::OpConversionPattern;
124 LogicalResult
125 matchAndRewrite(math::CopySignOp copySignOp, OpAdaptor adaptor,
126 ConversionPatternRewriter &rewriter) const override {
127 if (LogicalResult res = checkSourceOpTypes(rewriter, copySignOp);
128 failed(res))
129 return res;
131 Type type = getTypeConverter()->convertType(copySignOp.getType());
132 if (!type)
133 return failure();
135 FloatType floatType;
136 if (auto scalarType = dyn_cast<FloatType>(copySignOp.getType())) {
137 floatType = scalarType;
138 } else if (auto vectorType = dyn_cast<VectorType>(copySignOp.getType())) {
139 floatType = cast<FloatType>(vectorType.getElementType());
140 } else {
141 return failure();
144 Location loc = copySignOp.getLoc();
145 int bitwidth = floatType.getWidth();
146 Type intType = rewriter.getIntegerType(bitwidth);
147 uint64_t intValue = uint64_t(1) << (bitwidth - 1);
149 Value signMask = rewriter.create<spirv::ConstantOp>(
150 loc, intType, rewriter.getIntegerAttr(intType, intValue));
151 Value valueMask = rewriter.create<spirv::ConstantOp>(
152 loc, intType, rewriter.getIntegerAttr(intType, intValue - 1u));
154 if (auto vectorType = dyn_cast<VectorType>(type)) {
155 assert(vectorType.getRank() == 1);
156 int count = vectorType.getNumElements();
157 intType = VectorType::get(count, intType);
159 SmallVector<Value> signSplat(count, signMask);
160 signMask =
161 rewriter.create<spirv::CompositeConstructOp>(loc, intType, signSplat);
163 SmallVector<Value> valueSplat(count, valueMask);
164 valueMask = rewriter.create<spirv::CompositeConstructOp>(loc, intType,
165 valueSplat);
168 Value lhsCast =
169 rewriter.create<spirv::BitcastOp>(loc, intType, adaptor.getLhs());
170 Value rhsCast =
171 rewriter.create<spirv::BitcastOp>(loc, intType, adaptor.getRhs());
173 Value value = rewriter.create<spirv::BitwiseAndOp>(
174 loc, intType, ValueRange{lhsCast, valueMask});
175 Value sign = rewriter.create<spirv::BitwiseAndOp>(
176 loc, intType, ValueRange{rhsCast, signMask});
178 Value result = rewriter.create<spirv::BitwiseOrOp>(loc, intType,
179 ValueRange{value, sign});
180 rewriter.replaceOpWithNewOp<spirv::BitcastOp>(copySignOp, type, result);
181 return success();
185 /// Converts math.ctlz to SPIR-V ops.
187 /// SPIR-V does not have a direct operations for counting leading zeros. If
188 /// Shader capability is supported, we can leverage GL FindUMsb to calculate
189 /// it.
190 struct CountLeadingZerosPattern final
191 : public OpConversionPattern<math::CountLeadingZerosOp> {
192 using OpConversionPattern::OpConversionPattern;
194 LogicalResult
195 matchAndRewrite(math::CountLeadingZerosOp countOp, OpAdaptor adaptor,
196 ConversionPatternRewriter &rewriter) const override {
197 if (LogicalResult res = checkSourceOpTypes(rewriter, countOp); failed(res))
198 return res;
200 Type type = getTypeConverter()->convertType(countOp.getType());
201 if (!type)
202 return failure();
204 // We can only support 32-bit integer types for now.
205 unsigned bitwidth = 0;
206 if (isa<IntegerType>(type))
207 bitwidth = type.getIntOrFloatBitWidth();
208 if (auto vectorType = dyn_cast<VectorType>(type))
209 bitwidth = vectorType.getElementTypeBitWidth();
210 if (bitwidth != 32)
211 return failure();
213 Location loc = countOp.getLoc();
214 Value input = adaptor.getOperand();
215 Value val1 = getScalarOrVectorI32Constant(type, 1, rewriter, loc);
216 Value val31 = getScalarOrVectorI32Constant(type, 31, rewriter, loc);
217 Value val32 = getScalarOrVectorI32Constant(type, 32, rewriter, loc);
219 Value msb = rewriter.create<spirv::GLFindUMsbOp>(loc, input);
220 // We need to subtract from 31 given that the index returned by GLSL
221 // FindUMsb is counted from the least significant bit. Theoretically this
222 // also gives the correct result even if the integer has all zero bits, in
223 // which case GL FindUMsb would return -1.
224 Value subMsb = rewriter.create<spirv::ISubOp>(loc, val31, msb);
225 // However, certain Vulkan implementations have driver bugs for the corner
226 // case where the input is zero. And.. it can be smart to optimize a select
227 // only involving the corner case. So separately compute the result when the
228 // input is either zero or one.
229 Value subInput = rewriter.create<spirv::ISubOp>(loc, val32, input);
230 Value cmp = rewriter.create<spirv::ULessThanEqualOp>(loc, input, val1);
231 rewriter.replaceOpWithNewOp<spirv::SelectOp>(countOp, cmp, subInput,
232 subMsb);
233 return success();
237 /// Converts math.expm1 to SPIR-V ops.
239 /// SPIR-V does not have a direct operations for exp(x)-1. Explicitly lower to
240 /// these operations.
241 template <typename ExpOp>
242 struct ExpM1OpPattern final : public OpConversionPattern<math::ExpM1Op> {
243 using OpConversionPattern::OpConversionPattern;
245 LogicalResult
246 matchAndRewrite(math::ExpM1Op operation, OpAdaptor adaptor,
247 ConversionPatternRewriter &rewriter) const override {
248 assert(adaptor.getOperands().size() == 1);
249 if (LogicalResult res = checkSourceOpTypes(rewriter, operation);
250 failed(res))
251 return res;
253 Location loc = operation.getLoc();
254 Type type = this->getTypeConverter()->convertType(operation.getType());
255 if (!type)
256 return failure();
258 Value exp = rewriter.create<ExpOp>(loc, type, adaptor.getOperand());
259 auto one = spirv::ConstantOp::getOne(type, loc, rewriter);
260 rewriter.replaceOpWithNewOp<spirv::FSubOp>(operation, exp, one);
261 return success();
265 /// Converts math.log1p to SPIR-V ops.
267 /// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to
268 /// these operations.
269 template <typename LogOp>
270 struct Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> {
271 using OpConversionPattern::OpConversionPattern;
273 LogicalResult
274 matchAndRewrite(math::Log1pOp operation, OpAdaptor adaptor,
275 ConversionPatternRewriter &rewriter) const override {
276 assert(adaptor.getOperands().size() == 1);
277 if (LogicalResult res = checkSourceOpTypes(rewriter, operation);
278 failed(res))
279 return res;
281 Location loc = operation.getLoc();
282 Type type = this->getTypeConverter()->convertType(operation.getType());
283 if (!type)
284 return failure();
286 auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter);
287 Value onePlus =
288 rewriter.create<spirv::FAddOp>(loc, one, adaptor.getOperand());
289 rewriter.replaceOpWithNewOp<LogOp>(operation, type, onePlus);
290 return success();
294 /// Converts math.log2 and math.log10 to SPIR-V ops.
296 /// SPIR-V does not have direct operations for log2 and log10. Explicitly
297 /// lower to these operations using:
298 /// log2(x) = log(x) * 1/log(2)
299 /// log10(x) = log(x) * 1/log(10)
301 template <typename MathLogOp, typename SpirvLogOp>
302 struct Log2Log10OpPattern final : public OpConversionPattern<MathLogOp> {
303 using OpConversionPattern<MathLogOp>::OpConversionPattern;
304 using typename OpConversionPattern<MathLogOp>::OpAdaptor;
306 static constexpr double log2Reciprocal =
307 1.442695040888963407359924681001892137426645954152985934135449407;
308 static constexpr double log10Reciprocal =
309 0.4342944819032518276511289189166050822943970058036665661144537832;
311 LogicalResult
312 matchAndRewrite(MathLogOp operation, OpAdaptor adaptor,
313 ConversionPatternRewriter &rewriter) const override {
314 assert(adaptor.getOperands().size() == 1);
315 if (LogicalResult res = checkSourceOpTypes(rewriter, operation);
316 failed(res))
317 return res;
319 Location loc = operation.getLoc();
320 Type type = this->getTypeConverter()->convertType(operation.getType());
321 if (!type)
322 return rewriter.notifyMatchFailure(operation, "type conversion failed");
324 auto getConstantValue = [&](double value) {
325 if (auto floatType = dyn_cast<FloatType>(type)) {
326 return rewriter.create<spirv::ConstantOp>(
327 loc, type, rewriter.getFloatAttr(floatType, value));
329 if (auto vectorType = dyn_cast<VectorType>(type)) {
330 Type elemType = vectorType.getElementType();
332 if (isa<FloatType>(elemType)) {
333 return rewriter.create<spirv::ConstantOp>(
334 loc, type,
335 DenseFPElementsAttr::get(
336 vectorType, FloatAttr::get(elemType, value).getValue()));
340 llvm_unreachable("unimplemented types for log2/log10");
343 Value constantValue = getConstantValue(
344 std::is_same<MathLogOp, math::Log2Op>() ? log2Reciprocal
345 : log10Reciprocal);
346 Value log = rewriter.create<SpirvLogOp>(loc, adaptor.getOperand());
347 rewriter.replaceOpWithNewOp<spirv::FMulOp>(operation, type, log,
348 constantValue);
349 return success();
353 /// Converts math.powf to SPIRV-Ops.
354 struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
355 using OpConversionPattern::OpConversionPattern;
357 LogicalResult
358 matchAndRewrite(math::PowFOp powfOp, OpAdaptor adaptor,
359 ConversionPatternRewriter &rewriter) const override {
360 if (LogicalResult res = checkSourceOpTypes(rewriter, powfOp); failed(res))
361 return res;
363 Type dstType = getTypeConverter()->convertType(powfOp.getType());
364 if (!dstType)
365 return failure();
367 // Get the scalar float type.
368 FloatType scalarFloatType;
369 if (auto scalarType = dyn_cast<FloatType>(powfOp.getType())) {
370 scalarFloatType = scalarType;
371 } else if (auto vectorType = dyn_cast<VectorType>(powfOp.getType())) {
372 scalarFloatType = cast<FloatType>(vectorType.getElementType());
373 } else {
374 return failure();
377 // Get int type of the same shape as the float type.
378 Type scalarIntType = rewriter.getIntegerType(32);
379 Type intType = scalarIntType;
380 auto operandType = adaptor.getRhs().getType();
381 if (auto vectorType = dyn_cast<VectorType>(operandType)) {
382 auto shape = vectorType.getShape();
383 intType = VectorType::get(shape, scalarIntType);
386 // Per GL Pow extended instruction spec:
387 // "Result is undefined if x < 0. Result is undefined if x = 0 and y <= 0."
388 Location loc = powfOp.getLoc();
389 Value zero = spirv::ConstantOp::getZero(operandType, loc, rewriter);
390 Value lessThan =
391 rewriter.create<spirv::FOrdLessThanOp>(loc, adaptor.getLhs(), zero);
393 // Per C/C++ spec:
394 // > pow(base, exponent) returns NaN (and raises FE_INVALID) if base is
395 // > finite and negative and exponent is finite and non-integer.
396 // Calculate the reminder from the exponent and check whether it is zero.
397 Value floatOne = spirv::ConstantOp::getOne(operandType, loc, rewriter);
398 Value expRem =
399 rewriter.create<spirv::FRemOp>(loc, adaptor.getRhs(), floatOne);
400 Value expRemNonZero =
401 rewriter.create<spirv::FOrdNotEqualOp>(loc, expRem, zero);
402 Value cmpNegativeWithFractionalExp =
403 rewriter.create<spirv::LogicalAndOp>(loc, expRemNonZero, lessThan);
404 // Create NaN result and replace base value if conditions are met.
405 const auto &floatSemantics = scalarFloatType.getFloatSemantics();
406 const auto nan = APFloat::getNaN(floatSemantics);
407 Attribute nanAttr = rewriter.getFloatAttr(scalarFloatType, nan);
408 if (auto vectorType = dyn_cast<VectorType>(operandType))
409 nanAttr = DenseElementsAttr::get(vectorType, nan);
411 Value NanValue =
412 rewriter.create<spirv::ConstantOp>(loc, operandType, nanAttr);
413 Value lhs = rewriter.create<spirv::SelectOp>(
414 loc, cmpNegativeWithFractionalExp, NanValue, adaptor.getLhs());
415 Value abs = rewriter.create<spirv::GLFAbsOp>(loc, lhs);
417 // TODO: The following just forcefully casts y into an integer value in
418 // order to properly propagate the sign, assuming integer y cases. It
419 // doesn't cover other cases and should be fixed.
421 // Cast exponent to integer and calculate exponent % 2 != 0.
422 Value intRhs =
423 rewriter.create<spirv::ConvertFToSOp>(loc, intType, adaptor.getRhs());
424 Value intOne = spirv::ConstantOp::getOne(intType, loc, rewriter);
425 Value bitwiseAndOne =
426 rewriter.create<spirv::BitwiseAndOp>(loc, intRhs, intOne);
427 Value isOdd = rewriter.create<spirv::IEqualOp>(loc, bitwiseAndOne, intOne);
429 // calculate pow based on abs(lhs)^rhs.
430 Value pow = rewriter.create<spirv::GLPowOp>(loc, abs, adaptor.getRhs());
431 Value negate = rewriter.create<spirv::FNegateOp>(loc, pow);
432 // if the exponent is odd and lhs < 0, negate the result.
433 Value shouldNegate =
434 rewriter.create<spirv::LogicalAndOp>(loc, lessThan, isOdd);
435 rewriter.replaceOpWithNewOp<spirv::SelectOp>(powfOp, shouldNegate, negate,
436 pow);
437 return success();
441 /// Converts math.round to GLSL SPIRV extended ops.
442 struct RoundOpPattern final : public OpConversionPattern<math::RoundOp> {
443 using OpConversionPattern::OpConversionPattern;
445 LogicalResult
446 matchAndRewrite(math::RoundOp roundOp, OpAdaptor adaptor,
447 ConversionPatternRewriter &rewriter) const override {
448 if (LogicalResult res = checkSourceOpTypes(rewriter, roundOp); failed(res))
449 return res;
451 Location loc = roundOp.getLoc();
452 Value operand = roundOp.getOperand();
453 Type ty = operand.getType();
454 Type ety = getElementTypeOrSelf(ty);
456 auto zero = spirv::ConstantOp::getZero(ty, loc, rewriter);
457 auto one = spirv::ConstantOp::getOne(ty, loc, rewriter);
458 Value half;
459 if (VectorType vty = dyn_cast<VectorType>(ty)) {
460 half = rewriter.create<spirv::ConstantOp>(
461 loc, vty,
462 DenseElementsAttr::get(vty,
463 rewriter.getFloatAttr(ety, 0.5).getValue()));
464 } else {
465 half = rewriter.create<spirv::ConstantOp>(
466 loc, ty, rewriter.getFloatAttr(ety, 0.5));
469 auto abs = rewriter.create<spirv::GLFAbsOp>(loc, operand);
470 auto floor = rewriter.create<spirv::GLFloorOp>(loc, abs);
471 auto sub = rewriter.create<spirv::FSubOp>(loc, abs, floor);
472 auto greater =
473 rewriter.create<spirv::FOrdGreaterThanEqualOp>(loc, sub, half);
474 auto select = rewriter.create<spirv::SelectOp>(loc, greater, one, zero);
475 auto add = rewriter.create<spirv::FAddOp>(loc, floor, select);
476 rewriter.replaceOpWithNewOp<math::CopySignOp>(roundOp, add, operand);
477 return success();
481 } // namespace
483 //===----------------------------------------------------------------------===//
484 // Pattern population
485 //===----------------------------------------------------------------------===//
487 namespace mlir {
488 void populateMathToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
489 RewritePatternSet &patterns) {
490 // Core patterns
491 patterns.add<CopySignPattern>(typeConverter, patterns.getContext());
493 // GLSL patterns
494 patterns
495 .add<CountLeadingZerosPattern, Log1pOpPattern<spirv::GLLogOp>,
496 Log2Log10OpPattern<math::Log2Op, spirv::GLLogOp>,
497 Log2Log10OpPattern<math::Log10Op, spirv::GLLogOp>,
498 ExpM1OpPattern<spirv::GLExpOp>, PowFOpPattern, RoundOpPattern,
499 CheckedElementwiseOpPattern<math::AbsFOp, spirv::GLFAbsOp>,
500 CheckedElementwiseOpPattern<math::AbsIOp, spirv::GLSAbsOp>,
501 CheckedElementwiseOpPattern<math::AtanOp, spirv::GLAtanOp>,
502 CheckedElementwiseOpPattern<math::CeilOp, spirv::GLCeilOp>,
503 CheckedElementwiseOpPattern<math::CosOp, spirv::GLCosOp>,
504 CheckedElementwiseOpPattern<math::ExpOp, spirv::GLExpOp>,
505 CheckedElementwiseOpPattern<math::FloorOp, spirv::GLFloorOp>,
506 CheckedElementwiseOpPattern<math::FmaOp, spirv::GLFmaOp>,
507 CheckedElementwiseOpPattern<math::LogOp, spirv::GLLogOp>,
508 CheckedElementwiseOpPattern<math::RoundEvenOp, spirv::GLRoundEvenOp>,
509 CheckedElementwiseOpPattern<math::RsqrtOp, spirv::GLInverseSqrtOp>,
510 CheckedElementwiseOpPattern<math::SinOp, spirv::GLSinOp>,
511 CheckedElementwiseOpPattern<math::SqrtOp, spirv::GLSqrtOp>,
512 CheckedElementwiseOpPattern<math::TanhOp, spirv::GLTanhOp>>(
513 typeConverter, patterns.getContext());
515 // OpenCL patterns
516 patterns.add<Log1pOpPattern<spirv::CLLogOp>, ExpM1OpPattern<spirv::CLExpOp>,
517 Log2Log10OpPattern<math::Log2Op, spirv::CLLogOp>,
518 Log2Log10OpPattern<math::Log10Op, spirv::CLLogOp>,
519 CheckedElementwiseOpPattern<math::AbsFOp, spirv::CLFAbsOp>,
520 CheckedElementwiseOpPattern<math::AbsIOp, spirv::CLSAbsOp>,
521 CheckedElementwiseOpPattern<math::AtanOp, spirv::CLAtanOp>,
522 CheckedElementwiseOpPattern<math::Atan2Op, spirv::CLAtan2Op>,
523 CheckedElementwiseOpPattern<math::CeilOp, spirv::CLCeilOp>,
524 CheckedElementwiseOpPattern<math::CosOp, spirv::CLCosOp>,
525 CheckedElementwiseOpPattern<math::ErfOp, spirv::CLErfOp>,
526 CheckedElementwiseOpPattern<math::ExpOp, spirv::CLExpOp>,
527 CheckedElementwiseOpPattern<math::FloorOp, spirv::CLFloorOp>,
528 CheckedElementwiseOpPattern<math::FmaOp, spirv::CLFmaOp>,
529 CheckedElementwiseOpPattern<math::LogOp, spirv::CLLogOp>,
530 CheckedElementwiseOpPattern<math::PowFOp, spirv::CLPowOp>,
531 CheckedElementwiseOpPattern<math::RoundEvenOp, spirv::CLRintOp>,
532 CheckedElementwiseOpPattern<math::RoundOp, spirv::CLRoundOp>,
533 CheckedElementwiseOpPattern<math::RsqrtOp, spirv::CLRsqrtOp>,
534 CheckedElementwiseOpPattern<math::SinOp, spirv::CLSinOp>,
535 CheckedElementwiseOpPattern<math::SqrtOp, spirv::CLSqrtOp>,
536 CheckedElementwiseOpPattern<math::TanhOp, spirv::CLTanhOp>>(
537 typeConverter, patterns.getContext());
540 } // namespace mlir