Rename CODE_OWNERS -> Maintainers (#114544)
[llvm-project.git] / mlir / lib / Conversion / ComplexToSPIRV / ComplexToSPIRV.cpp
blob5e7d2d8491533f5c7029be097d93274524e0f3f5
1 //===- ComplexToSPIRV.cpp - Complex to SPIR-V Patterns --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert Complex dialect to SPIR-V dialect.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Conversion/ComplexToSPIRV/ComplexToSPIRV.h"
14 #include "mlir/Dialect/Complex/IR/Complex.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
18 #include "mlir/Transforms/DialectConversion.h"
19 #include "llvm/Support/Debug.h"
21 #define DEBUG_TYPE "complex-to-spirv-pattern"
23 using namespace mlir;
25 //===----------------------------------------------------------------------===//
26 // Operation conversion
27 //===----------------------------------------------------------------------===//
29 namespace {
31 struct ConstantOpPattern final : OpConversionPattern<complex::ConstantOp> {
32 using OpConversionPattern::OpConversionPattern;
34 LogicalResult
35 matchAndRewrite(complex::ConstantOp constOp, OpAdaptor adaptor,
36 ConversionPatternRewriter &rewriter) const override {
37 auto spirvType =
38 getTypeConverter()->convertType<ShapedType>(constOp.getType());
39 if (!spirvType)
40 return rewriter.notifyMatchFailure(constOp,
41 "unable to convert result type");
43 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
44 constOp, spirvType,
45 DenseElementsAttr::get(spirvType, constOp.getValue().getValue()));
46 return success();
50 struct CreateOpPattern final : OpConversionPattern<complex::CreateOp> {
51 using OpConversionPattern::OpConversionPattern;
53 LogicalResult
54 matchAndRewrite(complex::CreateOp createOp, OpAdaptor adaptor,
55 ConversionPatternRewriter &rewriter) const override {
56 Type spirvType = getTypeConverter()->convertType(createOp.getType());
57 if (!spirvType)
58 return rewriter.notifyMatchFailure(createOp,
59 "unable to convert result type");
61 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
62 createOp, spirvType, adaptor.getOperands());
63 return success();
67 struct ReOpPattern final : OpConversionPattern<complex::ReOp> {
68 using OpConversionPattern::OpConversionPattern;
70 LogicalResult
71 matchAndRewrite(complex::ReOp reOp, OpAdaptor adaptor,
72 ConversionPatternRewriter &rewriter) const override {
73 Type spirvType = getTypeConverter()->convertType(reOp.getType());
74 if (!spirvType)
75 return rewriter.notifyMatchFailure(reOp, "unable to convert result type");
77 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
78 reOp, adaptor.getComplex(), llvm::ArrayRef(0));
79 return success();
83 struct ImOpPattern final : OpConversionPattern<complex::ImOp> {
84 using OpConversionPattern::OpConversionPattern;
86 LogicalResult
87 matchAndRewrite(complex::ImOp imOp, OpAdaptor adaptor,
88 ConversionPatternRewriter &rewriter) const override {
89 Type spirvType = getTypeConverter()->convertType(imOp.getType());
90 if (!spirvType)
91 return rewriter.notifyMatchFailure(imOp, "unable to convert result type");
93 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
94 imOp, adaptor.getComplex(), llvm::ArrayRef(1));
95 return success();
99 } // namespace
101 //===----------------------------------------------------------------------===//
102 // Pattern population
103 //===----------------------------------------------------------------------===//
105 void mlir::populateComplexToSPIRVPatterns(
106 const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
107 MLIRContext *context = patterns.getContext();
109 patterns.add<ConstantOpPattern, CreateOpPattern, ReOpPattern, ImOpPattern>(
110 typeConverter, context);