1 //===- ComplexOps.cpp - MLIR Complex Operations ---------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Dialect/Arith/IR/Arith.h"
10 #include "mlir/Dialect/Complex/IR/Complex.h"
11 #include "mlir/IR/Builders.h"
12 #include "mlir/IR/BuiltinTypes.h"
13 #include "mlir/IR/Matchers.h"
14 #include "mlir/IR/PatternMatch.h"
17 using namespace mlir::complex;
19 //===----------------------------------------------------------------------===//
21 //===----------------------------------------------------------------------===//
23 OpFoldResult
ConstantOp::fold(FoldAdaptor adaptor
) {
27 void ConstantOp::getAsmResultNames(
28 function_ref
<void(Value
, StringRef
)> setNameFn
) {
29 setNameFn(getResult(), "cst");
32 bool ConstantOp::isBuildableWith(Attribute value
, Type type
) {
33 if (auto arrAttr
= llvm::dyn_cast
<ArrayAttr
>(value
)) {
34 auto complexTy
= llvm::dyn_cast
<ComplexType
>(type
);
35 if (!complexTy
|| arrAttr
.size() != 2)
37 auto complexEltTy
= complexTy
.getElementType();
38 if (auto fre
= llvm::dyn_cast
<FloatAttr
>(arrAttr
[0])) {
39 auto im
= llvm::dyn_cast
<FloatAttr
>(arrAttr
[1]);
40 return im
&& fre
.getType() == complexEltTy
&&
41 im
.getType() == complexEltTy
;
43 if (auto ire
= llvm::dyn_cast
<IntegerAttr
>(arrAttr
[0])) {
44 auto im
= llvm::dyn_cast
<IntegerAttr
>(arrAttr
[1]);
45 return im
&& ire
.getType() == complexEltTy
&&
46 im
.getType() == complexEltTy
;
52 LogicalResult
ConstantOp::verify() {
53 ArrayAttr arrayAttr
= getValue();
54 if (arrayAttr
.size() != 2) {
56 "requires 'value' to be a complex constant, represented as array of "
60 auto complexEltTy
= getType().getElementType();
61 auto re
= llvm::dyn_cast
<FloatAttr
>(arrayAttr
[0]);
62 auto im
= llvm::dyn_cast
<FloatAttr
>(arrayAttr
[1]);
64 return emitOpError("requires attribute's elements to be float attributes");
65 if (complexEltTy
!= re
.getType() || complexEltTy
!= im
.getType()) {
67 << "requires attribute's element types (" << re
.getType() << ", "
69 << ") to match the element type of the op's return type ("
70 << complexEltTy
<< ")";
75 //===----------------------------------------------------------------------===//
77 //===----------------------------------------------------------------------===//
79 OpFoldResult
BitcastOp::fold(FoldAdaptor bitcast
) {
80 if (getOperand().getType() == getType())
86 LogicalResult
BitcastOp::verify() {
87 auto operandType
= getOperand().getType();
88 auto resultType
= getType();
90 // We allow this to be legal as it can be folded away.
91 if (operandType
== resultType
)
94 if (!operandType
.isIntOrFloat() && !isa
<ComplexType
>(operandType
)) {
95 return emitOpError("operand must be int/float/complex");
98 if (!resultType
.isIntOrFloat() && !isa
<ComplexType
>(resultType
)) {
99 return emitOpError("result must be int/float/complex");
102 if (isa
<ComplexType
>(operandType
) == isa
<ComplexType
>(resultType
)) {
103 return emitOpError("requires input or output is a complex type");
106 if (isa
<ComplexType
>(resultType
))
107 std::swap(operandType
, resultType
);
109 int32_t operandBitwidth
= dyn_cast
<ComplexType
>(operandType
)
111 .getIntOrFloatBitWidth() *
113 int32_t resultBitwidth
= resultType
.getIntOrFloatBitWidth();
115 if (operandBitwidth
!= resultBitwidth
) {
116 return emitOpError("casting bitwidths do not match");
122 struct MergeComplexBitcast final
: OpRewritePattern
<BitcastOp
> {
123 using OpRewritePattern
<BitcastOp
>::OpRewritePattern
;
125 LogicalResult
matchAndRewrite(BitcastOp op
,
126 PatternRewriter
&rewriter
) const override
{
127 if (auto defining
= op
.getOperand().getDefiningOp
<BitcastOp
>()) {
128 rewriter
.replaceOpWithNewOp
<BitcastOp
>(op
, op
.getType(),
129 defining
.getOperand());
133 if (auto defining
= op
.getOperand().getDefiningOp
<arith::BitcastOp
>()) {
134 rewriter
.replaceOpWithNewOp
<BitcastOp
>(op
, op
.getType(),
135 defining
.getOperand());
143 struct MergeArithBitcast final
: OpRewritePattern
<arith::BitcastOp
> {
144 using OpRewritePattern
<arith::BitcastOp
>::OpRewritePattern
;
146 LogicalResult
matchAndRewrite(arith::BitcastOp op
,
147 PatternRewriter
&rewriter
) const override
{
148 if (auto defining
= op
.getOperand().getDefiningOp
<complex::BitcastOp
>()) {
149 rewriter
.replaceOpWithNewOp
<complex::BitcastOp
>(op
, op
.getType(),
150 defining
.getOperand());
158 struct ArithBitcast final
: OpRewritePattern
<BitcastOp
> {
159 using OpRewritePattern
<complex::BitcastOp
>::OpRewritePattern
;
161 LogicalResult
matchAndRewrite(BitcastOp op
,
162 PatternRewriter
&rewriter
) const override
{
163 if (isa
<ComplexType
>(op
.getType()) ||
164 isa
<ComplexType
>(op
.getOperand().getType()))
167 rewriter
.replaceOpWithNewOp
<arith::BitcastOp
>(op
, op
.getType(),
173 void BitcastOp::getCanonicalizationPatterns(RewritePatternSet
&results
,
174 MLIRContext
*context
) {
175 results
.add
<ArithBitcast
, MergeComplexBitcast
, MergeArithBitcast
>(context
);
178 //===----------------------------------------------------------------------===//
180 //===----------------------------------------------------------------------===//
182 OpFoldResult
CreateOp::fold(FoldAdaptor adaptor
) {
183 // Fold complex.create(complex.re(op), complex.im(op)).
184 if (auto reOp
= getOperand(0).getDefiningOp
<ReOp
>()) {
185 if (auto imOp
= getOperand(1).getDefiningOp
<ImOp
>()) {
186 if (reOp
.getOperand() == imOp
.getOperand()) {
187 return reOp
.getOperand();
194 //===----------------------------------------------------------------------===//
196 //===----------------------------------------------------------------------===//
198 OpFoldResult
ImOp::fold(FoldAdaptor adaptor
) {
199 ArrayAttr arrayAttr
=
200 llvm::dyn_cast_if_present
<ArrayAttr
>(adaptor
.getComplex());
201 if (arrayAttr
&& arrayAttr
.size() == 2)
203 if (auto createOp
= getOperand().getDefiningOp
<CreateOp
>())
204 return createOp
.getOperand(1);
209 template <typename OpKind
, int ComponentIndex
>
210 struct FoldComponentNeg final
: OpRewritePattern
<OpKind
> {
211 using OpRewritePattern
<OpKind
>::OpRewritePattern
;
213 LogicalResult
matchAndRewrite(OpKind op
,
214 PatternRewriter
&rewriter
) const override
{
215 auto negOp
= op
.getOperand().template getDefiningOp
<NegOp
>();
219 auto createOp
= negOp
.getComplex().template getDefiningOp
<CreateOp
>();
223 Type elementType
= createOp
.getType().getElementType();
224 assert(isa
<FloatType
>(elementType
));
226 rewriter
.replaceOpWithNewOp
<arith::NegFOp
>(
227 op
, elementType
, createOp
.getOperand(ComponentIndex
));
233 void ImOp::getCanonicalizationPatterns(RewritePatternSet
&results
,
234 MLIRContext
*context
) {
235 results
.add
<FoldComponentNeg
<ImOp
, 1>>(context
);
238 //===----------------------------------------------------------------------===//
240 //===----------------------------------------------------------------------===//
242 OpFoldResult
ReOp::fold(FoldAdaptor adaptor
) {
243 ArrayAttr arrayAttr
=
244 llvm::dyn_cast_if_present
<ArrayAttr
>(adaptor
.getComplex());
245 if (arrayAttr
&& arrayAttr
.size() == 2)
247 if (auto createOp
= getOperand().getDefiningOp
<CreateOp
>())
248 return createOp
.getOperand(0);
252 void ReOp::getCanonicalizationPatterns(RewritePatternSet
&results
,
253 MLIRContext
*context
) {
254 results
.add
<FoldComponentNeg
<ReOp
, 0>>(context
);
257 //===----------------------------------------------------------------------===//
259 //===----------------------------------------------------------------------===//
261 OpFoldResult
AddOp::fold(FoldAdaptor adaptor
) {
262 // complex.add(complex.sub(a, b), b) -> a
263 if (auto sub
= getLhs().getDefiningOp
<SubOp
>())
264 if (getRhs() == sub
.getRhs())
267 // complex.add(b, complex.sub(a, b)) -> a
268 if (auto sub
= getRhs().getDefiningOp
<SubOp
>())
269 if (getLhs() == sub
.getRhs())
272 // complex.add(a, complex.constant<0.0, 0.0>) -> a
273 if (auto constantOp
= getRhs().getDefiningOp
<ConstantOp
>()) {
274 auto arrayAttr
= constantOp
.getValue();
275 if (llvm::cast
<FloatAttr
>(arrayAttr
[0]).getValue().isZero() &&
276 llvm::cast
<FloatAttr
>(arrayAttr
[1]).getValue().isZero()) {
284 //===----------------------------------------------------------------------===//
286 //===----------------------------------------------------------------------===//
288 OpFoldResult
SubOp::fold(FoldAdaptor adaptor
) {
289 // complex.sub(complex.add(a, b), b) -> a
290 if (auto add
= getLhs().getDefiningOp
<AddOp
>())
291 if (getRhs() == add
.getRhs())
294 // complex.sub(a, complex.constant<0.0, 0.0>) -> a
295 if (auto constantOp
= getRhs().getDefiningOp
<ConstantOp
>()) {
296 auto arrayAttr
= constantOp
.getValue();
297 if (llvm::cast
<FloatAttr
>(arrayAttr
[0]).getValue().isZero() &&
298 llvm::cast
<FloatAttr
>(arrayAttr
[1]).getValue().isZero()) {
306 //===----------------------------------------------------------------------===//
308 //===----------------------------------------------------------------------===//
310 OpFoldResult
NegOp::fold(FoldAdaptor adaptor
) {
311 // complex.neg(complex.neg(a)) -> a
312 if (auto negOp
= getOperand().getDefiningOp
<NegOp
>())
313 return negOp
.getOperand();
318 //===----------------------------------------------------------------------===//
320 //===----------------------------------------------------------------------===//
322 OpFoldResult
LogOp::fold(FoldAdaptor adaptor
) {
323 // complex.log(complex.exp(a)) -> a
324 if (auto expOp
= getOperand().getDefiningOp
<ExpOp
>())
325 return expOp
.getOperand();
330 //===----------------------------------------------------------------------===//
332 //===----------------------------------------------------------------------===//
334 OpFoldResult
ExpOp::fold(FoldAdaptor adaptor
) {
335 // complex.exp(complex.log(a)) -> a
336 if (auto logOp
= getOperand().getDefiningOp
<LogOp
>())
337 return logOp
.getOperand();
342 //===----------------------------------------------------------------------===//
344 //===----------------------------------------------------------------------===//
346 OpFoldResult
ConjOp::fold(FoldAdaptor adaptor
) {
347 // complex.conj(complex.conj(a)) -> a
348 if (auto conjOp
= getOperand().getDefiningOp
<ConjOp
>())
349 return conjOp
.getOperand();
354 //===----------------------------------------------------------------------===//
356 //===----------------------------------------------------------------------===//
358 OpFoldResult
MulOp::fold(FoldAdaptor adaptor
) {
359 auto constant
= getRhs().getDefiningOp
<ConstantOp
>();
363 ArrayAttr arrayAttr
= constant
.getValue();
364 APFloat real
= cast
<FloatAttr
>(arrayAttr
[0]).getValue();
365 APFloat imag
= cast
<FloatAttr
>(arrayAttr
[1]).getValue();
370 // complex.mul(a, complex.constant<1.0, 0.0>) -> a
371 if (real
== APFloat(real
.getSemantics(), 1))
377 //===----------------------------------------------------------------------===//
378 // TableGen'd op method definitions
379 //===----------------------------------------------------------------------===//
381 #define GET_OP_CLASSES
382 #include "mlir/Dialect/Complex/IR/ComplexOps.cpp.inc"