1 //===-- MathToROCDL.cpp - conversion from Math to rocdl calls -------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
10 #include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
11 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
12 #include "mlir/Dialect/Func/IR/FuncOps.h"
13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
15 #include "mlir/Dialect/Math/IR/Math.h"
16 #include "mlir/Dialect/Utils/IndexingUtils.h"
17 #include "mlir/Dialect/Vector/IR/VectorOps.h"
18 #include "mlir/IR/BuiltinDialect.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/Pass/Pass.h"
21 #include "mlir/Transforms/DialectConversion.h"
23 #include "../GPUCommon/GPUOpsLowering.h"
24 #include "../GPUCommon/IndexIntrinsicsOpLowering.h"
25 #include "../GPUCommon/OpToFuncCallLowering.h"
26 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
29 #define GEN_PASS_DEF_CONVERTMATHTOROCDL
30 #include "mlir/Conversion/Passes.h.inc"
35 #define DEBUG_TYPE "math-to-rocdl"
36 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
38 template <typename OpTy
>
39 static void populateOpPatterns(const LLVMTypeConverter
&converter
,
40 RewritePatternSet
&patterns
, StringRef f32Func
,
41 StringRef f64Func
, StringRef f16Func
,
42 StringRef f32ApproxFunc
= "") {
43 patterns
.add
<ScalarizeVectorOpLowering
<OpTy
>>(converter
);
44 patterns
.add
<OpToFuncCallLowering
<OpTy
>>(converter
, f32Func
, f64Func
,
45 f32ApproxFunc
, f16Func
);
48 void mlir::populateMathToROCDLConversionPatterns(
49 const LLVMTypeConverter
&converter
, RewritePatternSet
&patterns
) {
50 // Handled by mathToLLVM: math::AbsIOp
51 // Handled by mathToLLVM: math::AbsFOp
52 // Handled by mathToLLVM: math::CopySignOp
53 // Handled by mathToLLVM: math::CountLeadingZerosOp
54 // Handled by mathToLLVM: math::CountTrailingZerosOp
55 // Handled by mathToLLVM: math::CgPopOp
56 // Handled by mathToLLVM: math::ExpOp (32-bit only)
57 // Handled by mathToLLVM: math::FmaOp
58 // Handled by mathToLLVM: math::LogOp (32-bit only)
59 // FIXME: math::IPowIOp
60 // FIXME: math::FPowIOp
61 // Handled by mathToLLVM: math::RoundEvenOp
62 // Handled by mathToLLVM: math::RoundOp
63 // Handled by mathToLLVM: math::SqrtOp
64 // Handled by mathToLLVM: math::TruncOp
65 populateOpPatterns
<math::AcosOp
>(converter
, patterns
, "__ocml_acos_f32",
66 "__ocml_acos_f64", "__ocml_acos_f16");
67 populateOpPatterns
<math::AcoshOp
>(converter
, patterns
, "__ocml_acosh_f32",
68 "__ocml_acosh_f64", "__ocml_acosh_f16");
69 populateOpPatterns
<math::AsinOp
>(converter
, patterns
, "__ocml_asin_f32",
70 "__ocml_asin_f64", "__ocml_asin_f16");
71 populateOpPatterns
<math::AsinhOp
>(converter
, patterns
, "__ocml_asinh_f32",
72 "__ocml_asinh_f64", "__ocml_asinh_f16");
73 populateOpPatterns
<math::AtanOp
>(converter
, patterns
, "__ocml_atan_f32",
74 "__ocml_atan_f64", "__ocml_atan_f16");
75 populateOpPatterns
<math::AtanhOp
>(converter
, patterns
, "__ocml_atanh_f32",
76 "__ocml_atanh_f64", "__ocml_atanh_f16");
77 populateOpPatterns
<math::Atan2Op
>(converter
, patterns
, "__ocml_atan2_f32",
78 "__ocml_atan2_f64", "__ocml_atan2_f16");
79 populateOpPatterns
<math::CbrtOp
>(converter
, patterns
, "__ocml_cbrt_f32",
80 "__ocml_cbrt_f64", "__ocml_cbrt_f16");
81 populateOpPatterns
<math::CeilOp
>(converter
, patterns
, "__ocml_ceil_f32",
82 "__ocml_ceil_f64", "__ocml_ceil_f16");
83 populateOpPatterns
<math::CosOp
>(converter
, patterns
, "__ocml_cos_f32",
84 "__ocml_cos_f64", "__ocml_cos_f16");
85 populateOpPatterns
<math::CoshOp
>(converter
, patterns
, "__ocml_cosh_f32",
86 "__ocml_cosh_f64", "__ocml_cosh_f16");
87 populateOpPatterns
<math::SinhOp
>(converter
, patterns
, "__ocml_sinh_f32",
88 "__ocml_sinh_f64", "__ocml_sinh_f16");
89 populateOpPatterns
<math::ExpOp
>(converter
, patterns
, "", "__ocml_exp_f64",
91 populateOpPatterns
<math::Exp2Op
>(converter
, patterns
, "__ocml_exp2_f32",
92 "__ocml_exp2_f64", "__ocml_exp2_f16");
93 populateOpPatterns
<math::ExpM1Op
>(converter
, patterns
, "__ocml_expm1_f32",
94 "__ocml_expm1_f64", "__ocml_expm1_f16");
95 populateOpPatterns
<math::FloorOp
>(converter
, patterns
, "__ocml_floor_f32",
96 "__ocml_floor_f64", "__ocml_floor_f16");
97 populateOpPatterns
<math::LogOp
>(converter
, patterns
, "", "__ocml_log_f64",
99 populateOpPatterns
<math::Log10Op
>(converter
, patterns
, "__ocml_log10_f32",
100 "__ocml_log10_f64", "__ocml_log10_f16");
101 populateOpPatterns
<math::Log1pOp
>(converter
, patterns
, "__ocml_log1p_f32",
102 "__ocml_log1p_f64", "__ocml_log1p_f16");
103 populateOpPatterns
<math::Log2Op
>(converter
, patterns
, "__ocml_log2_f32",
104 "__ocml_log2_f64", "__ocml_log2_f16");
105 populateOpPatterns
<math::PowFOp
>(converter
, patterns
, "__ocml_pow_f32",
106 "__ocml_pow_f64", "__ocml_pow_f16");
107 populateOpPatterns
<math::RsqrtOp
>(converter
, patterns
, "__ocml_rsqrt_f32",
108 "__ocml_rsqrt_f64", "__ocml_rsqrt_f16");
109 populateOpPatterns
<math::SinOp
>(converter
, patterns
, "__ocml_sin_f32",
110 "__ocml_sin_f64", "__ocml_sin_f16");
111 populateOpPatterns
<math::TanhOp
>(converter
, patterns
, "__ocml_tanh_f32",
112 "__ocml_tanh_f64", "__ocml_tanh_f16");
113 populateOpPatterns
<math::TanOp
>(converter
, patterns
, "__ocml_tan_f32",
114 "__ocml_tan_f64", "__ocml_tan_f16");
115 populateOpPatterns
<math::ErfOp
>(converter
, patterns
, "__ocml_erf_f32",
116 "__ocml_erf_f64", "__ocml_erf_f16");
117 // Single arith pattern that needs a ROCDL call, probably not
118 // worth creating a separate pass for it.
119 populateOpPatterns
<arith::RemFOp
>(converter
, patterns
, "__ocml_fmod_f32",
120 "__ocml_fmod_f64", "__ocml_fmod_f16");
124 struct ConvertMathToROCDLPass
125 : public impl::ConvertMathToROCDLBase
<ConvertMathToROCDLPass
> {
126 ConvertMathToROCDLPass() = default;
127 void runOnOperation() override
;
131 void ConvertMathToROCDLPass::runOnOperation() {
132 auto m
= getOperation();
133 MLIRContext
*ctx
= m
.getContext();
135 RewritePatternSet
patterns(&getContext());
136 LowerToLLVMOptions
options(ctx
, DataLayout(m
));
137 LLVMTypeConverter
converter(ctx
, options
);
138 populateMathToROCDLConversionPatterns(converter
, patterns
);
139 ConversionTarget
target(getContext());
140 target
.addLegalDialect
<BuiltinDialect
, func::FuncDialect
,
141 vector::VectorDialect
, LLVM::LLVMDialect
>();
142 target
.addIllegalOp
<LLVM::CosOp
, LLVM::ExpOp
, LLVM::Exp2Op
, LLVM::FAbsOp
,
143 LLVM::FCeilOp
, LLVM::FFloorOp
, LLVM::FRemOp
, LLVM::LogOp
,
144 LLVM::Log10Op
, LLVM::Log2Op
, LLVM::PowOp
, LLVM::SinOp
,
146 if (failed(applyPartialConversion(m
, target
, std::move(patterns
))))