[Workflow] Roll back some settings since they caused more issues
[llvm-project.git] / mlir / lib / Dialect / Complex / IR / ComplexOps.cpp
blob8fd914dd107ffb5c5199ecf623bb013180a7c998
1 //===- ComplexOps.cpp - MLIR Complex Operations ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Dialect/Arith/IR/Arith.h"
10 #include "mlir/Dialect/Complex/IR/Complex.h"
11 #include "mlir/IR/Builders.h"
12 #include "mlir/IR/BuiltinTypes.h"
13 #include "mlir/IR/Matchers.h"
14 #include "mlir/IR/PatternMatch.h"
16 using namespace mlir;
17 using namespace mlir::complex;
19 //===----------------------------------------------------------------------===//
20 // ConstantOp
21 //===----------------------------------------------------------------------===//
23 OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
24 return getValue();
27 void ConstantOp::getAsmResultNames(
28 function_ref<void(Value, StringRef)> setNameFn) {
29 setNameFn(getResult(), "cst");
32 bool ConstantOp::isBuildableWith(Attribute value, Type type) {
33 if (auto arrAttr = llvm::dyn_cast<ArrayAttr>(value)) {
34 auto complexTy = llvm::dyn_cast<ComplexType>(type);
35 if (!complexTy || arrAttr.size() != 2)
36 return false;
37 auto complexEltTy = complexTy.getElementType();
38 if (auto fre = llvm::dyn_cast<FloatAttr>(arrAttr[0])) {
39 auto im = llvm::dyn_cast<FloatAttr>(arrAttr[1]);
40 return im && fre.getType() == complexEltTy &&
41 im.getType() == complexEltTy;
43 if (auto ire = llvm::dyn_cast<IntegerAttr>(arrAttr[0])) {
44 auto im = llvm::dyn_cast<IntegerAttr>(arrAttr[1]);
45 return im && ire.getType() == complexEltTy &&
46 im.getType() == complexEltTy;
49 return false;
52 LogicalResult ConstantOp::verify() {
53 ArrayAttr arrayAttr = getValue();
54 if (arrayAttr.size() != 2) {
55 return emitOpError(
56 "requires 'value' to be a complex constant, represented as array of "
57 "two values");
60 auto complexEltTy = getType().getElementType();
61 auto re = llvm::dyn_cast<FloatAttr>(arrayAttr[0]);
62 auto im = llvm::dyn_cast<FloatAttr>(arrayAttr[1]);
63 if (!re || !im)
64 return emitOpError("requires attribute's elements to be float attributes");
65 if (complexEltTy != re.getType() || complexEltTy != im.getType()) {
66 return emitOpError()
67 << "requires attribute's element types (" << re.getType() << ", "
68 << im.getType()
69 << ") to match the element type of the op's return type ("
70 << complexEltTy << ")";
72 return success();
75 //===----------------------------------------------------------------------===//
76 // BitcastOp
77 //===----------------------------------------------------------------------===//
79 OpFoldResult BitcastOp::fold(FoldAdaptor bitcast) {
80 if (getOperand().getType() == getType())
81 return getOperand();
83 return {};
86 LogicalResult BitcastOp::verify() {
87 auto operandType = getOperand().getType();
88 auto resultType = getType();
90 // We allow this to be legal as it can be folded away.
91 if (operandType == resultType)
92 return success();
94 if (!operandType.isIntOrFloat() && !isa<ComplexType>(operandType)) {
95 return emitOpError("operand must be int/float/complex");
98 if (!resultType.isIntOrFloat() && !isa<ComplexType>(resultType)) {
99 return emitOpError("result must be int/float/complex");
102 if (isa<ComplexType>(operandType) == isa<ComplexType>(resultType)) {
103 return emitOpError("requires input or output is a complex type");
106 if (isa<ComplexType>(resultType))
107 std::swap(operandType, resultType);
109 int32_t operandBitwidth = dyn_cast<ComplexType>(operandType)
110 .getElementType()
111 .getIntOrFloatBitWidth() *
113 int32_t resultBitwidth = resultType.getIntOrFloatBitWidth();
115 if (operandBitwidth != resultBitwidth) {
116 return emitOpError("casting bitwidths do not match");
119 return success();
122 struct MergeComplexBitcast final : OpRewritePattern<BitcastOp> {
123 using OpRewritePattern<BitcastOp>::OpRewritePattern;
125 LogicalResult matchAndRewrite(BitcastOp op,
126 PatternRewriter &rewriter) const override {
127 if (auto defining = op.getOperand().getDefiningOp<BitcastOp>()) {
128 rewriter.replaceOpWithNewOp<BitcastOp>(op, op.getType(),
129 defining.getOperand());
130 return success();
133 if (auto defining = op.getOperand().getDefiningOp<arith::BitcastOp>()) {
134 rewriter.replaceOpWithNewOp<BitcastOp>(op, op.getType(),
135 defining.getOperand());
136 return success();
139 return failure();
143 struct MergeArithBitcast final : OpRewritePattern<arith::BitcastOp> {
144 using OpRewritePattern<arith::BitcastOp>::OpRewritePattern;
146 LogicalResult matchAndRewrite(arith::BitcastOp op,
147 PatternRewriter &rewriter) const override {
148 if (auto defining = op.getOperand().getDefiningOp<complex::BitcastOp>()) {
149 rewriter.replaceOpWithNewOp<complex::BitcastOp>(op, op.getType(),
150 defining.getOperand());
151 return success();
154 return failure();
158 struct ArithBitcast final : OpRewritePattern<BitcastOp> {
159 using OpRewritePattern<complex::BitcastOp>::OpRewritePattern;
161 LogicalResult matchAndRewrite(BitcastOp op,
162 PatternRewriter &rewriter) const override {
163 if (isa<ComplexType>(op.getType()) ||
164 isa<ComplexType>(op.getOperand().getType()))
165 return failure();
167 rewriter.replaceOpWithNewOp<arith::BitcastOp>(op, op.getType(),
168 op.getOperand());
169 return success();
173 void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
174 MLIRContext *context) {
175 results.add<ArithBitcast, MergeComplexBitcast, MergeArithBitcast>(context);
178 //===----------------------------------------------------------------------===//
179 // CreateOp
180 //===----------------------------------------------------------------------===//
182 OpFoldResult CreateOp::fold(FoldAdaptor adaptor) {
183 // Fold complex.create(complex.re(op), complex.im(op)).
184 if (auto reOp = getOperand(0).getDefiningOp<ReOp>()) {
185 if (auto imOp = getOperand(1).getDefiningOp<ImOp>()) {
186 if (reOp.getOperand() == imOp.getOperand()) {
187 return reOp.getOperand();
191 return {};
194 //===----------------------------------------------------------------------===//
195 // ImOp
196 //===----------------------------------------------------------------------===//
198 OpFoldResult ImOp::fold(FoldAdaptor adaptor) {
199 ArrayAttr arrayAttr =
200 llvm::dyn_cast_if_present<ArrayAttr>(adaptor.getComplex());
201 if (arrayAttr && arrayAttr.size() == 2)
202 return arrayAttr[1];
203 if (auto createOp = getOperand().getDefiningOp<CreateOp>())
204 return createOp.getOperand(1);
205 return {};
208 namespace {
209 template <typename OpKind, int ComponentIndex>
210 struct FoldComponentNeg final : OpRewritePattern<OpKind> {
211 using OpRewritePattern<OpKind>::OpRewritePattern;
213 LogicalResult matchAndRewrite(OpKind op,
214 PatternRewriter &rewriter) const override {
215 auto negOp = op.getOperand().template getDefiningOp<NegOp>();
216 if (!negOp)
217 return failure();
219 auto createOp = negOp.getComplex().template getDefiningOp<CreateOp>();
220 if (!createOp)
221 return failure();
223 Type elementType = createOp.getType().getElementType();
224 assert(isa<FloatType>(elementType));
226 rewriter.replaceOpWithNewOp<arith::NegFOp>(
227 op, elementType, createOp.getOperand(ComponentIndex));
228 return success();
231 } // namespace
233 void ImOp::getCanonicalizationPatterns(RewritePatternSet &results,
234 MLIRContext *context) {
235 results.add<FoldComponentNeg<ImOp, 1>>(context);
238 //===----------------------------------------------------------------------===//
239 // ReOp
240 //===----------------------------------------------------------------------===//
242 OpFoldResult ReOp::fold(FoldAdaptor adaptor) {
243 ArrayAttr arrayAttr =
244 llvm::dyn_cast_if_present<ArrayAttr>(adaptor.getComplex());
245 if (arrayAttr && arrayAttr.size() == 2)
246 return arrayAttr[0];
247 if (auto createOp = getOperand().getDefiningOp<CreateOp>())
248 return createOp.getOperand(0);
249 return {};
252 void ReOp::getCanonicalizationPatterns(RewritePatternSet &results,
253 MLIRContext *context) {
254 results.add<FoldComponentNeg<ReOp, 0>>(context);
257 //===----------------------------------------------------------------------===//
258 // AddOp
259 //===----------------------------------------------------------------------===//
261 OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
262 // complex.add(complex.sub(a, b), b) -> a
263 if (auto sub = getLhs().getDefiningOp<SubOp>())
264 if (getRhs() == sub.getRhs())
265 return sub.getLhs();
267 // complex.add(b, complex.sub(a, b)) -> a
268 if (auto sub = getRhs().getDefiningOp<SubOp>())
269 if (getLhs() == sub.getRhs())
270 return sub.getLhs();
272 // complex.add(a, complex.constant<0.0, 0.0>) -> a
273 if (auto constantOp = getRhs().getDefiningOp<ConstantOp>()) {
274 auto arrayAttr = constantOp.getValue();
275 if (llvm::cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
276 llvm::cast<FloatAttr>(arrayAttr[1]).getValue().isZero()) {
277 return getLhs();
281 return {};
284 //===----------------------------------------------------------------------===//
285 // SubOp
286 //===----------------------------------------------------------------------===//
288 OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
289 // complex.sub(complex.add(a, b), b) -> a
290 if (auto add = getLhs().getDefiningOp<AddOp>())
291 if (getRhs() == add.getRhs())
292 return add.getLhs();
294 // complex.sub(a, complex.constant<0.0, 0.0>) -> a
295 if (auto constantOp = getRhs().getDefiningOp<ConstantOp>()) {
296 auto arrayAttr = constantOp.getValue();
297 if (llvm::cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
298 llvm::cast<FloatAttr>(arrayAttr[1]).getValue().isZero()) {
299 return getLhs();
303 return {};
306 //===----------------------------------------------------------------------===//
307 // NegOp
308 //===----------------------------------------------------------------------===//
310 OpFoldResult NegOp::fold(FoldAdaptor adaptor) {
311 // complex.neg(complex.neg(a)) -> a
312 if (auto negOp = getOperand().getDefiningOp<NegOp>())
313 return negOp.getOperand();
315 return {};
318 //===----------------------------------------------------------------------===//
319 // LogOp
320 //===----------------------------------------------------------------------===//
322 OpFoldResult LogOp::fold(FoldAdaptor adaptor) {
323 // complex.log(complex.exp(a)) -> a
324 if (auto expOp = getOperand().getDefiningOp<ExpOp>())
325 return expOp.getOperand();
327 return {};
330 //===----------------------------------------------------------------------===//
331 // ExpOp
332 //===----------------------------------------------------------------------===//
334 OpFoldResult ExpOp::fold(FoldAdaptor adaptor) {
335 // complex.exp(complex.log(a)) -> a
336 if (auto logOp = getOperand().getDefiningOp<LogOp>())
337 return logOp.getOperand();
339 return {};
342 //===----------------------------------------------------------------------===//
343 // ConjOp
344 //===----------------------------------------------------------------------===//
346 OpFoldResult ConjOp::fold(FoldAdaptor adaptor) {
347 // complex.conj(complex.conj(a)) -> a
348 if (auto conjOp = getOperand().getDefiningOp<ConjOp>())
349 return conjOp.getOperand();
351 return {};
354 //===----------------------------------------------------------------------===//
355 // MulOp
356 //===----------------------------------------------------------------------===//
358 OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
359 auto constant = getRhs().getDefiningOp<ConstantOp>();
360 if (!constant)
361 return {};
363 ArrayAttr arrayAttr = constant.getValue();
364 APFloat real = cast<FloatAttr>(arrayAttr[0]).getValue();
365 APFloat imag = cast<FloatAttr>(arrayAttr[1]).getValue();
367 if (!imag.isZero())
368 return {};
370 // complex.mul(a, complex.constant<1.0, 0.0>) -> a
371 if (real == APFloat(real.getSemantics(), 1))
372 return getLhs();
374 return {};
377 //===----------------------------------------------------------------------===//
378 // TableGen'd op method definitions
379 //===----------------------------------------------------------------------===//
381 #define GET_OP_CLASSES
382 #include "mlir/Dialect/Complex/IR/ComplexOps.cpp.inc"