[Hexagon] Better detection of impossible completions to perfect shuffles
[llvm-project.git] / mlir / lib / Conversion / ComplexToLibm / ComplexToLibm.cpp
blobfa8db2c965175223215f40bcbd1e7940aea7a9d7
1 //===-- ComplexToLibm.cpp - conversion from Complex to libm calls ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Conversion/ComplexToLibm/ComplexToLibm.h"
11 #include "mlir/Dialect/Complex/IR/Complex.h"
12 #include "mlir/Dialect/Func/IR/FuncOps.h"
13 #include "mlir/IR/PatternMatch.h"
14 #include "mlir/Pass/Pass.h"
16 namespace mlir {
17 #define GEN_PASS_DEF_CONVERTCOMPLEXTOLIBM
18 #include "mlir/Conversion/Passes.h.inc"
19 } // namespace mlir
21 using namespace mlir;
23 namespace {
24 // Functor to resolve the function name corresponding to the given complex
25 // result type.
26 struct ComplexTypeResolver {
27 llvm::Optional<bool> operator()(Type type) const {
28 auto complexType = type.cast<ComplexType>();
29 auto elementType = complexType.getElementType();
30 if (!elementType.isa<Float32Type, Float64Type>())
31 return {};
33 return elementType.getIntOrFloatBitWidth() == 64;
37 // Functor to resolve the function name corresponding to the given float result
38 // type.
39 struct FloatTypeResolver {
40 llvm::Optional<bool> operator()(Type type) const {
41 auto elementType = type.cast<FloatType>();
42 if (!elementType.isa<Float32Type, Float64Type>())
43 return {};
45 return elementType.getIntOrFloatBitWidth() == 64;
49 // Pattern to convert scalar complex operations to calls to libm functions.
50 // Additionally the libm function signatures are declared.
51 // TypeResolver is a functor returning the libm function name according to the
52 // expected type double or float.
53 template <typename Op, typename TypeResolver = ComplexTypeResolver>
54 struct ScalarOpToLibmCall : public OpRewritePattern<Op> {
55 public:
56 using OpRewritePattern<Op>::OpRewritePattern;
57 ScalarOpToLibmCall<Op, TypeResolver>(MLIRContext *context,
58 StringRef floatFunc,
59 StringRef doubleFunc,
60 PatternBenefit benefit)
61 : OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc),
62 doubleFunc(doubleFunc){};
64 LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
66 private:
67 std::string floatFunc, doubleFunc;
69 } // namespace
71 template <typename Op, typename TypeResolver>
72 LogicalResult ScalarOpToLibmCall<Op, TypeResolver>::matchAndRewrite(
73 Op op, PatternRewriter &rewriter) const {
74 auto module = SymbolTable::getNearestSymbolTable(op);
75 auto isDouble = TypeResolver()(op.getType());
76 if (!isDouble.has_value())
77 return failure();
79 auto name = isDouble.value() ? doubleFunc : floatFunc;
81 auto opFunc = dyn_cast_or_null<SymbolOpInterface>(
82 SymbolTable::lookupSymbolIn(module, name));
83 // Forward declare function if it hasn't already been
84 if (!opFunc) {
85 OpBuilder::InsertionGuard guard(rewriter);
86 rewriter.setInsertionPointToStart(&module->getRegion(0).front());
87 auto opFunctionTy = FunctionType::get(
88 rewriter.getContext(), op->getOperandTypes(), op->getResultTypes());
89 opFunc = rewriter.create<func::FuncOp>(rewriter.getUnknownLoc(), name,
90 opFunctionTy);
91 opFunc.setPrivate();
93 assert(isa<FunctionOpInterface>(SymbolTable::lookupSymbolIn(module, name)));
95 rewriter.replaceOpWithNewOp<func::CallOp>(op, name, op.getType(),
96 op->getOperands());
98 return success();
101 void mlir::populateComplexToLibmConversionPatterns(RewritePatternSet &patterns,
102 PatternBenefit benefit) {
103 patterns.add<ScalarOpToLibmCall<complex::PowOp>>(patterns.getContext(),
104 "cpowf", "cpow", benefit);
105 patterns.add<ScalarOpToLibmCall<complex::SqrtOp>>(patterns.getContext(),
106 "csqrtf", "csqrt", benefit);
107 patterns.add<ScalarOpToLibmCall<complex::TanhOp>>(patterns.getContext(),
108 "ctanhf", "ctanh", benefit);
109 patterns.add<ScalarOpToLibmCall<complex::CosOp>>(patterns.getContext(),
110 "ccosf", "ccos", benefit);
111 patterns.add<ScalarOpToLibmCall<complex::SinOp>>(patterns.getContext(),
112 "csinf", "csin", benefit);
113 patterns.add<ScalarOpToLibmCall<complex::ConjOp>>(patterns.getContext(),
114 "conjf", "conj", benefit);
115 patterns.add<ScalarOpToLibmCall<complex::LogOp>>(patterns.getContext(),
116 "clogf", "clog", benefit);
117 patterns.add<ScalarOpToLibmCall<complex::AbsOp, FloatTypeResolver>>(
118 patterns.getContext(), "cabsf", "cabs", benefit);
119 patterns.add<ScalarOpToLibmCall<complex::AngleOp, FloatTypeResolver>>(
120 patterns.getContext(), "cargf", "carg", benefit);
123 namespace {
124 struct ConvertComplexToLibmPass
125 : public impl::ConvertComplexToLibmBase<ConvertComplexToLibmPass> {
126 void runOnOperation() override;
128 } // namespace
130 void ConvertComplexToLibmPass::runOnOperation() {
131 auto module = getOperation();
133 RewritePatternSet patterns(&getContext());
134 populateComplexToLibmConversionPatterns(patterns, /*benefit=*/1);
136 ConversionTarget target(getContext());
137 target.addLegalDialect<func::FuncDialect>();
138 target.addIllegalOp<complex::PowOp, complex::SqrtOp, complex::TanhOp,
139 complex::CosOp, complex::SinOp, complex::ConjOp,
140 complex::LogOp, complex::AbsOp, complex::AngleOp>();
141 if (failed(applyPartialConversion(module, target, std::move(patterns))))
142 signalPassFailure();
145 std::unique_ptr<OperationPass<ModuleOp>>
146 mlir::createConvertComplexToLibmPass() {
147 return std::make_unique<ConvertComplexToLibmPass>();