[libc++abi] Build cxxabi with sanitizers (#119612)
[llvm-project.git] / mlir / test / lib / Dialect / Tosa / TosaTestPasses.cpp
blobe5a3e2b6fccaa32f539e0457473b7388554f359e
1 //===- TosaTestPasses.cpp -------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Test passes to exercise TOSA helper functions.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Dialect/Func/IR/FuncOps.h"
14 #include "mlir/Dialect/Tensor/IR/Tensor.h"
15 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
16 #include "mlir/Dialect/Tosa/Transforms/Passes.h"
17 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/Matchers.h"
20 #include "mlir/Pass/Pass.h"
21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23 #define PASS_NAME "tosa-test-quant-utils"
25 using namespace mlir;
26 using namespace mlir::tosa;
28 // This transformation converts quantized uint8 to quantized int8. The
29 // construction of the new type invokes buildQTypeFromMinMax. Extracted from
30 // TOSA legalization infrastructure.
31 struct ConvertTosaNegateOp : public RewritePattern {
32 explicit ConvertTosaNegateOp(MLIRContext *context)
33 : RewritePattern(tosa::NegateOp::getOperationName(), 1, context) {}
34 LogicalResult matchAndRewrite(Operation *op,
35 PatternRewriter &rewriter) const override;
38 LogicalResult
39 ConvertTosaNegateOp::matchAndRewrite(Operation *op,
40 PatternRewriter &rewriter) const {
42 auto tosaNegateOp = cast<tosa::NegateOp>(op);
44 auto inputType =
45 dyn_cast<mlir::RankedTensorType>(tosaNegateOp.getInput1().getType());
46 // skip if input is not ranked tensor type
47 if (!inputType)
48 return failure();
50 // skip if it's not ranked tensor type.
51 auto outputType =
52 dyn_cast<mlir::RankedTensorType>(tosaNegateOp.getResult().getType());
53 if (!outputType)
54 return failure();
56 // skip if output is not per-tensor quantized type.
57 auto outputElementType =
58 dyn_cast<mlir::quant::UniformQuantizedType>(outputType.getElementType());
59 if (!outputElementType)
60 return failure();
62 // skip if output is not uint8.
63 if (outputElementType.isSigned() ||
64 outputElementType.getStorageTypeIntegralWidth() != 8)
65 return failure();
67 double typeRangeMin = double(outputElementType.getStorageTypeMin() -
68 outputElementType.getZeroPoint()) *
69 outputElementType.getScale();
70 double typeRangeMax = double(outputElementType.getStorageTypeMax() -
71 outputElementType.getZeroPoint()) *
72 outputElementType.getScale();
73 bool narrowRange = outputElementType.getStorageTypeMin() == 1;
75 auto dstQConstType = RankedTensorType::get(
76 outputType.getShape(),
77 buildQTypeFromMinMax(rewriter, outputElementType.getExpressedType(),
78 rewriter.getF64FloatAttr(typeRangeMin),
79 rewriter.getF64FloatAttr(typeRangeMax),
80 rewriter.getI32IntegerAttr(
81 outputElementType.getStorageTypeIntegralWidth()),
82 0, true /* signed */,
83 rewriter.getBoolAttr(narrowRange)));
85 ElementsAttr inputElems;
86 if (!matchPattern(tosaNegateOp.getInput1(), m_Constant(&inputElems)))
87 return failure();
89 auto newConstOp =
90 rewriter.create<tosa::ConstOp>(op->getLoc(), dstQConstType, inputElems);
91 auto newNegateOp = rewriter.create<tosa::NegateOp>(
92 op->getLoc(), dstQConstType, newConstOp.getResult());
94 rewriter.replaceOp(op, {newNegateOp.getResult()});
95 return success();
98 // This transformation modifies the quantized output of a test conv2d input and
99 // appends a TOSA rescale after it. The rescale op requires the invocation of
100 // computeMultiplierAndShift. From TOSA legalization infrastructure.
101 struct ConvertTosaConv2DOp : public RewritePattern {
102 explicit ConvertTosaConv2DOp(MLIRContext *context)
103 : RewritePattern(tosa::Conv2DOp::getOperationName(), 1, context) {}
104 LogicalResult matchAndRewrite(Operation *op,
105 PatternRewriter &rewriter) const override;
108 LogicalResult
109 ConvertTosaConv2DOp::matchAndRewrite(Operation *op,
110 PatternRewriter &rewriter) const {
112 auto tosaConv2DOp = cast<tosa::Conv2DOp>(op);
114 auto inputType =
115 dyn_cast<mlir::RankedTensorType>(tosaConv2DOp.getInput().getType());
117 // skip if input is not ranked tensor type
118 if (!inputType)
119 return failure();
121 auto weightType =
122 dyn_cast<mlir::RankedTensorType>(tosaConv2DOp.getWeight().getType());
124 // skip if wt is not ranked tensor type
125 if (!weightType)
126 return failure();
128 // skip if it's not ranked tensor type.
129 auto outputType =
130 dyn_cast<mlir::RankedTensorType>(tosaConv2DOp.getResult().getType());
131 if (!outputType)
132 return failure();
134 auto inputQType =
135 dyn_cast<mlir::quant::UniformQuantizedType>(inputType.getElementType());
136 auto weightQType =
137 dyn_cast<mlir::quant::UniformQuantizedType>(weightType.getElementType());
138 auto outputQType =
139 dyn_cast<mlir::quant::UniformQuantizedType>(outputType.getElementType());
141 // Works on quantized type only.
142 if (!(inputQType && weightQType && outputQType))
143 return failure();
145 auto newTosaConv2DOpType =
146 RankedTensorType::get(outputType.getShape(), rewriter.getIntegerType(32));
148 auto newTosaConv2DOp = rewriter.create<tosa::Conv2DOp>(
149 op->getLoc(), newTosaConv2DOpType, tosaConv2DOp.getInput(),
150 tosaConv2DOp.getWeight(), tosaConv2DOp.getBias(),
151 tosaConv2DOp.getPadAttr(), tosaConv2DOp.getStrideAttr(),
152 tosaConv2DOp.getDilationAttr());
154 // Create rescale to quantized type
155 double inputScale = inputQType.getScale();
156 double weightScale = weightQType.getScale();
157 double outputScale = outputQType.getScale();
158 int64_t outputZp = outputQType.getZeroPoint();
160 double opTensorScale = (inputScale * weightScale) / outputScale;
162 int32_t multiplier;
163 int32_t shift;
165 // Obtain the quantized scale = multiplier and shift.
166 computeMultiplierAndShift(opTensorScale, multiplier, shift, 32);
168 auto newTosaRescaleOp = rewriter.create<tosa::RescaleOp>(
169 op->getLoc(), outputType, newTosaConv2DOp.getResult(),
170 rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(outputZp),
171 rewriter.getDenseI32ArrayAttr({multiplier}),
172 rewriter.getDenseI8ArrayAttr({static_cast<int8_t>(shift)}),
173 rewriter.getBoolAttr(true), rewriter.getBoolAttr(true),
174 rewriter.getBoolAttr(false));
176 rewriter.replaceOp(op, {newTosaRescaleOp.getResult()});
177 return success();
180 namespace {
182 struct TosaTestQuantUtilAPI
183 : public PassWrapper<TosaTestQuantUtilAPI, OperationPass<func::FuncOp>> {
184 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TosaTestQuantUtilAPI)
186 StringRef getArgument() const final { return PASS_NAME; }
187 StringRef getDescription() const final {
188 return "TOSA Test: Exercise the APIs in QuantUtils.cpp.";
190 void runOnOperation() override;
193 void TosaTestQuantUtilAPI::runOnOperation() {
194 auto *ctx = &getContext();
195 RewritePatternSet patterns(ctx);
196 auto func = getOperation();
198 patterns.add<ConvertTosaNegateOp>(ctx);
199 patterns.add<ConvertTosaConv2DOp>(ctx);
200 (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
203 } // namespace
205 namespace mlir {
206 void registerTosaTestQuantUtilAPIPass() {
207 PassRegistration<TosaTestQuantUtilAPI>();
209 } // namespace mlir