1 //===- TosaToArith.cpp - Lowering Tosa to Arith Dialect -------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // These rewriters lower from the Tosa to the Arith dialect.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Conversion/TosaToArith/TosaToArith.h"
14 #include "mlir/Dialect/Arith/IR/Arith.h"
15 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
16 #include "mlir/IR/PatternMatch.h"
17 #include "mlir/IR/TypeUtilities.h"
18 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
25 class ConstOpConverter
: public OpRewritePattern
<tosa::ConstOp
> {
27 using OpRewritePattern
<tosa::ConstOp
>::OpRewritePattern
;
29 LogicalResult
matchAndRewrite(tosa::ConstOp op
,
30 PatternRewriter
&rewriter
) const final
{
31 rewriter
.replaceOpWithNewOp
<arith::ConstantOp
>(op
, op
.getValue());
36 Type
matchContainerType(Type element
, Type container
) {
37 if (auto shapedTy
= dyn_cast
<ShapedType
>(container
))
38 return shapedTy
.clone(element
);
43 TypedAttr
getConstantAttr(Type type
, int64_t value
, PatternRewriter
&rewriter
) {
44 if (auto shapedTy
= dyn_cast
<ShapedType
>(type
)) {
45 Type eTy
= shapedTy
.getElementType();
46 APInt
valueInt(eTy
.getIntOrFloatBitWidth(), value
, /*isSigned=*/true);
47 return DenseIntElementsAttr::get(shapedTy
, valueInt
);
50 return rewriter
.getIntegerAttr(type
, value
);
53 Value
getConstantValue(Location loc
, Type type
, int64_t value
,
54 PatternRewriter
&rewriter
) {
55 return rewriter
.create
<arith::ConstantOp
>(
56 loc
, getConstantAttr(type
, value
, rewriter
));
59 // This converts the TOSA ApplyScale operator to a set of arithmetic ops,
60 // using 64-bit operations to perform the necessary multiply, bias, and shift.
61 class ApplyScaleGenericOpConverter
62 : public OpRewritePattern
<tosa::ApplyScaleOp
> {
64 using OpRewritePattern
<tosa::ApplyScaleOp
>::OpRewritePattern
;
66 LogicalResult
matchAndRewrite(tosa::ApplyScaleOp op
,
67 PatternRewriter
&rewriter
) const final
{
68 Location loc
= op
.getLoc();
69 Value value
= op
.getValue();
70 Value multiplier32
= op
.getMultiplier();
72 Type resultTy
= op
.getType();
73 Type valueTy
= value
.getType();
74 Type i32Ty
= matchContainerType(rewriter
.getI32Type(), resultTy
);
75 Type i64Ty
= matchContainerType(rewriter
.getI64Type(), resultTy
);
77 Value zero
= getConstantValue(loc
, valueTy
, 0, rewriter
);
78 Value one64
= getConstantValue(loc
, i64Ty
, 1, rewriter
);
79 Value thirtyOne32
= getConstantValue(loc
, i32Ty
, 31, rewriter
);
81 Value shift32
= rewriter
.create
<arith::ExtUIOp
>(loc
, i32Ty
, op
.getShift());
83 // Compute the multiplication in 64-bits then select the high / low parts.
84 Value value64
= value
;
85 if (getElementTypeOrSelf(valueTy
) != rewriter
.getI64Type())
86 value64
= rewriter
.create
<arith::ExtSIOp
>(loc
, i64Ty
, value
);
88 rewriter
.create
<arith::ExtSIOp
>(loc
, i64Ty
, multiplier32
);
90 rewriter
.create
<arith::MulIOp
>(loc
, value64
, multiplier64
);
92 // Apply normal rounding.
93 Value shift64
= rewriter
.create
<arith::ExtUIOp
>(loc
, i64Ty
, shift32
);
94 Value round
= rewriter
.create
<arith::ShLIOp
>(loc
, one64
, shift64
);
95 round
= rewriter
.create
<arith::ShRUIOp
>(loc
, round
, one64
);
96 multiply64
= rewriter
.create
<arith::AddIOp
>(loc
, multiply64
, round
);
98 // Apply double rounding if necessary.
99 if (op
.getDoubleRound()) {
100 int64_t roundInt
= 1 << 30;
101 Value roundUp
= getConstantValue(loc
, i64Ty
, roundInt
, rewriter
);
102 Value roundDown
= getConstantValue(loc
, i64Ty
, -roundInt
, rewriter
);
103 Value positive
= rewriter
.create
<arith::CmpIOp
>(
104 loc
, arith::CmpIPredicate::sge
, value
, zero
);
106 rewriter
.create
<arith::SelectOp
>(loc
, positive
, roundUp
, roundDown
);
107 Value val
= rewriter
.create
<arith::AddIOp
>(loc
, dir
, multiply64
);
108 Value valid
= rewriter
.create
<arith::CmpIOp
>(
109 loc
, arith::CmpIPredicate::sgt
, shift32
, thirtyOne32
);
111 rewriter
.create
<arith::SelectOp
>(loc
, valid
, val
, multiply64
);
114 Value result64
= rewriter
.create
<arith::ShRSIOp
>(loc
, multiply64
, shift64
);
115 Value result32
= rewriter
.create
<arith::TruncIOp
>(loc
, i32Ty
, result64
);
117 rewriter
.replaceOp(op
, result32
);
122 class ApplyScale32BitOpConverter
: public OpRewritePattern
<tosa::ApplyScaleOp
> {
124 using OpRewritePattern
<tosa::ApplyScaleOp
>::OpRewritePattern
;
126 LogicalResult
matchAndRewrite(tosa::ApplyScaleOp op
,
127 PatternRewriter
&rewriter
) const final
{
128 Location loc
= op
.getLoc();
130 Type resultTy
= op
.getType();
131 Type i32Ty
= matchContainerType(rewriter
.getI32Type(), resultTy
);
133 Value value
= op
.getValue();
134 if (getElementTypeOrSelf(value
.getType()).getIntOrFloatBitWidth() > 32) {
138 Value value32
= op
.getValue();
139 Value multiplier32
= op
.getMultiplier();
140 Value shift32
= rewriter
.create
<arith::ExtUIOp
>(loc
, i32Ty
, op
.getShift());
142 // Constants used during the scaling operation.
143 Value zero32
= getConstantValue(loc
, i32Ty
, 0, rewriter
);
144 Value one32
= getConstantValue(loc
, i32Ty
, 1, rewriter
);
145 Value two32
= getConstantValue(loc
, i32Ty
, 2, rewriter
);
146 Value thirty32
= getConstantValue(loc
, i32Ty
, 30, rewriter
);
147 Value thirtyTwo32
= getConstantValue(loc
, i32Ty
, 32, rewriter
);
149 // Compute the multiplication in 64-bits then select the high / low parts.
150 // Grab out the high/low of the computation
152 rewriter
.create
<arith::MulSIExtendedOp
>(loc
, value32
, multiplier32
);
153 Value low32
= value64
.getLow();
154 Value high32
= value64
.getHigh();
156 // Determine the direction and amount to shift the high bits.
157 Value shiftOver32
= rewriter
.create
<arith::CmpIOp
>(
158 loc
, arith::CmpIPredicate::sge
, shift32
, thirtyTwo32
);
159 Value roundHighBits
= rewriter
.create
<arith::CmpIOp
>(
160 loc
, arith::CmpIPredicate::sgt
, shift32
, thirtyTwo32
);
163 rewriter
.create
<arith::SubIOp
>(loc
, thirtyTwo32
, shift32
);
165 rewriter
.create
<arith::SubIOp
>(loc
, shift32
, thirtyTwo32
);
168 rewriter
.create
<arith::SelectOp
>(loc
, shiftOver32
, zero32
, shiftHighL
);
170 rewriter
.create
<arith::SelectOp
>(loc
, shiftOver32
, shiftHighR
, zero32
);
172 // Conditionally perform our double round.
173 if (op
.getDoubleRound()) {
174 Value negOne32
= getConstantValue(loc
, i32Ty
, -1, rewriter
);
175 Value valuePositive
= rewriter
.create
<arith::CmpIOp
>(
176 loc
, arith::CmpIPredicate::sge
, value32
, zero32
);
179 rewriter
.create
<arith::SelectOp
>(loc
, valuePositive
, one32
, negOne32
);
181 rewriter
.create
<arith::SelectOp
>(loc
, shiftOver32
, roundDir
, zero32
);
183 Value shiftLow
= rewriter
.create
<arith::ShRUIOp
>(loc
, low32
, thirty32
);
184 Value rounded
= rewriter
.create
<arith::AddIOp
>(loc
, shiftLow
, roundDir
);
185 Value carry
= rewriter
.create
<arith::ShRSIOp
>(loc
, rounded
, two32
);
188 rewriter
.create
<arith::ShLIOp
>(loc
, roundDir
, thirty32
);
190 low32
= rewriter
.create
<arith::AddIOp
>(loc
, low32
, shiftRound
);
191 high32
= rewriter
.create
<arith::AddIOp
>(loc
, high32
, carry
);
194 // Conditionally apply rounding in the low bits.
196 Value shiftSubOne
= rewriter
.create
<arith::SubIOp
>(loc
, shift32
, one32
);
197 Value roundBit
= rewriter
.create
<arith::ShLIOp
>(loc
, one32
, shiftSubOne
);
198 roundBit
= rewriter
.create
<arith::SelectOp
>(loc
, roundHighBits
, zero32
,
201 Value newLow32
= rewriter
.create
<arith::AddIOp
>(loc
, low32
, roundBit
);
202 Value wasRounded
= rewriter
.create
<arith::CmpIOp
>(
203 loc
, arith::CmpIPredicate::ugt
, low32
, newLow32
);
206 Value rounded32
= rewriter
.create
<arith::ExtUIOp
>(loc
, i32Ty
, wasRounded
);
207 high32
= rewriter
.create
<arith::AddIOp
>(loc
, high32
, rounded32
);
210 // Conditionally apply rounding in the high bits.
213 rewriter
.create
<arith::SubIOp
>(loc
, shiftHighR
, one32
);
214 Value roundBit
= rewriter
.create
<arith::ShLIOp
>(loc
, one32
, shiftSubOne
);
215 roundBit
= rewriter
.create
<arith::SelectOp
>(loc
, roundHighBits
, roundBit
,
217 high32
= rewriter
.create
<arith::AddIOp
>(loc
, high32
, roundBit
);
220 // Combine the correct high/low bits into the final rescale result.
221 high32
= rewriter
.create
<arith::ShLIOp
>(loc
, high32
, shiftHighL
);
222 high32
= rewriter
.create
<arith::ShRSIOp
>(loc
, high32
, shiftHighR
);
223 low32
= rewriter
.create
<arith::ShRUIOp
>(loc
, low32
, shift32
);
224 low32
= rewriter
.create
<arith::SelectOp
>(loc
, shiftOver32
, zero32
, low32
);
226 // Apply the rounding behavior and shift to the final alignment.
227 Value result
= rewriter
.create
<arith::AddIOp
>(loc
, low32
, high32
);
229 // Truncate if necessary.
230 if (!getElementTypeOrSelf(resultTy
).isInteger(32)) {
231 result
= rewriter
.create
<arith::TruncIOp
>(loc
, resultTy
, result
);
234 rewriter
.replaceOp(op
, result
);
241 void mlir::tosa::populateTosaToArithConversionPatterns(
242 RewritePatternSet
*patterns
) {
243 patterns
->add
<ConstOpConverter
>(patterns
->getContext());
246 void mlir::tosa::populateTosaRescaleToArithConversionPatterns(
247 RewritePatternSet
*patterns
, bool include32Bit
) {
248 patterns
->add
<ApplyScaleGenericOpConverter
>(patterns
->getContext(), 100);
250 patterns
->add
<ApplyScale32BitOpConverter
>(patterns
->getContext(), 200);