[rtsan] Remove mkfifoat interceptor (#116997)
[llvm-project.git] / mlir / lib / Conversion / ComplexToStandard / ComplexToStandard.cpp
blob807beebe4fb22a72c47fb84694d1b7ad042e95ab
1 //===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
11 #include "mlir/Dialect/Arith/IR/Arith.h"
12 #include "mlir/Dialect/Complex/IR/Complex.h"
13 #include "mlir/Dialect/Math/IR/Math.h"
14 #include "mlir/IR/ImplicitLocOpBuilder.h"
15 #include "mlir/IR/PatternMatch.h"
16 #include "mlir/Pass/Pass.h"
17 #include "mlir/Transforms/DialectConversion.h"
18 #include <memory>
19 #include <type_traits>
21 namespace mlir {
22 #define GEN_PASS_DEF_CONVERTCOMPLEXTOSTANDARD
23 #include "mlir/Conversion/Passes.h.inc"
24 } // namespace mlir
26 using namespace mlir;
28 namespace {
30 enum class AbsFn { abs, sqrt, rsqrt };
32 // Returns the absolute value, its square root or its reciprocal square root.
33 Value computeAbs(Value real, Value imag, arith::FastMathFlags fmf,
34 ImplicitLocOpBuilder &b, AbsFn fn = AbsFn::abs) {
35 Value one = b.create<arith::ConstantOp>(real.getType(),
36 b.getFloatAttr(real.getType(), 1.0));
38 Value absReal = b.create<math::AbsFOp>(real, fmf);
39 Value absImag = b.create<math::AbsFOp>(imag, fmf);
41 Value max = b.create<arith::MaximumFOp>(absReal, absImag, fmf);
42 Value min = b.create<arith::MinimumFOp>(absReal, absImag, fmf);
44 // The lowering below requires NaNs and infinities to work correctly.
45 arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear(
46 fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf);
47 Value ratio = b.create<arith::DivFOp>(min, max, fmfWithNaNInf);
48 Value ratioSq = b.create<arith::MulFOp>(ratio, ratio, fmfWithNaNInf);
49 Value ratioSqPlusOne = b.create<arith::AddFOp>(ratioSq, one, fmfWithNaNInf);
50 Value result;
52 if (fn == AbsFn::rsqrt) {
53 ratioSqPlusOne = b.create<math::RsqrtOp>(ratioSqPlusOne, fmfWithNaNInf);
54 min = b.create<math::RsqrtOp>(min, fmfWithNaNInf);
55 max = b.create<math::RsqrtOp>(max, fmfWithNaNInf);
58 if (fn == AbsFn::sqrt) {
59 Value quarter = b.create<arith::ConstantOp>(
60 real.getType(), b.getFloatAttr(real.getType(), 0.25));
61 // sqrt(sqrt(a*b)) would avoid the pow, but will overflow more easily.
62 Value sqrt = b.create<math::SqrtOp>(max, fmfWithNaNInf);
63 Value p025 = b.create<math::PowFOp>(ratioSqPlusOne, quarter, fmfWithNaNInf);
64 result = b.create<arith::MulFOp>(sqrt, p025, fmfWithNaNInf);
65 } else {
66 Value sqrt = b.create<math::SqrtOp>(ratioSqPlusOne, fmfWithNaNInf);
67 result = b.create<arith::MulFOp>(max, sqrt, fmfWithNaNInf);
70 Value isNaN = b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, result,
71 result, fmfWithNaNInf);
72 return b.create<arith::SelectOp>(isNaN, min, result);
75 struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
76 using OpConversionPattern<complex::AbsOp>::OpConversionPattern;
78 LogicalResult
79 matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
80 ConversionPatternRewriter &rewriter) const override {
81 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
83 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
85 Value real = b.create<complex::ReOp>(adaptor.getComplex());
86 Value imag = b.create<complex::ImOp>(adaptor.getComplex());
87 rewriter.replaceOp(op, computeAbs(real, imag, fmf, b));
89 return success();
93 // atan2(y,x) = -i * log((x + i * y)/sqrt(x**2+y**2))
94 struct Atan2OpConversion : public OpConversionPattern<complex::Atan2Op> {
95 using OpConversionPattern<complex::Atan2Op>::OpConversionPattern;
97 LogicalResult
98 matchAndRewrite(complex::Atan2Op op, OpAdaptor adaptor,
99 ConversionPatternRewriter &rewriter) const override {
100 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
102 auto type = cast<ComplexType>(op.getType());
103 Type elementType = type.getElementType();
104 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
106 Value lhs = adaptor.getLhs();
107 Value rhs = adaptor.getRhs();
109 Value rhsSquared = b.create<complex::MulOp>(type, rhs, rhs, fmf);
110 Value lhsSquared = b.create<complex::MulOp>(type, lhs, lhs, fmf);
111 Value rhsSquaredPlusLhsSquared =
112 b.create<complex::AddOp>(type, rhsSquared, lhsSquared, fmf);
113 Value sqrtOfRhsSquaredPlusLhsSquared =
114 b.create<complex::SqrtOp>(type, rhsSquaredPlusLhsSquared, fmf);
116 Value zero =
117 b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
118 Value one = b.create<arith::ConstantOp>(elementType,
119 b.getFloatAttr(elementType, 1));
120 Value i = b.create<complex::CreateOp>(type, zero, one);
121 Value iTimesLhs = b.create<complex::MulOp>(i, lhs, fmf);
122 Value rhsPlusILhs = b.create<complex::AddOp>(rhs, iTimesLhs, fmf);
124 Value divResult = b.create<complex::DivOp>(
125 rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared, fmf);
126 Value logResult = b.create<complex::LogOp>(divResult, fmf);
128 Value negativeOne = b.create<arith::ConstantOp>(
129 elementType, b.getFloatAttr(elementType, -1));
130 Value negativeI = b.create<complex::CreateOp>(type, zero, negativeOne);
132 rewriter.replaceOpWithNewOp<complex::MulOp>(op, negativeI, logResult, fmf);
133 return success();
137 template <typename ComparisonOp, arith::CmpFPredicate p>
138 struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> {
139 using OpConversionPattern<ComparisonOp>::OpConversionPattern;
140 using ResultCombiner =
141 std::conditional_t<std::is_same<ComparisonOp, complex::EqualOp>::value,
142 arith::AndIOp, arith::OrIOp>;
144 LogicalResult
145 matchAndRewrite(ComparisonOp op, typename ComparisonOp::Adaptor adaptor,
146 ConversionPatternRewriter &rewriter) const override {
147 auto loc = op.getLoc();
148 auto type = cast<ComplexType>(adaptor.getLhs().getType()).getElementType();
150 Value realLhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getLhs());
151 Value imagLhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getLhs());
152 Value realRhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getRhs());
153 Value imagRhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getRhs());
154 Value realComparison =
155 rewriter.create<arith::CmpFOp>(loc, p, realLhs, realRhs);
156 Value imagComparison =
157 rewriter.create<arith::CmpFOp>(loc, p, imagLhs, imagRhs);
159 rewriter.replaceOpWithNewOp<ResultCombiner>(op, realComparison,
160 imagComparison);
161 return success();
165 // Default conversion which applies the BinaryStandardOp separately on the real
166 // and imaginary parts. Can for example be used for complex::AddOp and
167 // complex::SubOp.
168 template <typename BinaryComplexOp, typename BinaryStandardOp>
169 struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> {
170 using OpConversionPattern<BinaryComplexOp>::OpConversionPattern;
172 LogicalResult
173 matchAndRewrite(BinaryComplexOp op, typename BinaryComplexOp::Adaptor adaptor,
174 ConversionPatternRewriter &rewriter) const override {
175 auto type = cast<ComplexType>(adaptor.getLhs().getType());
176 auto elementType = cast<FloatType>(type.getElementType());
177 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
178 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
180 Value realLhs = b.create<complex::ReOp>(elementType, adaptor.getLhs());
181 Value realRhs = b.create<complex::ReOp>(elementType, adaptor.getRhs());
182 Value resultReal = b.create<BinaryStandardOp>(elementType, realLhs, realRhs,
183 fmf.getValue());
184 Value imagLhs = b.create<complex::ImOp>(elementType, adaptor.getLhs());
185 Value imagRhs = b.create<complex::ImOp>(elementType, adaptor.getRhs());
186 Value resultImag = b.create<BinaryStandardOp>(elementType, imagLhs, imagRhs,
187 fmf.getValue());
188 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
189 resultImag);
190 return success();
194 template <typename TrigonometricOp>
195 struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> {
196 using OpAdaptor = typename OpConversionPattern<TrigonometricOp>::OpAdaptor;
198 using OpConversionPattern<TrigonometricOp>::OpConversionPattern;
200 LogicalResult
201 matchAndRewrite(TrigonometricOp op, OpAdaptor adaptor,
202 ConversionPatternRewriter &rewriter) const override {
203 auto loc = op.getLoc();
204 auto type = cast<ComplexType>(adaptor.getComplex().getType());
205 auto elementType = cast<FloatType>(type.getElementType());
206 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
208 Value real =
209 rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
210 Value imag =
211 rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
213 // Trigonometric ops use a set of common building blocks to convert to real
214 // ops. Here we create these building blocks and call into an op-specific
215 // implementation in the subclass to combine them.
216 Value half = rewriter.create<arith::ConstantOp>(
217 loc, elementType, rewriter.getFloatAttr(elementType, 0.5));
218 Value exp = rewriter.create<math::ExpOp>(loc, imag, fmf);
219 Value scaledExp = rewriter.create<arith::MulFOp>(loc, half, exp, fmf);
220 Value reciprocalExp = rewriter.create<arith::DivFOp>(loc, half, exp, fmf);
221 Value sin = rewriter.create<math::SinOp>(loc, real, fmf);
222 Value cos = rewriter.create<math::CosOp>(loc, real, fmf);
224 auto resultPair =
225 combine(loc, scaledExp, reciprocalExp, sin, cos, rewriter, fmf);
227 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultPair.first,
228 resultPair.second);
229 return success();
232 virtual std::pair<Value, Value>
233 combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin,
234 Value cos, ConversionPatternRewriter &rewriter,
235 arith::FastMathFlagsAttr fmf) const = 0;
238 struct CosOpConversion : public TrigonometricOpConversion<complex::CosOp> {
239 using TrigonometricOpConversion<complex::CosOp>::TrigonometricOpConversion;
241 std::pair<Value, Value> combine(Location loc, Value scaledExp,
242 Value reciprocalExp, Value sin, Value cos,
243 ConversionPatternRewriter &rewriter,
244 arith::FastMathFlagsAttr fmf) const override {
245 // Complex cosine is defined as;
246 // cos(x + iy) = 0.5 * (exp(i(x + iy)) + exp(-i(x + iy)))
247 // Plugging in:
248 // exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x))
249 // exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x)))
250 // and defining t := exp(y)
251 // We get:
252 // Re(cos(x + iy)) = (0.5/t + 0.5*t) * cos x
253 // Im(cos(x + iy)) = (0.5/t - 0.5*t) * sin x
254 Value sum =
255 rewriter.create<arith::AddFOp>(loc, reciprocalExp, scaledExp, fmf);
256 Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, cos, fmf);
257 Value diff =
258 rewriter.create<arith::SubFOp>(loc, reciprocalExp, scaledExp, fmf);
259 Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, sin, fmf);
260 return {resultReal, resultImag};
264 struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
265 using OpConversionPattern<complex::DivOp>::OpConversionPattern;
267 LogicalResult
268 matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
269 ConversionPatternRewriter &rewriter) const override {
270 auto loc = op.getLoc();
271 auto type = cast<ComplexType>(adaptor.getLhs().getType());
272 auto elementType = cast<FloatType>(type.getElementType());
273 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
275 Value lhsReal =
276 rewriter.create<complex::ReOp>(loc, elementType, adaptor.getLhs());
277 Value lhsImag =
278 rewriter.create<complex::ImOp>(loc, elementType, adaptor.getLhs());
279 Value rhsReal =
280 rewriter.create<complex::ReOp>(loc, elementType, adaptor.getRhs());
281 Value rhsImag =
282 rewriter.create<complex::ImOp>(loc, elementType, adaptor.getRhs());
284 // Smith's algorithm to divide complex numbers. It is just a bit smarter
285 // way to compute the following formula:
286 // (lhsReal + lhsImag * i) / (rhsReal + rhsImag * i)
287 // = (lhsReal + lhsImag * i) (rhsReal - rhsImag * i) /
288 // ((rhsReal + rhsImag * i)(rhsReal - rhsImag * i))
289 // = ((lhsReal * rhsReal + lhsImag * rhsImag) +
290 // (lhsImag * rhsReal - lhsReal * rhsImag) * i) / ||rhs||^2
292 // Depending on whether |rhsReal| < |rhsImag| we compute either
293 // rhsRealImagRatio = rhsReal / rhsImag
294 // rhsRealImagDenom = rhsImag + rhsReal * rhsRealImagRatio
295 // resultReal = (lhsReal * rhsRealImagRatio + lhsImag) / rhsRealImagDenom
296 // resultImag = (lhsImag * rhsRealImagRatio - lhsReal) / rhsRealImagDenom
298 // or
300 // rhsImagRealRatio = rhsImag / rhsReal
301 // rhsImagRealDenom = rhsReal + rhsImag * rhsImagRealRatio
302 // resultReal = (lhsReal + lhsImag * rhsImagRealRatio) / rhsImagRealDenom
303 // resultImag = (lhsImag - lhsReal * rhsImagRealRatio) / rhsImagRealDenom
305 // See https://dl.acm.org/citation.cfm?id=368661 for more details.
306 Value rhsRealImagRatio =
307 rewriter.create<arith::DivFOp>(loc, rhsReal, rhsImag, fmf);
308 Value rhsRealImagDenom = rewriter.create<arith::AddFOp>(
309 loc, rhsImag,
310 rewriter.create<arith::MulFOp>(loc, rhsRealImagRatio, rhsReal, fmf),
311 fmf);
312 Value realNumerator1 = rewriter.create<arith::AddFOp>(
313 loc,
314 rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealImagRatio, fmf),
315 lhsImag, fmf);
316 Value resultReal1 = rewriter.create<arith::DivFOp>(loc, realNumerator1,
317 rhsRealImagDenom, fmf);
318 Value imagNumerator1 = rewriter.create<arith::SubFOp>(
319 loc,
320 rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealImagRatio, fmf),
321 lhsReal, fmf);
322 Value resultImag1 = rewriter.create<arith::DivFOp>(loc, imagNumerator1,
323 rhsRealImagDenom, fmf);
325 Value rhsImagRealRatio =
326 rewriter.create<arith::DivFOp>(loc, rhsImag, rhsReal, fmf);
327 Value rhsImagRealDenom = rewriter.create<arith::AddFOp>(
328 loc, rhsReal,
329 rewriter.create<arith::MulFOp>(loc, rhsImagRealRatio, rhsImag, fmf),
330 fmf);
331 Value realNumerator2 = rewriter.create<arith::AddFOp>(
332 loc, lhsReal,
333 rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagRealRatio, fmf),
334 fmf);
335 Value resultReal2 = rewriter.create<arith::DivFOp>(loc, realNumerator2,
336 rhsImagRealDenom, fmf);
337 Value imagNumerator2 = rewriter.create<arith::SubFOp>(
338 loc, lhsImag,
339 rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagRealRatio, fmf),
340 fmf);
341 Value resultImag2 = rewriter.create<arith::DivFOp>(loc, imagNumerator2,
342 rhsImagRealDenom, fmf);
344 // Consider corner cases.
345 // Case 1. Zero denominator, numerator contains at most one NaN value.
346 Value zero = rewriter.create<arith::ConstantOp>(
347 loc, elementType, rewriter.getZeroAttr(elementType));
348 Value rhsRealAbs = rewriter.create<math::AbsFOp>(loc, rhsReal, fmf);
349 Value rhsRealIsZero = rewriter.create<arith::CmpFOp>(
350 loc, arith::CmpFPredicate::OEQ, rhsRealAbs, zero);
351 Value rhsImagAbs = rewriter.create<math::AbsFOp>(loc, rhsImag, fmf);
352 Value rhsImagIsZero = rewriter.create<arith::CmpFOp>(
353 loc, arith::CmpFPredicate::OEQ, rhsImagAbs, zero);
354 Value lhsRealIsNotNaN = rewriter.create<arith::CmpFOp>(
355 loc, arith::CmpFPredicate::ORD, lhsReal, zero);
356 Value lhsImagIsNotNaN = rewriter.create<arith::CmpFOp>(
357 loc, arith::CmpFPredicate::ORD, lhsImag, zero);
358 Value lhsContainsNotNaNValue =
359 rewriter.create<arith::OrIOp>(loc, lhsRealIsNotNaN, lhsImagIsNotNaN);
360 Value resultIsInfinity = rewriter.create<arith::AndIOp>(
361 loc, lhsContainsNotNaNValue,
362 rewriter.create<arith::AndIOp>(loc, rhsRealIsZero, rhsImagIsZero));
363 Value inf = rewriter.create<arith::ConstantOp>(
364 loc, elementType,
365 rewriter.getFloatAttr(
366 elementType, APFloat::getInf(elementType.getFloatSemantics())));
367 Value infWithSignOfRhsReal =
368 rewriter.create<math::CopySignOp>(loc, inf, rhsReal);
369 Value infinityResultReal =
370 rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsReal, fmf);
371 Value infinityResultImag =
372 rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsImag, fmf);
374 // Case 2. Infinite numerator, finite denominator.
375 Value rhsRealFinite = rewriter.create<arith::CmpFOp>(
376 loc, arith::CmpFPredicate::ONE, rhsRealAbs, inf);
377 Value rhsImagFinite = rewriter.create<arith::CmpFOp>(
378 loc, arith::CmpFPredicate::ONE, rhsImagAbs, inf);
379 Value rhsFinite =
380 rewriter.create<arith::AndIOp>(loc, rhsRealFinite, rhsImagFinite);
381 Value lhsRealAbs = rewriter.create<math::AbsFOp>(loc, lhsReal, fmf);
382 Value lhsRealInfinite = rewriter.create<arith::CmpFOp>(
383 loc, arith::CmpFPredicate::OEQ, lhsRealAbs, inf);
384 Value lhsImagAbs = rewriter.create<math::AbsFOp>(loc, lhsImag, fmf);
385 Value lhsImagInfinite = rewriter.create<arith::CmpFOp>(
386 loc, arith::CmpFPredicate::OEQ, lhsImagAbs, inf);
387 Value lhsInfinite =
388 rewriter.create<arith::OrIOp>(loc, lhsRealInfinite, lhsImagInfinite);
389 Value infNumFiniteDenom =
390 rewriter.create<arith::AndIOp>(loc, lhsInfinite, rhsFinite);
391 Value one = rewriter.create<arith::ConstantOp>(
392 loc, elementType, rewriter.getFloatAttr(elementType, 1));
393 Value lhsRealIsInfWithSign = rewriter.create<math::CopySignOp>(
394 loc, rewriter.create<arith::SelectOp>(loc, lhsRealInfinite, one, zero),
395 lhsReal);
396 Value lhsImagIsInfWithSign = rewriter.create<math::CopySignOp>(
397 loc, rewriter.create<arith::SelectOp>(loc, lhsImagInfinite, one, zero),
398 lhsImag);
399 Value lhsRealIsInfWithSignTimesRhsReal =
400 rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsReal, fmf);
401 Value lhsImagIsInfWithSignTimesRhsImag =
402 rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsImag, fmf);
403 Value resultReal3 = rewriter.create<arith::MulFOp>(
404 loc, inf,
405 rewriter.create<arith::AddFOp>(loc, lhsRealIsInfWithSignTimesRhsReal,
406 lhsImagIsInfWithSignTimesRhsImag, fmf),
407 fmf);
408 Value lhsRealIsInfWithSignTimesRhsImag =
409 rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsImag, fmf);
410 Value lhsImagIsInfWithSignTimesRhsReal =
411 rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsReal, fmf);
412 Value resultImag3 = rewriter.create<arith::MulFOp>(
413 loc, inf,
414 rewriter.create<arith::SubFOp>(loc, lhsImagIsInfWithSignTimesRhsReal,
415 lhsRealIsInfWithSignTimesRhsImag, fmf),
416 fmf);
418 // Case 3: Finite numerator, infinite denominator.
419 Value lhsRealFinite = rewriter.create<arith::CmpFOp>(
420 loc, arith::CmpFPredicate::ONE, lhsRealAbs, inf);
421 Value lhsImagFinite = rewriter.create<arith::CmpFOp>(
422 loc, arith::CmpFPredicate::ONE, lhsImagAbs, inf);
423 Value lhsFinite =
424 rewriter.create<arith::AndIOp>(loc, lhsRealFinite, lhsImagFinite);
425 Value rhsRealInfinite = rewriter.create<arith::CmpFOp>(
426 loc, arith::CmpFPredicate::OEQ, rhsRealAbs, inf);
427 Value rhsImagInfinite = rewriter.create<arith::CmpFOp>(
428 loc, arith::CmpFPredicate::OEQ, rhsImagAbs, inf);
429 Value rhsInfinite =
430 rewriter.create<arith::OrIOp>(loc, rhsRealInfinite, rhsImagInfinite);
431 Value finiteNumInfiniteDenom =
432 rewriter.create<arith::AndIOp>(loc, lhsFinite, rhsInfinite);
433 Value rhsRealIsInfWithSign = rewriter.create<math::CopySignOp>(
434 loc, rewriter.create<arith::SelectOp>(loc, rhsRealInfinite, one, zero),
435 rhsReal);
436 Value rhsImagIsInfWithSign = rewriter.create<math::CopySignOp>(
437 loc, rewriter.create<arith::SelectOp>(loc, rhsImagInfinite, one, zero),
438 rhsImag);
439 Value rhsRealIsInfWithSignTimesLhsReal =
440 rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealIsInfWithSign, fmf);
441 Value rhsImagIsInfWithSignTimesLhsImag =
442 rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagIsInfWithSign, fmf);
443 Value resultReal4 = rewriter.create<arith::MulFOp>(
444 loc, zero,
445 rewriter.create<arith::AddFOp>(loc, rhsRealIsInfWithSignTimesLhsReal,
446 rhsImagIsInfWithSignTimesLhsImag, fmf),
447 fmf);
448 Value rhsRealIsInfWithSignTimesLhsImag =
449 rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealIsInfWithSign, fmf);
450 Value rhsImagIsInfWithSignTimesLhsReal =
451 rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagIsInfWithSign, fmf);
452 Value resultImag4 = rewriter.create<arith::MulFOp>(
453 loc, zero,
454 rewriter.create<arith::SubFOp>(loc, rhsRealIsInfWithSignTimesLhsImag,
455 rhsImagIsInfWithSignTimesLhsReal, fmf),
456 fmf);
458 Value realAbsSmallerThanImagAbs = rewriter.create<arith::CmpFOp>(
459 loc, arith::CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs);
460 Value resultReal = rewriter.create<arith::SelectOp>(
461 loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2);
462 Value resultImag = rewriter.create<arith::SelectOp>(
463 loc, realAbsSmallerThanImagAbs, resultImag1, resultImag2);
464 Value resultRealSpecialCase3 = rewriter.create<arith::SelectOp>(
465 loc, finiteNumInfiniteDenom, resultReal4, resultReal);
466 Value resultImagSpecialCase3 = rewriter.create<arith::SelectOp>(
467 loc, finiteNumInfiniteDenom, resultImag4, resultImag);
468 Value resultRealSpecialCase2 = rewriter.create<arith::SelectOp>(
469 loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3);
470 Value resultImagSpecialCase2 = rewriter.create<arith::SelectOp>(
471 loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3);
472 Value resultRealSpecialCase1 = rewriter.create<arith::SelectOp>(
473 loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2);
474 Value resultImagSpecialCase1 = rewriter.create<arith::SelectOp>(
475 loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2);
477 Value resultRealIsNaN = rewriter.create<arith::CmpFOp>(
478 loc, arith::CmpFPredicate::UNO, resultReal, zero);
479 Value resultImagIsNaN = rewriter.create<arith::CmpFOp>(
480 loc, arith::CmpFPredicate::UNO, resultImag, zero);
481 Value resultIsNaN =
482 rewriter.create<arith::AndIOp>(loc, resultRealIsNaN, resultImagIsNaN);
483 Value resultRealWithSpecialCases = rewriter.create<arith::SelectOp>(
484 loc, resultIsNaN, resultRealSpecialCase1, resultReal);
485 Value resultImagWithSpecialCases = rewriter.create<arith::SelectOp>(
486 loc, resultIsNaN, resultImagSpecialCase1, resultImag);
488 rewriter.replaceOpWithNewOp<complex::CreateOp>(
489 op, type, resultRealWithSpecialCases, resultImagWithSpecialCases);
490 return success();
494 struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> {
495 using OpConversionPattern<complex::ExpOp>::OpConversionPattern;
497 LogicalResult
498 matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor,
499 ConversionPatternRewriter &rewriter) const override {
500 auto loc = op.getLoc();
501 auto type = cast<ComplexType>(adaptor.getComplex().getType());
502 auto elementType = cast<FloatType>(type.getElementType());
503 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
505 Value real =
506 rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
507 Value imag =
508 rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
509 Value expReal = rewriter.create<math::ExpOp>(loc, real, fmf.getValue());
510 Value cosImag = rewriter.create<math::CosOp>(loc, imag, fmf.getValue());
511 Value resultReal =
512 rewriter.create<arith::MulFOp>(loc, expReal, cosImag, fmf.getValue());
513 Value sinImag = rewriter.create<math::SinOp>(loc, imag, fmf.getValue());
514 Value resultImag =
515 rewriter.create<arith::MulFOp>(loc, expReal, sinImag, fmf.getValue());
517 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
518 resultImag);
519 return success();
523 Value evaluatePolynomial(ImplicitLocOpBuilder &b, Value arg,
524 ArrayRef<double> coefficients,
525 arith::FastMathFlagsAttr fmf) {
526 auto argType = mlir::cast<FloatType>(arg.getType());
527 Value poly =
528 b.create<arith::ConstantOp>(b.getFloatAttr(argType, coefficients[0]));
529 for (unsigned i = 1; i < coefficients.size(); ++i) {
530 poly = b.create<math::FmaOp>(
531 poly, arg,
532 b.create<arith::ConstantOp>(b.getFloatAttr(argType, coefficients[i])),
533 fmf);
535 return poly;
538 struct Expm1OpConversion : public OpConversionPattern<complex::Expm1Op> {
539 using OpConversionPattern<complex::Expm1Op>::OpConversionPattern;
541 // e^(a+bi)-1 = (e^a*cos(b)-1)+e^a*sin(b)i
542 // [handle inaccuracies when a and/or b are small]
543 // = ((e^a - 1) * cos(b) + cos(b) - 1) + e^a*sin(b)i
544 // = (expm1(a) * cos(b) + cosm1(b)) + e^a*sin(b)i
545 LogicalResult
546 matchAndRewrite(complex::Expm1Op op, OpAdaptor adaptor,
547 ConversionPatternRewriter &rewriter) const override {
548 auto type = op.getType();
549 auto elemType = mlir::cast<FloatType>(type.getElementType());
551 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
552 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
553 Value real = b.create<complex::ReOp>(adaptor.getComplex());
554 Value imag = b.create<complex::ImOp>(adaptor.getComplex());
556 Value zero = b.create<arith::ConstantOp>(b.getFloatAttr(elemType, 0.0));
557 Value one = b.create<arith::ConstantOp>(b.getFloatAttr(elemType, 1.0));
559 Value expm1Real = b.create<math::ExpM1Op>(real, fmf);
560 Value expReal = b.create<arith::AddFOp>(expm1Real, one, fmf);
562 Value sinImag = b.create<math::SinOp>(imag, fmf);
563 Value cosm1Imag = emitCosm1(imag, fmf, b);
564 Value cosImag = b.create<arith::AddFOp>(cosm1Imag, one, fmf);
566 Value realResult = b.create<arith::AddFOp>(
567 b.create<arith::MulFOp>(expm1Real, cosImag, fmf), cosm1Imag, fmf);
569 Value imagIsZero = b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag,
570 zero, fmf.getValue());
571 Value imagResult = b.create<arith::SelectOp>(
572 imagIsZero, zero, b.create<arith::MulFOp>(expReal, sinImag, fmf));
574 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, realResult,
575 imagResult);
576 return success();
579 private:
580 Value emitCosm1(Value arg, arith::FastMathFlagsAttr fmf,
581 ImplicitLocOpBuilder &b) const {
582 auto argType = mlir::cast<FloatType>(arg.getType());
583 auto negHalf = b.create<arith::ConstantOp>(b.getFloatAttr(argType, -0.5));
584 auto negOne = b.create<arith::ConstantOp>(b.getFloatAttr(argType, -1.0));
586 // Algorithm copied from cephes cosm1.
587 SmallVector<double, 7> kCoeffs{
588 4.7377507964246204691685E-14, -1.1470284843425359765671E-11,
589 2.0876754287081521758361E-9, -2.7557319214999787979814E-7,
590 2.4801587301570552304991E-5, -1.3888888888888872993737E-3,
591 4.1666666666666666609054E-2,
593 Value cos = b.create<math::CosOp>(arg, fmf);
594 Value forLargeArg = b.create<arith::AddFOp>(cos, negOne, fmf);
596 Value argPow2 = b.create<arith::MulFOp>(arg, arg, fmf);
597 Value argPow4 = b.create<arith::MulFOp>(argPow2, argPow2, fmf);
598 Value poly = evaluatePolynomial(b, argPow2, kCoeffs, fmf);
600 auto forSmallArg =
601 b.create<arith::AddFOp>(b.create<arith::MulFOp>(argPow4, poly, fmf),
602 b.create<arith::MulFOp>(negHalf, argPow2, fmf));
604 // (pi/4)^2 is approximately 0.61685
605 Value piOver4Pow2 =
606 b.create<arith::ConstantOp>(b.getFloatAttr(argType, 0.61685));
607 Value cond = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, argPow2,
608 piOver4Pow2, fmf.getValue());
609 return b.create<arith::SelectOp>(cond, forLargeArg, forSmallArg);
613 struct LogOpConversion : public OpConversionPattern<complex::LogOp> {
614 using OpConversionPattern<complex::LogOp>::OpConversionPattern;
616 LogicalResult
617 matchAndRewrite(complex::LogOp op, OpAdaptor adaptor,
618 ConversionPatternRewriter &rewriter) const override {
619 auto type = cast<ComplexType>(adaptor.getComplex().getType());
620 auto elementType = cast<FloatType>(type.getElementType());
621 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
622 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
624 Value abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex(),
625 fmf.getValue());
626 Value resultReal = b.create<math::LogOp>(elementType, abs, fmf.getValue());
627 Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
628 Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
629 Value resultImag =
630 b.create<math::Atan2Op>(elementType, imag, real, fmf.getValue());
631 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
632 resultImag);
633 return success();
637 struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> {
638 using OpConversionPattern<complex::Log1pOp>::OpConversionPattern;
640 LogicalResult
641 matchAndRewrite(complex::Log1pOp op, OpAdaptor adaptor,
642 ConversionPatternRewriter &rewriter) const override {
643 auto type = cast<ComplexType>(adaptor.getComplex().getType());
644 auto elementType = cast<FloatType>(type.getElementType());
645 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
646 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
648 Value real = b.create<complex::ReOp>(adaptor.getComplex());
649 Value imag = b.create<complex::ImOp>(adaptor.getComplex());
651 Value half = b.create<arith::ConstantOp>(elementType,
652 b.getFloatAttr(elementType, 0.5));
653 Value one = b.create<arith::ConstantOp>(elementType,
654 b.getFloatAttr(elementType, 1));
655 Value realPlusOne = b.create<arith::AddFOp>(real, one, fmf);
656 Value absRealPlusOne = b.create<math::AbsFOp>(realPlusOne, fmf);
657 Value absImag = b.create<math::AbsFOp>(imag, fmf);
659 Value maxAbs = b.create<arith::MaximumFOp>(absRealPlusOne, absImag, fmf);
660 Value minAbs = b.create<arith::MinimumFOp>(absRealPlusOne, absImag, fmf);
662 Value useReal = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGT,
663 realPlusOne, absImag, fmf);
664 Value maxMinusOne = b.create<arith::SubFOp>(maxAbs, one, fmf);
665 Value maxAbsOfRealPlusOneAndImagMinusOne =
666 b.create<arith::SelectOp>(useReal, real, maxMinusOne);
667 arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear(
668 fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf);
669 Value minMaxRatio = b.create<arith::DivFOp>(minAbs, maxAbs, fmfWithNaNInf);
670 Value logOfMaxAbsOfRealPlusOneAndImag =
671 b.create<math::Log1pOp>(maxAbsOfRealPlusOneAndImagMinusOne, fmf);
672 Value logOfSqrtPart = b.create<math::Log1pOp>(
673 b.create<arith::MulFOp>(minMaxRatio, minMaxRatio, fmfWithNaNInf),
674 fmfWithNaNInf);
675 Value r = b.create<arith::AddFOp>(
676 b.create<arith::MulFOp>(half, logOfSqrtPart, fmfWithNaNInf),
677 logOfMaxAbsOfRealPlusOneAndImag, fmfWithNaNInf);
678 Value resultReal = b.create<arith::SelectOp>(
679 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, r, r, fmfWithNaNInf),
680 minAbs, r);
681 Value resultImag = b.create<math::Atan2Op>(imag, realPlusOne, fmf);
682 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
683 resultImag);
684 return success();
688 struct MulOpConversion : public OpConversionPattern<complex::MulOp> {
689 using OpConversionPattern<complex::MulOp>::OpConversionPattern;
691 LogicalResult
692 matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
693 ConversionPatternRewriter &rewriter) const override {
694 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
695 auto type = cast<ComplexType>(adaptor.getLhs().getType());
696 auto elementType = cast<FloatType>(type.getElementType());
697 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
698 auto fmfValue = fmf.getValue();
700 Value lhsReal = b.create<complex::ReOp>(elementType, adaptor.getLhs());
701 Value lhsRealAbs = b.create<math::AbsFOp>(lhsReal, fmfValue);
702 Value lhsImag = b.create<complex::ImOp>(elementType, adaptor.getLhs());
703 Value lhsImagAbs = b.create<math::AbsFOp>(lhsImag, fmfValue);
704 Value rhsReal = b.create<complex::ReOp>(elementType, adaptor.getRhs());
705 Value rhsRealAbs = b.create<math::AbsFOp>(rhsReal, fmfValue);
706 Value rhsImag = b.create<complex::ImOp>(elementType, adaptor.getRhs());
707 Value rhsImagAbs = b.create<math::AbsFOp>(rhsImag, fmfValue);
709 Value lhsRealTimesRhsReal =
710 b.create<arith::MulFOp>(lhsReal, rhsReal, fmfValue);
711 Value lhsRealTimesRhsRealAbs =
712 b.create<math::AbsFOp>(lhsRealTimesRhsReal, fmfValue);
713 Value lhsImagTimesRhsImag =
714 b.create<arith::MulFOp>(lhsImag, rhsImag, fmfValue);
715 Value lhsImagTimesRhsImagAbs =
716 b.create<math::AbsFOp>(lhsImagTimesRhsImag, fmfValue);
717 Value real = b.create<arith::SubFOp>(lhsRealTimesRhsReal,
718 lhsImagTimesRhsImag, fmfValue);
720 Value lhsImagTimesRhsReal =
721 b.create<arith::MulFOp>(lhsImag, rhsReal, fmfValue);
722 Value lhsImagTimesRhsRealAbs =
723 b.create<math::AbsFOp>(lhsImagTimesRhsReal, fmfValue);
724 Value lhsRealTimesRhsImag =
725 b.create<arith::MulFOp>(lhsReal, rhsImag, fmfValue);
726 Value lhsRealTimesRhsImagAbs =
727 b.create<math::AbsFOp>(lhsRealTimesRhsImag, fmfValue);
728 Value imag = b.create<arith::AddFOp>(lhsImagTimesRhsReal,
729 lhsRealTimesRhsImag, fmfValue);
731 // Handle cases where the "naive" calculation results in NaN values.
732 Value realIsNan =
733 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, real, real);
734 Value imagIsNan =
735 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, imag, imag);
736 Value isNan = b.create<arith::AndIOp>(realIsNan, imagIsNan);
738 Value inf = b.create<arith::ConstantOp>(
739 elementType,
740 b.getFloatAttr(elementType,
741 APFloat::getInf(elementType.getFloatSemantics())));
743 // Case 1. `lhsReal` or `lhsImag` are infinite.
744 Value lhsRealIsInf =
745 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsRealAbs, inf);
746 Value lhsImagIsInf =
747 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsImagAbs, inf);
748 Value lhsIsInf = b.create<arith::OrIOp>(lhsRealIsInf, lhsImagIsInf);
749 Value rhsRealIsNan =
750 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsReal, rhsReal);
751 Value rhsImagIsNan =
752 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsImag, rhsImag);
753 Value zero =
754 b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
755 Value one = b.create<arith::ConstantOp>(elementType,
756 b.getFloatAttr(elementType, 1));
757 Value lhsRealIsInfFloat =
758 b.create<arith::SelectOp>(lhsRealIsInf, one, zero);
759 lhsReal = b.create<arith::SelectOp>(
760 lhsIsInf, b.create<math::CopySignOp>(lhsRealIsInfFloat, lhsReal),
761 lhsReal);
762 Value lhsImagIsInfFloat =
763 b.create<arith::SelectOp>(lhsImagIsInf, one, zero);
764 lhsImag = b.create<arith::SelectOp>(
765 lhsIsInf, b.create<math::CopySignOp>(lhsImagIsInfFloat, lhsImag),
766 lhsImag);
767 Value lhsIsInfAndRhsRealIsNan =
768 b.create<arith::AndIOp>(lhsIsInf, rhsRealIsNan);
769 rhsReal = b.create<arith::SelectOp>(
770 lhsIsInfAndRhsRealIsNan, b.create<math::CopySignOp>(zero, rhsReal),
771 rhsReal);
772 Value lhsIsInfAndRhsImagIsNan =
773 b.create<arith::AndIOp>(lhsIsInf, rhsImagIsNan);
774 rhsImag = b.create<arith::SelectOp>(
775 lhsIsInfAndRhsImagIsNan, b.create<math::CopySignOp>(zero, rhsImag),
776 rhsImag);
778 // Case 2. `rhsReal` or `rhsImag` are infinite.
779 Value rhsRealIsInf =
780 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsRealAbs, inf);
781 Value rhsImagIsInf =
782 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsImagAbs, inf);
783 Value rhsIsInf = b.create<arith::OrIOp>(rhsRealIsInf, rhsImagIsInf);
784 Value lhsRealIsNan =
785 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsReal, lhsReal);
786 Value lhsImagIsNan =
787 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsImag, lhsImag);
788 Value rhsRealIsInfFloat =
789 b.create<arith::SelectOp>(rhsRealIsInf, one, zero);
790 rhsReal = b.create<arith::SelectOp>(
791 rhsIsInf, b.create<math::CopySignOp>(rhsRealIsInfFloat, rhsReal),
792 rhsReal);
793 Value rhsImagIsInfFloat =
794 b.create<arith::SelectOp>(rhsImagIsInf, one, zero);
795 rhsImag = b.create<arith::SelectOp>(
796 rhsIsInf, b.create<math::CopySignOp>(rhsImagIsInfFloat, rhsImag),
797 rhsImag);
798 Value rhsIsInfAndLhsRealIsNan =
799 b.create<arith::AndIOp>(rhsIsInf, lhsRealIsNan);
800 lhsReal = b.create<arith::SelectOp>(
801 rhsIsInfAndLhsRealIsNan, b.create<math::CopySignOp>(zero, lhsReal),
802 lhsReal);
803 Value rhsIsInfAndLhsImagIsNan =
804 b.create<arith::AndIOp>(rhsIsInf, lhsImagIsNan);
805 lhsImag = b.create<arith::SelectOp>(
806 rhsIsInfAndLhsImagIsNan, b.create<math::CopySignOp>(zero, lhsImag),
807 lhsImag);
808 Value recalc = b.create<arith::OrIOp>(lhsIsInf, rhsIsInf);
810 // Case 3. One of the pairwise products of left hand side with right hand
811 // side is infinite.
812 Value lhsRealTimesRhsRealIsInf = b.create<arith::CmpFOp>(
813 arith::CmpFPredicate::OEQ, lhsRealTimesRhsRealAbs, inf);
814 Value lhsImagTimesRhsImagIsInf = b.create<arith::CmpFOp>(
815 arith::CmpFPredicate::OEQ, lhsImagTimesRhsImagAbs, inf);
816 Value isSpecialCase = b.create<arith::OrIOp>(lhsRealTimesRhsRealIsInf,
817 lhsImagTimesRhsImagIsInf);
818 Value lhsRealTimesRhsImagIsInf = b.create<arith::CmpFOp>(
819 arith::CmpFPredicate::OEQ, lhsRealTimesRhsImagAbs, inf);
820 isSpecialCase =
821 b.create<arith::OrIOp>(isSpecialCase, lhsRealTimesRhsImagIsInf);
822 Value lhsImagTimesRhsRealIsInf = b.create<arith::CmpFOp>(
823 arith::CmpFPredicate::OEQ, lhsImagTimesRhsRealAbs, inf);
824 isSpecialCase =
825 b.create<arith::OrIOp>(isSpecialCase, lhsImagTimesRhsRealIsInf);
826 Type i1Type = b.getI1Type();
827 Value notRecalc = b.create<arith::XOrIOp>(
828 recalc,
829 b.create<arith::ConstantOp>(i1Type, b.getIntegerAttr(i1Type, 1)));
830 isSpecialCase = b.create<arith::AndIOp>(isSpecialCase, notRecalc);
831 Value isSpecialCaseAndLhsRealIsNan =
832 b.create<arith::AndIOp>(isSpecialCase, lhsRealIsNan);
833 lhsReal = b.create<arith::SelectOp>(
834 isSpecialCaseAndLhsRealIsNan, b.create<math::CopySignOp>(zero, lhsReal),
835 lhsReal);
836 Value isSpecialCaseAndLhsImagIsNan =
837 b.create<arith::AndIOp>(isSpecialCase, lhsImagIsNan);
838 lhsImag = b.create<arith::SelectOp>(
839 isSpecialCaseAndLhsImagIsNan, b.create<math::CopySignOp>(zero, lhsImag),
840 lhsImag);
841 Value isSpecialCaseAndRhsRealIsNan =
842 b.create<arith::AndIOp>(isSpecialCase, rhsRealIsNan);
843 rhsReal = b.create<arith::SelectOp>(
844 isSpecialCaseAndRhsRealIsNan, b.create<math::CopySignOp>(zero, rhsReal),
845 rhsReal);
846 Value isSpecialCaseAndRhsImagIsNan =
847 b.create<arith::AndIOp>(isSpecialCase, rhsImagIsNan);
848 rhsImag = b.create<arith::SelectOp>(
849 isSpecialCaseAndRhsImagIsNan, b.create<math::CopySignOp>(zero, rhsImag),
850 rhsImag);
851 recalc = b.create<arith::OrIOp>(recalc, isSpecialCase);
852 recalc = b.create<arith::AndIOp>(isNan, recalc);
854 // Recalculate real part.
855 lhsRealTimesRhsReal = b.create<arith::MulFOp>(lhsReal, rhsReal, fmfValue);
856 lhsImagTimesRhsImag = b.create<arith::MulFOp>(lhsImag, rhsImag, fmfValue);
857 Value newReal = b.create<arith::SubFOp>(lhsRealTimesRhsReal,
858 lhsImagTimesRhsImag, fmfValue);
859 real = b.create<arith::SelectOp>(
860 recalc, b.create<arith::MulFOp>(inf, newReal, fmfValue), real);
862 // Recalculate imag part.
863 lhsImagTimesRhsReal = b.create<arith::MulFOp>(lhsImag, rhsReal, fmfValue);
864 lhsRealTimesRhsImag = b.create<arith::MulFOp>(lhsReal, rhsImag, fmfValue);
865 Value newImag = b.create<arith::AddFOp>(lhsImagTimesRhsReal,
866 lhsRealTimesRhsImag, fmfValue);
867 imag = b.create<arith::SelectOp>(
868 recalc, b.create<arith::MulFOp>(inf, newImag, fmfValue), imag);
870 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, imag);
871 return success();
875 struct NegOpConversion : public OpConversionPattern<complex::NegOp> {
876 using OpConversionPattern<complex::NegOp>::OpConversionPattern;
878 LogicalResult
879 matchAndRewrite(complex::NegOp op, OpAdaptor adaptor,
880 ConversionPatternRewriter &rewriter) const override {
881 auto loc = op.getLoc();
882 auto type = cast<ComplexType>(adaptor.getComplex().getType());
883 auto elementType = cast<FloatType>(type.getElementType());
885 Value real =
886 rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
887 Value imag =
888 rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
889 Value negReal = rewriter.create<arith::NegFOp>(loc, real);
890 Value negImag = rewriter.create<arith::NegFOp>(loc, imag);
891 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, negReal, negImag);
892 return success();
896 struct SinOpConversion : public TrigonometricOpConversion<complex::SinOp> {
897 using TrigonometricOpConversion<complex::SinOp>::TrigonometricOpConversion;
899 std::pair<Value, Value> combine(Location loc, Value scaledExp,
900 Value reciprocalExp, Value sin, Value cos,
901 ConversionPatternRewriter &rewriter,
902 arith::FastMathFlagsAttr fmf) const override {
903 // Complex sine is defined as;
904 // sin(x + iy) = -0.5i * (exp(i(x + iy)) - exp(-i(x + iy)))
905 // Plugging in:
906 // exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x))
907 // exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x)))
908 // and defining t := exp(y)
909 // We get:
910 // Re(sin(x + iy)) = (0.5*t + 0.5/t) * sin x
911 // Im(cos(x + iy)) = (0.5*t - 0.5/t) * cos x
912 Value sum =
913 rewriter.create<arith::AddFOp>(loc, scaledExp, reciprocalExp, fmf);
914 Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, sin, fmf);
915 Value diff =
916 rewriter.create<arith::SubFOp>(loc, scaledExp, reciprocalExp, fmf);
917 Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, cos, fmf);
918 return {resultReal, resultImag};
922 // The algorithm is listed in https://dl.acm.org/doi/pdf/10.1145/363717.363780.
923 struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> {
924 using OpConversionPattern<complex::SqrtOp>::OpConversionPattern;
926 LogicalResult
927 matchAndRewrite(complex::SqrtOp op, OpAdaptor adaptor,
928 ConversionPatternRewriter &rewriter) const override {
929 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
931 auto type = cast<ComplexType>(op.getType());
932 auto elementType = cast<FloatType>(type.getElementType());
933 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
935 auto cst = [&](APFloat v) {
936 return b.create<arith::ConstantOp>(elementType,
937 b.getFloatAttr(elementType, v));
939 const auto &floatSemantics = elementType.getFloatSemantics();
940 Value zero = cst(APFloat::getZero(floatSemantics));
941 Value half = b.create<arith::ConstantOp>(elementType,
942 b.getFloatAttr(elementType, 0.5));
944 Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
945 Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
946 Value absSqrt = computeAbs(real, imag, fmf, b, AbsFn::sqrt);
947 Value argArg = b.create<math::Atan2Op>(imag, real, fmf);
948 Value sqrtArg = b.create<arith::MulFOp>(argArg, half, fmf);
949 Value cos = b.create<math::CosOp>(sqrtArg, fmf);
950 Value sin = b.create<math::SinOp>(sqrtArg, fmf);
951 // sin(atan2(0, inf)) = 0, sqrt(abs(inf)) = inf, but we can't multiply
952 // 0 * inf.
953 Value sinIsZero =
954 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, sin, zero, fmf);
956 Value resultReal = b.create<arith::MulFOp>(absSqrt, cos, fmf);
957 Value resultImag = b.create<arith::SelectOp>(
958 sinIsZero, zero, b.create<arith::MulFOp>(absSqrt, sin, fmf));
959 if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
960 arith::FastMathFlags::ninf)) {
961 Value inf = cst(APFloat::getInf(floatSemantics));
962 Value negInf = cst(APFloat::getInf(floatSemantics, true));
963 Value nan = cst(APFloat::getNaN(floatSemantics));
964 Value absImag = b.create<math::AbsFOp>(elementType, imag, fmf);
966 Value absImagIsInf =
967 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absImag, inf, fmf);
968 Value absImagIsNotInf =
969 b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, absImag, inf, fmf);
970 Value realIsInf =
971 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, inf, fmf);
972 Value realIsNegInf =
973 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, negInf, fmf);
975 resultReal = b.create<arith::SelectOp>(
976 b.create<arith::AndIOp>(realIsNegInf, absImagIsNotInf), zero,
977 resultReal);
978 resultReal = b.create<arith::SelectOp>(
979 b.create<arith::OrIOp>(absImagIsInf, realIsInf), inf, resultReal);
981 Value imagSignInf = b.create<math::CopySignOp>(inf, imag, fmf);
982 resultImag = b.create<arith::SelectOp>(
983 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, absSqrt, absSqrt),
984 nan, resultImag);
985 resultImag = b.create<arith::SelectOp>(
986 b.create<arith::OrIOp>(absImagIsInf, realIsNegInf), imagSignInf,
987 resultImag);
990 Value resultIsZero =
991 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absSqrt, zero, fmf);
992 resultReal = b.create<arith::SelectOp>(resultIsZero, zero, resultReal);
993 resultImag = b.create<arith::SelectOp>(resultIsZero, zero, resultImag);
995 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
996 resultImag);
997 return success();
1001 struct SignOpConversion : public OpConversionPattern<complex::SignOp> {
1002 using OpConversionPattern<complex::SignOp>::OpConversionPattern;
1004 LogicalResult
1005 matchAndRewrite(complex::SignOp op, OpAdaptor adaptor,
1006 ConversionPatternRewriter &rewriter) const override {
1007 auto type = cast<ComplexType>(adaptor.getComplex().getType());
1008 auto elementType = cast<FloatType>(type.getElementType());
1009 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
1010 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
1012 Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
1013 Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
1014 Value zero =
1015 b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
1016 Value realIsZero =
1017 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
1018 Value imagIsZero =
1019 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
1020 Value isZero = b.create<arith::AndIOp>(realIsZero, imagIsZero);
1021 auto abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex(), fmf);
1022 Value realSign = b.create<arith::DivFOp>(real, abs, fmf);
1023 Value imagSign = b.create<arith::DivFOp>(imag, abs, fmf);
1024 Value sign = b.create<complex::CreateOp>(type, realSign, imagSign);
1025 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isZero,
1026 adaptor.getComplex(), sign);
1027 return success();
1031 template <typename Op>
1032 struct TanTanhOpConversion : public OpConversionPattern<Op> {
1033 using OpConversionPattern<Op>::OpConversionPattern;
1035 LogicalResult
1036 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
1037 ConversionPatternRewriter &rewriter) const override {
1038 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
1039 auto loc = op.getLoc();
1040 auto type = cast<ComplexType>(adaptor.getComplex().getType());
1041 auto elementType = cast<FloatType>(type.getElementType());
1042 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
1043 const auto &floatSemantics = elementType.getFloatSemantics();
1045 Value real =
1046 b.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
1047 Value imag =
1048 b.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
1049 Value negOne = b.create<arith::ConstantOp>(
1050 elementType, b.getFloatAttr(elementType, -1.0));
1052 if constexpr (std::is_same_v<Op, complex::TanOp>) {
1053 // tan(x+yi) = -i*tanh(-y + xi)
1054 std::swap(real, imag);
1055 real = b.create<arith::MulFOp>(real, negOne, fmf);
1058 auto cst = [&](APFloat v) {
1059 return b.create<arith::ConstantOp>(elementType,
1060 b.getFloatAttr(elementType, v));
1062 Value inf = cst(APFloat::getInf(floatSemantics));
1063 Value four = b.create<arith::ConstantOp>(elementType,
1064 b.getFloatAttr(elementType, 4.0));
1065 Value twoReal = b.create<arith::AddFOp>(real, real, fmf);
1066 Value negTwoReal = b.create<arith::MulFOp>(negOne, twoReal, fmf);
1068 Value expTwoRealMinusOne = b.create<math::ExpM1Op>(twoReal, fmf);
1069 Value expNegTwoRealMinusOne = b.create<math::ExpM1Op>(negTwoReal, fmf);
1070 Value realNum =
1071 b.create<arith::SubFOp>(expTwoRealMinusOne, expNegTwoRealMinusOne, fmf);
1073 Value cosImag = b.create<math::CosOp>(imag, fmf);
1074 Value cosImagSq = b.create<arith::MulFOp>(cosImag, cosImag, fmf);
1075 Value twoCosTwoImagPlusOne = b.create<arith::MulFOp>(cosImagSq, four, fmf);
1076 Value sinImag = b.create<math::SinOp>(imag, fmf);
1078 Value imagNum = b.create<arith::MulFOp>(
1079 four, b.create<arith::MulFOp>(cosImag, sinImag, fmf), fmf);
1081 Value expSumMinusTwo =
1082 b.create<arith::AddFOp>(expTwoRealMinusOne, expNegTwoRealMinusOne, fmf);
1083 Value denom =
1084 b.create<arith::AddFOp>(expSumMinusTwo, twoCosTwoImagPlusOne, fmf);
1086 Value isInf = b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
1087 expSumMinusTwo, inf, fmf);
1088 Value realLimit = b.create<math::CopySignOp>(negOne, real, fmf);
1090 Value resultReal = b.create<arith::SelectOp>(
1091 isInf, realLimit, b.create<arith::DivFOp>(realNum, denom, fmf));
1092 Value resultImag = b.create<arith::DivFOp>(imagNum, denom, fmf);
1094 if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
1095 arith::FastMathFlags::ninf)) {
1096 Value absReal = b.create<math::AbsFOp>(real, fmf);
1097 Value zero = b.create<arith::ConstantOp>(
1098 elementType, b.getFloatAttr(elementType, 0.0));
1099 Value nan = cst(APFloat::getNaN(floatSemantics));
1101 Value absRealIsInf =
1102 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absReal, inf, fmf);
1103 Value imagIsZero =
1104 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero, fmf);
1105 Value absRealIsNotInf = b.create<arith::XOrIOp>(
1106 absRealIsInf, b.create<arith::ConstantIntOp>(true, /*width=*/1));
1108 Value imagNumIsNaN = b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO,
1109 imagNum, imagNum, fmf);
1110 Value resultRealIsNaN =
1111 b.create<arith::AndIOp>(imagNumIsNaN, absRealIsNotInf);
1112 Value resultImagIsZero = b.create<arith::OrIOp>(
1113 imagIsZero, b.create<arith::AndIOp>(absRealIsInf, imagNumIsNaN));
1115 resultReal = b.create<arith::SelectOp>(resultRealIsNaN, nan, resultReal);
1116 resultImag =
1117 b.create<arith::SelectOp>(resultImagIsZero, zero, resultImag);
1120 if constexpr (std::is_same_v<Op, complex::TanOp>) {
1121 // tan(x+yi) = -i*tanh(-y + xi)
1122 std::swap(resultReal, resultImag);
1123 resultImag = b.create<arith::MulFOp>(resultImag, negOne, fmf);
1126 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
1127 resultImag);
1128 return success();
1132 struct ConjOpConversion : public OpConversionPattern<complex::ConjOp> {
1133 using OpConversionPattern<complex::ConjOp>::OpConversionPattern;
1135 LogicalResult
1136 matchAndRewrite(complex::ConjOp op, OpAdaptor adaptor,
1137 ConversionPatternRewriter &rewriter) const override {
1138 auto loc = op.getLoc();
1139 auto type = cast<ComplexType>(adaptor.getComplex().getType());
1140 auto elementType = cast<FloatType>(type.getElementType());
1141 Value real =
1142 rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
1143 Value imag =
1144 rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
1145 Value negImag = rewriter.create<arith::NegFOp>(loc, elementType, imag);
1147 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, negImag);
1149 return success();
1153 /// Converts lhs^y = (a+bi)^(c+di) to
1154 /// (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)),
1155 /// where q = c*atan2(b,a)+0.5d*ln(a*a+b*b)
1156 static Value powOpConversionImpl(mlir::ImplicitLocOpBuilder &builder,
1157 ComplexType type, Value lhs, Value c, Value d,
1158 arith::FastMathFlags fmf) {
1159 auto elementType = cast<FloatType>(type.getElementType());
1161 Value a = builder.create<complex::ReOp>(lhs);
1162 Value b = builder.create<complex::ImOp>(lhs);
1164 Value abs = builder.create<complex::AbsOp>(lhs, fmf);
1165 Value absToC = builder.create<math::PowFOp>(abs, c, fmf);
1167 Value negD = builder.create<arith::NegFOp>(d, fmf);
1168 Value argLhs = builder.create<math::Atan2Op>(b, a, fmf);
1169 Value negDArgLhs = builder.create<arith::MulFOp>(negD, argLhs, fmf);
1170 Value expNegDArgLhs = builder.create<math::ExpOp>(negDArgLhs, fmf);
1172 Value coeff = builder.create<arith::MulFOp>(absToC, expNegDArgLhs, fmf);
1173 Value lnAbs = builder.create<math::LogOp>(abs, fmf);
1174 Value cArgLhs = builder.create<arith::MulFOp>(c, argLhs, fmf);
1175 Value dLnAbs = builder.create<arith::MulFOp>(d, lnAbs, fmf);
1176 Value q = builder.create<arith::AddFOp>(cArgLhs, dLnAbs, fmf);
1177 Value cosQ = builder.create<math::CosOp>(q, fmf);
1178 Value sinQ = builder.create<math::SinOp>(q, fmf);
1180 Value inf = builder.create<arith::ConstantOp>(
1181 elementType,
1182 builder.getFloatAttr(elementType,
1183 APFloat::getInf(elementType.getFloatSemantics())));
1184 Value zero = builder.create<arith::ConstantOp>(
1185 elementType, builder.getFloatAttr(elementType, 0.0));
1186 Value one = builder.create<arith::ConstantOp>(
1187 elementType, builder.getFloatAttr(elementType, 1.0));
1188 Value complexOne = builder.create<complex::CreateOp>(type, one, zero);
1189 Value complexZero = builder.create<complex::CreateOp>(type, zero, zero);
1190 Value complexInf = builder.create<complex::CreateOp>(type, inf, zero);
1192 // Case 0:
1193 // d^c is 0 if d is 0 and c > 0. 0^0 is defined to be 1.0, see
1194 // Branch Cuts for Complex Elementary Functions or Much Ado About
1195 // Nothing's Sign Bit, W. Kahan, Section 10.
1196 Value absEqZero =
1197 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, abs, zero, fmf);
1198 Value dEqZero =
1199 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, d, zero, fmf);
1200 Value cEqZero =
1201 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, c, zero, fmf);
1202 Value bEqZero =
1203 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, b, zero, fmf);
1205 Value zeroLeC =
1206 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLE, zero, c, fmf);
1207 Value coeffCosQ = builder.create<arith::MulFOp>(coeff, cosQ, fmf);
1208 Value coeffSinQ = builder.create<arith::MulFOp>(coeff, sinQ, fmf);
1209 Value complexOneOrZero =
1210 builder.create<arith::SelectOp>(cEqZero, complexOne, complexZero);
1211 Value coeffCosSin =
1212 builder.create<complex::CreateOp>(type, coeffCosQ, coeffSinQ);
1213 Value cutoff0 = builder.create<arith::SelectOp>(
1214 builder.create<arith::AndIOp>(
1215 builder.create<arith::AndIOp>(absEqZero, dEqZero), zeroLeC),
1216 complexOneOrZero, coeffCosSin);
1218 // Case 1:
1219 // x^0 is defined to be 1 for any x, see
1220 // Branch Cuts for Complex Elementary Functions or Much Ado About
1221 // Nothing's Sign Bit, W. Kahan, Section 10.
1222 Value rhsEqZero = builder.create<arith::AndIOp>(cEqZero, dEqZero);
1223 Value cutoff1 =
1224 builder.create<arith::SelectOp>(rhsEqZero, complexOne, cutoff0);
1226 // Case 2:
1227 // 1^(c + d*i) = 1 + 0*i
1228 Value lhsEqOne = builder.create<arith::AndIOp>(
1229 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, one, fmf),
1230 bEqZero);
1231 Value cutoff2 =
1232 builder.create<arith::SelectOp>(lhsEqOne, complexOne, cutoff1);
1234 // Case 3:
1235 // inf^(c + 0*i) = inf + 0*i, c > 0
1236 Value lhsEqInf = builder.create<arith::AndIOp>(
1237 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, inf, fmf),
1238 bEqZero);
1239 Value rhsGt0 = builder.create<arith::AndIOp>(
1240 dEqZero,
1241 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, c, zero, fmf));
1242 Value cutoff3 = builder.create<arith::SelectOp>(
1243 builder.create<arith::AndIOp>(lhsEqInf, rhsGt0), complexInf, cutoff2);
1245 // Case 4:
1246 // inf^(c + 0*i) = 0 + 0*i, c < 0
1247 Value rhsLt0 = builder.create<arith::AndIOp>(
1248 dEqZero,
1249 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, c, zero, fmf));
1250 Value cutoff4 = builder.create<arith::SelectOp>(
1251 builder.create<arith::AndIOp>(lhsEqInf, rhsLt0), complexZero, cutoff3);
1253 return cutoff4;
1256 struct PowOpConversion : public OpConversionPattern<complex::PowOp> {
1257 using OpConversionPattern<complex::PowOp>::OpConversionPattern;
1259 LogicalResult
1260 matchAndRewrite(complex::PowOp op, OpAdaptor adaptor,
1261 ConversionPatternRewriter &rewriter) const override {
1262 mlir::ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
1263 auto type = cast<ComplexType>(adaptor.getLhs().getType());
1264 auto elementType = cast<FloatType>(type.getElementType());
1266 Value c = builder.create<complex::ReOp>(elementType, adaptor.getRhs());
1267 Value d = builder.create<complex::ImOp>(elementType, adaptor.getRhs());
1269 rewriter.replaceOp(op, {powOpConversionImpl(builder, type, adaptor.getLhs(),
1270 c, d, op.getFastmath())});
1271 return success();
1275 struct RsqrtOpConversion : public OpConversionPattern<complex::RsqrtOp> {
1276 using OpConversionPattern<complex::RsqrtOp>::OpConversionPattern;
1278 LogicalResult
1279 matchAndRewrite(complex::RsqrtOp op, OpAdaptor adaptor,
1280 ConversionPatternRewriter &rewriter) const override {
1281 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
1282 auto type = cast<ComplexType>(adaptor.getComplex().getType());
1283 auto elementType = cast<FloatType>(type.getElementType());
1285 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
1287 auto cst = [&](APFloat v) {
1288 return b.create<arith::ConstantOp>(elementType,
1289 b.getFloatAttr(elementType, v));
1291 const auto &floatSemantics = elementType.getFloatSemantics();
1292 Value zero = cst(APFloat::getZero(floatSemantics));
1293 Value inf = cst(APFloat::getInf(floatSemantics));
1294 Value negHalf = b.create<arith::ConstantOp>(
1295 elementType, b.getFloatAttr(elementType, -0.5));
1296 Value nan = cst(APFloat::getNaN(floatSemantics));
1298 Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
1299 Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
1300 Value absRsqrt = computeAbs(real, imag, fmf, b, AbsFn::rsqrt);
1301 Value argArg = b.create<math::Atan2Op>(imag, real, fmf);
1302 Value rsqrtArg = b.create<arith::MulFOp>(argArg, negHalf, fmf);
1303 Value cos = b.create<math::CosOp>(rsqrtArg, fmf);
1304 Value sin = b.create<math::SinOp>(rsqrtArg, fmf);
1306 Value resultReal = b.create<arith::MulFOp>(absRsqrt, cos, fmf);
1307 Value resultImag = b.create<arith::MulFOp>(absRsqrt, sin, fmf);
1309 if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
1310 arith::FastMathFlags::ninf)) {
1311 Value negOne = b.create<arith::ConstantOp>(
1312 elementType, b.getFloatAttr(elementType, -1));
1314 Value realSignedZero = b.create<math::CopySignOp>(zero, real, fmf);
1315 Value imagSignedZero = b.create<math::CopySignOp>(zero, imag, fmf);
1316 Value negImagSignedZero =
1317 b.create<arith::MulFOp>(negOne, imagSignedZero, fmf);
1319 Value absReal = b.create<math::AbsFOp>(real, fmf);
1320 Value absImag = b.create<math::AbsFOp>(imag, fmf);
1322 Value absImagIsInf =
1323 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absImag, inf, fmf);
1324 Value realIsNan =
1325 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, real, real, fmf);
1326 Value realIsInf =
1327 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absReal, inf, fmf);
1328 Value inIsNanInf = b.create<arith::AndIOp>(absImagIsInf, realIsNan);
1330 Value resultIsZero = b.create<arith::OrIOp>(inIsNanInf, realIsInf);
1332 resultReal =
1333 b.create<arith::SelectOp>(resultIsZero, realSignedZero, resultReal);
1334 resultImag = b.create<arith::SelectOp>(resultIsZero, negImagSignedZero,
1335 resultImag);
1338 Value isRealZero =
1339 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero, fmf);
1340 Value isImagZero =
1341 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero, fmf);
1342 Value isZero = b.create<arith::AndIOp>(isRealZero, isImagZero);
1344 resultReal = b.create<arith::SelectOp>(isZero, inf, resultReal);
1345 resultImag = b.create<arith::SelectOp>(isZero, nan, resultImag);
1347 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
1348 resultImag);
1349 return success();
1353 struct AngleOpConversion : public OpConversionPattern<complex::AngleOp> {
1354 using OpConversionPattern<complex::AngleOp>::OpConversionPattern;
1356 LogicalResult
1357 matchAndRewrite(complex::AngleOp op, OpAdaptor adaptor,
1358 ConversionPatternRewriter &rewriter) const override {
1359 auto loc = op.getLoc();
1360 auto type = op.getType();
1361 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
1363 Value real =
1364 rewriter.create<complex::ReOp>(loc, type, adaptor.getComplex());
1365 Value imag =
1366 rewriter.create<complex::ImOp>(loc, type, adaptor.getComplex());
1368 rewriter.replaceOpWithNewOp<math::Atan2Op>(op, imag, real, fmf);
1370 return success();
1374 } // namespace
1376 void mlir::populateComplexToStandardConversionPatterns(
1377 RewritePatternSet &patterns) {
1378 // clang-format off
1379 patterns.add<
1380 AbsOpConversion,
1381 AngleOpConversion,
1382 Atan2OpConversion,
1383 BinaryComplexOpConversion<complex::AddOp, arith::AddFOp>,
1384 BinaryComplexOpConversion<complex::SubOp, arith::SubFOp>,
1385 ComparisonOpConversion<complex::EqualOp, arith::CmpFPredicate::OEQ>,
1386 ComparisonOpConversion<complex::NotEqualOp, arith::CmpFPredicate::UNE>,
1387 ConjOpConversion,
1388 CosOpConversion,
1389 DivOpConversion,
1390 ExpOpConversion,
1391 Expm1OpConversion,
1392 Log1pOpConversion,
1393 LogOpConversion,
1394 MulOpConversion,
1395 NegOpConversion,
1396 SignOpConversion,
1397 SinOpConversion,
1398 SqrtOpConversion,
1399 TanTanhOpConversion<complex::TanOp>,
1400 TanTanhOpConversion<complex::TanhOp>,
1401 PowOpConversion,
1402 RsqrtOpConversion
1403 >(patterns.getContext());
1404 // clang-format on
1407 namespace {
1408 struct ConvertComplexToStandardPass
1409 : public impl::ConvertComplexToStandardBase<ConvertComplexToStandardPass> {
1410 void runOnOperation() override;
1413 void ConvertComplexToStandardPass::runOnOperation() {
1414 // Convert to the Standard dialect using the converter defined above.
1415 RewritePatternSet patterns(&getContext());
1416 populateComplexToStandardConversionPatterns(patterns);
1418 ConversionTarget target(getContext());
1419 target.addLegalDialect<arith::ArithDialect, math::MathDialect>();
1420 target.addLegalOp<complex::CreateOp, complex::ImOp, complex::ReOp>();
1421 if (failed(
1422 applyPartialConversion(getOperation(), target, std::move(patterns))))
1423 signalPassFailure();
1425 } // namespace
1427 std::unique_ptr<Pass> mlir::createConvertComplexToStandardPass() {
1428 return std::make_unique<ConvertComplexToStandardPass>();