1 //===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
11 #include "mlir/Dialect/Arith/IR/Arith.h"
12 #include "mlir/Dialect/Complex/IR/Complex.h"
13 #include "mlir/Dialect/Math/IR/Math.h"
14 #include "mlir/IR/ImplicitLocOpBuilder.h"
15 #include "mlir/IR/PatternMatch.h"
16 #include "mlir/Pass/Pass.h"
17 #include "mlir/Transforms/DialectConversion.h"
19 #include <type_traits>
22 #define GEN_PASS_DEF_CONVERTCOMPLEXTOSTANDARD
23 #include "mlir/Conversion/Passes.h.inc"
30 enum class AbsFn
{ abs
, sqrt
, rsqrt
};
32 // Returns the absolute value, its square root or its reciprocal square root.
33 Value
computeAbs(Value real
, Value imag
, arith::FastMathFlags fmf
,
34 ImplicitLocOpBuilder
&b
, AbsFn fn
= AbsFn::abs
) {
35 Value one
= b
.create
<arith::ConstantOp
>(real
.getType(),
36 b
.getFloatAttr(real
.getType(), 1.0));
38 Value absReal
= b
.create
<math::AbsFOp
>(real
, fmf
);
39 Value absImag
= b
.create
<math::AbsFOp
>(imag
, fmf
);
41 Value max
= b
.create
<arith::MaximumFOp
>(absReal
, absImag
, fmf
);
42 Value min
= b
.create
<arith::MinimumFOp
>(absReal
, absImag
, fmf
);
44 // The lowering below requires NaNs and infinities to work correctly.
45 arith::FastMathFlags fmfWithNaNInf
= arith::bitEnumClear(
46 fmf
, arith::FastMathFlags::nnan
| arith::FastMathFlags::ninf
);
47 Value ratio
= b
.create
<arith::DivFOp
>(min
, max
, fmfWithNaNInf
);
48 Value ratioSq
= b
.create
<arith::MulFOp
>(ratio
, ratio
, fmfWithNaNInf
);
49 Value ratioSqPlusOne
= b
.create
<arith::AddFOp
>(ratioSq
, one
, fmfWithNaNInf
);
52 if (fn
== AbsFn::rsqrt
) {
53 ratioSqPlusOne
= b
.create
<math::RsqrtOp
>(ratioSqPlusOne
, fmfWithNaNInf
);
54 min
= b
.create
<math::RsqrtOp
>(min
, fmfWithNaNInf
);
55 max
= b
.create
<math::RsqrtOp
>(max
, fmfWithNaNInf
);
58 if (fn
== AbsFn::sqrt
) {
59 Value quarter
= b
.create
<arith::ConstantOp
>(
60 real
.getType(), b
.getFloatAttr(real
.getType(), 0.25));
61 // sqrt(sqrt(a*b)) would avoid the pow, but will overflow more easily.
62 Value sqrt
= b
.create
<math::SqrtOp
>(max
, fmfWithNaNInf
);
63 Value p025
= b
.create
<math::PowFOp
>(ratioSqPlusOne
, quarter
, fmfWithNaNInf
);
64 result
= b
.create
<arith::MulFOp
>(sqrt
, p025
, fmfWithNaNInf
);
66 Value sqrt
= b
.create
<math::SqrtOp
>(ratioSqPlusOne
, fmfWithNaNInf
);
67 result
= b
.create
<arith::MulFOp
>(max
, sqrt
, fmfWithNaNInf
);
70 Value isNaN
= b
.create
<arith::CmpFOp
>(arith::CmpFPredicate::UNO
, result
,
71 result
, fmfWithNaNInf
);
72 return b
.create
<arith::SelectOp
>(isNaN
, min
, result
);
75 struct AbsOpConversion
: public OpConversionPattern
<complex::AbsOp
> {
76 using OpConversionPattern
<complex::AbsOp
>::OpConversionPattern
;
79 matchAndRewrite(complex::AbsOp op
, OpAdaptor adaptor
,
80 ConversionPatternRewriter
&rewriter
) const override
{
81 ImplicitLocOpBuilder
b(op
.getLoc(), rewriter
);
83 arith::FastMathFlags fmf
= op
.getFastMathFlagsAttr().getValue();
85 Value real
= b
.create
<complex::ReOp
>(adaptor
.getComplex());
86 Value imag
= b
.create
<complex::ImOp
>(adaptor
.getComplex());
87 rewriter
.replaceOp(op
, computeAbs(real
, imag
, fmf
, b
));
93 // atan2(y,x) = -i * log((x + i * y)/sqrt(x**2+y**2))
94 struct Atan2OpConversion
: public OpConversionPattern
<complex::Atan2Op
> {
95 using OpConversionPattern
<complex::Atan2Op
>::OpConversionPattern
;
98 matchAndRewrite(complex::Atan2Op op
, OpAdaptor adaptor
,
99 ConversionPatternRewriter
&rewriter
) const override
{
100 mlir::ImplicitLocOpBuilder
b(op
.getLoc(), rewriter
);
102 auto type
= cast
<ComplexType
>(op
.getType());
103 Type elementType
= type
.getElementType();
104 arith::FastMathFlagsAttr fmf
= op
.getFastMathFlagsAttr();
106 Value lhs
= adaptor
.getLhs();
107 Value rhs
= adaptor
.getRhs();
109 Value rhsSquared
= b
.create
<complex::MulOp
>(type
, rhs
, rhs
, fmf
);
110 Value lhsSquared
= b
.create
<complex::MulOp
>(type
, lhs
, lhs
, fmf
);
111 Value rhsSquaredPlusLhsSquared
=
112 b
.create
<complex::AddOp
>(type
, rhsSquared
, lhsSquared
, fmf
);
113 Value sqrtOfRhsSquaredPlusLhsSquared
=
114 b
.create
<complex::SqrtOp
>(type
, rhsSquaredPlusLhsSquared
, fmf
);
117 b
.create
<arith::ConstantOp
>(elementType
, b
.getZeroAttr(elementType
));
118 Value one
= b
.create
<arith::ConstantOp
>(elementType
,
119 b
.getFloatAttr(elementType
, 1));
120 Value i
= b
.create
<complex::CreateOp
>(type
, zero
, one
);
121 Value iTimesLhs
= b
.create
<complex::MulOp
>(i
, lhs
, fmf
);
122 Value rhsPlusILhs
= b
.create
<complex::AddOp
>(rhs
, iTimesLhs
, fmf
);
124 Value divResult
= b
.create
<complex::DivOp
>(
125 rhsPlusILhs
, sqrtOfRhsSquaredPlusLhsSquared
, fmf
);
126 Value logResult
= b
.create
<complex::LogOp
>(divResult
, fmf
);
128 Value negativeOne
= b
.create
<arith::ConstantOp
>(
129 elementType
, b
.getFloatAttr(elementType
, -1));
130 Value negativeI
= b
.create
<complex::CreateOp
>(type
, zero
, negativeOne
);
132 rewriter
.replaceOpWithNewOp
<complex::MulOp
>(op
, negativeI
, logResult
, fmf
);
137 template <typename ComparisonOp
, arith::CmpFPredicate p
>
138 struct ComparisonOpConversion
: public OpConversionPattern
<ComparisonOp
> {
139 using OpConversionPattern
<ComparisonOp
>::OpConversionPattern
;
140 using ResultCombiner
=
141 std::conditional_t
<std::is_same
<ComparisonOp
, complex::EqualOp
>::value
,
142 arith::AndIOp
, arith::OrIOp
>;
145 matchAndRewrite(ComparisonOp op
, typename
ComparisonOp::Adaptor adaptor
,
146 ConversionPatternRewriter
&rewriter
) const override
{
147 auto loc
= op
.getLoc();
148 auto type
= cast
<ComplexType
>(adaptor
.getLhs().getType()).getElementType();
150 Value realLhs
= rewriter
.create
<complex::ReOp
>(loc
, type
, adaptor
.getLhs());
151 Value imagLhs
= rewriter
.create
<complex::ImOp
>(loc
, type
, adaptor
.getLhs());
152 Value realRhs
= rewriter
.create
<complex::ReOp
>(loc
, type
, adaptor
.getRhs());
153 Value imagRhs
= rewriter
.create
<complex::ImOp
>(loc
, type
, adaptor
.getRhs());
154 Value realComparison
=
155 rewriter
.create
<arith::CmpFOp
>(loc
, p
, realLhs
, realRhs
);
156 Value imagComparison
=
157 rewriter
.create
<arith::CmpFOp
>(loc
, p
, imagLhs
, imagRhs
);
159 rewriter
.replaceOpWithNewOp
<ResultCombiner
>(op
, realComparison
,
165 // Default conversion which applies the BinaryStandardOp separately on the real
166 // and imaginary parts. Can for example be used for complex::AddOp and
168 template <typename BinaryComplexOp
, typename BinaryStandardOp
>
169 struct BinaryComplexOpConversion
: public OpConversionPattern
<BinaryComplexOp
> {
170 using OpConversionPattern
<BinaryComplexOp
>::OpConversionPattern
;
173 matchAndRewrite(BinaryComplexOp op
, typename
BinaryComplexOp::Adaptor adaptor
,
174 ConversionPatternRewriter
&rewriter
) const override
{
175 auto type
= cast
<ComplexType
>(adaptor
.getLhs().getType());
176 auto elementType
= cast
<FloatType
>(type
.getElementType());
177 mlir::ImplicitLocOpBuilder
b(op
.getLoc(), rewriter
);
178 arith::FastMathFlagsAttr fmf
= op
.getFastMathFlagsAttr();
180 Value realLhs
= b
.create
<complex::ReOp
>(elementType
, adaptor
.getLhs());
181 Value realRhs
= b
.create
<complex::ReOp
>(elementType
, adaptor
.getRhs());
182 Value resultReal
= b
.create
<BinaryStandardOp
>(elementType
, realLhs
, realRhs
,
184 Value imagLhs
= b
.create
<complex::ImOp
>(elementType
, adaptor
.getLhs());
185 Value imagRhs
= b
.create
<complex::ImOp
>(elementType
, adaptor
.getRhs());
186 Value resultImag
= b
.create
<BinaryStandardOp
>(elementType
, imagLhs
, imagRhs
,
188 rewriter
.replaceOpWithNewOp
<complex::CreateOp
>(op
, type
, resultReal
,
194 template <typename TrigonometricOp
>
195 struct TrigonometricOpConversion
: public OpConversionPattern
<TrigonometricOp
> {
196 using OpAdaptor
= typename OpConversionPattern
<TrigonometricOp
>::OpAdaptor
;
198 using OpConversionPattern
<TrigonometricOp
>::OpConversionPattern
;
201 matchAndRewrite(TrigonometricOp op
, OpAdaptor adaptor
,
202 ConversionPatternRewriter
&rewriter
) const override
{
203 auto loc
= op
.getLoc();
204 auto type
= cast
<ComplexType
>(adaptor
.getComplex().getType());
205 auto elementType
= cast
<FloatType
>(type
.getElementType());
206 arith::FastMathFlagsAttr fmf
= op
.getFastMathFlagsAttr();
209 rewriter
.create
<complex::ReOp
>(loc
, elementType
, adaptor
.getComplex());
211 rewriter
.create
<complex::ImOp
>(loc
, elementType
, adaptor
.getComplex());
213 // Trigonometric ops use a set of common building blocks to convert to real
214 // ops. Here we create these building blocks and call into an op-specific
215 // implementation in the subclass to combine them.
216 Value half
= rewriter
.create
<arith::ConstantOp
>(
217 loc
, elementType
, rewriter
.getFloatAttr(elementType
, 0.5));
218 Value exp
= rewriter
.create
<math::ExpOp
>(loc
, imag
, fmf
);
219 Value scaledExp
= rewriter
.create
<arith::MulFOp
>(loc
, half
, exp
, fmf
);
220 Value reciprocalExp
= rewriter
.create
<arith::DivFOp
>(loc
, half
, exp
, fmf
);
221 Value sin
= rewriter
.create
<math::SinOp
>(loc
, real
, fmf
);
222 Value cos
= rewriter
.create
<math::CosOp
>(loc
, real
, fmf
);
225 combine(loc
, scaledExp
, reciprocalExp
, sin
, cos
, rewriter
, fmf
);
227 rewriter
.replaceOpWithNewOp
<complex::CreateOp
>(op
, type
, resultPair
.first
,
232 virtual std::pair
<Value
, Value
>
233 combine(Location loc
, Value scaledExp
, Value reciprocalExp
, Value sin
,
234 Value cos
, ConversionPatternRewriter
&rewriter
,
235 arith::FastMathFlagsAttr fmf
) const = 0;
238 struct CosOpConversion
: public TrigonometricOpConversion
<complex::CosOp
> {
239 using TrigonometricOpConversion
<complex::CosOp
>::TrigonometricOpConversion
;
241 std::pair
<Value
, Value
> combine(Location loc
, Value scaledExp
,
242 Value reciprocalExp
, Value sin
, Value cos
,
243 ConversionPatternRewriter
&rewriter
,
244 arith::FastMathFlagsAttr fmf
) const override
{
245 // Complex cosine is defined as;
246 // cos(x + iy) = 0.5 * (exp(i(x + iy)) + exp(-i(x + iy)))
248 // exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x))
249 // exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x)))
250 // and defining t := exp(y)
252 // Re(cos(x + iy)) = (0.5/t + 0.5*t) * cos x
253 // Im(cos(x + iy)) = (0.5/t - 0.5*t) * sin x
255 rewriter
.create
<arith::AddFOp
>(loc
, reciprocalExp
, scaledExp
, fmf
);
256 Value resultReal
= rewriter
.create
<arith::MulFOp
>(loc
, sum
, cos
, fmf
);
258 rewriter
.create
<arith::SubFOp
>(loc
, reciprocalExp
, scaledExp
, fmf
);
259 Value resultImag
= rewriter
.create
<arith::MulFOp
>(loc
, diff
, sin
, fmf
);
260 return {resultReal
, resultImag
};
264 struct DivOpConversion
: public OpConversionPattern
<complex::DivOp
> {
265 using OpConversionPattern
<complex::DivOp
>::OpConversionPattern
;
268 matchAndRewrite(complex::DivOp op
, OpAdaptor adaptor
,
269 ConversionPatternRewriter
&rewriter
) const override
{
270 auto loc
= op
.getLoc();
271 auto type
= cast
<ComplexType
>(adaptor
.getLhs().getType());
272 auto elementType
= cast
<FloatType
>(type
.getElementType());
273 arith::FastMathFlagsAttr fmf
= op
.getFastMathFlagsAttr();
276 rewriter
.create
<complex::ReOp
>(loc
, elementType
, adaptor
.getLhs());
278 rewriter
.create
<complex::ImOp
>(loc
, elementType
, adaptor
.getLhs());
280 rewriter
.create
<complex::ReOp
>(loc
, elementType
, adaptor
.getRhs());
282 rewriter
.create
<complex::ImOp
>(loc
, elementType
, adaptor
.getRhs());
284 // Smith's algorithm to divide complex numbers. It is just a bit smarter
285 // way to compute the following formula:
286 // (lhsReal + lhsImag * i) / (rhsReal + rhsImag * i)
287 // = (lhsReal + lhsImag * i) (rhsReal - rhsImag * i) /
288 // ((rhsReal + rhsImag * i)(rhsReal - rhsImag * i))
289 // = ((lhsReal * rhsReal + lhsImag * rhsImag) +
290 // (lhsImag * rhsReal - lhsReal * rhsImag) * i) / ||rhs||^2
292 // Depending on whether |rhsReal| < |rhsImag| we compute either
293 // rhsRealImagRatio = rhsReal / rhsImag
294 // rhsRealImagDenom = rhsImag + rhsReal * rhsRealImagRatio
295 // resultReal = (lhsReal * rhsRealImagRatio + lhsImag) / rhsRealImagDenom
296 // resultImag = (lhsImag * rhsRealImagRatio - lhsReal) / rhsRealImagDenom
300 // rhsImagRealRatio = rhsImag / rhsReal
301 // rhsImagRealDenom = rhsReal + rhsImag * rhsImagRealRatio
302 // resultReal = (lhsReal + lhsImag * rhsImagRealRatio) / rhsImagRealDenom
303 // resultImag = (lhsImag - lhsReal * rhsImagRealRatio) / rhsImagRealDenom
305 // See https://dl.acm.org/citation.cfm?id=368661 for more details.
306 Value rhsRealImagRatio
=
307 rewriter
.create
<arith::DivFOp
>(loc
, rhsReal
, rhsImag
, fmf
);
308 Value rhsRealImagDenom
= rewriter
.create
<arith::AddFOp
>(
310 rewriter
.create
<arith::MulFOp
>(loc
, rhsRealImagRatio
, rhsReal
, fmf
),
312 Value realNumerator1
= rewriter
.create
<arith::AddFOp
>(
314 rewriter
.create
<arith::MulFOp
>(loc
, lhsReal
, rhsRealImagRatio
, fmf
),
316 Value resultReal1
= rewriter
.create
<arith::DivFOp
>(loc
, realNumerator1
,
317 rhsRealImagDenom
, fmf
);
318 Value imagNumerator1
= rewriter
.create
<arith::SubFOp
>(
320 rewriter
.create
<arith::MulFOp
>(loc
, lhsImag
, rhsRealImagRatio
, fmf
),
322 Value resultImag1
= rewriter
.create
<arith::DivFOp
>(loc
, imagNumerator1
,
323 rhsRealImagDenom
, fmf
);
325 Value rhsImagRealRatio
=
326 rewriter
.create
<arith::DivFOp
>(loc
, rhsImag
, rhsReal
, fmf
);
327 Value rhsImagRealDenom
= rewriter
.create
<arith::AddFOp
>(
329 rewriter
.create
<arith::MulFOp
>(loc
, rhsImagRealRatio
, rhsImag
, fmf
),
331 Value realNumerator2
= rewriter
.create
<arith::AddFOp
>(
333 rewriter
.create
<arith::MulFOp
>(loc
, lhsImag
, rhsImagRealRatio
, fmf
),
335 Value resultReal2
= rewriter
.create
<arith::DivFOp
>(loc
, realNumerator2
,
336 rhsImagRealDenom
, fmf
);
337 Value imagNumerator2
= rewriter
.create
<arith::SubFOp
>(
339 rewriter
.create
<arith::MulFOp
>(loc
, lhsReal
, rhsImagRealRatio
, fmf
),
341 Value resultImag2
= rewriter
.create
<arith::DivFOp
>(loc
, imagNumerator2
,
342 rhsImagRealDenom
, fmf
);
344 // Consider corner cases.
345 // Case 1. Zero denominator, numerator contains at most one NaN value.
346 Value zero
= rewriter
.create
<arith::ConstantOp
>(
347 loc
, elementType
, rewriter
.getZeroAttr(elementType
));
348 Value rhsRealAbs
= rewriter
.create
<math::AbsFOp
>(loc
, rhsReal
, fmf
);
349 Value rhsRealIsZero
= rewriter
.create
<arith::CmpFOp
>(
350 loc
, arith::CmpFPredicate::OEQ
, rhsRealAbs
, zero
);
351 Value rhsImagAbs
= rewriter
.create
<math::AbsFOp
>(loc
, rhsImag
, fmf
);
352 Value rhsImagIsZero
= rewriter
.create
<arith::CmpFOp
>(
353 loc
, arith::CmpFPredicate::OEQ
, rhsImagAbs
, zero
);
354 Value lhsRealIsNotNaN
= rewriter
.create
<arith::CmpFOp
>(
355 loc
, arith::CmpFPredicate::ORD
, lhsReal
, zero
);
356 Value lhsImagIsNotNaN
= rewriter
.create
<arith::CmpFOp
>(
357 loc
, arith::CmpFPredicate::ORD
, lhsImag
, zero
);
358 Value lhsContainsNotNaNValue
=
359 rewriter
.create
<arith::OrIOp
>(loc
, lhsRealIsNotNaN
, lhsImagIsNotNaN
);
360 Value resultIsInfinity
= rewriter
.create
<arith::AndIOp
>(
361 loc
, lhsContainsNotNaNValue
,
362 rewriter
.create
<arith::AndIOp
>(loc
, rhsRealIsZero
, rhsImagIsZero
));
363 Value inf
= rewriter
.create
<arith::ConstantOp
>(
365 rewriter
.getFloatAttr(
366 elementType
, APFloat::getInf(elementType
.getFloatSemantics())));
367 Value infWithSignOfRhsReal
=
368 rewriter
.create
<math::CopySignOp
>(loc
, inf
, rhsReal
);
369 Value infinityResultReal
=
370 rewriter
.create
<arith::MulFOp
>(loc
, infWithSignOfRhsReal
, lhsReal
, fmf
);
371 Value infinityResultImag
=
372 rewriter
.create
<arith::MulFOp
>(loc
, infWithSignOfRhsReal
, lhsImag
, fmf
);
374 // Case 2. Infinite numerator, finite denominator.
375 Value rhsRealFinite
= rewriter
.create
<arith::CmpFOp
>(
376 loc
, arith::CmpFPredicate::ONE
, rhsRealAbs
, inf
);
377 Value rhsImagFinite
= rewriter
.create
<arith::CmpFOp
>(
378 loc
, arith::CmpFPredicate::ONE
, rhsImagAbs
, inf
);
380 rewriter
.create
<arith::AndIOp
>(loc
, rhsRealFinite
, rhsImagFinite
);
381 Value lhsRealAbs
= rewriter
.create
<math::AbsFOp
>(loc
, lhsReal
, fmf
);
382 Value lhsRealInfinite
= rewriter
.create
<arith::CmpFOp
>(
383 loc
, arith::CmpFPredicate::OEQ
, lhsRealAbs
, inf
);
384 Value lhsImagAbs
= rewriter
.create
<math::AbsFOp
>(loc
, lhsImag
, fmf
);
385 Value lhsImagInfinite
= rewriter
.create
<arith::CmpFOp
>(
386 loc
, arith::CmpFPredicate::OEQ
, lhsImagAbs
, inf
);
388 rewriter
.create
<arith::OrIOp
>(loc
, lhsRealInfinite
, lhsImagInfinite
);
389 Value infNumFiniteDenom
=
390 rewriter
.create
<arith::AndIOp
>(loc
, lhsInfinite
, rhsFinite
);
391 Value one
= rewriter
.create
<arith::ConstantOp
>(
392 loc
, elementType
, rewriter
.getFloatAttr(elementType
, 1));
393 Value lhsRealIsInfWithSign
= rewriter
.create
<math::CopySignOp
>(
394 loc
, rewriter
.create
<arith::SelectOp
>(loc
, lhsRealInfinite
, one
, zero
),
396 Value lhsImagIsInfWithSign
= rewriter
.create
<math::CopySignOp
>(
397 loc
, rewriter
.create
<arith::SelectOp
>(loc
, lhsImagInfinite
, one
, zero
),
399 Value lhsRealIsInfWithSignTimesRhsReal
=
400 rewriter
.create
<arith::MulFOp
>(loc
, lhsRealIsInfWithSign
, rhsReal
, fmf
);
401 Value lhsImagIsInfWithSignTimesRhsImag
=
402 rewriter
.create
<arith::MulFOp
>(loc
, lhsImagIsInfWithSign
, rhsImag
, fmf
);
403 Value resultReal3
= rewriter
.create
<arith::MulFOp
>(
405 rewriter
.create
<arith::AddFOp
>(loc
, lhsRealIsInfWithSignTimesRhsReal
,
406 lhsImagIsInfWithSignTimesRhsImag
, fmf
),
408 Value lhsRealIsInfWithSignTimesRhsImag
=
409 rewriter
.create
<arith::MulFOp
>(loc
, lhsRealIsInfWithSign
, rhsImag
, fmf
);
410 Value lhsImagIsInfWithSignTimesRhsReal
=
411 rewriter
.create
<arith::MulFOp
>(loc
, lhsImagIsInfWithSign
, rhsReal
, fmf
);
412 Value resultImag3
= rewriter
.create
<arith::MulFOp
>(
414 rewriter
.create
<arith::SubFOp
>(loc
, lhsImagIsInfWithSignTimesRhsReal
,
415 lhsRealIsInfWithSignTimesRhsImag
, fmf
),
418 // Case 3: Finite numerator, infinite denominator.
419 Value lhsRealFinite
= rewriter
.create
<arith::CmpFOp
>(
420 loc
, arith::CmpFPredicate::ONE
, lhsRealAbs
, inf
);
421 Value lhsImagFinite
= rewriter
.create
<arith::CmpFOp
>(
422 loc
, arith::CmpFPredicate::ONE
, lhsImagAbs
, inf
);
424 rewriter
.create
<arith::AndIOp
>(loc
, lhsRealFinite
, lhsImagFinite
);
425 Value rhsRealInfinite
= rewriter
.create
<arith::CmpFOp
>(
426 loc
, arith::CmpFPredicate::OEQ
, rhsRealAbs
, inf
);
427 Value rhsImagInfinite
= rewriter
.create
<arith::CmpFOp
>(
428 loc
, arith::CmpFPredicate::OEQ
, rhsImagAbs
, inf
);
430 rewriter
.create
<arith::OrIOp
>(loc
, rhsRealInfinite
, rhsImagInfinite
);
431 Value finiteNumInfiniteDenom
=
432 rewriter
.create
<arith::AndIOp
>(loc
, lhsFinite
, rhsInfinite
);
433 Value rhsRealIsInfWithSign
= rewriter
.create
<math::CopySignOp
>(
434 loc
, rewriter
.create
<arith::SelectOp
>(loc
, rhsRealInfinite
, one
, zero
),
436 Value rhsImagIsInfWithSign
= rewriter
.create
<math::CopySignOp
>(
437 loc
, rewriter
.create
<arith::SelectOp
>(loc
, rhsImagInfinite
, one
, zero
),
439 Value rhsRealIsInfWithSignTimesLhsReal
=
440 rewriter
.create
<arith::MulFOp
>(loc
, lhsReal
, rhsRealIsInfWithSign
, fmf
);
441 Value rhsImagIsInfWithSignTimesLhsImag
=
442 rewriter
.create
<arith::MulFOp
>(loc
, lhsImag
, rhsImagIsInfWithSign
, fmf
);
443 Value resultReal4
= rewriter
.create
<arith::MulFOp
>(
445 rewriter
.create
<arith::AddFOp
>(loc
, rhsRealIsInfWithSignTimesLhsReal
,
446 rhsImagIsInfWithSignTimesLhsImag
, fmf
),
448 Value rhsRealIsInfWithSignTimesLhsImag
=
449 rewriter
.create
<arith::MulFOp
>(loc
, lhsImag
, rhsRealIsInfWithSign
, fmf
);
450 Value rhsImagIsInfWithSignTimesLhsReal
=
451 rewriter
.create
<arith::MulFOp
>(loc
, lhsReal
, rhsImagIsInfWithSign
, fmf
);
452 Value resultImag4
= rewriter
.create
<arith::MulFOp
>(
454 rewriter
.create
<arith::SubFOp
>(loc
, rhsRealIsInfWithSignTimesLhsImag
,
455 rhsImagIsInfWithSignTimesLhsReal
, fmf
),
458 Value realAbsSmallerThanImagAbs
= rewriter
.create
<arith::CmpFOp
>(
459 loc
, arith::CmpFPredicate::OLT
, rhsRealAbs
, rhsImagAbs
);
460 Value resultReal
= rewriter
.create
<arith::SelectOp
>(
461 loc
, realAbsSmallerThanImagAbs
, resultReal1
, resultReal2
);
462 Value resultImag
= rewriter
.create
<arith::SelectOp
>(
463 loc
, realAbsSmallerThanImagAbs
, resultImag1
, resultImag2
);
464 Value resultRealSpecialCase3
= rewriter
.create
<arith::SelectOp
>(
465 loc
, finiteNumInfiniteDenom
, resultReal4
, resultReal
);
466 Value resultImagSpecialCase3
= rewriter
.create
<arith::SelectOp
>(
467 loc
, finiteNumInfiniteDenom
, resultImag4
, resultImag
);
468 Value resultRealSpecialCase2
= rewriter
.create
<arith::SelectOp
>(
469 loc
, infNumFiniteDenom
, resultReal3
, resultRealSpecialCase3
);
470 Value resultImagSpecialCase2
= rewriter
.create
<arith::SelectOp
>(
471 loc
, infNumFiniteDenom
, resultImag3
, resultImagSpecialCase3
);
472 Value resultRealSpecialCase1
= rewriter
.create
<arith::SelectOp
>(
473 loc
, resultIsInfinity
, infinityResultReal
, resultRealSpecialCase2
);
474 Value resultImagSpecialCase1
= rewriter
.create
<arith::SelectOp
>(
475 loc
, resultIsInfinity
, infinityResultImag
, resultImagSpecialCase2
);
477 Value resultRealIsNaN
= rewriter
.create
<arith::CmpFOp
>(
478 loc
, arith::CmpFPredicate::UNO
, resultReal
, zero
);
479 Value resultImagIsNaN
= rewriter
.create
<arith::CmpFOp
>(
480 loc
, arith::CmpFPredicate::UNO
, resultImag
, zero
);
482 rewriter
.create
<arith::AndIOp
>(loc
, resultRealIsNaN
, resultImagIsNaN
);
483 Value resultRealWithSpecialCases
= rewriter
.create
<arith::SelectOp
>(
484 loc
, resultIsNaN
, resultRealSpecialCase1
, resultReal
);
485 Value resultImagWithSpecialCases
= rewriter
.create
<arith::SelectOp
>(
486 loc
, resultIsNaN
, resultImagSpecialCase1
, resultImag
);
488 rewriter
.replaceOpWithNewOp
<complex::CreateOp
>(
489 op
, type
, resultRealWithSpecialCases
, resultImagWithSpecialCases
);
494 struct ExpOpConversion
: public OpConversionPattern
<complex::ExpOp
> {
495 using OpConversionPattern
<complex::ExpOp
>::OpConversionPattern
;
498 matchAndRewrite(complex::ExpOp op
, OpAdaptor adaptor
,
499 ConversionPatternRewriter
&rewriter
) const override
{
500 auto loc
= op
.getLoc();
501 auto type
= cast
<ComplexType
>(adaptor
.getComplex().getType());
502 auto elementType
= cast
<FloatType
>(type
.getElementType());
503 arith::FastMathFlagsAttr fmf
= op
.getFastMathFlagsAttr();
506 rewriter
.create
<complex::ReOp
>(loc
, elementType
, adaptor
.getComplex());
508 rewriter
.create
<complex::ImOp
>(loc
, elementType
, adaptor
.getComplex());
509 Value expReal
= rewriter
.create
<math::ExpOp
>(loc
, real
, fmf
.getValue());
510 Value cosImag
= rewriter
.create
<math::CosOp
>(loc
, imag
, fmf
.getValue());
512 rewriter
.create
<arith::MulFOp
>(loc
, expReal
, cosImag
, fmf
.getValue());
513 Value sinImag
= rewriter
.create
<math::SinOp
>(loc
, imag
, fmf
.getValue());
515 rewriter
.create
<arith::MulFOp
>(loc
, expReal
, sinImag
, fmf
.getValue());
517 rewriter
.replaceOpWithNewOp
<complex::CreateOp
>(op
, type
, resultReal
,
523 Value
evaluatePolynomial(ImplicitLocOpBuilder
&b
, Value arg
,
524 ArrayRef
<double> coefficients
,
525 arith::FastMathFlagsAttr fmf
) {
526 auto argType
= mlir::cast
<FloatType
>(arg
.getType());
528 b
.create
<arith::ConstantOp
>(b
.getFloatAttr(argType
, coefficients
[0]));
529 for (unsigned i
= 1; i
< coefficients
.size(); ++i
) {
530 poly
= b
.create
<math::FmaOp
>(
532 b
.create
<arith::ConstantOp
>(b
.getFloatAttr(argType
, coefficients
[i
])),
538 struct Expm1OpConversion
: public OpConversionPattern
<complex::Expm1Op
> {
539 using OpConversionPattern
<complex::Expm1Op
>::OpConversionPattern
;
541 // e^(a+bi)-1 = (e^a*cos(b)-1)+e^a*sin(b)i
542 // [handle inaccuracies when a and/or b are small]
543 // = ((e^a - 1) * cos(b) + cos(b) - 1) + e^a*sin(b)i
544 // = (expm1(a) * cos(b) + cosm1(b)) + e^a*sin(b)i
546 matchAndRewrite(complex::Expm1Op op
, OpAdaptor adaptor
,
547 ConversionPatternRewriter
&rewriter
) const override
{
548 auto type
= op
.getType();
549 auto elemType
= mlir::cast
<FloatType
>(type
.getElementType());
551 arith::FastMathFlagsAttr fmf
= op
.getFastMathFlagsAttr();
552 ImplicitLocOpBuilder
b(op
.getLoc(), rewriter
);
553 Value real
= b
.create
<complex::ReOp
>(adaptor
.getComplex());
554 Value imag
= b
.create
<complex::ImOp
>(adaptor
.getComplex());
556 Value zero
= b
.create
<arith::ConstantOp
>(b
.getFloatAttr(elemType
, 0.0));
557 Value one
= b
.create
<arith::ConstantOp
>(b
.getFloatAttr(elemType
, 1.0));
559 Value expm1Real
= b
.create
<math::ExpM1Op
>(real
, fmf
);
560 Value expReal
= b
.create
<arith::AddFOp
>(expm1Real
, one
, fmf
);
562 Value sinImag
= b
.create
<math::SinOp
>(imag
, fmf
);
563 Value cosm1Imag
= emitCosm1(imag
, fmf
, b
);
564 Value cosImag
= b
.create
<arith::AddFOp
>(cosm1Imag
, one
, fmf
);
566 Value realResult
= b
.create
<arith::AddFOp
>(
567 b
.create
<arith::MulFOp
>(expm1Real
, cosImag
, fmf
), cosm1Imag
, fmf
);
569 Value imagIsZero
= b
.create
<arith::CmpFOp
>(arith::CmpFPredicate::OEQ
, imag
,
570 zero
, fmf
.getValue());
571 Value imagResult
= b
.create
<arith::SelectOp
>(
572 imagIsZero
, zero
, b
.create
<arith::MulFOp
>(expReal
, sinImag
, fmf
));
574 rewriter
.replaceOpWithNewOp
<complex::CreateOp
>(op
, type
, realResult
,
580 Value
emitCosm1(Value arg
, arith::FastMathFlagsAttr fmf
,
581 ImplicitLocOpBuilder
&b
) const {
582 auto argType
= mlir::cast
<FloatType
>(arg
.getType());
583 auto negHalf
= b
.create
<arith::ConstantOp
>(b
.getFloatAttr(argType
, -0.5));
584 auto negOne
= b
.create
<arith::ConstantOp
>(b
.getFloatAttr(argType
, -1.0));
586 // Algorithm copied from cephes cosm1.
587 SmallVector
<double, 7> kCoeffs
{
588 4.7377507964246204691685E-14, -1.1470284843425359765671E-11,
589 2.0876754287081521758361E-9, -2.7557319214999787979814E-7,
590 2.4801587301570552304991E-5, -1.3888888888888872993737E-3,
591 4.1666666666666666609054E-2,
593 Value cos
= b
.create
<math::CosOp
>(arg
, fmf
);
594 Value forLargeArg
= b
.create
<arith::AddFOp
>(cos
, negOne
, fmf
);
596 Value argPow2
= b
.create
<arith::MulFOp
>(arg
, arg
, fmf
);
597 Value argPow4
= b
.create
<arith::MulFOp
>(argPow2
, argPow2
, fmf
);
598 Value poly
= evaluatePolynomial(b
, argPow2
, kCoeffs
, fmf
);
601 b
.create
<arith::AddFOp
>(b
.create
<arith::MulFOp
>(argPow4
, poly
, fmf
),
602 b
.create
<arith::MulFOp
>(negHalf
, argPow2
, fmf
));
604 // (pi/4)^2 is approximately 0.61685
606 b
.create
<arith::ConstantOp
>(b
.getFloatAttr(argType
, 0.61685));
607 Value cond
= b
.create
<arith::CmpFOp
>(arith::CmpFPredicate::OGE
, argPow2
,
608 piOver4Pow2
, fmf
.getValue());
609 return b
.create
<arith::SelectOp
>(cond
, forLargeArg
, forSmallArg
);
613 struct LogOpConversion
: public OpConversionPattern
<complex::LogOp
> {
614 using OpConversionPattern
<complex::LogOp
>::OpConversionPattern
;
617 matchAndRewrite(complex::LogOp op
, OpAdaptor adaptor
,
618 ConversionPatternRewriter
&rewriter
) const override
{
619 auto type
= cast
<ComplexType
>(adaptor
.getComplex().getType());
620 auto elementType
= cast
<FloatType
>(type
.getElementType());
621 arith::FastMathFlagsAttr fmf
= op
.getFastMathFlagsAttr();
622 mlir::ImplicitLocOpBuilder
b(op
.getLoc(), rewriter
);
624 Value abs
= b
.create
<complex::AbsOp
>(elementType
, adaptor
.getComplex(),
626 Value resultReal
= b
.create
<math::LogOp
>(elementType
, abs
, fmf
.getValue());
627 Value real
= b
.create
<complex::ReOp
>(elementType
, adaptor
.getComplex());
628 Value imag
= b
.create
<complex::ImOp
>(elementType
, adaptor
.getComplex());
630 b
.create
<math::Atan2Op
>(elementType
, imag
, real
, fmf
.getValue());
631 rewriter
.replaceOpWithNewOp
<complex::CreateOp
>(op
, type
, resultReal
,
637 struct Log1pOpConversion
: public OpConversionPattern
<complex::Log1pOp
> {
638 using OpConversionPattern
<complex::Log1pOp
>::OpConversionPattern
;
641 matchAndRewrite(complex::Log1pOp op
, OpAdaptor adaptor
,
642 ConversionPatternRewriter
&rewriter
) const override
{
643 auto type
= cast
<ComplexType
>(adaptor
.getComplex().getType());
644 auto elementType
= cast
<FloatType
>(type
.getElementType());
645 arith::FastMathFlags fmf
= op
.getFastMathFlagsAttr().getValue();
646 mlir::ImplicitLocOpBuilder
b(op
.getLoc(), rewriter
);
648 Value real
= b
.create
<complex::ReOp
>(adaptor
.getComplex());
649 Value imag
= b
.create
<complex::ImOp
>(adaptor
.getComplex());
651 Value half
= b
.create
<arith::ConstantOp
>(elementType
,
652 b
.getFloatAttr(elementType
, 0.5));
653 Value one
= b
.create
<arith::ConstantOp
>(elementType
,
654 b
.getFloatAttr(elementType
, 1));
655 Value realPlusOne
= b
.create
<arith::AddFOp
>(real
, one
, fmf
);
656 Value absRealPlusOne
= b
.create
<math::AbsFOp
>(realPlusOne
, fmf
);
657 Value absImag
= b
.create
<math::AbsFOp
>(imag
, fmf
);
659 Value maxAbs
= b
.create
<arith::MaximumFOp
>(absRealPlusOne
, absImag
, fmf
);
660 Value minAbs
= b
.create
<arith::MinimumFOp
>(absRealPlusOne
, absImag
, fmf
);
662 Value useReal
= b
.create
<arith::CmpFOp
>(arith::CmpFPredicate::OGT
,
663 realPlusOne
, absImag
, fmf
);
664 Value maxMinusOne
= b
.create
<arith::SubFOp
>(maxAbs
, one
, fmf
);
665 Value maxAbsOfRealPlusOneAndImagMinusOne
=
666 b
.create
<arith::SelectOp
>(useReal
, real
, maxMinusOne
);
667 arith::FastMathFlags fmfWithNaNInf
= arith::bitEnumClear(
668 fmf
, arith::FastMathFlags::nnan
| arith::FastMathFlags::ninf
);
669 Value minMaxRatio
= b
.create
<arith::DivFOp
>(minAbs
, maxAbs
, fmfWithNaNInf
);
670 Value logOfMaxAbsOfRealPlusOneAndImag
=
671 b
.create
<math::Log1pOp
>(maxAbsOfRealPlusOneAndImagMinusOne
, fmf
);
672 Value logOfSqrtPart
= b
.create
<math::Log1pOp
>(
673 b
.create
<arith::MulFOp
>(minMaxRatio
, minMaxRatio
, fmfWithNaNInf
),
675 Value r
= b
.create
<arith::AddFOp
>(
676 b
.create
<arith::MulFOp
>(half
, logOfSqrtPart
, fmfWithNaNInf
),
677 logOfMaxAbsOfRealPlusOneAndImag
, fmfWithNaNInf
);
678 Value resultReal
= b
.create
<arith::SelectOp
>(
679 b
.create
<arith::CmpFOp
>(arith::CmpFPredicate::UNO
, r
, r
, fmfWithNaNInf
),
681 Value resultImag
= b
.create
<math::Atan2Op
>(imag
, realPlusOne
, fmf
);
682 rewriter
.replaceOpWithNewOp
<complex::CreateOp
>(op
, type
, resultReal
,
688 struct MulOpConversion
: public OpConversionPattern
<complex::MulOp
> {
689 using OpConversionPattern
<complex::MulOp
>::OpConversionPattern
;
692 matchAndRewrite(complex::MulOp op
, OpAdaptor adaptor
,
693 ConversionPatternRewriter
&rewriter
) const override
{
694 mlir::ImplicitLocOpBuilder
b(op
.getLoc(), rewriter
);
695 auto type
= cast
<ComplexType
>(adaptor
.getLhs().getType());
696 auto elementType
= cast
<FloatType
>(type
.getElementType());
697 arith::FastMathFlagsAttr fmf
= op
.getFastMathFlagsAttr();
698 auto fmfValue
= fmf
.getValue();
700 Value lhsReal
= b
.create
<complex::ReOp
>(elementType
, adaptor
.getLhs());
701 Value lhsRealAbs
= b
.create
<math::AbsFOp
>(lhsReal
, fmfValue
);
702 Value lhsImag
= b
.create
<complex::ImOp
>(elementType
, adaptor
.getLhs());
703 Value lhsImagAbs
= b
.create
<math::AbsFOp
>(lhsImag
, fmfValue
);
704 Value rhsReal
= b
.create
<complex::ReOp
>(elementType
, adaptor
.getRhs());
705 Value rhsRealAbs
= b
.create
<math::AbsFOp
>(rhsReal
, fmfValue
);
706 Value rhsImag
= b
.create
<complex::ImOp
>(elementType
, adaptor
.getRhs());
707 Value rhsImagAbs
= b
.create
<math::AbsFOp
>(rhsImag
, fmfValue
);
709 Value lhsRealTimesRhsReal
=
710 b
.create
<arith::MulFOp
>(lhsReal
, rhsReal
, fmfValue
);
711 Value lhsRealTimesRhsRealAbs
=
712 b
.create
<math::AbsFOp
>(lhsRealTimesRhsReal
, fmfValue
);
713 Value lhsImagTimesRhsImag
=
714 b
.create
<arith::MulFOp
>(lhsImag
, rhsImag
, fmfValue
);
715 Value lhsImagTimesRhsImagAbs
=
716 b
.create
<math::AbsFOp
>(lhsImagTimesRhsImag
, fmfValue
);
717 Value real
= b
.create
<arith::SubFOp
>(lhsRealTimesRhsReal
,
718 lhsImagTimesRhsImag
, fmfValue
);
720 Value lhsImagTimesRhsReal
=
721 b
.create
<arith::MulFOp
>(lhsImag
, rhsReal
, fmfValue
);
722 Value lhsImagTimesRhsRealAbs
=
723 b
.create
<math::AbsFOp
>(lhsImagTimesRhsReal
, fmfValue
);
724 Value lhsRealTimesRhsImag
=
725 b
.create
<arith::MulFOp
>(lhsReal
, rhsImag
, fmfValue
);
726 Value lhsRealTimesRhsImagAbs
=
727 b
.create
<math::AbsFOp
>(lhsRealTimesRhsImag
, fmfValue
);
728 Value imag
= b
.create
<arith::AddFOp
>(lhsImagTimesRhsReal
,
729 lhsRealTimesRhsImag
, fmfValue
);
731 // Handle cases where the "naive" calculation results in NaN values.
733 b
.create
<arith::CmpFOp
>(arith::CmpFPredicate::UNO
, real
, real
);
735 b
.create
<arith::CmpFOp
>(arith::CmpFPredicate::UNO
, imag
, imag
);
736 Value isNan
= b
.create
<arith::AndIOp
>(realIsNan
, imagIsNan
);
738 Value inf
= b
.create
<arith::ConstantOp
>(
740 b
.getFloatAttr(elementType
,
741 APFloat::getInf(elementType
.getFloatSemantics())));
743 // Case 1. `lhsReal` or `lhsImag` are infinite.
745 b
.create
<arith::CmpFOp
>(arith::CmpFPredicate::OEQ
, lhsRealAbs
, inf
);
747 b
.create
<arith::CmpFOp
>(arith::CmpFPredicate::OEQ
, lhsImagAbs
, inf
);
748 Value lhsIsInf
= b
.create
<arith::OrIOp
>(lhsRealIsInf
, lhsImagIsInf
);
750 b
.create
<arith::CmpFOp
>(arith::CmpFPredicate::UNO
, rhsReal
, rhsReal
);
752 b
.create
<arith::CmpFOp
>(arith::CmpFPredicate::UNO
, rhsImag
, rhsImag
);
754 b
.create
<arith::ConstantOp
>(elementType
, b
.getZeroAttr(elementType
));
755 Value one
= b
.create
<arith::ConstantOp
>(elementType
,
756 b
.getFloatAttr(elementType
, 1));
757 Value lhsRealIsInfFloat
=
758 b
.create
<arith::SelectOp
>(lhsRealIsInf
, one
, zero
);
759 lhsReal
= b
.create
<arith::SelectOp
>(
760 lhsIsInf
, b
.create
<math::CopySignOp
>(lhsRealIsInfFloat
, lhsReal
),
762 Value lhsImagIsInfFloat
=
763 b
.create
<arith::SelectOp
>(lhsImagIsInf
, one
, zero
);
764 lhsImag
= b
.create
<arith::SelectOp
>(
765 lhsIsInf
, b
.create
<math::CopySignOp
>(lhsImagIsInfFloat
, lhsImag
),
767 Value lhsIsInfAndRhsRealIsNan
=
768 b
.create
<arith::AndIOp
>(lhsIsInf
, rhsRealIsNan
);
769 rhsReal
= b
.create
<arith::SelectOp
>(
770 lhsIsInfAndRhsRealIsNan
, b
.create
<math::CopySignOp
>(zero
, rhsReal
),
772 Value lhsIsInfAndRhsImagIsNan
=
773 b
.create
<arith::AndIOp
>(lhsIsInf
, rhsImagIsNan
);
774 rhsImag
= b
.create
<arith::SelectOp
>(
775 lhsIsInfAndRhsImagIsNan
, b
.create
<math::CopySignOp
>(zero
, rhsImag
),
778 // Case 2. `rhsReal` or `rhsImag` are infinite.
780 b
.create
<arith::CmpFOp
>(arith::CmpFPredicate::OEQ
, rhsRealAbs
, inf
);
782 b
.create
<arith::CmpFOp
>(arith::CmpFPredicate::OEQ
, rhsImagAbs
, inf
);
783 Value rhsIsInf
= b
.create
<arith::OrIOp
>(rhsRealIsInf
, rhsImagIsInf
);
785 b
.create
<arith::CmpFOp
>(arith::CmpFPredicate::UNO
, lhsReal
, lhsReal
);
787 b
.create
<arith::CmpFOp
>(arith::CmpFPredicate::UNO
, lhsImag
, lhsImag
);
788 Value rhsRealIsInfFloat
=
789 b
.create
<arith::SelectOp
>(rhsRealIsInf
, one
, zero
);
790 rhsReal
= b
.create
<arith::SelectOp
>(
791 rhsIsInf
, b
.create
<math::CopySignOp
>(rhsRealIsInfFloat
, rhsReal
),
793 Value rhsImagIsInfFloat
=
794 b
.create
<arith::SelectOp
>(rhsImagIsInf
, one
, zero
);
795 rhsImag
= b
.create
<arith::SelectOp
>(
796 rhsIsInf
, b
.create
<math::CopySignOp
>(rhsImagIsInfFloat
, rhsImag
),
798 Value rhsIsInfAndLhsRealIsNan
=
799 b
.create
<arith::AndIOp
>(rhsIsInf
, lhsRealIsNan
);
800 lhsReal
= b
.create
<arith::SelectOp
>(
801 rhsIsInfAndLhsRealIsNan
, b
.create
<math::CopySignOp
>(zero
, lhsReal
),
803 Value rhsIsInfAndLhsImagIsNan
=
804 b
.create
<arith::AndIOp
>(rhsIsInf
, lhsImagIsNan
);
805 lhsImag
= b
.create
<arith::SelectOp
>(
806 rhsIsInfAndLhsImagIsNan
, b
.create
<math::CopySignOp
>(zero
, lhsImag
),
808 Value recalc
= b
.create
<arith::OrIOp
>(lhsIsInf
, rhsIsInf
);
810 // Case 3. One of the pairwise products of left hand side with right hand
812 Value lhsRealTimesRhsRealIsInf
= b
.create
<arith::CmpFOp
>(
813 arith::CmpFPredicate::OEQ
, lhsRealTimesRhsRealAbs
, inf
);
814 Value lhsImagTimesRhsImagIsInf
= b
.create
<arith::CmpFOp
>(
815 arith::CmpFPredicate::OEQ
, lhsImagTimesRhsImagAbs
, inf
);
816 Value isSpecialCase
= b
.create
<arith::OrIOp
>(lhsRealTimesRhsRealIsInf
,
817 lhsImagTimesRhsImagIsInf
);
818 Value lhsRealTimesRhsImagIsInf
= b
.create
<arith::CmpFOp
>(
819 arith::CmpFPredicate::OEQ
, lhsRealTimesRhsImagAbs
, inf
);
821 b
.create
<arith::OrIOp
>(isSpecialCase
, lhsRealTimesRhsImagIsInf
);
822 Value lhsImagTimesRhsRealIsInf
= b
.create
<arith::CmpFOp
>(
823 arith::CmpFPredicate::OEQ
, lhsImagTimesRhsRealAbs
, inf
);
825 b
.create
<arith::OrIOp
>(isSpecialCase
, lhsImagTimesRhsRealIsInf
);
826 Type i1Type
= b
.getI1Type();
827 Value notRecalc
= b
.create
<arith::XOrIOp
>(
829 b
.create
<arith::ConstantOp
>(i1Type
, b
.getIntegerAttr(i1Type
, 1)));
830 isSpecialCase
= b
.create
<arith::AndIOp
>(isSpecialCase
, notRecalc
);
831 Value isSpecialCaseAndLhsRealIsNan
=
832 b
.create
<arith::AndIOp
>(isSpecialCase
, lhsRealIsNan
);
833 lhsReal
= b
.create
<arith::SelectOp
>(
834 isSpecialCaseAndLhsRealIsNan
, b
.create
<math::CopySignOp
>(zero
, lhsReal
),
836 Value isSpecialCaseAndLhsImagIsNan
=
837 b
.create
<arith::AndIOp
>(isSpecialCase
, lhsImagIsNan
);
838 lhsImag
= b
.create
<arith::SelectOp
>(
839 isSpecialCaseAndLhsImagIsNan
, b
.create
<math::CopySignOp
>(zero
, lhsImag
),
841 Value isSpecialCaseAndRhsRealIsNan
=
842 b
.create
<arith::AndIOp
>(isSpecialCase
, rhsRealIsNan
);
843 rhsReal
= b
.create
<arith::SelectOp
>(
844 isSpecialCaseAndRhsRealIsNan
, b
.create
<math::CopySignOp
>(zero
, rhsReal
),
846 Value isSpecialCaseAndRhsImagIsNan
=
847 b
.create
<arith::AndIOp
>(isSpecialCase
, rhsImagIsNan
);
848 rhsImag
= b
.create
<arith::SelectOp
>(
849 isSpecialCaseAndRhsImagIsNan
, b
.create
<math::CopySignOp
>(zero
, rhsImag
),
851 recalc
= b
.create
<arith::OrIOp
>(recalc
, isSpecialCase
);
852 recalc
= b
.create
<arith::AndIOp
>(isNan
, recalc
);
854 // Recalculate real part.
855 lhsRealTimesRhsReal
= b
.create
<arith::MulFOp
>(lhsReal
, rhsReal
, fmfValue
);
856 lhsImagTimesRhsImag
= b
.create
<arith::MulFOp
>(lhsImag
, rhsImag
, fmfValue
);
857 Value newReal
= b
.create
<arith::SubFOp
>(lhsRealTimesRhsReal
,
858 lhsImagTimesRhsImag
, fmfValue
);
859 real
= b
.create
<arith::SelectOp
>(
860 recalc
, b
.create
<arith::MulFOp
>(inf
, newReal
, fmfValue
), real
);
862 // Recalculate imag part.
863 lhsImagTimesRhsReal
= b
.create
<arith::MulFOp
>(lhsImag
, rhsReal
, fmfValue
);
864 lhsRealTimesRhsImag
= b
.create
<arith::MulFOp
>(lhsReal
, rhsImag
, fmfValue
);
865 Value newImag
= b
.create
<arith::AddFOp
>(lhsImagTimesRhsReal
,
866 lhsRealTimesRhsImag
, fmfValue
);
867 imag
= b
.create
<arith::SelectOp
>(
868 recalc
, b
.create
<arith::MulFOp
>(inf
, newImag
, fmfValue
), imag
);
870 rewriter
.replaceOpWithNewOp
<complex::CreateOp
>(op
, type
, real
, imag
);
875 struct NegOpConversion
: public OpConversionPattern
<complex::NegOp
> {
876 using OpConversionPattern
<complex::NegOp
>::OpConversionPattern
;
879 matchAndRewrite(complex::NegOp op
, OpAdaptor adaptor
,
880 ConversionPatternRewriter
&rewriter
) const override
{
881 auto loc
= op
.getLoc();
882 auto type
= cast
<ComplexType
>(adaptor
.getComplex().getType());
883 auto elementType
= cast
<FloatType
>(type
.getElementType());
886 rewriter
.create
<complex::ReOp
>(loc
, elementType
, adaptor
.getComplex());
888 rewriter
.create
<complex::ImOp
>(loc
, elementType
, adaptor
.getComplex());
889 Value negReal
= rewriter
.create
<arith::NegFOp
>(loc
, real
);
890 Value negImag
= rewriter
.create
<arith::NegFOp
>(loc
, imag
);
891 rewriter
.replaceOpWithNewOp
<complex::CreateOp
>(op
, type
, negReal
, negImag
);
896 struct SinOpConversion
: public TrigonometricOpConversion
<complex::SinOp
> {
897 using TrigonometricOpConversion
<complex::SinOp
>::TrigonometricOpConversion
;
899 std::pair
<Value
, Value
> combine(Location loc
, Value scaledExp
,
900 Value reciprocalExp
, Value sin
, Value cos
,
901 ConversionPatternRewriter
&rewriter
,
902 arith::FastMathFlagsAttr fmf
) const override
{
903 // Complex sine is defined as;
904 // sin(x + iy) = -0.5i * (exp(i(x + iy)) - exp(-i(x + iy)))
906 // exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x))
907 // exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x)))
908 // and defining t := exp(y)
910 // Re(sin(x + iy)) = (0.5*t + 0.5/t) * sin x
911 // Im(cos(x + iy)) = (0.5*t - 0.5/t) * cos x
913 rewriter
.create
<arith::AddFOp
>(loc
, scaledExp
, reciprocalExp
, fmf
);
914 Value resultReal
= rewriter
.create
<arith::MulFOp
>(loc
, sum
, sin
, fmf
);
916 rewriter
.create
<arith::SubFOp
>(loc
, scaledExp
, reciprocalExp
, fmf
);
917 Value resultImag
= rewriter
.create
<arith::MulFOp
>(loc
, diff
, cos
, fmf
);
918 return {resultReal
, resultImag
};
922 // The algorithm is listed in https://dl.acm.org/doi/pdf/10.1145/363717.363780.
923 struct SqrtOpConversion
: public OpConversionPattern
<complex::SqrtOp
> {
924 using OpConversionPattern
<complex::SqrtOp
>::OpConversionPattern
;
927 matchAndRewrite(complex::SqrtOp op
, OpAdaptor adaptor
,
928 ConversionPatternRewriter
&rewriter
) const override
{
929 ImplicitLocOpBuilder
b(op
.getLoc(), rewriter
);
931 auto type
= cast
<ComplexType
>(op
.getType());
932 auto elementType
= cast
<FloatType
>(type
.getElementType());
933 arith::FastMathFlags fmf
= op
.getFastMathFlagsAttr().getValue();
935 auto cst
= [&](APFloat v
) {
936 return b
.create
<arith::ConstantOp
>(elementType
,
937 b
.getFloatAttr(elementType
, v
));
939 const auto &floatSemantics
= elementType
.getFloatSemantics();
940 Value zero
= cst(APFloat::getZero(floatSemantics
));
941 Value half
= b
.create
<arith::ConstantOp
>(elementType
,
942 b
.getFloatAttr(elementType
, 0.5));
944 Value real
= b
.create
<complex::ReOp
>(elementType
, adaptor
.getComplex());
945 Value imag
= b
.create
<complex::ImOp
>(elementType
, adaptor
.getComplex());
946 Value absSqrt
= computeAbs(real
, imag
, fmf
, b
, AbsFn::sqrt
);
947 Value argArg
= b
.create
<math::Atan2Op
>(imag
, real
, fmf
);
948 Value sqrtArg
= b
.create
<arith::MulFOp
>(argArg
, half
, fmf
);
949 Value cos
= b
.create
<math::CosOp
>(sqrtArg
, fmf
);
950 Value sin
= b
.create
<math::SinOp
>(sqrtArg
, fmf
);
951 // sin(atan2(0, inf)) = 0, sqrt(abs(inf)) = inf, but we can't multiply
954 b
.create
<arith::CmpFOp
>(arith::CmpFPredicate::OEQ
, sin
, zero
, fmf
);
956 Value resultReal
= b
.create
<arith::MulFOp
>(absSqrt
, cos
, fmf
);
957 Value resultImag
= b
.create
<arith::SelectOp
>(
958 sinIsZero
, zero
, b
.create
<arith::MulFOp
>(absSqrt
, sin
, fmf
));
959 if (!arith::bitEnumContainsAll(fmf
, arith::FastMathFlags::nnan
|
960 arith::FastMathFlags::ninf
)) {
961 Value inf
= cst(APFloat::getInf(floatSemantics
));
962 Value negInf
= cst(APFloat::getInf(floatSemantics
, true));
963 Value nan
= cst(APFloat::getNaN(floatSemantics
));
964 Value absImag
= b
.create
<math::AbsFOp
>(elementType
, imag
, fmf
);
967 b
.create
<arith::CmpFOp
>(arith::CmpFPredicate::OEQ
, absImag
, inf
, fmf
);
968 Value absImagIsNotInf
=
969 b
.create
<arith::CmpFOp
>(arith::CmpFPredicate::ONE
, absImag
, inf
, fmf
);
971 b
.create
<arith::CmpFOp
>(arith::CmpFPredicate::OEQ
, real
, inf
, fmf
);
973 b
.create
<arith::CmpFOp
>(arith::CmpFPredicate::OEQ
, real
, negInf
, fmf
);
975 resultReal
= b
.create
<arith::SelectOp
>(
976 b
.create
<arith::AndIOp
>(realIsNegInf
, absImagIsNotInf
), zero
,
978 resultReal
= b
.create
<arith::SelectOp
>(
979 b
.create
<arith::OrIOp
>(absImagIsInf
, realIsInf
), inf
, resultReal
);
981 Value imagSignInf
= b
.create
<math::CopySignOp
>(inf
, imag
, fmf
);
982 resultImag
= b
.create
<arith::SelectOp
>(
983 b
.create
<arith::CmpFOp
>(arith::CmpFPredicate::UNO
, absSqrt
, absSqrt
),
985 resultImag
= b
.create
<arith::SelectOp
>(
986 b
.create
<arith::OrIOp
>(absImagIsInf
, realIsNegInf
), imagSignInf
,
991 b
.create
<arith::CmpFOp
>(arith::CmpFPredicate::OEQ
, absSqrt
, zero
, fmf
);
992 resultReal
= b
.create
<arith::SelectOp
>(resultIsZero
, zero
, resultReal
);
993 resultImag
= b
.create
<arith::SelectOp
>(resultIsZero
, zero
, resultImag
);
995 rewriter
.replaceOpWithNewOp
<complex::CreateOp
>(op
, type
, resultReal
,
1001 struct SignOpConversion
: public OpConversionPattern
<complex::SignOp
> {
1002 using OpConversionPattern
<complex::SignOp
>::OpConversionPattern
;
1005 matchAndRewrite(complex::SignOp op
, OpAdaptor adaptor
,
1006 ConversionPatternRewriter
&rewriter
) const override
{
1007 auto type
= cast
<ComplexType
>(adaptor
.getComplex().getType());
1008 auto elementType
= cast
<FloatType
>(type
.getElementType());
1009 mlir::ImplicitLocOpBuilder
b(op
.getLoc(), rewriter
);
1010 arith::FastMathFlagsAttr fmf
= op
.getFastMathFlagsAttr();
1012 Value real
= b
.create
<complex::ReOp
>(elementType
, adaptor
.getComplex());
1013 Value imag
= b
.create
<complex::ImOp
>(elementType
, adaptor
.getComplex());
1015 b
.create
<arith::ConstantOp
>(elementType
, b
.getZeroAttr(elementType
));
1017 b
.create
<arith::CmpFOp
>(arith::CmpFPredicate::OEQ
, real
, zero
);
1019 b
.create
<arith::CmpFOp
>(arith::CmpFPredicate::OEQ
, imag
, zero
);
1020 Value isZero
= b
.create
<arith::AndIOp
>(realIsZero
, imagIsZero
);
1021 auto abs
= b
.create
<complex::AbsOp
>(elementType
, adaptor
.getComplex(), fmf
);
1022 Value realSign
= b
.create
<arith::DivFOp
>(real
, abs
, fmf
);
1023 Value imagSign
= b
.create
<arith::DivFOp
>(imag
, abs
, fmf
);
1024 Value sign
= b
.create
<complex::CreateOp
>(type
, realSign
, imagSign
);
1025 rewriter
.replaceOpWithNewOp
<arith::SelectOp
>(op
, isZero
,
1026 adaptor
.getComplex(), sign
);
1031 template <typename Op
>
1032 struct TanTanhOpConversion
: public OpConversionPattern
<Op
> {
1033 using OpConversionPattern
<Op
>::OpConversionPattern
;
1036 matchAndRewrite(Op op
, typename
Op::Adaptor adaptor
,
1037 ConversionPatternRewriter
&rewriter
) const override
{
1038 ImplicitLocOpBuilder
b(op
.getLoc(), rewriter
);
1039 auto loc
= op
.getLoc();
1040 auto type
= cast
<ComplexType
>(adaptor
.getComplex().getType());
1041 auto elementType
= cast
<FloatType
>(type
.getElementType());
1042 arith::FastMathFlags fmf
= op
.getFastMathFlagsAttr().getValue();
1043 const auto &floatSemantics
= elementType
.getFloatSemantics();
1046 b
.create
<complex::ReOp
>(loc
, elementType
, adaptor
.getComplex());
1048 b
.create
<complex::ImOp
>(loc
, elementType
, adaptor
.getComplex());
1049 Value negOne
= b
.create
<arith::ConstantOp
>(
1050 elementType
, b
.getFloatAttr(elementType
, -1.0));
1052 if constexpr (std::is_same_v
<Op
, complex::TanOp
>) {
1053 // tan(x+yi) = -i*tanh(-y + xi)
1054 std::swap(real
, imag
);
1055 real
= b
.create
<arith::MulFOp
>(real
, negOne
, fmf
);
1058 auto cst
= [&](APFloat v
) {
1059 return b
.create
<arith::ConstantOp
>(elementType
,
1060 b
.getFloatAttr(elementType
, v
));
1062 Value inf
= cst(APFloat::getInf(floatSemantics
));
1063 Value four
= b
.create
<arith::ConstantOp
>(elementType
,
1064 b
.getFloatAttr(elementType
, 4.0));
1065 Value twoReal
= b
.create
<arith::AddFOp
>(real
, real
, fmf
);
1066 Value negTwoReal
= b
.create
<arith::MulFOp
>(negOne
, twoReal
, fmf
);
1068 Value expTwoRealMinusOne
= b
.create
<math::ExpM1Op
>(twoReal
, fmf
);
1069 Value expNegTwoRealMinusOne
= b
.create
<math::ExpM1Op
>(negTwoReal
, fmf
);
1071 b
.create
<arith::SubFOp
>(expTwoRealMinusOne
, expNegTwoRealMinusOne
, fmf
);
1073 Value cosImag
= b
.create
<math::CosOp
>(imag
, fmf
);
1074 Value cosImagSq
= b
.create
<arith::MulFOp
>(cosImag
, cosImag
, fmf
);
1075 Value twoCosTwoImagPlusOne
= b
.create
<arith::MulFOp
>(cosImagSq
, four
, fmf
);
1076 Value sinImag
= b
.create
<math::SinOp
>(imag
, fmf
);
1078 Value imagNum
= b
.create
<arith::MulFOp
>(
1079 four
, b
.create
<arith::MulFOp
>(cosImag
, sinImag
, fmf
), fmf
);
1081 Value expSumMinusTwo
=
1082 b
.create
<arith::AddFOp
>(expTwoRealMinusOne
, expNegTwoRealMinusOne
, fmf
);
1084 b
.create
<arith::AddFOp
>(expSumMinusTwo
, twoCosTwoImagPlusOne
, fmf
);
1086 Value isInf
= b
.create
<arith::CmpFOp
>(arith::CmpFPredicate::OEQ
,
1087 expSumMinusTwo
, inf
, fmf
);
1088 Value realLimit
= b
.create
<math::CopySignOp
>(negOne
, real
, fmf
);
1090 Value resultReal
= b
.create
<arith::SelectOp
>(
1091 isInf
, realLimit
, b
.create
<arith::DivFOp
>(realNum
, denom
, fmf
));
1092 Value resultImag
= b
.create
<arith::DivFOp
>(imagNum
, denom
, fmf
);
1094 if (!arith::bitEnumContainsAll(fmf
, arith::FastMathFlags::nnan
|
1095 arith::FastMathFlags::ninf
)) {
1096 Value absReal
= b
.create
<math::AbsFOp
>(real
, fmf
);
1097 Value zero
= b
.create
<arith::ConstantOp
>(
1098 elementType
, b
.getFloatAttr(elementType
, 0.0));
1099 Value nan
= cst(APFloat::getNaN(floatSemantics
));
1101 Value absRealIsInf
=
1102 b
.create
<arith::CmpFOp
>(arith::CmpFPredicate::OEQ
, absReal
, inf
, fmf
);
1104 b
.create
<arith::CmpFOp
>(arith::CmpFPredicate::OEQ
, imag
, zero
, fmf
);
1105 Value absRealIsNotInf
= b
.create
<arith::XOrIOp
>(
1106 absRealIsInf
, b
.create
<arith::ConstantIntOp
>(true, /*width=*/1));
1108 Value imagNumIsNaN
= b
.create
<arith::CmpFOp
>(arith::CmpFPredicate::UNO
,
1109 imagNum
, imagNum
, fmf
);
1110 Value resultRealIsNaN
=
1111 b
.create
<arith::AndIOp
>(imagNumIsNaN
, absRealIsNotInf
);
1112 Value resultImagIsZero
= b
.create
<arith::OrIOp
>(
1113 imagIsZero
, b
.create
<arith::AndIOp
>(absRealIsInf
, imagNumIsNaN
));
1115 resultReal
= b
.create
<arith::SelectOp
>(resultRealIsNaN
, nan
, resultReal
);
1117 b
.create
<arith::SelectOp
>(resultImagIsZero
, zero
, resultImag
);
1120 if constexpr (std::is_same_v
<Op
, complex::TanOp
>) {
1121 // tan(x+yi) = -i*tanh(-y + xi)
1122 std::swap(resultReal
, resultImag
);
1123 resultImag
= b
.create
<arith::MulFOp
>(resultImag
, negOne
, fmf
);
1126 rewriter
.replaceOpWithNewOp
<complex::CreateOp
>(op
, type
, resultReal
,
1132 struct ConjOpConversion
: public OpConversionPattern
<complex::ConjOp
> {
1133 using OpConversionPattern
<complex::ConjOp
>::OpConversionPattern
;
1136 matchAndRewrite(complex::ConjOp op
, OpAdaptor adaptor
,
1137 ConversionPatternRewriter
&rewriter
) const override
{
1138 auto loc
= op
.getLoc();
1139 auto type
= cast
<ComplexType
>(adaptor
.getComplex().getType());
1140 auto elementType
= cast
<FloatType
>(type
.getElementType());
1142 rewriter
.create
<complex::ReOp
>(loc
, elementType
, adaptor
.getComplex());
1144 rewriter
.create
<complex::ImOp
>(loc
, elementType
, adaptor
.getComplex());
1145 Value negImag
= rewriter
.create
<arith::NegFOp
>(loc
, elementType
, imag
);
1147 rewriter
.replaceOpWithNewOp
<complex::CreateOp
>(op
, type
, real
, negImag
);
1153 /// Converts lhs^y = (a+bi)^(c+di) to
1154 /// (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)),
1155 /// where q = c*atan2(b,a)+0.5d*ln(a*a+b*b)
1156 static Value
powOpConversionImpl(mlir::ImplicitLocOpBuilder
&builder
,
1157 ComplexType type
, Value lhs
, Value c
, Value d
,
1158 arith::FastMathFlags fmf
) {
1159 auto elementType
= cast
<FloatType
>(type
.getElementType());
1161 Value a
= builder
.create
<complex::ReOp
>(lhs
);
1162 Value b
= builder
.create
<complex::ImOp
>(lhs
);
1164 Value abs
= builder
.create
<complex::AbsOp
>(lhs
, fmf
);
1165 Value absToC
= builder
.create
<math::PowFOp
>(abs
, c
, fmf
);
1167 Value negD
= builder
.create
<arith::NegFOp
>(d
, fmf
);
1168 Value argLhs
= builder
.create
<math::Atan2Op
>(b
, a
, fmf
);
1169 Value negDArgLhs
= builder
.create
<arith::MulFOp
>(negD
, argLhs
, fmf
);
1170 Value expNegDArgLhs
= builder
.create
<math::ExpOp
>(negDArgLhs
, fmf
);
1172 Value coeff
= builder
.create
<arith::MulFOp
>(absToC
, expNegDArgLhs
, fmf
);
1173 Value lnAbs
= builder
.create
<math::LogOp
>(abs
, fmf
);
1174 Value cArgLhs
= builder
.create
<arith::MulFOp
>(c
, argLhs
, fmf
);
1175 Value dLnAbs
= builder
.create
<arith::MulFOp
>(d
, lnAbs
, fmf
);
1176 Value q
= builder
.create
<arith::AddFOp
>(cArgLhs
, dLnAbs
, fmf
);
1177 Value cosQ
= builder
.create
<math::CosOp
>(q
, fmf
);
1178 Value sinQ
= builder
.create
<math::SinOp
>(q
, fmf
);
1180 Value inf
= builder
.create
<arith::ConstantOp
>(
1182 builder
.getFloatAttr(elementType
,
1183 APFloat::getInf(elementType
.getFloatSemantics())));
1184 Value zero
= builder
.create
<arith::ConstantOp
>(
1185 elementType
, builder
.getFloatAttr(elementType
, 0.0));
1186 Value one
= builder
.create
<arith::ConstantOp
>(
1187 elementType
, builder
.getFloatAttr(elementType
, 1.0));
1188 Value complexOne
= builder
.create
<complex::CreateOp
>(type
, one
, zero
);
1189 Value complexZero
= builder
.create
<complex::CreateOp
>(type
, zero
, zero
);
1190 Value complexInf
= builder
.create
<complex::CreateOp
>(type
, inf
, zero
);
1193 // d^c is 0 if d is 0 and c > 0. 0^0 is defined to be 1.0, see
1194 // Branch Cuts for Complex Elementary Functions or Much Ado About
1195 // Nothing's Sign Bit, W. Kahan, Section 10.
1197 builder
.create
<arith::CmpFOp
>(arith::CmpFPredicate::OEQ
, abs
, zero
, fmf
);
1199 builder
.create
<arith::CmpFOp
>(arith::CmpFPredicate::OEQ
, d
, zero
, fmf
);
1201 builder
.create
<arith::CmpFOp
>(arith::CmpFPredicate::OEQ
, c
, zero
, fmf
);
1203 builder
.create
<arith::CmpFOp
>(arith::CmpFPredicate::OEQ
, b
, zero
, fmf
);
1206 builder
.create
<arith::CmpFOp
>(arith::CmpFPredicate::OLE
, zero
, c
, fmf
);
1207 Value coeffCosQ
= builder
.create
<arith::MulFOp
>(coeff
, cosQ
, fmf
);
1208 Value coeffSinQ
= builder
.create
<arith::MulFOp
>(coeff
, sinQ
, fmf
);
1209 Value complexOneOrZero
=
1210 builder
.create
<arith::SelectOp
>(cEqZero
, complexOne
, complexZero
);
1212 builder
.create
<complex::CreateOp
>(type
, coeffCosQ
, coeffSinQ
);
1213 Value cutoff0
= builder
.create
<arith::SelectOp
>(
1214 builder
.create
<arith::AndIOp
>(
1215 builder
.create
<arith::AndIOp
>(absEqZero
, dEqZero
), zeroLeC
),
1216 complexOneOrZero
, coeffCosSin
);
1219 // x^0 is defined to be 1 for any x, see
1220 // Branch Cuts for Complex Elementary Functions or Much Ado About
1221 // Nothing's Sign Bit, W. Kahan, Section 10.
1222 Value rhsEqZero
= builder
.create
<arith::AndIOp
>(cEqZero
, dEqZero
);
1224 builder
.create
<arith::SelectOp
>(rhsEqZero
, complexOne
, cutoff0
);
1227 // 1^(c + d*i) = 1 + 0*i
1228 Value lhsEqOne
= builder
.create
<arith::AndIOp
>(
1229 builder
.create
<arith::CmpFOp
>(arith::CmpFPredicate::OEQ
, a
, one
, fmf
),
1232 builder
.create
<arith::SelectOp
>(lhsEqOne
, complexOne
, cutoff1
);
1235 // inf^(c + 0*i) = inf + 0*i, c > 0
1236 Value lhsEqInf
= builder
.create
<arith::AndIOp
>(
1237 builder
.create
<arith::CmpFOp
>(arith::CmpFPredicate::OEQ
, a
, inf
, fmf
),
1239 Value rhsGt0
= builder
.create
<arith::AndIOp
>(
1241 builder
.create
<arith::CmpFOp
>(arith::CmpFPredicate::OGT
, c
, zero
, fmf
));
1242 Value cutoff3
= builder
.create
<arith::SelectOp
>(
1243 builder
.create
<arith::AndIOp
>(lhsEqInf
, rhsGt0
), complexInf
, cutoff2
);
1246 // inf^(c + 0*i) = 0 + 0*i, c < 0
1247 Value rhsLt0
= builder
.create
<arith::AndIOp
>(
1249 builder
.create
<arith::CmpFOp
>(arith::CmpFPredicate::OLT
, c
, zero
, fmf
));
1250 Value cutoff4
= builder
.create
<arith::SelectOp
>(
1251 builder
.create
<arith::AndIOp
>(lhsEqInf
, rhsLt0
), complexZero
, cutoff3
);
1256 struct PowOpConversion
: public OpConversionPattern
<complex::PowOp
> {
1257 using OpConversionPattern
<complex::PowOp
>::OpConversionPattern
;
1260 matchAndRewrite(complex::PowOp op
, OpAdaptor adaptor
,
1261 ConversionPatternRewriter
&rewriter
) const override
{
1262 mlir::ImplicitLocOpBuilder
builder(op
.getLoc(), rewriter
);
1263 auto type
= cast
<ComplexType
>(adaptor
.getLhs().getType());
1264 auto elementType
= cast
<FloatType
>(type
.getElementType());
1266 Value c
= builder
.create
<complex::ReOp
>(elementType
, adaptor
.getRhs());
1267 Value d
= builder
.create
<complex::ImOp
>(elementType
, adaptor
.getRhs());
1269 rewriter
.replaceOp(op
, {powOpConversionImpl(builder
, type
, adaptor
.getLhs(),
1270 c
, d
, op
.getFastmath())});
1275 struct RsqrtOpConversion
: public OpConversionPattern
<complex::RsqrtOp
> {
1276 using OpConversionPattern
<complex::RsqrtOp
>::OpConversionPattern
;
1279 matchAndRewrite(complex::RsqrtOp op
, OpAdaptor adaptor
,
1280 ConversionPatternRewriter
&rewriter
) const override
{
1281 mlir::ImplicitLocOpBuilder
b(op
.getLoc(), rewriter
);
1282 auto type
= cast
<ComplexType
>(adaptor
.getComplex().getType());
1283 auto elementType
= cast
<FloatType
>(type
.getElementType());
1285 arith::FastMathFlags fmf
= op
.getFastMathFlagsAttr().getValue();
1287 auto cst
= [&](APFloat v
) {
1288 return b
.create
<arith::ConstantOp
>(elementType
,
1289 b
.getFloatAttr(elementType
, v
));
1291 const auto &floatSemantics
= elementType
.getFloatSemantics();
1292 Value zero
= cst(APFloat::getZero(floatSemantics
));
1293 Value inf
= cst(APFloat::getInf(floatSemantics
));
1294 Value negHalf
= b
.create
<arith::ConstantOp
>(
1295 elementType
, b
.getFloatAttr(elementType
, -0.5));
1296 Value nan
= cst(APFloat::getNaN(floatSemantics
));
1298 Value real
= b
.create
<complex::ReOp
>(elementType
, adaptor
.getComplex());
1299 Value imag
= b
.create
<complex::ImOp
>(elementType
, adaptor
.getComplex());
1300 Value absRsqrt
= computeAbs(real
, imag
, fmf
, b
, AbsFn::rsqrt
);
1301 Value argArg
= b
.create
<math::Atan2Op
>(imag
, real
, fmf
);
1302 Value rsqrtArg
= b
.create
<arith::MulFOp
>(argArg
, negHalf
, fmf
);
1303 Value cos
= b
.create
<math::CosOp
>(rsqrtArg
, fmf
);
1304 Value sin
= b
.create
<math::SinOp
>(rsqrtArg
, fmf
);
1306 Value resultReal
= b
.create
<arith::MulFOp
>(absRsqrt
, cos
, fmf
);
1307 Value resultImag
= b
.create
<arith::MulFOp
>(absRsqrt
, sin
, fmf
);
1309 if (!arith::bitEnumContainsAll(fmf
, arith::FastMathFlags::nnan
|
1310 arith::FastMathFlags::ninf
)) {
1311 Value negOne
= b
.create
<arith::ConstantOp
>(
1312 elementType
, b
.getFloatAttr(elementType
, -1));
1314 Value realSignedZero
= b
.create
<math::CopySignOp
>(zero
, real
, fmf
);
1315 Value imagSignedZero
= b
.create
<math::CopySignOp
>(zero
, imag
, fmf
);
1316 Value negImagSignedZero
=
1317 b
.create
<arith::MulFOp
>(negOne
, imagSignedZero
, fmf
);
1319 Value absReal
= b
.create
<math::AbsFOp
>(real
, fmf
);
1320 Value absImag
= b
.create
<math::AbsFOp
>(imag
, fmf
);
1322 Value absImagIsInf
=
1323 b
.create
<arith::CmpFOp
>(arith::CmpFPredicate::OEQ
, absImag
, inf
, fmf
);
1325 b
.create
<arith::CmpFOp
>(arith::CmpFPredicate::UNO
, real
, real
, fmf
);
1327 b
.create
<arith::CmpFOp
>(arith::CmpFPredicate::OEQ
, absReal
, inf
, fmf
);
1328 Value inIsNanInf
= b
.create
<arith::AndIOp
>(absImagIsInf
, realIsNan
);
1330 Value resultIsZero
= b
.create
<arith::OrIOp
>(inIsNanInf
, realIsInf
);
1333 b
.create
<arith::SelectOp
>(resultIsZero
, realSignedZero
, resultReal
);
1334 resultImag
= b
.create
<arith::SelectOp
>(resultIsZero
, negImagSignedZero
,
1339 b
.create
<arith::CmpFOp
>(arith::CmpFPredicate::OEQ
, real
, zero
, fmf
);
1341 b
.create
<arith::CmpFOp
>(arith::CmpFPredicate::OEQ
, imag
, zero
, fmf
);
1342 Value isZero
= b
.create
<arith::AndIOp
>(isRealZero
, isImagZero
);
1344 resultReal
= b
.create
<arith::SelectOp
>(isZero
, inf
, resultReal
);
1345 resultImag
= b
.create
<arith::SelectOp
>(isZero
, nan
, resultImag
);
1347 rewriter
.replaceOpWithNewOp
<complex::CreateOp
>(op
, type
, resultReal
,
1353 struct AngleOpConversion
: public OpConversionPattern
<complex::AngleOp
> {
1354 using OpConversionPattern
<complex::AngleOp
>::OpConversionPattern
;
1357 matchAndRewrite(complex::AngleOp op
, OpAdaptor adaptor
,
1358 ConversionPatternRewriter
&rewriter
) const override
{
1359 auto loc
= op
.getLoc();
1360 auto type
= op
.getType();
1361 arith::FastMathFlagsAttr fmf
= op
.getFastMathFlagsAttr();
1364 rewriter
.create
<complex::ReOp
>(loc
, type
, adaptor
.getComplex());
1366 rewriter
.create
<complex::ImOp
>(loc
, type
, adaptor
.getComplex());
1368 rewriter
.replaceOpWithNewOp
<math::Atan2Op
>(op
, imag
, real
, fmf
);
1376 void mlir::populateComplexToStandardConversionPatterns(
1377 RewritePatternSet
&patterns
) {
1383 BinaryComplexOpConversion
<complex::AddOp
, arith::AddFOp
>,
1384 BinaryComplexOpConversion
<complex::SubOp
, arith::SubFOp
>,
1385 ComparisonOpConversion
<complex::EqualOp
, arith::CmpFPredicate::OEQ
>,
1386 ComparisonOpConversion
<complex::NotEqualOp
, arith::CmpFPredicate::UNE
>,
1399 TanTanhOpConversion
<complex::TanOp
>,
1400 TanTanhOpConversion
<complex::TanhOp
>,
1403 >(patterns
.getContext());
1408 struct ConvertComplexToStandardPass
1409 : public impl::ConvertComplexToStandardBase
<ConvertComplexToStandardPass
> {
1410 void runOnOperation() override
;
1413 void ConvertComplexToStandardPass::runOnOperation() {
1414 // Convert to the Standard dialect using the converter defined above.
1415 RewritePatternSet
patterns(&getContext());
1416 populateComplexToStandardConversionPatterns(patterns
);
1418 ConversionTarget
target(getContext());
1419 target
.addLegalDialect
<arith::ArithDialect
, math::MathDialect
>();
1420 target
.addLegalOp
<complex::CreateOp
, complex::ImOp
, complex::ReOp
>();
1422 applyPartialConversion(getOperation(), target
, std::move(patterns
))))
1423 signalPassFailure();
1427 std::unique_ptr
<Pass
> mlir::createConvertComplexToStandardPass() {
1428 return std::make_unique
<ConvertComplexToStandardPass
>();