1 //===- TosaTestPasses.cpp -------------------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // Test passes to exercise TOSA helper functions.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Dialect/Func/IR/FuncOps.h"
14 #include "mlir/Dialect/Tensor/IR/Tensor.h"
15 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
16 #include "mlir/Dialect/Tosa/Transforms/Passes.h"
17 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/Matchers.h"
20 #include "mlir/Pass/Pass.h"
21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23 #define PASS_NAME "tosa-test-quant-utils"
26 using namespace mlir::tosa
;
28 // This transformation converts quantized uint8 to quantized int8. The
29 // construction of the new type invokes buildQTypeFromMinMax. Extracted from
30 // TOSA legalization infrastructure.
31 struct ConvertTosaNegateOp
: public RewritePattern
{
32 explicit ConvertTosaNegateOp(MLIRContext
*context
)
33 : RewritePattern(tosa::NegateOp::getOperationName(), 1, context
) {}
34 LogicalResult
matchAndRewrite(Operation
*op
,
35 PatternRewriter
&rewriter
) const override
;
39 ConvertTosaNegateOp::matchAndRewrite(Operation
*op
,
40 PatternRewriter
&rewriter
) const {
42 auto tosaNegateOp
= cast
<tosa::NegateOp
>(op
);
45 dyn_cast
<mlir::RankedTensorType
>(tosaNegateOp
.getInput1().getType());
46 // skip if input is not ranked tensor type
50 // skip if it's not ranked tensor type.
52 dyn_cast
<mlir::RankedTensorType
>(tosaNegateOp
.getResult().getType());
56 // skip if output is not per-tensor quantized type.
57 auto outputElementType
=
58 dyn_cast
<mlir::quant::UniformQuantizedType
>(outputType
.getElementType());
59 if (!outputElementType
)
62 // skip if output is not uint8.
63 if (outputElementType
.isSigned() ||
64 outputElementType
.getStorageTypeIntegralWidth() != 8)
67 double typeRangeMin
= double(outputElementType
.getStorageTypeMin() -
68 outputElementType
.getZeroPoint()) *
69 outputElementType
.getScale();
70 double typeRangeMax
= double(outputElementType
.getStorageTypeMax() -
71 outputElementType
.getZeroPoint()) *
72 outputElementType
.getScale();
73 bool narrowRange
= outputElementType
.getStorageTypeMin() == 1;
75 auto dstQConstType
= RankedTensorType::get(
76 outputType
.getShape(),
77 buildQTypeFromMinMax(rewriter
, outputElementType
.getExpressedType(),
78 rewriter
.getF64FloatAttr(typeRangeMin
),
79 rewriter
.getF64FloatAttr(typeRangeMax
),
80 rewriter
.getI32IntegerAttr(
81 outputElementType
.getStorageTypeIntegralWidth()),
83 rewriter
.getBoolAttr(narrowRange
)));
85 ElementsAttr inputElems
;
86 if (!matchPattern(tosaNegateOp
.getInput1(), m_Constant(&inputElems
)))
90 rewriter
.create
<tosa::ConstOp
>(op
->getLoc(), dstQConstType
, inputElems
);
91 auto newNegateOp
= rewriter
.create
<tosa::NegateOp
>(
92 op
->getLoc(), dstQConstType
, newConstOp
.getResult());
94 rewriter
.replaceOp(op
, {newNegateOp
.getResult()});
98 // This transformation modifies the quantized output of a test conv2d input and
99 // appends a TOSA rescale after it. The rescale op requires the invocation of
100 // computeMultiplierAndShift. From TOSA legalization infrastructure.
101 struct ConvertTosaConv2DOp
: public RewritePattern
{
102 explicit ConvertTosaConv2DOp(MLIRContext
*context
)
103 : RewritePattern(tosa::Conv2DOp::getOperationName(), 1, context
) {}
104 LogicalResult
matchAndRewrite(Operation
*op
,
105 PatternRewriter
&rewriter
) const override
;
109 ConvertTosaConv2DOp::matchAndRewrite(Operation
*op
,
110 PatternRewriter
&rewriter
) const {
112 auto tosaConv2DOp
= cast
<tosa::Conv2DOp
>(op
);
115 dyn_cast
<mlir::RankedTensorType
>(tosaConv2DOp
.getInput().getType());
117 // skip if input is not ranked tensor type
122 dyn_cast
<mlir::RankedTensorType
>(tosaConv2DOp
.getWeight().getType());
124 // skip if wt is not ranked tensor type
128 // skip if it's not ranked tensor type.
130 dyn_cast
<mlir::RankedTensorType
>(tosaConv2DOp
.getResult().getType());
135 dyn_cast
<mlir::quant::UniformQuantizedType
>(inputType
.getElementType());
137 dyn_cast
<mlir::quant::UniformQuantizedType
>(weightType
.getElementType());
139 dyn_cast
<mlir::quant::UniformQuantizedType
>(outputType
.getElementType());
141 // Works on quantized type only.
142 if (!(inputQType
&& weightQType
&& outputQType
))
145 auto newTosaConv2DOpType
=
146 RankedTensorType::get(outputType
.getShape(), rewriter
.getIntegerType(32));
148 auto newTosaConv2DOp
= rewriter
.create
<tosa::Conv2DOp
>(
149 op
->getLoc(), newTosaConv2DOpType
, tosaConv2DOp
.getInput(),
150 tosaConv2DOp
.getWeight(), tosaConv2DOp
.getBias(),
151 tosaConv2DOp
.getPadAttr(), tosaConv2DOp
.getStrideAttr(),
152 tosaConv2DOp
.getDilationAttr());
154 // Create rescale to quantized type
155 double inputScale
= inputQType
.getScale();
156 double weightScale
= weightQType
.getScale();
157 double outputScale
= outputQType
.getScale();
158 int64_t outputZp
= outputQType
.getZeroPoint();
160 double opTensorScale
= (inputScale
* weightScale
) / outputScale
;
165 // Obtain the quantized scale = multiplier and shift.
166 computeMultiplierAndShift(opTensorScale
, multiplier
, shift
, 32);
168 auto newTosaRescaleOp
= rewriter
.create
<tosa::RescaleOp
>(
169 op
->getLoc(), outputType
, newTosaConv2DOp
.getResult(),
170 rewriter
.getI32IntegerAttr(0), rewriter
.getI32IntegerAttr(outputZp
),
171 rewriter
.getDenseI32ArrayAttr({multiplier
}),
172 rewriter
.getDenseI8ArrayAttr({static_cast<int8_t>(shift
)}),
173 rewriter
.getBoolAttr(true), rewriter
.getBoolAttr(true),
174 rewriter
.getBoolAttr(false));
176 rewriter
.replaceOp(op
, {newTosaRescaleOp
.getResult()});
182 struct TosaTestQuantUtilAPI
183 : public PassWrapper
<TosaTestQuantUtilAPI
, OperationPass
<func::FuncOp
>> {
184 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TosaTestQuantUtilAPI
)
186 StringRef
getArgument() const final
{ return PASS_NAME
; }
187 StringRef
getDescription() const final
{
188 return "TOSA Test: Exercise the APIs in QuantUtils.cpp.";
190 void runOnOperation() override
;
193 void TosaTestQuantUtilAPI::runOnOperation() {
194 auto *ctx
= &getContext();
195 RewritePatternSet
patterns(ctx
);
196 auto func
= getOperation();
198 patterns
.add
<ConvertTosaNegateOp
>(ctx
);
199 patterns
.add
<ConvertTosaConv2DOp
>(ctx
);
200 (void)applyPatternsAndFoldGreedily(func
, std::move(patterns
));
206 void registerTosaTestQuantUtilAPIPass() {
207 PassRegistration
<TosaTestQuantUtilAPI
>();