[rtsan] Remove mkfifoat interceptor (#116997)
[llvm-project.git] / mlir / lib / Conversion / TosaToArith / TosaToArith.cpp
blob593dbaa6c6545abfd83021bfaef1d768b1e4d1ba
1 //===- TosaToArith.cpp - Lowering Tosa to Arith Dialect -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // These rewriters lower from the Tosa to the Arith dialect.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Conversion/TosaToArith/TosaToArith.h"
14 #include "mlir/Dialect/Arith/IR/Arith.h"
15 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
16 #include "mlir/IR/PatternMatch.h"
17 #include "mlir/IR/TypeUtilities.h"
18 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20 using namespace mlir;
21 using namespace tosa;
23 namespace {
25 class ConstOpConverter : public OpRewritePattern<tosa::ConstOp> {
26 public:
27 using OpRewritePattern<tosa::ConstOp>::OpRewritePattern;
29 LogicalResult matchAndRewrite(tosa::ConstOp op,
30 PatternRewriter &rewriter) const final {
31 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, op.getValue());
32 return success();
36 Type matchContainerType(Type element, Type container) {
37 if (auto shapedTy = dyn_cast<ShapedType>(container))
38 return shapedTy.clone(element);
40 return element;
43 TypedAttr getConstantAttr(Type type, int64_t value, PatternRewriter &rewriter) {
44 if (auto shapedTy = dyn_cast<ShapedType>(type)) {
45 Type eTy = shapedTy.getElementType();
46 APInt valueInt(eTy.getIntOrFloatBitWidth(), value, /*isSigned=*/true);
47 return DenseIntElementsAttr::get(shapedTy, valueInt);
50 return rewriter.getIntegerAttr(type, value);
53 Value getConstantValue(Location loc, Type type, int64_t value,
54 PatternRewriter &rewriter) {
55 return rewriter.create<arith::ConstantOp>(
56 loc, getConstantAttr(type, value, rewriter));
59 // This converts the TOSA ApplyScale operator to a set of arithmetic ops,
60 // using 64-bit operations to perform the necessary multiply, bias, and shift.
61 class ApplyScaleGenericOpConverter
62 : public OpRewritePattern<tosa::ApplyScaleOp> {
63 public:
64 using OpRewritePattern<tosa::ApplyScaleOp>::OpRewritePattern;
66 LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
67 PatternRewriter &rewriter) const final {
68 Location loc = op.getLoc();
69 Value value = op.getValue();
70 Value multiplier32 = op.getMultiplier();
72 Type resultTy = op.getType();
73 Type valueTy = value.getType();
74 Type i32Ty = matchContainerType(rewriter.getI32Type(), resultTy);
75 Type i64Ty = matchContainerType(rewriter.getI64Type(), resultTy);
77 Value zero = getConstantValue(loc, valueTy, 0, rewriter);
78 Value one64 = getConstantValue(loc, i64Ty, 1, rewriter);
79 Value thirtyOne32 = getConstantValue(loc, i32Ty, 31, rewriter);
81 Value shift32 = rewriter.create<arith::ExtUIOp>(loc, i32Ty, op.getShift());
83 // Compute the multiplication in 64-bits then select the high / low parts.
84 Value value64 = value;
85 if (getElementTypeOrSelf(valueTy) != rewriter.getI64Type())
86 value64 = rewriter.create<arith::ExtSIOp>(loc, i64Ty, value);
87 Value multiplier64 =
88 rewriter.create<arith::ExtSIOp>(loc, i64Ty, multiplier32);
89 Value multiply64 =
90 rewriter.create<arith::MulIOp>(loc, value64, multiplier64);
92 // Apply normal rounding.
93 Value shift64 = rewriter.create<arith::ExtUIOp>(loc, i64Ty, shift32);
94 Value round = rewriter.create<arith::ShLIOp>(loc, one64, shift64);
95 round = rewriter.create<arith::ShRUIOp>(loc, round, one64);
96 multiply64 = rewriter.create<arith::AddIOp>(loc, multiply64, round);
98 // Apply double rounding if necessary.
99 if (op.getDoubleRound()) {
100 int64_t roundInt = 1 << 30;
101 Value roundUp = getConstantValue(loc, i64Ty, roundInt, rewriter);
102 Value roundDown = getConstantValue(loc, i64Ty, -roundInt, rewriter);
103 Value positive = rewriter.create<arith::CmpIOp>(
104 loc, arith::CmpIPredicate::sge, value, zero);
105 Value dir =
106 rewriter.create<arith::SelectOp>(loc, positive, roundUp, roundDown);
107 Value val = rewriter.create<arith::AddIOp>(loc, dir, multiply64);
108 Value valid = rewriter.create<arith::CmpIOp>(
109 loc, arith::CmpIPredicate::sgt, shift32, thirtyOne32);
110 multiply64 =
111 rewriter.create<arith::SelectOp>(loc, valid, val, multiply64);
114 Value result64 = rewriter.create<arith::ShRSIOp>(loc, multiply64, shift64);
115 Value result32 = rewriter.create<arith::TruncIOp>(loc, i32Ty, result64);
117 rewriter.replaceOp(op, result32);
118 return success();
122 class ApplyScale32BitOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
123 public:
124 using OpRewritePattern<tosa::ApplyScaleOp>::OpRewritePattern;
126 LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
127 PatternRewriter &rewriter) const final {
128 Location loc = op.getLoc();
130 Type resultTy = op.getType();
131 Type i32Ty = matchContainerType(rewriter.getI32Type(), resultTy);
133 Value value = op.getValue();
134 if (getElementTypeOrSelf(value.getType()).getIntOrFloatBitWidth() > 32) {
135 return failure();
138 Value value32 = op.getValue();
139 Value multiplier32 = op.getMultiplier();
140 Value shift32 = rewriter.create<arith::ExtUIOp>(loc, i32Ty, op.getShift());
142 // Constants used during the scaling operation.
143 Value zero32 = getConstantValue(loc, i32Ty, 0, rewriter);
144 Value one32 = getConstantValue(loc, i32Ty, 1, rewriter);
145 Value two32 = getConstantValue(loc, i32Ty, 2, rewriter);
146 Value thirty32 = getConstantValue(loc, i32Ty, 30, rewriter);
147 Value thirtyTwo32 = getConstantValue(loc, i32Ty, 32, rewriter);
149 // Compute the multiplication in 64-bits then select the high / low parts.
150 // Grab out the high/low of the computation
151 auto value64 =
152 rewriter.create<arith::MulSIExtendedOp>(loc, value32, multiplier32);
153 Value low32 = value64.getLow();
154 Value high32 = value64.getHigh();
156 // Determine the direction and amount to shift the high bits.
157 Value shiftOver32 = rewriter.create<arith::CmpIOp>(
158 loc, arith::CmpIPredicate::sge, shift32, thirtyTwo32);
159 Value roundHighBits = rewriter.create<arith::CmpIOp>(
160 loc, arith::CmpIPredicate::sgt, shift32, thirtyTwo32);
162 Value shiftHighL =
163 rewriter.create<arith::SubIOp>(loc, thirtyTwo32, shift32);
164 Value shiftHighR =
165 rewriter.create<arith::SubIOp>(loc, shift32, thirtyTwo32);
167 shiftHighL =
168 rewriter.create<arith::SelectOp>(loc, shiftOver32, zero32, shiftHighL);
169 shiftHighR =
170 rewriter.create<arith::SelectOp>(loc, shiftOver32, shiftHighR, zero32);
172 // Conditionally perform our double round.
173 if (op.getDoubleRound()) {
174 Value negOne32 = getConstantValue(loc, i32Ty, -1, rewriter);
175 Value valuePositive = rewriter.create<arith::CmpIOp>(
176 loc, arith::CmpIPredicate::sge, value32, zero32);
178 Value roundDir =
179 rewriter.create<arith::SelectOp>(loc, valuePositive, one32, negOne32);
180 roundDir =
181 rewriter.create<arith::SelectOp>(loc, shiftOver32, roundDir, zero32);
183 Value shiftLow = rewriter.create<arith::ShRUIOp>(loc, low32, thirty32);
184 Value rounded = rewriter.create<arith::AddIOp>(loc, shiftLow, roundDir);
185 Value carry = rewriter.create<arith::ShRSIOp>(loc, rounded, two32);
187 Value shiftRound =
188 rewriter.create<arith::ShLIOp>(loc, roundDir, thirty32);
190 low32 = rewriter.create<arith::AddIOp>(loc, low32, shiftRound);
191 high32 = rewriter.create<arith::AddIOp>(loc, high32, carry);
194 // Conditionally apply rounding in the low bits.
196 Value shiftSubOne = rewriter.create<arith::SubIOp>(loc, shift32, one32);
197 Value roundBit = rewriter.create<arith::ShLIOp>(loc, one32, shiftSubOne);
198 roundBit = rewriter.create<arith::SelectOp>(loc, roundHighBits, zero32,
199 roundBit);
201 Value newLow32 = rewriter.create<arith::AddIOp>(loc, low32, roundBit);
202 Value wasRounded = rewriter.create<arith::CmpIOp>(
203 loc, arith::CmpIPredicate::ugt, low32, newLow32);
204 low32 = newLow32;
206 Value rounded32 = rewriter.create<arith::ExtUIOp>(loc, i32Ty, wasRounded);
207 high32 = rewriter.create<arith::AddIOp>(loc, high32, rounded32);
210 // Conditionally apply rounding in the high bits.
212 Value shiftSubOne =
213 rewriter.create<arith::SubIOp>(loc, shiftHighR, one32);
214 Value roundBit = rewriter.create<arith::ShLIOp>(loc, one32, shiftSubOne);
215 roundBit = rewriter.create<arith::SelectOp>(loc, roundHighBits, roundBit,
216 zero32);
217 high32 = rewriter.create<arith::AddIOp>(loc, high32, roundBit);
220 // Combine the correct high/low bits into the final rescale result.
221 high32 = rewriter.create<arith::ShLIOp>(loc, high32, shiftHighL);
222 high32 = rewriter.create<arith::ShRSIOp>(loc, high32, shiftHighR);
223 low32 = rewriter.create<arith::ShRUIOp>(loc, low32, shift32);
224 low32 = rewriter.create<arith::SelectOp>(loc, shiftOver32, zero32, low32);
226 // Apply the rounding behavior and shift to the final alignment.
227 Value result = rewriter.create<arith::AddIOp>(loc, low32, high32);
229 // Truncate if necessary.
230 if (!getElementTypeOrSelf(resultTy).isInteger(32)) {
231 result = rewriter.create<arith::TruncIOp>(loc, resultTy, result);
234 rewriter.replaceOp(op, result);
235 return success();
239 } // namespace
241 void mlir::tosa::populateTosaToArithConversionPatterns(
242 RewritePatternSet *patterns) {
243 patterns->add<ConstOpConverter>(patterns->getContext());
246 void mlir::tosa::populateTosaRescaleToArithConversionPatterns(
247 RewritePatternSet *patterns, bool include32Bit) {
248 patterns->add<ApplyScaleGenericOpConverter>(patterns->getContext(), 100);
249 if (include32Bit) {
250 patterns->add<ApplyScale32BitOpConverter>(patterns->getContext(), 200);