1 //===- TosaToLinalgPass.cpp - Lowering Tosa to Linalg Dialect -------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This transformation pass legalizes Tosa operations to the Linalg dialect.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
15 #include "mlir/Dialect/Arith/IR/Arith.h"
16 #include "mlir/Dialect/Func/IR/FuncOps.h"
17 #include "mlir/Dialect/Index/IR/IndexDialect.h"
18 #include "mlir/Dialect/Linalg/IR/Linalg.h"
19 #include "mlir/Dialect/Math/IR/Math.h"
20 #include "mlir/Dialect/SCF/IR/SCF.h"
21 #include "mlir/Dialect/Tensor/IR/Tensor.h"
22 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
23 #include "mlir/Dialect/Tosa/Transforms/Passes.h"
24 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/Pass/PassManager.h"
27 #include "mlir/Transforms/DialectConversion.h"
28 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
29 #include "mlir/Transforms/Passes.h"
32 #define GEN_PASS_DEF_TOSATOLINALG
33 #include "mlir/Conversion/Passes.h.inc"
39 struct TosaToLinalg
: public impl::TosaToLinalgBase
<TosaToLinalg
> {
41 void getDependentDialects(DialectRegistry
®istry
) const override
{
43 .insert
<arith::ArithDialect
, linalg::LinalgDialect
, math::MathDialect
,
44 index::IndexDialect
, tensor::TensorDialect
, scf::SCFDialect
>();
47 void runOnOperation() override
{
48 RewritePatternSet
patterns(&getContext());
49 ConversionTarget
target(getContext());
50 target
.addLegalDialect
<linalg::LinalgDialect
, tensor::TensorDialect
,
52 target
.addIllegalDialect
<tosa::TosaDialect
>();
54 // Not every TOSA op can be legalized to linalg.
55 target
.addLegalOp
<tosa::ApplyScaleOp
>();
56 target
.addLegalOp
<tosa::IfOp
>();
57 target
.addLegalOp
<tosa::ConstOp
>();
58 target
.addLegalOp
<tosa::WhileOp
>();
59 target
.addLegalOp
<tosa::ConcatOp
>();
60 target
.addLegalOp
<tosa::SliceOp
>();
61 target
.addLegalOp
<tosa::ReshapeOp
>();
62 target
.addLegalOp
<tosa::PadOp
>();
64 target
.markUnknownOpDynamicallyLegal([](Operation
*) { return true; });
66 TypeConverter converter
;
67 tosa::populateTosaTypeConversion(converter
);
69 FunctionOpInterface func
= getOperation();
70 mlir::tosa::populateTosaToLinalgConversionPatterns(converter
, &patterns
);
71 if (failed(applyFullConversion(func
, target
, std::move(patterns
))))
77 std::unique_ptr
<Pass
> mlir::tosa::createTosaToLinalg() {
78 return std::make_unique
<TosaToLinalg
>();
81 void mlir::tosa::addTosaToLinalgPasses(
82 OpPassManager
&pm
, const TosaToLinalgOptions
&options
,
83 const TosaToLinalgNamedOptions
&tosaToLinalgNamedOptions
,
84 std::optional
<tosa::TosaValidationOptions
> validationOptions
) {
85 // Optional decompositions are designed to benefit linalg.
86 if (!options
.disableTosaDecompositions
)
87 pm
.addNestedPass
<func::FuncOp
>(tosa::createTosaOptionalDecompositions());
88 pm
.addNestedPass
<func::FuncOp
>(createCanonicalizerPass());
90 pm
.addNestedPass
<func::FuncOp
>(tosa::createTosaInferShapesPass());
91 pm
.addNestedPass
<func::FuncOp
>(tosa::createTosaMakeBroadcastablePass());
92 pm
.addNestedPass
<func::FuncOp
>(
93 tosa::createTosaToLinalgNamed(tosaToLinalgNamedOptions
));
94 pm
.addNestedPass
<func::FuncOp
>(createCanonicalizerPass());
95 // TODO: Remove pass that operates on const tensor and enable optionality
96 pm
.addNestedPass
<func::FuncOp
>(tosa::createTosaLayerwiseConstantFoldPass(
97 {options
.aggressiveReduceConstant
}));
98 pm
.addNestedPass
<func::FuncOp
>(tosa::createTosaMakeBroadcastablePass());
99 if (validationOptions
)
100 pm
.addPass(tosa::createTosaValidation(*validationOptions
));
101 pm
.addNestedPass
<func::FuncOp
>(tosa::createTosaToLinalg());
104 //===----------------------------------------------------------------------===//
105 // Pipeline registration.
106 //===----------------------------------------------------------------------===//
108 void mlir::tosa::registerTosaToLinalgPipelines() {
109 PassPipelineRegistration
<>(
110 "tosa-to-linalg-pipeline",
111 "The default pipeline for converting TOSA operators to the equivalent "
112 "operations using the tensor operations in LinAlg as well as LinAlg "
114 [](OpPassManager
&pm
) {
115 TosaToLinalgOptions tosaToLinalgOptions
;
116 TosaToLinalgNamedOptions tosaToLinalgNamedOptions
;
117 TosaValidationOptions validationOptions
;
118 validationOptions
.profile
= {"none"};
119 validationOptions
.StrictOperationSpecAlignment
= true;
120 validationOptions
.level
= tosa::TosaLevelEnum::EightK
;
121 tosa::addTosaToLinalgPasses(pm
, tosaToLinalgOptions
,
122 tosaToLinalgNamedOptions
,