1 //===- TosaToMLProgram.cpp - Lowering Tosa to MLProgram Dialect------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // These rewriters lower from the TOSA dialect to the MLProgram dialect.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Conversion/TosaToMLProgram/TosaToMLProgram.h"
14 #include "mlir/Dialect/MLProgram/IR/MLProgram.h"
15 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
16 #include "mlir/IR/IRMapping.h"
17 #include "mlir/IR/PatternMatch.h"
23 class VariableOpConverter
: public OpRewritePattern
<tosa::VariableOp
> {
25 using OpRewritePattern
<tosa::VariableOp
>::OpRewritePattern
;
27 LogicalResult
matchAndRewrite(tosa::VariableOp op
,
28 PatternRewriter
&rewriter
) const final
{
29 auto newVariable
= rewriter
.create
<mlir::ml_program::GlobalOp
>(
30 op
.getLoc(), op
.getName(), op
.getType(), /*is_mutable=*/true,
31 op
.getInitialValueAttr(), /*sym_visibility=*/nullptr);
32 newVariable
.setPrivate();
33 rewriter
.replaceOp(op
, newVariable
);
38 class VariableWriteOpConverter
39 : public OpRewritePattern
<tosa::VariableWriteOp
> {
41 using OpRewritePattern
<tosa::VariableWriteOp
>::OpRewritePattern
;
43 LogicalResult
matchAndRewrite(tosa::VariableWriteOp op
,
44 PatternRewriter
&rewriter
) const final
{
45 auto globalSymbolRef
=
46 SymbolRefAttr::get(rewriter
.getContext(), op
.getName());
47 auto newVariableWrite
= rewriter
.create
<ml_program::GlobalStoreOp
>(
48 op
.getLoc(), globalSymbolRef
, op
.getValue());
49 rewriter
.replaceOp(op
, newVariableWrite
);
54 class VariableReadOpConverter
: public OpRewritePattern
<tosa::VariableReadOp
> {
56 using OpRewritePattern
<tosa::VariableReadOp
>::OpRewritePattern
;
58 LogicalResult
matchAndRewrite(tosa::VariableReadOp op
,
59 PatternRewriter
&rewriter
) const final
{
60 auto globalSymbolRef
=
61 SymbolRefAttr::get(rewriter
.getContext(), op
.getName());
62 auto newVariableRead
= rewriter
.create
<ml_program::GlobalLoadOp
>(
63 op
.getLoc(), op
.getType(), globalSymbolRef
);
64 rewriter
.replaceOp(op
, newVariableRead
);
72 void mlir::tosa::populateTosaToMLProgramConversionPatterns(
73 RewritePatternSet
*patterns
) {
74 patterns
->add
<VariableOpConverter
, VariableWriteOpConverter
,
75 VariableReadOpConverter
>(patterns
->getContext());