1 //===- MemRefToEmitC.cpp - MemRef to EmitC conversion ---------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements patterns to convert memref ops into emitc ops.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h"
15 #include "mlir/Dialect/EmitC/IR/EmitC.h"
16 #include "mlir/Dialect/MemRef/IR/MemRef.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/Transforms/DialectConversion.h"
24 struct ConvertAlloca final
: public OpConversionPattern
<memref::AllocaOp
> {
25 using OpConversionPattern::OpConversionPattern
;
28 matchAndRewrite(memref::AllocaOp op
, OpAdaptor operands
,
29 ConversionPatternRewriter
&rewriter
) const override
{
31 if (!op
.getType().hasStaticShape()) {
32 return rewriter
.notifyMatchFailure(
33 op
.getLoc(), "cannot transform alloca with dynamic shape");
36 if (op
.getAlignment().value_or(1) > 1) {
37 // TODO: Allow alignment if it is not more than the natural alignment
39 return rewriter
.notifyMatchFailure(
40 op
.getLoc(), "cannot transform alloca with alignment requirement");
43 auto resultTy
= getTypeConverter()->convertType(op
.getType());
45 return rewriter
.notifyMatchFailure(op
.getLoc(), "cannot convert type");
47 auto noInit
= emitc::OpaqueAttr::get(getContext(), "");
48 rewriter
.replaceOpWithNewOp
<emitc::VariableOp
>(op
, resultTy
, noInit
);
53 struct ConvertGlobal final
: public OpConversionPattern
<memref::GlobalOp
> {
54 using OpConversionPattern::OpConversionPattern
;
57 matchAndRewrite(memref::GlobalOp op
, OpAdaptor operands
,
58 ConversionPatternRewriter
&rewriter
) const override
{
60 if (!op
.getType().hasStaticShape()) {
61 return rewriter
.notifyMatchFailure(
62 op
.getLoc(), "cannot transform global with dynamic shape");
65 if (op
.getAlignment().value_or(1) > 1) {
66 // TODO: Extend GlobalOp to specify alignment via the `alignas` specifier.
67 return rewriter
.notifyMatchFailure(
68 op
.getLoc(), "global variable with alignment requirement is "
69 "currently not supported");
71 auto resultTy
= getTypeConverter()->convertType(op
.getType());
73 return rewriter
.notifyMatchFailure(op
.getLoc(),
74 "cannot convert result type");
77 SymbolTable::Visibility visibility
= SymbolTable::getSymbolVisibility(op
);
78 if (visibility
!= SymbolTable::Visibility::Public
&&
79 visibility
!= SymbolTable::Visibility::Private
) {
80 return rewriter
.notifyMatchFailure(
82 "only public and private visibility is currently supported");
84 // We are explicit in specifing the linkage because the default linkage
85 // for constants is different in C and C++.
86 bool staticSpecifier
= visibility
== SymbolTable::Visibility::Private
;
87 bool externSpecifier
= !staticSpecifier
;
89 Attribute initialValue
= operands
.getInitialValueAttr();
90 if (isa_and_present
<UnitAttr
>(initialValue
))
93 rewriter
.replaceOpWithNewOp
<emitc::GlobalOp
>(
94 op
, operands
.getSymName(), resultTy
, initialValue
, externSpecifier
,
95 staticSpecifier
, operands
.getConstant());
100 struct ConvertGetGlobal final
101 : public OpConversionPattern
<memref::GetGlobalOp
> {
102 using OpConversionPattern::OpConversionPattern
;
105 matchAndRewrite(memref::GetGlobalOp op
, OpAdaptor operands
,
106 ConversionPatternRewriter
&rewriter
) const override
{
108 auto resultTy
= getTypeConverter()->convertType(op
.getType());
110 return rewriter
.notifyMatchFailure(op
.getLoc(),
111 "cannot convert result type");
113 rewriter
.replaceOpWithNewOp
<emitc::GetGlobalOp
>(op
, resultTy
,
114 operands
.getNameAttr());
119 struct ConvertLoad final
: public OpConversionPattern
<memref::LoadOp
> {
120 using OpConversionPattern::OpConversionPattern
;
123 matchAndRewrite(memref::LoadOp op
, OpAdaptor operands
,
124 ConversionPatternRewriter
&rewriter
) const override
{
126 auto resultTy
= getTypeConverter()->convertType(op
.getType());
128 return rewriter
.notifyMatchFailure(op
.getLoc(), "cannot convert type");
132 dyn_cast
<TypedValue
<emitc::ArrayType
>>(operands
.getMemref());
134 return rewriter
.notifyMatchFailure(op
.getLoc(), "expected array type");
137 auto subscript
= rewriter
.create
<emitc::SubscriptOp
>(
138 op
.getLoc(), arrayValue
, operands
.getIndices());
140 rewriter
.replaceOpWithNewOp
<emitc::LoadOp
>(op
, resultTy
, subscript
);
145 struct ConvertStore final
: public OpConversionPattern
<memref::StoreOp
> {
146 using OpConversionPattern::OpConversionPattern
;
149 matchAndRewrite(memref::StoreOp op
, OpAdaptor operands
,
150 ConversionPatternRewriter
&rewriter
) const override
{
152 dyn_cast
<TypedValue
<emitc::ArrayType
>>(operands
.getMemref());
154 return rewriter
.notifyMatchFailure(op
.getLoc(), "expected array type");
157 auto subscript
= rewriter
.create
<emitc::SubscriptOp
>(
158 op
.getLoc(), arrayValue
, operands
.getIndices());
159 rewriter
.replaceOpWithNewOp
<emitc::AssignOp
>(op
, subscript
,
160 operands
.getValue());
166 void mlir::populateMemRefToEmitCTypeConversion(TypeConverter
&typeConverter
) {
167 typeConverter
.addConversion(
168 [&](MemRefType memRefType
) -> std::optional
<Type
> {
169 if (!memRefType
.hasStaticShape() ||
170 !memRefType
.getLayout().isIdentity() || memRefType
.getRank() == 0 ||
171 llvm::any_of(memRefType
.getShape(),
172 [](int64_t dim
) { return dim
== 0; })) {
175 Type convertedElementType
=
176 typeConverter
.convertType(memRefType
.getElementType());
177 if (!convertedElementType
)
179 return emitc::ArrayType::get(memRefType
.getShape(),
180 convertedElementType
);
184 void mlir::populateMemRefToEmitCConversionPatterns(
185 RewritePatternSet
&patterns
, const TypeConverter
&converter
) {
186 patterns
.add
<ConvertAlloca
, ConvertGlobal
, ConvertGetGlobal
, ConvertLoad
,
187 ConvertStore
>(converter
, patterns
.getContext());