1 //===- ComplexToSPIRV.cpp - Complex to SPIR-V Patterns --------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements patterns to convert Complex dialect to SPIR-V dialect.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Conversion/ComplexToSPIRV/ComplexToSPIRV.h"
14 #include "mlir/Dialect/Complex/IR/Complex.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
18 #include "mlir/Transforms/DialectConversion.h"
19 #include "llvm/Support/Debug.h"
21 #define DEBUG_TYPE "complex-to-spirv-pattern"
25 //===----------------------------------------------------------------------===//
26 // Operation conversion
27 //===----------------------------------------------------------------------===//
31 struct ConstantOpPattern final
: OpConversionPattern
<complex::ConstantOp
> {
32 using OpConversionPattern::OpConversionPattern
;
35 matchAndRewrite(complex::ConstantOp constOp
, OpAdaptor adaptor
,
36 ConversionPatternRewriter
&rewriter
) const override
{
38 getTypeConverter()->convertType
<ShapedType
>(constOp
.getType());
40 return rewriter
.notifyMatchFailure(constOp
,
41 "unable to convert result type");
43 rewriter
.replaceOpWithNewOp
<spirv::ConstantOp
>(
45 DenseElementsAttr::get(spirvType
, constOp
.getValue().getValue()));
50 struct CreateOpPattern final
: OpConversionPattern
<complex::CreateOp
> {
51 using OpConversionPattern::OpConversionPattern
;
54 matchAndRewrite(complex::CreateOp createOp
, OpAdaptor adaptor
,
55 ConversionPatternRewriter
&rewriter
) const override
{
56 Type spirvType
= getTypeConverter()->convertType(createOp
.getType());
58 return rewriter
.notifyMatchFailure(createOp
,
59 "unable to convert result type");
61 rewriter
.replaceOpWithNewOp
<spirv::CompositeConstructOp
>(
62 createOp
, spirvType
, adaptor
.getOperands());
67 struct ReOpPattern final
: OpConversionPattern
<complex::ReOp
> {
68 using OpConversionPattern::OpConversionPattern
;
71 matchAndRewrite(complex::ReOp reOp
, OpAdaptor adaptor
,
72 ConversionPatternRewriter
&rewriter
) const override
{
73 Type spirvType
= getTypeConverter()->convertType(reOp
.getType());
75 return rewriter
.notifyMatchFailure(reOp
, "unable to convert result type");
77 rewriter
.replaceOpWithNewOp
<spirv::CompositeExtractOp
>(
78 reOp
, adaptor
.getComplex(), llvm::ArrayRef(0));
83 struct ImOpPattern final
: OpConversionPattern
<complex::ImOp
> {
84 using OpConversionPattern::OpConversionPattern
;
87 matchAndRewrite(complex::ImOp imOp
, OpAdaptor adaptor
,
88 ConversionPatternRewriter
&rewriter
) const override
{
89 Type spirvType
= getTypeConverter()->convertType(imOp
.getType());
91 return rewriter
.notifyMatchFailure(imOp
, "unable to convert result type");
93 rewriter
.replaceOpWithNewOp
<spirv::CompositeExtractOp
>(
94 imOp
, adaptor
.getComplex(), llvm::ArrayRef(1));
101 //===----------------------------------------------------------------------===//
102 // Pattern population
103 //===----------------------------------------------------------------------===//
105 void mlir::populateComplexToSPIRVPatterns(
106 const SPIRVTypeConverter
&typeConverter
, RewritePatternSet
&patterns
) {
107 MLIRContext
*context
= patterns
.getContext();
109 patterns
.add
<ConstantOpPattern
, CreateOpPattern
, ReOpPattern
, ImOpPattern
>(
110 typeConverter
, context
);