1 //===-- ComplexToLibm.cpp - conversion from Complex to libm calls ---------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Conversion/ComplexToLibm/ComplexToLibm.h"
11 #include "mlir/Dialect/Complex/IR/Complex.h"
12 #include "mlir/Dialect/Func/IR/FuncOps.h"
13 #include "mlir/IR/PatternMatch.h"
14 #include "mlir/Pass/Pass.h"
18 #define GEN_PASS_DEF_CONVERTCOMPLEXTOLIBM
19 #include "mlir/Conversion/Passes.h.inc"
25 // Functor to resolve the function name corresponding to the given complex
27 struct ComplexTypeResolver
{
28 std::optional
<bool> operator()(Type type
) const {
29 auto complexType
= cast
<ComplexType
>(type
);
30 auto elementType
= complexType
.getElementType();
31 if (!isa
<Float32Type
, Float64Type
>(elementType
))
34 return elementType
.getIntOrFloatBitWidth() == 64;
38 // Functor to resolve the function name corresponding to the given float result
40 struct FloatTypeResolver
{
41 std::optional
<bool> operator()(Type type
) const {
42 auto elementType
= cast
<FloatType
>(type
);
43 if (!isa
<Float32Type
, Float64Type
>(elementType
))
46 return elementType
.getIntOrFloatBitWidth() == 64;
50 // Pattern to convert scalar complex operations to calls to libm functions.
51 // Additionally the libm function signatures are declared.
52 // TypeResolver is a functor returning the libm function name according to the
53 // expected type double or float.
54 template <typename Op
, typename TypeResolver
= ComplexTypeResolver
>
55 struct ScalarOpToLibmCall
: public OpRewritePattern
<Op
> {
57 using OpRewritePattern
<Op
>::OpRewritePattern
;
58 ScalarOpToLibmCall(MLIRContext
*context
, StringRef floatFunc
,
59 StringRef doubleFunc
, PatternBenefit benefit
)
60 : OpRewritePattern
<Op
>(context
, benefit
), floatFunc(floatFunc
),
61 doubleFunc(doubleFunc
){};
63 LogicalResult
matchAndRewrite(Op op
, PatternRewriter
&rewriter
) const final
;
66 std::string floatFunc
, doubleFunc
;
70 template <typename Op
, typename TypeResolver
>
71 LogicalResult ScalarOpToLibmCall
<Op
, TypeResolver
>::matchAndRewrite(
72 Op op
, PatternRewriter
&rewriter
) const {
73 auto module
= SymbolTable::getNearestSymbolTable(op
);
74 auto isDouble
= TypeResolver()(op
.getType());
75 if (!isDouble
.has_value())
78 auto name
= *isDouble
? doubleFunc
: floatFunc
;
80 auto opFunc
= dyn_cast_or_null
<SymbolOpInterface
>(
81 SymbolTable::lookupSymbolIn(module
, name
));
82 // Forward declare function if it hasn't already been
84 OpBuilder::InsertionGuard
guard(rewriter
);
85 rewriter
.setInsertionPointToStart(&module
->getRegion(0).front());
86 auto opFunctionTy
= FunctionType::get(
87 rewriter
.getContext(), op
->getOperandTypes(), op
->getResultTypes());
88 opFunc
= rewriter
.create
<func::FuncOp
>(rewriter
.getUnknownLoc(), name
,
92 assert(isa
<FunctionOpInterface
>(SymbolTable::lookupSymbolIn(module
, name
)));
94 rewriter
.replaceOpWithNewOp
<func::CallOp
>(op
, name
, op
.getType(),
100 void mlir::populateComplexToLibmConversionPatterns(RewritePatternSet
&patterns
,
101 PatternBenefit benefit
) {
102 patterns
.add
<ScalarOpToLibmCall
<complex::PowOp
>>(patterns
.getContext(),
103 "cpowf", "cpow", benefit
);
104 patterns
.add
<ScalarOpToLibmCall
<complex::SqrtOp
>>(patterns
.getContext(),
105 "csqrtf", "csqrt", benefit
);
106 patterns
.add
<ScalarOpToLibmCall
<complex::TanhOp
>>(patterns
.getContext(),
107 "ctanhf", "ctanh", benefit
);
108 patterns
.add
<ScalarOpToLibmCall
<complex::CosOp
>>(patterns
.getContext(),
109 "ccosf", "ccos", benefit
);
110 patterns
.add
<ScalarOpToLibmCall
<complex::SinOp
>>(patterns
.getContext(),
111 "csinf", "csin", benefit
);
112 patterns
.add
<ScalarOpToLibmCall
<complex::ConjOp
>>(patterns
.getContext(),
113 "conjf", "conj", benefit
);
114 patterns
.add
<ScalarOpToLibmCall
<complex::LogOp
>>(patterns
.getContext(),
115 "clogf", "clog", benefit
);
116 patterns
.add
<ScalarOpToLibmCall
<complex::AbsOp
, FloatTypeResolver
>>(
117 patterns
.getContext(), "cabsf", "cabs", benefit
);
118 patterns
.add
<ScalarOpToLibmCall
<complex::AngleOp
, FloatTypeResolver
>>(
119 patterns
.getContext(), "cargf", "carg", benefit
);
120 patterns
.add
<ScalarOpToLibmCall
<complex::TanOp
>>(patterns
.getContext(),
121 "ctanf", "ctan", benefit
);
125 struct ConvertComplexToLibmPass
126 : public impl::ConvertComplexToLibmBase
<ConvertComplexToLibmPass
> {
127 void runOnOperation() override
;
131 void ConvertComplexToLibmPass::runOnOperation() {
132 auto module
= getOperation();
134 RewritePatternSet
patterns(&getContext());
135 populateComplexToLibmConversionPatterns(patterns
, /*benefit=*/1);
137 ConversionTarget
target(getContext());
138 target
.addLegalDialect
<func::FuncDialect
>();
139 target
.addIllegalOp
<complex::PowOp
, complex::SqrtOp
, complex::TanhOp
,
140 complex::CosOp
, complex::SinOp
, complex::ConjOp
,
141 complex::LogOp
, complex::AbsOp
, complex::AngleOp
,
143 if (failed(applyPartialConversion(module
, target
, std::move(patterns
))))
147 std::unique_ptr
<OperationPass
<ModuleOp
>>
148 mlir::createConvertComplexToLibmPass() {
149 return std::make_unique
<ConvertComplexToLibmPass
>();