1 //===-- ComplexToLibm.cpp - conversion from Complex to libm calls ---------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Conversion/ComplexToLibm/ComplexToLibm.h"
11 #include "mlir/Dialect/Complex/IR/Complex.h"
12 #include "mlir/Dialect/Func/IR/FuncOps.h"
13 #include "mlir/IR/PatternMatch.h"
14 #include "mlir/Pass/Pass.h"
17 #define GEN_PASS_DEF_CONVERTCOMPLEXTOLIBM
18 #include "mlir/Conversion/Passes.h.inc"
24 // Functor to resolve the function name corresponding to the given complex
26 struct ComplexTypeResolver
{
27 llvm::Optional
<bool> operator()(Type type
) const {
28 auto complexType
= type
.cast
<ComplexType
>();
29 auto elementType
= complexType
.getElementType();
30 if (!elementType
.isa
<Float32Type
, Float64Type
>())
33 return elementType
.getIntOrFloatBitWidth() == 64;
37 // Functor to resolve the function name corresponding to the given float result
39 struct FloatTypeResolver
{
40 llvm::Optional
<bool> operator()(Type type
) const {
41 auto elementType
= type
.cast
<FloatType
>();
42 if (!elementType
.isa
<Float32Type
, Float64Type
>())
45 return elementType
.getIntOrFloatBitWidth() == 64;
49 // Pattern to convert scalar complex operations to calls to libm functions.
50 // Additionally the libm function signatures are declared.
51 // TypeResolver is a functor returning the libm function name according to the
52 // expected type double or float.
53 template <typename Op
, typename TypeResolver
= ComplexTypeResolver
>
54 struct ScalarOpToLibmCall
: public OpRewritePattern
<Op
> {
56 using OpRewritePattern
<Op
>::OpRewritePattern
;
57 ScalarOpToLibmCall
<Op
, TypeResolver
>(MLIRContext
*context
,
60 PatternBenefit benefit
)
61 : OpRewritePattern
<Op
>(context
, benefit
), floatFunc(floatFunc
),
62 doubleFunc(doubleFunc
){};
64 LogicalResult
matchAndRewrite(Op op
, PatternRewriter
&rewriter
) const final
;
67 std::string floatFunc
, doubleFunc
;
71 template <typename Op
, typename TypeResolver
>
72 LogicalResult ScalarOpToLibmCall
<Op
, TypeResolver
>::matchAndRewrite(
73 Op op
, PatternRewriter
&rewriter
) const {
74 auto module
= SymbolTable::getNearestSymbolTable(op
);
75 auto isDouble
= TypeResolver()(op
.getType());
76 if (!isDouble
.has_value())
79 auto name
= isDouble
.value() ? doubleFunc
: floatFunc
;
81 auto opFunc
= dyn_cast_or_null
<SymbolOpInterface
>(
82 SymbolTable::lookupSymbolIn(module
, name
));
83 // Forward declare function if it hasn't already been
85 OpBuilder::InsertionGuard
guard(rewriter
);
86 rewriter
.setInsertionPointToStart(&module
->getRegion(0).front());
87 auto opFunctionTy
= FunctionType::get(
88 rewriter
.getContext(), op
->getOperandTypes(), op
->getResultTypes());
89 opFunc
= rewriter
.create
<func::FuncOp
>(rewriter
.getUnknownLoc(), name
,
93 assert(isa
<FunctionOpInterface
>(SymbolTable::lookupSymbolIn(module
, name
)));
95 rewriter
.replaceOpWithNewOp
<func::CallOp
>(op
, name
, op
.getType(),
101 void mlir::populateComplexToLibmConversionPatterns(RewritePatternSet
&patterns
,
102 PatternBenefit benefit
) {
103 patterns
.add
<ScalarOpToLibmCall
<complex::PowOp
>>(patterns
.getContext(),
104 "cpowf", "cpow", benefit
);
105 patterns
.add
<ScalarOpToLibmCall
<complex::SqrtOp
>>(patterns
.getContext(),
106 "csqrtf", "csqrt", benefit
);
107 patterns
.add
<ScalarOpToLibmCall
<complex::TanhOp
>>(patterns
.getContext(),
108 "ctanhf", "ctanh", benefit
);
109 patterns
.add
<ScalarOpToLibmCall
<complex::CosOp
>>(patterns
.getContext(),
110 "ccosf", "ccos", benefit
);
111 patterns
.add
<ScalarOpToLibmCall
<complex::SinOp
>>(patterns
.getContext(),
112 "csinf", "csin", benefit
);
113 patterns
.add
<ScalarOpToLibmCall
<complex::ConjOp
>>(patterns
.getContext(),
114 "conjf", "conj", benefit
);
115 patterns
.add
<ScalarOpToLibmCall
<complex::LogOp
>>(patterns
.getContext(),
116 "clogf", "clog", benefit
);
117 patterns
.add
<ScalarOpToLibmCall
<complex::AbsOp
, FloatTypeResolver
>>(
118 patterns
.getContext(), "cabsf", "cabs", benefit
);
119 patterns
.add
<ScalarOpToLibmCall
<complex::AngleOp
, FloatTypeResolver
>>(
120 patterns
.getContext(), "cargf", "carg", benefit
);
124 struct ConvertComplexToLibmPass
125 : public impl::ConvertComplexToLibmBase
<ConvertComplexToLibmPass
> {
126 void runOnOperation() override
;
130 void ConvertComplexToLibmPass::runOnOperation() {
131 auto module
= getOperation();
133 RewritePatternSet
patterns(&getContext());
134 populateComplexToLibmConversionPatterns(patterns
, /*benefit=*/1);
136 ConversionTarget
target(getContext());
137 target
.addLegalDialect
<func::FuncDialect
>();
138 target
.addIllegalOp
<complex::PowOp
, complex::SqrtOp
, complex::TanhOp
,
139 complex::CosOp
, complex::SinOp
, complex::ConjOp
,
140 complex::LogOp
, complex::AbsOp
, complex::AngleOp
>();
141 if (failed(applyPartialConversion(module
, target
, std::move(patterns
))))
145 std::unique_ptr
<OperationPass
<ModuleOp
>>
146 mlir::createConvertComplexToLibmPass() {
147 return std::make_unique
<ConvertComplexToLibmPass
>();