1 //===-- MathToLibm.cpp - conversion from Math to libm calls ---------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Conversion/MathToLibm/MathToLibm.h"
11 #include "mlir/Dialect/Arith/IR/Arith.h"
12 #include "mlir/Dialect/Func/IR/FuncOps.h"
13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14 #include "mlir/Dialect/Math/IR/Math.h"
15 #include "mlir/Dialect/Utils/IndexingUtils.h"
16 #include "mlir/Dialect/Vector/IR/VectorOps.h"
17 #include "mlir/IR/BuiltinDialect.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/Pass/Pass.h"
20 #include "mlir/Transforms/DialectConversion.h"
23 #define GEN_PASS_DEF_CONVERTMATHTOLIBM
24 #include "mlir/Conversion/Passes.h.inc"
30 // Pattern to convert vector operations to scalar operations. This is needed as
31 // libm calls require scalars.
32 template <typename Op
>
33 struct VecOpToScalarOp
: public OpRewritePattern
<Op
> {
35 using OpRewritePattern
<Op
>::OpRewritePattern
;
37 LogicalResult
matchAndRewrite(Op op
, PatternRewriter
&rewriter
) const final
;
39 // Pattern to promote an op of a smaller floating point type to F32.
40 template <typename Op
>
41 struct PromoteOpToF32
: public OpRewritePattern
<Op
> {
43 using OpRewritePattern
<Op
>::OpRewritePattern
;
45 LogicalResult
matchAndRewrite(Op op
, PatternRewriter
&rewriter
) const final
;
47 // Pattern to convert scalar math operations to calls to libm functions.
48 // Additionally the libm function signatures are declared.
49 template <typename Op
>
50 struct ScalarOpToLibmCall
: public OpRewritePattern
<Op
> {
52 using OpRewritePattern
<Op
>::OpRewritePattern
;
53 ScalarOpToLibmCall(MLIRContext
*context
, StringRef floatFunc
,
55 : OpRewritePattern
<Op
>(context
), floatFunc(floatFunc
),
56 doubleFunc(doubleFunc
){};
58 LogicalResult
matchAndRewrite(Op op
, PatternRewriter
&rewriter
) const final
;
61 std::string floatFunc
, doubleFunc
;
64 template <typename OpTy
>
65 void populatePatternsForOp(RewritePatternSet
&patterns
, MLIRContext
*ctx
,
66 StringRef floatFunc
, StringRef doubleFunc
) {
67 patterns
.add
<VecOpToScalarOp
<OpTy
>, PromoteOpToF32
<OpTy
>>(ctx
);
68 patterns
.add
<ScalarOpToLibmCall
<OpTy
>>(ctx
, floatFunc
, doubleFunc
);
73 template <typename Op
>
75 VecOpToScalarOp
<Op
>::matchAndRewrite(Op op
, PatternRewriter
&rewriter
) const {
76 auto opType
= op
.getType();
77 auto loc
= op
.getLoc();
78 auto vecType
= dyn_cast
<VectorType
>(opType
);
82 if (!vecType
.hasRank())
84 auto shape
= vecType
.getShape();
85 int64_t numElements
= vecType
.getNumElements();
87 Value result
= rewriter
.create
<arith::ConstantOp
>(
88 loc
, DenseElementsAttr::get(
89 vecType
, FloatAttr::get(vecType
.getElementType(), 0.0)));
90 SmallVector
<int64_t> strides
= computeStrides(shape
);
91 for (auto linearIndex
= 0; linearIndex
< numElements
; ++linearIndex
) {
92 SmallVector
<int64_t> positions
= delinearize(linearIndex
, strides
);
93 SmallVector
<Value
> operands
;
94 for (auto input
: op
->getOperands())
96 rewriter
.create
<vector::ExtractOp
>(loc
, input
, positions
));
98 rewriter
.create
<Op
>(loc
, vecType
.getElementType(), operands
);
100 rewriter
.create
<vector::InsertOp
>(loc
, scalarOp
, result
, positions
);
102 rewriter
.replaceOp(op
, {result
});
106 template <typename Op
>
108 PromoteOpToF32
<Op
>::matchAndRewrite(Op op
, PatternRewriter
&rewriter
) const {
109 auto opType
= op
.getType();
110 if (!isa
<Float16Type
, BFloat16Type
>(opType
))
113 auto loc
= op
.getLoc();
114 auto f32
= rewriter
.getF32Type();
115 auto extendedOperands
= llvm::to_vector(
116 llvm::map_range(op
->getOperands(), [&](Value operand
) -> Value
{
117 return rewriter
.create
<arith::ExtFOp
>(loc
, f32
, operand
);
119 auto newOp
= rewriter
.create
<Op
>(loc
, f32
, extendedOperands
);
120 rewriter
.replaceOpWithNewOp
<arith::TruncFOp
>(op
, opType
, newOp
);
124 template <typename Op
>
126 ScalarOpToLibmCall
<Op
>::matchAndRewrite(Op op
,
127 PatternRewriter
&rewriter
) const {
128 auto module
= SymbolTable::getNearestSymbolTable(op
);
129 auto type
= op
.getType();
130 if (!isa
<Float32Type
, Float64Type
>(type
))
133 auto name
= type
.getIntOrFloatBitWidth() == 64 ? doubleFunc
: floatFunc
;
134 auto opFunc
= dyn_cast_or_null
<SymbolOpInterface
>(
135 SymbolTable::lookupSymbolIn(module
, name
));
136 // Forward declare function if it hasn't already been
138 OpBuilder::InsertionGuard
guard(rewriter
);
139 rewriter
.setInsertionPointToStart(&module
->getRegion(0).front());
140 auto opFunctionTy
= FunctionType::get(
141 rewriter
.getContext(), op
->getOperandTypes(), op
->getResultTypes());
142 opFunc
= rewriter
.create
<func::FuncOp
>(rewriter
.getUnknownLoc(), name
,
146 // By definition Math dialect operations imply LLVM's "readnone"
147 // function attribute, so we can set it here to provide more
148 // optimization opportunities (e.g. LICM) for backends targeting LLVM IR.
149 // This will have to be changed, when strict FP behavior is supported
151 opFunc
->setAttr(LLVM::LLVMDialect::getReadnoneAttrName(),
152 UnitAttr::get(rewriter
.getContext()));
154 assert(isa
<FunctionOpInterface
>(SymbolTable::lookupSymbolIn(module
, name
)));
156 rewriter
.replaceOpWithNewOp
<func::CallOp
>(op
, name
, op
.getType(),
162 void mlir::populateMathToLibmConversionPatterns(RewritePatternSet
&patterns
) {
163 MLIRContext
*ctx
= patterns
.getContext();
165 populatePatternsForOp
<math::AbsFOp
>(patterns
, ctx
, "fabsf", "fabs");
166 populatePatternsForOp
<math::AcosOp
>(patterns
, ctx
, "acosf", "acos");
167 populatePatternsForOp
<math::AcoshOp
>(patterns
, ctx
, "acoshf", "acosh");
168 populatePatternsForOp
<math::AsinOp
>(patterns
, ctx
, "asinf", "asin");
169 populatePatternsForOp
<math::AsinhOp
>(patterns
, ctx
, "asinhf", "asinh");
170 populatePatternsForOp
<math::Atan2Op
>(patterns
, ctx
, "atan2f", "atan2");
171 populatePatternsForOp
<math::AtanOp
>(patterns
, ctx
, "atanf", "atan");
172 populatePatternsForOp
<math::AtanhOp
>(patterns
, ctx
, "atanhf", "atanh");
173 populatePatternsForOp
<math::CbrtOp
>(patterns
, ctx
, "cbrtf", "cbrt");
174 populatePatternsForOp
<math::CeilOp
>(patterns
, ctx
, "ceilf", "ceil");
175 populatePatternsForOp
<math::CosOp
>(patterns
, ctx
, "cosf", "cos");
176 populatePatternsForOp
<math::CoshOp
>(patterns
, ctx
, "coshf", "cosh");
177 populatePatternsForOp
<math::ErfOp
>(patterns
, ctx
, "erff", "erf");
178 populatePatternsForOp
<math::ExpOp
>(patterns
, ctx
, "expf", "exp");
179 populatePatternsForOp
<math::Exp2Op
>(patterns
, ctx
, "exp2f", "exp2");
180 populatePatternsForOp
<math::ExpM1Op
>(patterns
, ctx
, "expm1f", "expm1");
181 populatePatternsForOp
<math::FloorOp
>(patterns
, ctx
, "floorf", "floor");
182 populatePatternsForOp
<math::FmaOp
>(patterns
, ctx
, "fmaf", "fma");
183 populatePatternsForOp
<math::LogOp
>(patterns
, ctx
, "logf", "log");
184 populatePatternsForOp
<math::Log2Op
>(patterns
, ctx
, "log2f", "log2");
185 populatePatternsForOp
<math::Log10Op
>(patterns
, ctx
, "log10f", "log10");
186 populatePatternsForOp
<math::Log1pOp
>(patterns
, ctx
, "log1pf", "log1p");
187 populatePatternsForOp
<math::PowFOp
>(patterns
, ctx
, "powf", "pow");
188 populatePatternsForOp
<math::RoundEvenOp
>(patterns
, ctx
, "roundevenf",
190 populatePatternsForOp
<math::RoundOp
>(patterns
, ctx
, "roundf", "round");
191 populatePatternsForOp
<math::SinOp
>(patterns
, ctx
, "sinf", "sin");
192 populatePatternsForOp
<math::SinhOp
>(patterns
, ctx
, "sinhf", "sinh");
193 populatePatternsForOp
<math::SqrtOp
>(patterns
, ctx
, "sqrtf", "sqrt");
194 populatePatternsForOp
<math::RsqrtOp
>(patterns
, ctx
, "rsqrtf", "rsqrt");
195 populatePatternsForOp
<math::TanOp
>(patterns
, ctx
, "tanf", "tan");
196 populatePatternsForOp
<math::TanhOp
>(patterns
, ctx
, "tanhf", "tanh");
197 populatePatternsForOp
<math::TruncOp
>(patterns
, ctx
, "truncf", "trunc");
201 struct ConvertMathToLibmPass
202 : public impl::ConvertMathToLibmBase
<ConvertMathToLibmPass
> {
203 void runOnOperation() override
;
207 void ConvertMathToLibmPass::runOnOperation() {
208 auto module
= getOperation();
210 RewritePatternSet
patterns(&getContext());
211 populateMathToLibmConversionPatterns(patterns
);
213 ConversionTarget
target(getContext());
214 target
.addLegalDialect
<arith::ArithDialect
, BuiltinDialect
, func::FuncDialect
,
215 vector::VectorDialect
>();
216 target
.addIllegalDialect
<math::MathDialect
>();
217 if (failed(applyPartialConversion(module
, target
, std::move(patterns
))))
221 std::unique_ptr
<OperationPass
<ModuleOp
>> mlir::createConvertMathToLibmPass() {
222 return std::make_unique
<ConvertMathToLibmPass
>();