Revert "[HLSL] Add `Increment`/`DecrementCounter` methods to structured buffers ...
[llvm-project.git] / mlir / lib / Conversion / TosaToLinalg / TosaToLinalgPass.cpp
blob06a7262c467421ddfb7134a0a3ede8371564e28e
1 //===- TosaToLinalgPass.cpp - Lowering Tosa to Linalg Dialect -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This transformation pass legalizes Tosa operations to the Linalg dialect.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
15 #include "mlir/Dialect/Arith/IR/Arith.h"
16 #include "mlir/Dialect/Func/IR/FuncOps.h"
17 #include "mlir/Dialect/Index/IR/IndexDialect.h"
18 #include "mlir/Dialect/Linalg/IR/Linalg.h"
19 #include "mlir/Dialect/Math/IR/Math.h"
20 #include "mlir/Dialect/SCF/IR/SCF.h"
21 #include "mlir/Dialect/Tensor/IR/Tensor.h"
22 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
23 #include "mlir/Dialect/Tosa/Transforms/Passes.h"
24 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/Pass/PassManager.h"
27 #include "mlir/Transforms/DialectConversion.h"
28 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
29 #include "mlir/Transforms/Passes.h"
31 namespace mlir {
32 #define GEN_PASS_DEF_TOSATOLINALG
33 #include "mlir/Conversion/Passes.h.inc"
34 } // namespace mlir
36 using namespace mlir;
38 namespace {
39 struct TosaToLinalg : public impl::TosaToLinalgBase<TosaToLinalg> {
40 public:
41 void getDependentDialects(DialectRegistry &registry) const override {
42 registry
43 .insert<arith::ArithDialect, linalg::LinalgDialect, math::MathDialect,
44 index::IndexDialect, tensor::TensorDialect, scf::SCFDialect>();
47 void runOnOperation() override {
48 RewritePatternSet patterns(&getContext());
49 ConversionTarget target(getContext());
50 target.addLegalDialect<linalg::LinalgDialect, tensor::TensorDialect,
51 scf::SCFDialect>();
52 target.addIllegalDialect<tosa::TosaDialect>();
54 // Not every TOSA op can be legalized to linalg.
55 target.addLegalOp<tosa::ApplyScaleOp>();
56 target.addLegalOp<tosa::IfOp>();
57 target.addLegalOp<tosa::ConstOp>();
58 target.addLegalOp<tosa::WhileOp>();
59 target.addLegalOp<tosa::ConcatOp>();
60 target.addLegalOp<tosa::SliceOp>();
61 target.addLegalOp<tosa::ReshapeOp>();
62 target.addLegalOp<tosa::PadOp>();
64 target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
66 TypeConverter converter;
67 tosa::populateTosaTypeConversion(converter);
69 FunctionOpInterface func = getOperation();
70 mlir::tosa::populateTosaToLinalgConversionPatterns(converter, &patterns);
71 if (failed(applyFullConversion(func, target, std::move(patterns))))
72 signalPassFailure();
75 } // namespace
77 std::unique_ptr<Pass> mlir::tosa::createTosaToLinalg() {
78 return std::make_unique<TosaToLinalg>();
81 void mlir::tosa::addTosaToLinalgPasses(
82 OpPassManager &pm, const TosaToLinalgOptions &options,
83 const TosaToLinalgNamedOptions &tosaToLinalgNamedOptions,
84 std::optional<tosa::TosaValidationOptions> validationOptions) {
85 // Optional decompositions are designed to benefit linalg.
86 if (!options.disableTosaDecompositions)
87 pm.addNestedPass<func::FuncOp>(tosa::createTosaOptionalDecompositions());
88 pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
90 pm.addNestedPass<func::FuncOp>(tosa::createTosaInferShapesPass());
91 pm.addNestedPass<func::FuncOp>(tosa::createTosaMakeBroadcastablePass());
92 pm.addNestedPass<func::FuncOp>(
93 tosa::createTosaToLinalgNamed(tosaToLinalgNamedOptions));
94 pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
95 // TODO: Remove pass that operates on const tensor and enable optionality
96 pm.addNestedPass<func::FuncOp>(tosa::createTosaLayerwiseConstantFoldPass(
97 {options.aggressiveReduceConstant}));
98 pm.addNestedPass<func::FuncOp>(tosa::createTosaMakeBroadcastablePass());
99 if (validationOptions)
100 pm.addPass(tosa::createTosaValidation(*validationOptions));
101 pm.addNestedPass<func::FuncOp>(tosa::createTosaToLinalg());
104 //===----------------------------------------------------------------------===//
105 // Pipeline registration.
106 //===----------------------------------------------------------------------===//
108 void mlir::tosa::registerTosaToLinalgPipelines() {
109 PassPipelineRegistration<>(
110 "tosa-to-linalg-pipeline",
111 "The default pipeline for converting TOSA operators to the equivalent "
112 "operations using the tensor operations in LinAlg as well as LinAlg "
113 "named operations.",
114 [](OpPassManager &pm) {
115 TosaToLinalgOptions tosaToLinalgOptions;
116 TosaToLinalgNamedOptions tosaToLinalgNamedOptions;
117 TosaValidationOptions validationOptions;
118 validationOptions.profile = {"none"};
119 validationOptions.StrictOperationSpecAlignment = true;
120 validationOptions.level = tosa::TosaLevelEnum::EightK;
121 tosa::addTosaToLinalgPasses(pm, tosaToLinalgOptions,
122 tosaToLinalgNamedOptions,
123 validationOptions);