1 //===- MathToLLVM.cpp - Math to LLVM dialect conversion -------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
11 #include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
12 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
13 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
14 #include "mlir/Conversion/LLVMCommon/Pattern.h"
15 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
17 #include "mlir/Dialect/Math/IR/Math.h"
18 #include "mlir/IR/TypeUtilities.h"
19 #include "mlir/Pass/Pass.h"
22 #define GEN_PASS_DEF_CONVERTMATHTOLLVMPASS
23 #include "mlir/Conversion/Passes.h.inc"
30 template <typename SourceOp
, typename TargetOp
>
31 using ConvertFastMath
= arith::AttrConvertFastMathToLLVM
<SourceOp
, TargetOp
>;
33 template <typename SourceOp
, typename TargetOp
>
34 using ConvertFMFMathToLLVMPattern
=
35 VectorConvertToLLVMPattern
<SourceOp
, TargetOp
, ConvertFastMath
>;
37 using AbsFOpLowering
= ConvertFMFMathToLLVMPattern
<math::AbsFOp
, LLVM::FAbsOp
>;
38 using CeilOpLowering
= ConvertFMFMathToLLVMPattern
<math::CeilOp
, LLVM::FCeilOp
>;
39 using CopySignOpLowering
=
40 ConvertFMFMathToLLVMPattern
<math::CopySignOp
, LLVM::CopySignOp
>;
41 using CosOpLowering
= ConvertFMFMathToLLVMPattern
<math::CosOp
, LLVM::CosOp
>;
42 using CtPopFOpLowering
=
43 VectorConvertToLLVMPattern
<math::CtPopOp
, LLVM::CtPopOp
>;
44 using Exp2OpLowering
= ConvertFMFMathToLLVMPattern
<math::Exp2Op
, LLVM::Exp2Op
>;
45 using ExpOpLowering
= ConvertFMFMathToLLVMPattern
<math::ExpOp
, LLVM::ExpOp
>;
46 using FloorOpLowering
=
47 ConvertFMFMathToLLVMPattern
<math::FloorOp
, LLVM::FFloorOp
>;
48 using FmaOpLowering
= ConvertFMFMathToLLVMPattern
<math::FmaOp
, LLVM::FMAOp
>;
49 using Log10OpLowering
=
50 ConvertFMFMathToLLVMPattern
<math::Log10Op
, LLVM::Log10Op
>;
51 using Log2OpLowering
= ConvertFMFMathToLLVMPattern
<math::Log2Op
, LLVM::Log2Op
>;
52 using LogOpLowering
= ConvertFMFMathToLLVMPattern
<math::LogOp
, LLVM::LogOp
>;
53 using PowFOpLowering
= ConvertFMFMathToLLVMPattern
<math::PowFOp
, LLVM::PowOp
>;
54 using FPowIOpLowering
=
55 ConvertFMFMathToLLVMPattern
<math::FPowIOp
, LLVM::PowIOp
>;
56 using RoundEvenOpLowering
=
57 ConvertFMFMathToLLVMPattern
<math::RoundEvenOp
, LLVM::RoundEvenOp
>;
58 using RoundOpLowering
=
59 ConvertFMFMathToLLVMPattern
<math::RoundOp
, LLVM::RoundOp
>;
60 using SinOpLowering
= ConvertFMFMathToLLVMPattern
<math::SinOp
, LLVM::SinOp
>;
61 using SqrtOpLowering
= ConvertFMFMathToLLVMPattern
<math::SqrtOp
, LLVM::SqrtOp
>;
62 using FTruncOpLowering
=
63 ConvertFMFMathToLLVMPattern
<math::TruncOp
, LLVM::FTruncOp
>;
65 // A `CtLz/CtTz/absi(a)` is converted into `CtLz/CtTz/absi(a, false)`.
66 template <typename MathOp
, typename LLVMOp
>
67 struct IntOpWithFlagLowering
: public ConvertOpToLLVMPattern
<MathOp
> {
68 using ConvertOpToLLVMPattern
<MathOp
>::ConvertOpToLLVMPattern
;
69 using Super
= IntOpWithFlagLowering
<MathOp
, LLVMOp
>;
72 matchAndRewrite(MathOp op
, typename
MathOp::Adaptor adaptor
,
73 ConversionPatternRewriter
&rewriter
) const override
{
74 auto operandType
= adaptor
.getOperand().getType();
76 if (!operandType
|| !LLVM::isCompatibleType(operandType
))
79 auto loc
= op
.getLoc();
80 auto resultType
= op
.getResult().getType();
82 if (!isa
<LLVM::LLVMArrayType
>(operandType
)) {
83 rewriter
.replaceOpWithNewOp
<LLVMOp
>(op
, resultType
, adaptor
.getOperand(),
88 auto vectorType
= dyn_cast
<VectorType
>(resultType
);
92 return LLVM::detail::handleMultidimensionalVectors(
93 op
.getOperation(), adaptor
.getOperands(), *this->getTypeConverter(),
94 [&](Type llvm1DVectorTy
, ValueRange operands
) {
95 return rewriter
.create
<LLVMOp
>(loc
, llvm1DVectorTy
, operands
[0],
102 using CountLeadingZerosOpLowering
=
103 IntOpWithFlagLowering
<math::CountLeadingZerosOp
, LLVM::CountLeadingZerosOp
>;
104 using CountTrailingZerosOpLowering
=
105 IntOpWithFlagLowering
<math::CountTrailingZerosOp
,
106 LLVM::CountTrailingZerosOp
>;
107 using AbsIOpLowering
= IntOpWithFlagLowering
<math::AbsIOp
, LLVM::AbsOp
>;
109 // A `expm1` is converted into `exp - 1`.
110 struct ExpM1OpLowering
: public ConvertOpToLLVMPattern
<math::ExpM1Op
> {
111 using ConvertOpToLLVMPattern
<math::ExpM1Op
>::ConvertOpToLLVMPattern
;
114 matchAndRewrite(math::ExpM1Op op
, OpAdaptor adaptor
,
115 ConversionPatternRewriter
&rewriter
) const override
{
116 auto operandType
= adaptor
.getOperand().getType();
118 if (!operandType
|| !LLVM::isCompatibleType(operandType
))
121 auto loc
= op
.getLoc();
122 auto resultType
= op
.getResult().getType();
123 auto floatType
= cast
<FloatType
>(getElementTypeOrSelf(resultType
));
124 auto floatOne
= rewriter
.getFloatAttr(floatType
, 1.0);
125 ConvertFastMath
<math::ExpM1Op
, LLVM::ExpOp
> expAttrs(op
);
126 ConvertFastMath
<math::ExpM1Op
, LLVM::FSubOp
> subAttrs(op
);
128 if (!isa
<LLVM::LLVMArrayType
>(operandType
)) {
129 LLVM::ConstantOp one
;
130 if (LLVM::isCompatibleVectorType(operandType
)) {
131 one
= rewriter
.create
<LLVM::ConstantOp
>(
133 SplatElementsAttr::get(cast
<ShapedType
>(resultType
), floatOne
));
135 one
= rewriter
.create
<LLVM::ConstantOp
>(loc
, operandType
, floatOne
);
137 auto exp
= rewriter
.create
<LLVM::ExpOp
>(loc
, adaptor
.getOperand(),
138 expAttrs
.getAttrs());
139 rewriter
.replaceOpWithNewOp
<LLVM::FSubOp
>(
140 op
, operandType
, ValueRange
{exp
, one
}, subAttrs
.getAttrs());
144 auto vectorType
= dyn_cast
<VectorType
>(resultType
);
146 return rewriter
.notifyMatchFailure(op
, "expected vector result type");
148 return LLVM::detail::handleMultidimensionalVectors(
149 op
.getOperation(), adaptor
.getOperands(), *getTypeConverter(),
150 [&](Type llvm1DVectorTy
, ValueRange operands
) {
151 auto numElements
= LLVM::getVectorNumElements(llvm1DVectorTy
);
152 auto splatAttr
= SplatElementsAttr::get(
153 mlir::VectorType::get({numElements
.getKnownMinValue()}, floatType
,
154 {numElements
.isScalable()}),
157 rewriter
.create
<LLVM::ConstantOp
>(loc
, llvm1DVectorTy
, splatAttr
);
158 auto exp
= rewriter
.create
<LLVM::ExpOp
>(
159 loc
, llvm1DVectorTy
, operands
[0], expAttrs
.getAttrs());
160 return rewriter
.create
<LLVM::FSubOp
>(
161 loc
, llvm1DVectorTy
, ValueRange
{exp
, one
}, subAttrs
.getAttrs());
167 // A `log1p` is converted into `log(1 + ...)`.
168 struct Log1pOpLowering
: public ConvertOpToLLVMPattern
<math::Log1pOp
> {
169 using ConvertOpToLLVMPattern
<math::Log1pOp
>::ConvertOpToLLVMPattern
;
172 matchAndRewrite(math::Log1pOp op
, OpAdaptor adaptor
,
173 ConversionPatternRewriter
&rewriter
) const override
{
174 auto operandType
= adaptor
.getOperand().getType();
176 if (!operandType
|| !LLVM::isCompatibleType(operandType
))
177 return rewriter
.notifyMatchFailure(op
, "unsupported operand type");
179 auto loc
= op
.getLoc();
180 auto resultType
= op
.getResult().getType();
181 auto floatType
= cast
<FloatType
>(getElementTypeOrSelf(resultType
));
182 auto floatOne
= rewriter
.getFloatAttr(floatType
, 1.0);
183 ConvertFastMath
<math::Log1pOp
, LLVM::FAddOp
> addAttrs(op
);
184 ConvertFastMath
<math::Log1pOp
, LLVM::LogOp
> logAttrs(op
);
186 if (!isa
<LLVM::LLVMArrayType
>(operandType
)) {
187 LLVM::ConstantOp one
=
188 LLVM::isCompatibleVectorType(operandType
)
189 ? rewriter
.create
<LLVM::ConstantOp
>(
191 SplatElementsAttr::get(cast
<ShapedType
>(resultType
),
193 : rewriter
.create
<LLVM::ConstantOp
>(loc
, operandType
, floatOne
);
195 auto add
= rewriter
.create
<LLVM::FAddOp
>(
196 loc
, operandType
, ValueRange
{one
, adaptor
.getOperand()},
197 addAttrs
.getAttrs());
198 rewriter
.replaceOpWithNewOp
<LLVM::LogOp
>(op
, operandType
, ValueRange
{add
},
199 logAttrs
.getAttrs());
203 auto vectorType
= dyn_cast
<VectorType
>(resultType
);
205 return rewriter
.notifyMatchFailure(op
, "expected vector result type");
207 return LLVM::detail::handleMultidimensionalVectors(
208 op
.getOperation(), adaptor
.getOperands(), *getTypeConverter(),
209 [&](Type llvm1DVectorTy
, ValueRange operands
) {
210 auto numElements
= LLVM::getVectorNumElements(llvm1DVectorTy
);
211 auto splatAttr
= SplatElementsAttr::get(
212 mlir::VectorType::get({numElements
.getKnownMinValue()}, floatType
,
213 {numElements
.isScalable()}),
216 rewriter
.create
<LLVM::ConstantOp
>(loc
, llvm1DVectorTy
, splatAttr
);
217 auto add
= rewriter
.create
<LLVM::FAddOp
>(loc
, llvm1DVectorTy
,
218 ValueRange
{one
, operands
[0]},
219 addAttrs
.getAttrs());
220 return rewriter
.create
<LLVM::LogOp
>(
221 loc
, llvm1DVectorTy
, ValueRange
{add
}, logAttrs
.getAttrs());
227 // A `rsqrt` is converted into `1 / sqrt`.
228 struct RsqrtOpLowering
: public ConvertOpToLLVMPattern
<math::RsqrtOp
> {
229 using ConvertOpToLLVMPattern
<math::RsqrtOp
>::ConvertOpToLLVMPattern
;
232 matchAndRewrite(math::RsqrtOp op
, OpAdaptor adaptor
,
233 ConversionPatternRewriter
&rewriter
) const override
{
234 auto operandType
= adaptor
.getOperand().getType();
236 if (!operandType
|| !LLVM::isCompatibleType(operandType
))
239 auto loc
= op
.getLoc();
240 auto resultType
= op
.getResult().getType();
241 auto floatType
= cast
<FloatType
>(getElementTypeOrSelf(resultType
));
242 auto floatOne
= rewriter
.getFloatAttr(floatType
, 1.0);
243 ConvertFastMath
<math::RsqrtOp
, LLVM::SqrtOp
> sqrtAttrs(op
);
244 ConvertFastMath
<math::RsqrtOp
, LLVM::FDivOp
> divAttrs(op
);
246 if (!isa
<LLVM::LLVMArrayType
>(operandType
)) {
247 LLVM::ConstantOp one
;
248 if (LLVM::isCompatibleVectorType(operandType
)) {
249 one
= rewriter
.create
<LLVM::ConstantOp
>(
251 SplatElementsAttr::get(cast
<ShapedType
>(resultType
), floatOne
));
253 one
= rewriter
.create
<LLVM::ConstantOp
>(loc
, operandType
, floatOne
);
255 auto sqrt
= rewriter
.create
<LLVM::SqrtOp
>(loc
, adaptor
.getOperand(),
256 sqrtAttrs
.getAttrs());
257 rewriter
.replaceOpWithNewOp
<LLVM::FDivOp
>(
258 op
, operandType
, ValueRange
{one
, sqrt
}, divAttrs
.getAttrs());
262 auto vectorType
= dyn_cast
<VectorType
>(resultType
);
266 return LLVM::detail::handleMultidimensionalVectors(
267 op
.getOperation(), adaptor
.getOperands(), *getTypeConverter(),
268 [&](Type llvm1DVectorTy
, ValueRange operands
) {
269 auto numElements
= LLVM::getVectorNumElements(llvm1DVectorTy
);
270 auto splatAttr
= SplatElementsAttr::get(
271 mlir::VectorType::get({numElements
.getKnownMinValue()}, floatType
,
272 {numElements
.isScalable()}),
275 rewriter
.create
<LLVM::ConstantOp
>(loc
, llvm1DVectorTy
, splatAttr
);
276 auto sqrt
= rewriter
.create
<LLVM::SqrtOp
>(
277 loc
, llvm1DVectorTy
, operands
[0], sqrtAttrs
.getAttrs());
278 return rewriter
.create
<LLVM::FDivOp
>(
279 loc
, llvm1DVectorTy
, ValueRange
{one
, sqrt
}, divAttrs
.getAttrs());
285 struct ConvertMathToLLVMPass
286 : public impl::ConvertMathToLLVMPassBase
<ConvertMathToLLVMPass
> {
289 void runOnOperation() override
{
290 RewritePatternSet
patterns(&getContext());
291 LLVMTypeConverter
converter(&getContext());
292 populateMathToLLVMConversionPatterns(converter
, patterns
, approximateLog1p
);
293 LLVMConversionTarget
target(getContext());
294 if (failed(applyPartialConversion(getOperation(), target
,
295 std::move(patterns
))))
301 void mlir::populateMathToLLVMConversionPatterns(
302 const LLVMTypeConverter
&converter
, RewritePatternSet
&patterns
,
303 bool approximateLog1p
) {
304 if (approximateLog1p
)
305 patterns
.add
<Log1pOpLowering
>(converter
);
313 CountLeadingZerosOpLowering
,
314 CountTrailingZerosOpLowering
,
336 //===----------------------------------------------------------------------===//
337 // ConvertToLLVMPatternInterface implementation
338 //===----------------------------------------------------------------------===//
341 /// Implement the interface to convert Math to LLVM.
342 struct MathToLLVMDialectInterface
: public ConvertToLLVMPatternInterface
{
343 using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface
;
344 void loadDependentDialects(MLIRContext
*context
) const final
{
345 context
->loadDialect
<LLVM::LLVMDialect
>();
348 /// Hook for derived dialect interface to provide conversion patterns
349 /// and mark dialect legal for the conversion target.
350 void populateConvertToLLVMConversionPatterns(
351 ConversionTarget
&target
, LLVMTypeConverter
&typeConverter
,
352 RewritePatternSet
&patterns
) const final
{
353 populateMathToLLVMConversionPatterns(typeConverter
, patterns
);
358 void mlir::registerConvertMathToLLVMInterface(DialectRegistry
®istry
) {
359 registry
.addExtension(+[](MLIRContext
*ctx
, math::MathDialect
*dialect
) {
360 dialect
->addInterfaces
<MathToLLVMDialectInterface
>();