1 //===- TosaToLinalg.cpp - Lowering Tosa to Linalg Dialect -----------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // These rewriters lower from the Tosa to the Linalg dialect.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
14 #include "mlir/Dialect/Arith/IR/Arith.h"
15 #include "mlir/Dialect/Arith/Utils/Utils.h"
16 #include "mlir/Dialect/Index/IR/IndexOps.h"
17 #include "mlir/Dialect/Linalg/IR/Linalg.h"
18 #include "mlir/Dialect/Math/IR/Math.h"
19 #include "mlir/Dialect/SCF/IR/SCF.h"
20 #include "mlir/Dialect/Tensor/IR/Tensor.h"
21 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
22 #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
23 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
24 #include "mlir/Dialect/Utils/StaticValueUtils.h"
25 #include "mlir/IR/ImplicitLocOpBuilder.h"
26 #include "mlir/IR/Matchers.h"
27 #include "mlir/IR/OpDefinition.h"
28 #include "mlir/IR/PatternMatch.h"
29 #include "mlir/Transforms/DialectConversion.h"
30 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/Sequence.h"
37 using namespace mlir::tosa
;
40 static arith::ConstantOp
41 createConstFromIntAttribute(Operation
*op
, const std::string
&attrName
,
42 Type requiredAttrType
, OpBuilder
&rewriter
) {
43 auto castedN
= static_cast<T
>(
44 cast
<IntegerAttr
>(op
->getAttr(attrName
)).getValue().getSExtValue());
45 return rewriter
.create
<arith::ConstantOp
>(
46 op
->getLoc(), IntegerAttr::get(requiredAttrType
, castedN
));
49 static Value
createLinalgBodyCalculationForElementwiseOp(
50 Operation
*op
, ValueRange args
, ArrayRef
<Type
> resultTypes
,
51 ConversionPatternRewriter
&rewriter
) {
52 Location loc
= op
->getLoc();
54 cast
<ShapedType
>(op
->getOperand(0).getType()).getElementType();
57 if (isa
<tosa::AbsOp
>(op
) && isa
<FloatType
>(elementTy
))
58 return rewriter
.create
<math::AbsFOp
>(loc
, resultTypes
, args
);
60 if (isa
<tosa::AbsOp
>(op
) && isa
<IntegerType
>(elementTy
)) {
61 auto zero
= rewriter
.create
<arith::ConstantOp
>(
62 loc
, rewriter
.getZeroAttr(elementTy
));
63 auto neg
= rewriter
.create
<arith::SubIOp
>(loc
, zero
, args
[0]);
64 return rewriter
.create
<arith::MaxSIOp
>(loc
, args
[0], neg
);
68 if (isa
<tosa::AddOp
>(op
) && isa
<FloatType
>(elementTy
))
69 return rewriter
.create
<arith::AddFOp
>(loc
, resultTypes
, args
);
71 if (isa
<tosa::AddOp
>(op
) && isa
<IntegerType
>(elementTy
))
72 return rewriter
.create
<arith::AddIOp
>(loc
, resultTypes
, args
);
75 if (isa
<tosa::SubOp
>(op
) && isa
<FloatType
>(elementTy
))
76 return rewriter
.create
<arith::SubFOp
>(loc
, resultTypes
, args
);
78 if (isa
<tosa::SubOp
>(op
) && isa
<IntegerType
>(elementTy
))
79 return rewriter
.create
<arith::SubIOp
>(loc
, resultTypes
, args
);
82 if (isa
<tosa::IntDivOp
>(op
) && isa
<IntegerType
>(elementTy
))
83 return rewriter
.create
<arith::DivSIOp
>(loc
, resultTypes
, args
);
86 if (isa
<tosa::ReciprocalOp
>(op
) && isa
<FloatType
>(elementTy
)) {
88 rewriter
.create
<arith::ConstantOp
>(loc
, FloatAttr::get(elementTy
, 1));
89 return rewriter
.create
<arith::DivFOp
>(loc
, resultTypes
, one
, args
[0]);
93 if (isa
<tosa::MulOp
>(op
) && isa
<FloatType
>(elementTy
))
94 return rewriter
.create
<arith::MulFOp
>(loc
, resultTypes
, args
);
96 if (isa
<tosa::MulOp
>(op
) && isa
<IntegerType
>(elementTy
)) {
100 cast
<IntegerAttr
>(op
->getAttr("shift")).getValue().getSExtValue();
103 rewriter
.create
<arith::ConstantIntOp
>(loc
, shift
, /*bitwidth=*/8);
104 if (!a
.getType().isInteger(32))
105 a
= rewriter
.create
<arith::ExtSIOp
>(loc
, rewriter
.getI32Type(), a
);
107 if (!b
.getType().isInteger(32))
108 b
= rewriter
.create
<arith::ExtSIOp
>(loc
, rewriter
.getI32Type(), b
);
110 auto result
= rewriter
.create
<tosa::ApplyScaleOp
>(
111 loc
, rewriter
.getI32Type(), a
, b
, shiftConst
,
112 rewriter
.getBoolAttr(false));
114 if (elementTy
.isInteger(32))
117 return rewriter
.create
<arith::TruncIOp
>(loc
, elementTy
, result
);
120 int aWidth
= a
.getType().getIntOrFloatBitWidth();
121 int bWidth
= b
.getType().getIntOrFloatBitWidth();
122 int cWidth
= resultTypes
[0].getIntOrFloatBitWidth();
125 a
= rewriter
.create
<arith::ExtSIOp
>(loc
, resultTypes
[0], a
);
127 b
= rewriter
.create
<arith::ExtSIOp
>(loc
, resultTypes
[0], b
);
129 return rewriter
.create
<arith::MulIOp
>(loc
, resultTypes
, a
, b
);
133 if (isa
<tosa::NegateOp
>(op
) && isa
<FloatType
>(elementTy
))
134 return rewriter
.create
<arith::NegFOp
>(loc
, resultTypes
, args
);
136 if (isa
<tosa::NegateOp
>(op
) && isa
<IntegerType
>(elementTy
)) {
137 int64_t inZp
= 0, outZp
= 0;
139 if (cast
<tosa::NegateOp
>(op
).getQuantizationInfo()) {
140 auto quantizationInfo
= cast
<tosa::NegateOp
>(op
).getQuantizationInfo();
141 inZp
= quantizationInfo
.value().getInputZp();
142 outZp
= quantizationInfo
.value().getOutputZp();
145 int32_t inputBitWidth
= elementTy
.getIntOrFloatBitWidth();
146 if (!inZp
&& !outZp
) {
147 auto constant
= rewriter
.create
<arith::ConstantOp
>(
148 loc
, IntegerAttr::get(elementTy
, 0));
149 return rewriter
.create
<arith::SubIOp
>(loc
, resultTypes
, constant
,
153 // Compute the maximum value that can occur in the intermediate buffer.
154 int64_t zpAdd
= inZp
+ outZp
;
155 int64_t maxValue
= APInt::getSignedMaxValue(inputBitWidth
).getSExtValue() +
158 // Convert that maximum value into the maximum bitwidth needed to represent
159 // it. We assume 48-bit numbers may be supported further in the pipeline.
160 int intermediateBitWidth
= 64;
161 if (maxValue
<= APInt::getSignedMaxValue(16).getSExtValue()) {
162 intermediateBitWidth
= 16;
163 } else if (maxValue
<= APInt::getSignedMaxValue(32).getSExtValue()) {
164 intermediateBitWidth
= 32;
165 } else if (maxValue
<= APInt::getSignedMaxValue(48).getSExtValue()) {
166 intermediateBitWidth
= 48;
169 Type intermediateType
= rewriter
.getIntegerType(intermediateBitWidth
);
170 Value zpAddValue
= rewriter
.create
<arith::ConstantOp
>(
171 loc
, rewriter
.getIntegerAttr(intermediateType
, zpAdd
));
173 // The negation can be applied by doing:
174 // outputValue = inZp + outZp - inputValue
175 auto ext
= rewriter
.create
<arith::ExtSIOp
>(loc
, intermediateType
, args
[0]);
176 auto sub
= rewriter
.create
<arith::SubIOp
>(loc
, zpAddValue
, ext
);
178 // Clamp to the negation range.
179 Value min
= rewriter
.create
<arith::ConstantIntOp
>(
180 loc
, APInt::getSignedMinValue(inputBitWidth
).getSExtValue(),
182 Value max
= rewriter
.create
<arith::ConstantIntOp
>(
183 loc
, APInt::getSignedMaxValue(inputBitWidth
).getSExtValue(),
186 clampIntHelper(loc
, sub
, min
, max
, rewriter
, /*isUnsigned=*/false);
188 // Truncate to the final value.
189 return rewriter
.create
<arith::TruncIOp
>(loc
, elementTy
, clamp
);
192 // tosa::BitwiseAndOp
193 if (isa
<tosa::BitwiseAndOp
>(op
) && isa
<IntegerType
>(elementTy
))
194 return rewriter
.create
<arith::AndIOp
>(loc
, resultTypes
, args
);
197 if (isa
<tosa::BitwiseOrOp
>(op
) && isa
<IntegerType
>(elementTy
))
198 return rewriter
.create
<arith::OrIOp
>(loc
, resultTypes
, args
);
200 // tosa::BitwiseNotOp
201 if (isa
<tosa::BitwiseNotOp
>(op
) && isa
<IntegerType
>(elementTy
)) {
202 auto allOnesAttr
= rewriter
.getIntegerAttr(
203 elementTy
, APInt::getAllOnes(elementTy
.getIntOrFloatBitWidth()));
204 auto allOnes
= rewriter
.create
<arith::ConstantOp
>(loc
, allOnesAttr
);
205 return rewriter
.create
<arith::XOrIOp
>(loc
, resultTypes
, args
[0], allOnes
);
208 // tosa::BitwiseXOrOp
209 if (isa
<tosa::BitwiseXorOp
>(op
) && isa
<IntegerType
>(elementTy
))
210 return rewriter
.create
<arith::XOrIOp
>(loc
, resultTypes
, args
);
212 // tosa::LogicalLeftShiftOp
213 if (isa
<tosa::LogicalLeftShiftOp
>(op
) && isa
<IntegerType
>(elementTy
))
214 return rewriter
.create
<arith::ShLIOp
>(loc
, resultTypes
, args
);
216 // tosa::LogicalRightShiftOp
217 if (isa
<tosa::LogicalRightShiftOp
>(op
) && isa
<IntegerType
>(elementTy
))
218 return rewriter
.create
<arith::ShRUIOp
>(loc
, resultTypes
, args
);
220 // tosa::ArithmeticRightShiftOp
221 if (isa
<tosa::ArithmeticRightShiftOp
>(op
) && isa
<IntegerType
>(elementTy
)) {
222 auto result
= rewriter
.create
<arith::ShRSIOp
>(loc
, resultTypes
, args
);
223 auto round
= cast
<BoolAttr
>(op
->getAttr("round")).getValue();
228 Type i1Ty
= IntegerType::get(rewriter
.getContext(), /*width=*/1);
230 rewriter
.create
<arith::ConstantOp
>(loc
, IntegerAttr::get(elementTy
, 1));
232 rewriter
.create
<arith::ConstantOp
>(loc
, IntegerAttr::get(elementTy
, 0));
234 rewriter
.create
<arith::ConstantOp
>(loc
, IntegerAttr::get(i1Ty
, 1));
236 // Checking that input2 != 0
237 auto shiftValueGreaterThanZero
= rewriter
.create
<arith::CmpIOp
>(
238 loc
, arith::CmpIPredicate::sgt
, args
[1], zero
);
240 // Checking for the last bit of input1 to be 1
242 rewriter
.create
<arith::SubIOp
>(loc
, resultTypes
, args
[1], one
);
244 rewriter
.create
<arith::ShRSIOp
>(loc
, resultTypes
, args
[0], subtract
)
247 rewriter
.create
<arith::TruncIOp
>(loc
, i1Ty
, shifted
, std::nullopt
);
249 rewriter
.create
<arith::AndIOp
>(loc
, i1Ty
, truncated
, i1one
);
251 auto shouldRound
= rewriter
.create
<arith::AndIOp
>(
252 loc
, i1Ty
, shiftValueGreaterThanZero
, isInputOdd
);
254 rewriter
.create
<arith::ExtUIOp
>(loc
, resultTypes
, shouldRound
);
255 return rewriter
.create
<arith::AddIOp
>(loc
, resultTypes
, result
, extended
);
259 if (isa
<tosa::ClzOp
>(op
) && isa
<IntegerType
>(elementTy
)) {
260 return rewriter
.create
<math::CountLeadingZerosOp
>(loc
, elementTy
, args
[0]);
264 if (isa
<tosa::LogicalAndOp
>(op
) && elementTy
.isInteger(1))
265 return rewriter
.create
<arith::AndIOp
>(loc
, resultTypes
, args
);
268 if (isa
<tosa::LogicalNotOp
>(op
) && elementTy
.isInteger(1)) {
269 auto one
= rewriter
.create
<arith::ConstantOp
>(
270 loc
, rewriter
.getIntegerAttr(elementTy
, 1));
271 return rewriter
.create
<arith::XOrIOp
>(loc
, resultTypes
, args
[0], one
);
275 if (isa
<tosa::LogicalOrOp
>(op
) && elementTy
.isInteger(1))
276 return rewriter
.create
<arith::OrIOp
>(loc
, resultTypes
, args
);
279 if (isa
<tosa::LogicalXorOp
>(op
) && elementTy
.isInteger(1))
280 return rewriter
.create
<arith::XOrIOp
>(loc
, resultTypes
, args
);
283 if (isa
<tosa::PowOp
>(op
) && isa
<FloatType
>(elementTy
))
284 return rewriter
.create
<mlir::math::PowFOp
>(loc
, resultTypes
, args
);
287 if (isa
<tosa::RsqrtOp
>(op
) && isa
<FloatType
>(elementTy
))
288 return rewriter
.create
<mlir::math::RsqrtOp
>(loc
, resultTypes
, args
);
291 if (isa
<tosa::LogOp
>(op
) && isa
<FloatType
>(elementTy
))
292 return rewriter
.create
<mlir::math::LogOp
>(loc
, resultTypes
, args
);
295 if (isa
<tosa::ExpOp
>(op
) && isa
<FloatType
>(elementTy
))
296 return rewriter
.create
<mlir::math::ExpOp
>(loc
, resultTypes
, args
);
299 if (isa
<tosa::SinOp
>(op
) && isa
<FloatType
>(elementTy
))
300 return rewriter
.create
<mlir::math::SinOp
>(loc
, resultTypes
, args
);
303 if (isa
<tosa::CosOp
>(op
) && isa
<FloatType
>(elementTy
))
304 return rewriter
.create
<mlir::math::CosOp
>(loc
, resultTypes
, args
);
307 if (isa
<tosa::TanhOp
>(op
) && isa
<FloatType
>(elementTy
))
308 return rewriter
.create
<mlir::math::TanhOp
>(loc
, resultTypes
, args
);
311 if (isa
<tosa::ErfOp
>(op
) && llvm::isa
<FloatType
>(elementTy
))
312 return rewriter
.create
<mlir::math::ErfOp
>(loc
, resultTypes
, args
);
315 if (isa
<tosa::GreaterOp
>(op
) && isa
<FloatType
>(elementTy
))
316 return rewriter
.create
<arith::CmpFOp
>(loc
, arith::CmpFPredicate::OGT
,
319 if (isa
<tosa::GreaterOp
>(op
) && elementTy
.isSignlessInteger())
320 return rewriter
.create
<arith::CmpIOp
>(loc
, arith::CmpIPredicate::sgt
,
323 // tosa::GreaterEqualOp
324 if (isa
<tosa::GreaterEqualOp
>(op
) && isa
<FloatType
>(elementTy
))
325 return rewriter
.create
<arith::CmpFOp
>(loc
, arith::CmpFPredicate::OGE
,
328 if (isa
<tosa::GreaterEqualOp
>(op
) && elementTy
.isSignlessInteger())
329 return rewriter
.create
<arith::CmpIOp
>(loc
, arith::CmpIPredicate::sge
,
333 if (isa
<tosa::EqualOp
>(op
) && isa
<FloatType
>(elementTy
))
334 return rewriter
.create
<arith::CmpFOp
>(loc
, arith::CmpFPredicate::OEQ
,
337 if (isa
<tosa::EqualOp
>(op
) && elementTy
.isSignlessInteger())
338 return rewriter
.create
<arith::CmpIOp
>(loc
, arith::CmpIPredicate::eq
,
342 if (isa
<tosa::SelectOp
>(op
)) {
343 elementTy
= cast
<ShapedType
>(op
->getOperand(1).getType()).getElementType();
344 if (isa
<FloatType
>(elementTy
) || isa
<IntegerType
>(elementTy
))
345 return rewriter
.create
<arith::SelectOp
>(loc
, args
[0], args
[1], args
[2]);
349 if (isa
<tosa::MaximumOp
>(op
) && isa
<FloatType
>(elementTy
)) {
350 return rewriter
.create
<arith::MaximumFOp
>(loc
, args
[0], args
[1]);
353 if (isa
<tosa::MaximumOp
>(op
) && elementTy
.isSignlessInteger()) {
354 return rewriter
.create
<arith::MaxSIOp
>(loc
, args
[0], args
[1]);
358 if (isa
<tosa::MinimumOp
>(op
) && isa
<FloatType
>(elementTy
)) {
359 return rewriter
.create
<arith::MinimumFOp
>(loc
, args
[0], args
[1]);
362 if (isa
<tosa::MinimumOp
>(op
) && elementTy
.isSignlessInteger()) {
363 return rewriter
.create
<arith::MinSIOp
>(loc
, args
[0], args
[1]);
367 if (isa
<tosa::CeilOp
>(op
) && isa
<FloatType
>(elementTy
))
368 return rewriter
.create
<math::CeilOp
>(loc
, resultTypes
, args
);
371 if (isa
<tosa::FloorOp
>(op
) && isa
<FloatType
>(elementTy
))
372 return rewriter
.create
<math::FloorOp
>(loc
, resultTypes
, args
);
375 if (isa
<tosa::ClampOp
>(op
) && isa
<FloatType
>(elementTy
)) {
376 bool losesInfo
= false;
377 APFloat minApf
= cast
<FloatAttr
>(op
->getAttr("min_fp")).getValue();
378 APFloat maxApf
= cast
<FloatAttr
>(op
->getAttr("max_fp")).getValue();
379 minApf
.convert(cast
<FloatType
>(elementTy
).getFloatSemantics(),
380 APFloat::rmNearestTiesToEven
, &losesInfo
);
381 maxApf
.convert(cast
<FloatType
>(elementTy
).getFloatSemantics(),
382 APFloat::rmNearestTiesToEven
, &losesInfo
);
383 auto min
= rewriter
.create
<arith::ConstantOp
>(
384 loc
, elementTy
, rewriter
.getFloatAttr(elementTy
, minApf
));
385 auto max
= rewriter
.create
<arith::ConstantOp
>(
386 loc
, elementTy
, rewriter
.getFloatAttr(elementTy
, maxApf
));
387 return clampFloatHelper(loc
, args
[0], min
, max
, rewriter
);
390 if (isa
<tosa::ClampOp
>(op
) && isa
<IntegerType
>(elementTy
)) {
391 auto intTy
= cast
<IntegerType
>(elementTy
);
393 cast
<IntegerAttr
>(op
->getAttr("min_int")).getValue().getSExtValue();
395 cast
<IntegerAttr
>(op
->getAttr("max_int")).getValue().getSExtValue();
397 int64_t minRepresentable
= std::numeric_limits
<int64_t>::min();
398 int64_t maxRepresentable
= std::numeric_limits
<int64_t>::max();
399 if (intTy
.isUnsignedInteger()) {
400 minRepresentable
= 0;
401 if (intTy
.getIntOrFloatBitWidth() <= 63) {
403 (int64_t)APInt::getMaxValue(intTy
.getIntOrFloatBitWidth())
406 } else if (intTy
.getIntOrFloatBitWidth() <= 64) {
407 // Ensure that min & max fit into signed n-bit constants.
408 minRepresentable
= APInt::getSignedMinValue(intTy
.getIntOrFloatBitWidth())
410 maxRepresentable
= APInt::getSignedMaxValue(intTy
.getIntOrFloatBitWidth())
413 // Ensure that the bounds are representable as n-bit signed/unsigned
415 min
= std::max(min
, minRepresentable
);
416 max
= std::max(max
, minRepresentable
);
417 min
= std::min(min
, maxRepresentable
);
418 max
= std::min(max
, maxRepresentable
);
420 auto minVal
= rewriter
.create
<arith::ConstantIntOp
>(
421 loc
, min
, intTy
.getIntOrFloatBitWidth());
422 auto maxVal
= rewriter
.create
<arith::ConstantIntOp
>(
423 loc
, max
, intTy
.getIntOrFloatBitWidth());
424 return clampIntHelper(loc
, args
[0], minVal
, maxVal
, rewriter
,
425 intTy
.isUnsignedInteger());
429 if (isa
<tosa::SigmoidOp
>(op
) && isa
<FloatType
>(elementTy
)) {
431 rewriter
.create
<arith::ConstantOp
>(loc
, FloatAttr::get(elementTy
, 1));
432 auto negate
= rewriter
.create
<arith::NegFOp
>(loc
, resultTypes
, args
[0]);
433 auto exp
= rewriter
.create
<mlir::math::ExpOp
>(loc
, resultTypes
, negate
);
434 auto added
= rewriter
.create
<arith::AddFOp
>(loc
, resultTypes
, exp
, one
);
435 return rewriter
.create
<arith::DivFOp
>(loc
, resultTypes
, one
, added
);
439 if (isa
<tosa::CastOp
>(op
)) {
440 Type srcTy
= elementTy
;
441 Type dstTy
= resultTypes
.front();
443 srcTy
.getIntOrFloatBitWidth() < dstTy
.getIntOrFloatBitWidth();
448 if (isa
<FloatType
>(srcTy
) && isa
<FloatType
>(dstTy
) && bitExtend
)
449 return rewriter
.create
<arith::ExtFOp
>(loc
, resultTypes
, args
,
452 if (isa
<FloatType
>(srcTy
) && isa
<FloatType
>(dstTy
) && !bitExtend
)
453 return rewriter
.create
<arith::TruncFOp
>(loc
, resultTypes
, args
,
456 // 1-bit integers need to be treated as signless.
457 if (srcTy
.isInteger(1) && arith::UIToFPOp::areCastCompatible(srcTy
, dstTy
))
458 return rewriter
.create
<arith::UIToFPOp
>(loc
, resultTypes
, args
,
461 if (srcTy
.isInteger(1) && isa
<IntegerType
>(dstTy
) && bitExtend
)
462 return rewriter
.create
<arith::ExtUIOp
>(loc
, resultTypes
, args
,
465 // Unsigned integers need an unrealized cast so that they can be passed
467 if (srcTy
.isUnsignedInteger() && isa
<FloatType
>(dstTy
)) {
468 auto unrealizedCast
=
470 .create
<UnrealizedConversionCastOp
>(
471 loc
, rewriter
.getIntegerType(srcTy
.getIntOrFloatBitWidth()),
474 return rewriter
.create
<arith::UIToFPOp
>(loc
, resultTypes
[0],
478 // All other si-to-fp conversions should be handled by SIToFP.
479 if (arith::SIToFPOp::areCastCompatible(srcTy
, dstTy
))
480 return rewriter
.create
<arith::SIToFPOp
>(loc
, resultTypes
, args
,
483 // Casting to boolean, floats need to only be checked as not-equal to zero.
484 if (isa
<FloatType
>(srcTy
) && dstTy
.isInteger(1)) {
485 Value zero
= rewriter
.create
<arith::ConstantOp
>(
486 loc
, rewriter
.getFloatAttr(srcTy
, 0.0));
487 return rewriter
.create
<arith::CmpFOp
>(loc
, arith::CmpFPredicate::UNE
,
491 if (arith::FPToSIOp::areCastCompatible(srcTy
, dstTy
)) {
492 auto rounded
= rewriter
.create
<math::RoundEvenOp
>(loc
, args
[0]);
494 const auto &fltSemantics
= cast
<FloatType
>(srcTy
).getFloatSemantics();
495 // Check whether neither int min nor int max can be represented in the
496 // input floating-point type due to too short exponent range.
497 if (static_cast<int>(dstTy
.getIntOrFloatBitWidth()) - 1 >
498 APFloat::semanticsMaxExponent(fltSemantics
)) {
499 // Use cmp + select to replace infinites by int min / int max. Other
500 // integral values can be represented in the integer space.
501 auto conv
= rewriter
.create
<arith::FPToSIOp
>(loc
, dstTy
, rounded
);
502 auto posInf
= rewriter
.create
<arith::ConstantOp
>(
503 loc
, rewriter
.getFloatAttr(getElementTypeOrSelf(srcTy
),
504 APFloat::getInf(fltSemantics
)));
505 auto negInf
= rewriter
.create
<arith::ConstantOp
>(
506 loc
, rewriter
.getFloatAttr(
507 getElementTypeOrSelf(srcTy
),
508 APFloat::getInf(fltSemantics
, /*Negative=*/true)));
509 auto overflow
= rewriter
.create
<arith::CmpFOp
>(
510 loc
, arith::CmpFPredicate::UEQ
, rounded
, posInf
);
511 auto underflow
= rewriter
.create
<arith::CmpFOp
>(
512 loc
, arith::CmpFPredicate::UEQ
, rounded
, negInf
);
513 auto intMin
= rewriter
.create
<arith::ConstantOp
>(
514 loc
, rewriter
.getIntegerAttr(
515 getElementTypeOrSelf(dstTy
),
516 APInt::getSignedMinValue(dstTy
.getIntOrFloatBitWidth())));
517 auto intMax
= rewriter
.create
<arith::ConstantOp
>(
518 loc
, rewriter
.getIntegerAttr(
519 getElementTypeOrSelf(dstTy
),
520 APInt::getSignedMaxValue(dstTy
.getIntOrFloatBitWidth())));
522 rewriter
.create
<arith::SelectOp
>(loc
, overflow
, intMax
, conv
);
523 return rewriter
.create
<arith::SelectOp
>(loc
, underflow
, intMin
,
527 auto intMinFP
= rewriter
.create
<arith::ConstantOp
>(
528 loc
, rewriter
.getFloatAttr(
529 getElementTypeOrSelf(srcTy
),
530 APInt::getSignedMinValue(dstTy
.getIntOrFloatBitWidth())
533 // Check whether the mantissa has enough bits to represent int max.
534 if (cast
<FloatType
>(srcTy
).getFPMantissaWidth() >=
535 dstTy
.getIntOrFloatBitWidth() - 1) {
536 // Int min can also be represented since it is a power of two and thus
537 // consists of a single leading bit. Therefore we can clamp the input
538 // in the floating-point domain.
540 auto intMaxFP
= rewriter
.create
<arith::ConstantOp
>(
541 loc
, rewriter
.getFloatAttr(
542 getElementTypeOrSelf(srcTy
),
543 APInt::getSignedMaxValue(dstTy
.getIntOrFloatBitWidth())
547 clampFloatHelper(loc
, rounded
, intMinFP
, intMaxFP
, rewriter
);
548 return rewriter
.create
<arith::FPToSIOp
>(loc
, dstTy
, clamped
);
551 // Due to earlier check we know exponant range is big enough to represent
552 // int min. We can therefore rely on int max + 1 being representable as
553 // well because it's just int min with a positive sign. So clamp the min
554 // value and compare against that to select the max int value if needed.
555 auto intMaxPlusOneFP
= rewriter
.create
<arith::ConstantOp
>(
556 loc
, rewriter
.getFloatAttr(
557 getElementTypeOrSelf(srcTy
),
559 APInt::getSignedMaxValue(dstTy
.getIntOrFloatBitWidth())
563 auto intMax
= rewriter
.create
<arith::ConstantOp
>(
564 loc
, rewriter
.getIntegerAttr(
565 getElementTypeOrSelf(dstTy
),
566 APInt::getSignedMaxValue(dstTy
.getIntOrFloatBitWidth())));
568 rewriter
.create
<arith::MaximumFOp
>(loc
, rounded
, intMinFP
);
570 rewriter
.create
<arith::FPToSIOp
>(loc
, dstTy
, minClampedFP
);
571 auto overflow
= rewriter
.create
<arith::CmpFOp
>(
572 loc
, arith::CmpFPredicate::UGE
, rounded
, intMaxPlusOneFP
);
573 return rewriter
.create
<arith::SelectOp
>(loc
, overflow
, intMax
,
577 // Casting to boolean, integers need to only be checked as not-equal to
579 if (isa
<IntegerType
>(srcTy
) && dstTy
.isInteger(1)) {
580 Value zero
= rewriter
.create
<arith::ConstantIntOp
>(
581 loc
, 0, srcTy
.getIntOrFloatBitWidth());
582 return rewriter
.create
<arith::CmpIOp
>(loc
, arith::CmpIPredicate::ne
,
586 if (isa
<IntegerType
>(srcTy
) && isa
<IntegerType
>(dstTy
) && bitExtend
)
587 return rewriter
.create
<arith::ExtSIOp
>(loc
, resultTypes
, args
,
590 if (isa
<IntegerType
>(srcTy
) && isa
<IntegerType
>(dstTy
) && !bitExtend
) {
591 return rewriter
.create
<arith::TruncIOp
>(loc
, dstTy
, args
[0]);
595 (void)rewriter
.notifyMatchFailure(
596 op
, "unhandled op for linalg body calculation for elementwise op");
600 static Value
expandRank(PatternRewriter
&rewriter
, Location loc
, Value tensor
,
602 // No need to expand if we are already at the desired rank
603 auto shapedType
= dyn_cast
<ShapedType
>(tensor
.getType());
604 assert(shapedType
&& shapedType
.hasRank() && "expected a ranked shaped type");
605 int64_t numExtraDims
= rank
- shapedType
.getRank();
606 assert(numExtraDims
>= 0 && "cannot expand tensor to a lower rank");
610 // Compute reassociation indices
611 SmallVector
<SmallVector
<int64_t, 2>> reassociationIndices(
612 shapedType
.getRank());
614 for (index
= 0; index
<= numExtraDims
; index
++)
615 reassociationIndices
[0].push_back(index
);
616 for (size_t position
= 1; position
< reassociationIndices
.size(); position
++)
617 reassociationIndices
[position
].push_back(index
++);
619 // Compute result type
620 SmallVector
<int64_t> resultShape
;
621 for (index
= 0; index
< numExtraDims
; index
++)
622 resultShape
.push_back(1);
623 for (auto size
: shapedType
.getShape())
624 resultShape
.push_back(size
);
626 RankedTensorType::get(resultShape
, shapedType
.getElementType());
628 // Emit 'tensor.expand_shape' op
629 return rewriter
.create
<tensor::ExpandShapeOp
>(loc
, resultType
, tensor
,
630 reassociationIndices
);
633 static SmallVector
<Value
> expandInputRanks(PatternRewriter
&rewriter
,
634 Location loc
, ValueRange operands
,
636 return llvm::map_to_vector(operands
, [&](Value operand
) {
637 return expandRank(rewriter
, loc
, operand
, rank
);
641 using IndexPool
= DenseMap
<int64_t, Value
>;
643 // Emit an 'arith.constant' op for the given index if it has not been created
644 // yet, or return an existing constant. This will prevent an excessive creation
645 // of redundant constants, easing readability of emitted code for unit tests.
646 static Value
createIndex(PatternRewriter
&rewriter
, Location loc
,
647 IndexPool
&indexPool
, int64_t index
) {
648 auto [it
, inserted
] = indexPool
.try_emplace(index
);
651 rewriter
.create
<arith::ConstantOp
>(loc
, rewriter
.getIndexAttr(index
));
655 static Value
getTensorDim(PatternRewriter
&rewriter
, Location loc
,
656 IndexPool
&indexPool
, Value tensor
, int64_t index
) {
657 auto indexValue
= createIndex(rewriter
, loc
, indexPool
, index
);
658 return rewriter
.create
<tensor::DimOp
>(loc
, tensor
, indexValue
).getResult();
661 static OpFoldResult
getOrFoldTensorDim(PatternRewriter
&rewriter
, Location loc
,
662 IndexPool
&indexPool
, Value tensor
,
664 auto shapedType
= dyn_cast
<ShapedType
>(tensor
.getType());
665 assert(shapedType
&& shapedType
.hasRank() && "expected a ranked shaped type");
666 assert(index
>= 0 && index
< shapedType
.getRank() && "index out of bounds");
667 if (shapedType
.isDynamicDim(index
))
668 return getTensorDim(rewriter
, loc
, indexPool
, tensor
, index
);
669 return rewriter
.getIndexAttr(shapedType
.getDimSize(index
));
672 static bool operandsAndResultsRanked(Operation
*operation
) {
673 auto isRanked
= [](Value value
) {
674 return isa
<RankedTensorType
>(value
.getType());
676 return llvm::all_of(operation
->getOperands(), isRanked
) &&
677 llvm::all_of(operation
->getResults(), isRanked
);
680 // Compute the runtime dimension size for dimension 'dim' of the output by
681 // inspecting input 'operands', all of which are expected to have the same rank.
682 // This function returns a pair {targetSize, masterOperand}.
684 // The runtime size of the output dimension is returned either as a statically
685 // computed attribute or as a runtime SSA value.
687 // If the target size was inferred directly from one dominating operand, that
688 // operand is returned in 'masterOperand'. If the target size is inferred from
689 // multiple operands, 'masterOperand' is set to nullptr.
690 static std::pair
<OpFoldResult
, Value
>
691 computeTargetSize(PatternRewriter
&rewriter
, Location loc
, IndexPool
&indexPool
,
692 ValueRange operands
, int64_t dim
) {
693 // If any input operand contains a static size greater than 1 for this
694 // dimension, that is the target size. An occurrence of an additional static
695 // dimension greater than 1 with a different value is undefined behavior.
696 for (auto operand
: operands
) {
697 auto size
= cast
<RankedTensorType
>(operand
.getType()).getDimSize(dim
);
698 if (!ShapedType::isDynamic(size
) && size
> 1)
699 return {rewriter
.getIndexAttr(size
), operand
};
702 // Filter operands with dynamic dimension
703 auto operandsWithDynamicDim
=
704 llvm::to_vector(llvm::make_filter_range(operands
, [&](Value operand
) {
705 return cast
<RankedTensorType
>(operand
.getType()).isDynamicDim(dim
);
708 // If no operand has a dynamic dimension, it means all sizes were 1
709 if (operandsWithDynamicDim
.empty())
710 return {rewriter
.getIndexAttr(1), operands
.front()};
712 // Emit code that computes the runtime size for this dimension. If there is
713 // only one operand with a dynamic dimension, it is considered the master
714 // operand that determines the runtime size of the output dimension.
716 getTensorDim(rewriter
, loc
, indexPool
, operandsWithDynamicDim
[0], dim
);
717 if (operandsWithDynamicDim
.size() == 1)
718 return {targetSize
, operandsWithDynamicDim
[0]};
720 // Calculate maximum size among all dynamic dimensions
721 for (size_t i
= 1; i
< operandsWithDynamicDim
.size(); i
++) {
723 getTensorDim(rewriter
, loc
, indexPool
, operandsWithDynamicDim
[i
], dim
);
724 targetSize
= rewriter
.create
<arith::MaxUIOp
>(loc
, targetSize
, nextSize
);
726 return {targetSize
, nullptr};
729 // Compute the runtime output size for all dimensions. This function returns
730 // a pair {targetShape, masterOperands}.
731 static std::pair
<SmallVector
<OpFoldResult
>, SmallVector
<Value
>>
732 computeTargetShape(PatternRewriter
&rewriter
, Location loc
,
733 IndexPool
&indexPool
, ValueRange operands
) {
734 assert(!operands
.empty());
735 auto rank
= cast
<RankedTensorType
>(operands
.front().getType()).getRank();
736 SmallVector
<OpFoldResult
> targetShape
;
737 SmallVector
<Value
> masterOperands
;
738 for (auto dim
: llvm::seq
<int64_t>(0, rank
)) {
739 auto [targetSize
, masterOperand
] =
740 computeTargetSize(rewriter
, loc
, indexPool
, operands
, dim
);
741 targetShape
.push_back(targetSize
);
742 masterOperands
.push_back(masterOperand
);
744 return {targetShape
, masterOperands
};
747 static Value
broadcastDynamicDimension(PatternRewriter
&rewriter
, Location loc
,
748 IndexPool
&indexPool
, Value operand
,
749 int64_t dim
, OpFoldResult targetSize
,
750 Value masterOperand
) {
751 // Nothing to do if this is a static dimension
752 auto rankedTensorType
= cast
<RankedTensorType
>(operand
.getType());
753 if (!rankedTensorType
.isDynamicDim(dim
))
756 // If the target size for this dimension was directly inferred by only taking
757 // this operand into account, there is no need to broadcast. This is an
758 // optimization that will prevent redundant control flow, and constitutes the
759 // main motivation for tracking "master operands".
760 if (operand
== masterOperand
)
763 // Affine maps for 'linalg.generic' op
764 auto rank
= rankedTensorType
.getRank();
765 SmallVector
<AffineExpr
> affineExprs
;
766 for (auto index
: llvm::seq
<int64_t>(0, rank
)) {
767 auto affineExpr
= index
== dim
? rewriter
.getAffineConstantExpr(0)
768 : rewriter
.getAffineDimExpr(index
);
769 affineExprs
.push_back(affineExpr
);
771 auto broadcastAffineMap
=
772 AffineMap::get(rank
, 0, affineExprs
, rewriter
.getContext());
773 auto identityAffineMap
= rewriter
.getMultiDimIdentityMap(rank
);
774 SmallVector
<AffineMap
> affineMaps
= {broadcastAffineMap
, identityAffineMap
};
776 // Check if broadcast is necessary
777 auto one
= createIndex(rewriter
, loc
, indexPool
, 1);
778 auto runtimeSize
= getTensorDim(rewriter
, loc
, indexPool
, operand
, dim
);
779 auto broadcastNecessary
= rewriter
.create
<arith::CmpIOp
>(
780 loc
, arith::CmpIPredicate::eq
, runtimeSize
, one
);
782 // Emit 'then' region of 'scf.if'
783 auto emitThenRegion
= [&](OpBuilder
&opBuilder
, Location loc
) {
784 // It is not safe to cache constants across regions.
785 // New constants could potentially violate dominance requirements.
788 // Emit 'tensor.empty' op
789 SmallVector
<OpFoldResult
> outputTensorShape
;
790 for (auto index
: llvm::seq
<int64_t>(0, rank
)) {
791 auto size
= index
== dim
? targetSize
792 : getOrFoldTensorDim(rewriter
, loc
, localPool
,
794 outputTensorShape
.push_back(size
);
796 Value outputTensor
= opBuilder
.create
<tensor::EmptyOp
>(
797 loc
, outputTensorShape
, rankedTensorType
.getElementType());
799 // Emit 'linalg.generic' op
802 .create
<linalg::GenericOp
>(
803 loc
, outputTensor
.getType(), operand
, outputTensor
, affineMaps
,
804 getNParallelLoopsAttrs(rank
),
805 [&](OpBuilder
&opBuilder
, Location loc
, ValueRange blockArgs
) {
806 // Emit 'linalg.yield' op
807 opBuilder
.create
<linalg::YieldOp
>(loc
, blockArgs
.front());
811 // Cast to original operand type if necessary
812 auto castResultTensor
= rewriter
.createOrFold
<tensor::CastOp
>(
813 loc
, operand
.getType(), resultTensor
);
815 // Emit 'scf.yield' op
816 opBuilder
.create
<scf::YieldOp
>(loc
, castResultTensor
);
819 // Emit 'else' region of 'scf.if'
820 auto emitElseRegion
= [&](OpBuilder
&opBuilder
, Location loc
) {
821 opBuilder
.create
<scf::YieldOp
>(loc
, operand
);
825 auto ifOp
= rewriter
.create
<scf::IfOp
>(loc
, broadcastNecessary
,
826 emitThenRegion
, emitElseRegion
);
827 return ifOp
.getResult(0);
830 static Value
broadcastDynamicDimensions(PatternRewriter
&rewriter
, Location loc
,
831 IndexPool
&indexPool
, Value operand
,
832 ArrayRef
<OpFoldResult
> targetShape
,
833 ArrayRef
<Value
> masterOperands
) {
834 int64_t rank
= cast
<RankedTensorType
>(operand
.getType()).getRank();
835 assert((int64_t)targetShape
.size() == rank
);
836 assert((int64_t)masterOperands
.size() == rank
);
837 for (auto index
: llvm::seq
<int64_t>(0, rank
))
839 broadcastDynamicDimension(rewriter
, loc
, indexPool
, operand
, index
,
840 targetShape
[index
], masterOperands
[index
]);
844 static SmallVector
<Value
>
845 broadcastDynamicDimensions(PatternRewriter
&rewriter
, Location loc
,
846 IndexPool
&indexPool
, ValueRange operands
,
847 ArrayRef
<OpFoldResult
> targetShape
,
848 ArrayRef
<Value
> masterOperands
) {
849 // No need to broadcast for unary operations
850 if (operands
.size() == 1)
853 // Broadcast dynamic dimensions operand by operand
854 return llvm::map_to_vector(operands
, [&](Value operand
) {
855 return broadcastDynamicDimensions(rewriter
, loc
, indexPool
, operand
,
856 targetShape
, masterOperands
);
861 emitElementwiseComputation(ConversionPatternRewriter
&rewriter
, Location loc
,
862 Operation
*operation
, ValueRange operands
,
863 ArrayRef
<OpFoldResult
> targetShape
,
864 const TypeConverter
&converter
) {
865 // Generate output tensor
866 auto resultType
= cast_or_null
<RankedTensorType
>(
867 converter
.convertType(operation
->getResultTypes().front()));
869 return rewriter
.notifyMatchFailure(operation
, "failed to convert type");
871 Value outputTensor
= rewriter
.create
<tensor::EmptyOp
>(
872 loc
, targetShape
, resultType
.getElementType());
874 // Create affine maps. Input affine maps broadcast static dimensions of size
875 // 1. The output affine map is an identity map.
877 auto rank
= resultType
.getRank();
878 auto affineMaps
= llvm::map_to_vector(operands
, [&](Value operand
) {
879 auto shape
= cast
<ShapedType
>(operand
.getType()).getShape();
880 SmallVector
<AffineExpr
> affineExprs
;
881 for (auto it
: llvm::enumerate(shape
)) {
882 auto affineExpr
= it
.value() == 1 ? rewriter
.getAffineConstantExpr(0)
883 : rewriter
.getAffineDimExpr(it
.index());
884 affineExprs
.push_back(affineExpr
);
886 return AffineMap::get(rank
, 0, affineExprs
, rewriter
.getContext());
888 affineMaps
.push_back(rewriter
.getMultiDimIdentityMap(rank
));
890 // Emit 'linalg.generic' op
891 bool encounteredError
= false;
892 auto linalgOp
= rewriter
.create
<linalg::GenericOp
>(
893 loc
, outputTensor
.getType(), operands
, outputTensor
, affineMaps
,
894 getNParallelLoopsAttrs(rank
),
895 [&](OpBuilder
&opBuilder
, Location loc
, ValueRange blockArgs
) {
896 Value opResult
= createLinalgBodyCalculationForElementwiseOp(
897 operation
, blockArgs
.take_front(operation
->getNumOperands()),
898 {resultType
.getElementType()}, rewriter
);
900 encounteredError
= true;
903 opBuilder
.create
<linalg::YieldOp
>(loc
, opResult
);
905 if (encounteredError
)
906 return rewriter
.notifyMatchFailure(
907 operation
, "unable to create linalg.generic body for elementwise op");
909 // Cast 'linalg.generic' result into original result type if needed
910 auto castResult
= rewriter
.createOrFold
<tensor::CastOp
>(
911 loc
, resultType
, linalgOp
->getResult(0));
912 rewriter
.replaceOp(operation
, castResult
);
917 elementwiseMatchAndRewriteHelper(Operation
*operation
, ValueRange operands
,
918 ConversionPatternRewriter
&rewriter
,
919 const TypeConverter
&converter
) {
921 // Collect op properties
922 assert(operation
->getNumResults() == 1 && "elementwise op expects 1 result");
923 assert(operation
->getNumOperands() >= 1 &&
924 "elementwise op expects at least 1 operand");
925 if (!operandsAndResultsRanked(operation
))
926 return rewriter
.notifyMatchFailure(operation
,
927 "Unranked tensors not supported");
931 auto loc
= operation
->getLoc();
933 cast
<RankedTensorType
>(operation
->getResultTypes().front()).getRank();
934 auto expandedOperands
= expandInputRanks(rewriter
, loc
, operands
, rank
);
935 auto [targetShape
, masterOperands
] =
936 computeTargetShape(rewriter
, loc
, indexPool
, expandedOperands
);
937 auto broadcastOperands
= broadcastDynamicDimensions(
938 rewriter
, loc
, indexPool
, expandedOperands
, targetShape
, masterOperands
);
939 return emitElementwiseComputation(rewriter
, loc
, operation
, broadcastOperands
,
940 targetShape
, converter
);
943 // Returns the constant initial value for a given reduction operation. The
944 // attribute type varies depending on the element type required.
945 static TypedAttr
createInitialValueForReduceOp(Operation
*op
, Type elementTy
,
946 PatternRewriter
&rewriter
) {
947 if (isa
<tosa::ReduceSumOp
>(op
) && isa
<FloatType
>(elementTy
))
948 return rewriter
.getFloatAttr(elementTy
, 0.0);
950 if (isa
<tosa::ReduceSumOp
>(op
) && isa
<IntegerType
>(elementTy
))
951 return rewriter
.getIntegerAttr(elementTy
, 0);
953 if (isa
<tosa::ReduceProdOp
>(op
) && isa
<FloatType
>(elementTy
))
954 return rewriter
.getFloatAttr(elementTy
, 1.0);
956 if (isa
<tosa::ReduceProdOp
>(op
) && isa
<IntegerType
>(elementTy
))
957 return rewriter
.getIntegerAttr(elementTy
, 1);
959 if (isa
<tosa::ReduceMinOp
>(op
) && isa
<FloatType
>(elementTy
))
960 return rewriter
.getFloatAttr(
961 elementTy
, APFloat::getLargest(
962 cast
<FloatType
>(elementTy
).getFloatSemantics(), false));
964 if (isa
<tosa::ReduceMinOp
>(op
) && isa
<IntegerType
>(elementTy
))
965 return rewriter
.getIntegerAttr(
966 elementTy
, APInt::getSignedMaxValue(elementTy
.getIntOrFloatBitWidth()));
968 if (isa
<tosa::ReduceMaxOp
>(op
) && isa
<FloatType
>(elementTy
))
969 return rewriter
.getFloatAttr(
970 elementTy
, APFloat::getLargest(
971 cast
<FloatType
>(elementTy
).getFloatSemantics(), true));
973 if (isa
<tosa::ReduceMaxOp
>(op
) && isa
<IntegerType
>(elementTy
))
974 return rewriter
.getIntegerAttr(
975 elementTy
, APInt::getSignedMinValue(elementTy
.getIntOrFloatBitWidth()));
977 if (isa
<tosa::ReduceAllOp
>(op
) && elementTy
.isInteger(1))
978 return rewriter
.getIntegerAttr(elementTy
, APInt::getAllOnes(1));
980 if (isa
<tosa::ReduceAnyOp
>(op
) && elementTy
.isInteger(1))
981 return rewriter
.getIntegerAttr(elementTy
, APInt::getZero(1));
983 if (isa
<tosa::ArgMaxOp
>(op
) && isa
<FloatType
>(elementTy
))
984 return rewriter
.getFloatAttr(
985 elementTy
, APFloat::getLargest(
986 cast
<FloatType
>(elementTy
).getFloatSemantics(), true));
988 if (isa
<tosa::ArgMaxOp
>(op
) && isa
<IntegerType
>(elementTy
))
989 return rewriter
.getIntegerAttr(
990 elementTy
, APInt::getSignedMinValue(elementTy
.getIntOrFloatBitWidth()));
995 // Creates the body calculation for a reduction. The operations vary depending
996 // on the input type.
997 static Value
createLinalgBodyCalculationForReduceOp(Operation
*op
,
1000 PatternRewriter
&rewriter
) {
1001 Location loc
= op
->getLoc();
1002 if (isa
<tosa::ReduceSumOp
>(op
) && isa
<FloatType
>(elementTy
)) {
1003 return rewriter
.create
<arith::AddFOp
>(loc
, args
);
1006 if (isa
<tosa::ReduceSumOp
>(op
) && isa
<IntegerType
>(elementTy
)) {
1007 return rewriter
.create
<arith::AddIOp
>(loc
, args
);
1010 if (isa
<tosa::ReduceProdOp
>(op
) && isa
<FloatType
>(elementTy
)) {
1011 return rewriter
.create
<arith::MulFOp
>(loc
, args
);
1014 if (isa
<tosa::ReduceProdOp
>(op
) && isa
<IntegerType
>(elementTy
)) {
1015 return rewriter
.create
<arith::MulIOp
>(loc
, args
);
1018 if (isa
<tosa::ReduceMinOp
>(op
) && isa
<FloatType
>(elementTy
)) {
1019 return rewriter
.create
<arith::MinimumFOp
>(loc
, args
[0], args
[1]);
1022 if (isa
<tosa::ReduceMinOp
>(op
) && isa
<IntegerType
>(elementTy
)) {
1023 return rewriter
.create
<arith::MinSIOp
>(loc
, args
[0], args
[1]);
1026 if (isa
<tosa::ReduceMaxOp
>(op
) && isa
<FloatType
>(elementTy
)) {
1027 return rewriter
.create
<arith::MaximumFOp
>(loc
, args
[0], args
[1]);
1030 if (isa
<tosa::ReduceMaxOp
>(op
) && isa
<IntegerType
>(elementTy
)) {
1031 return rewriter
.create
<arith::MaxSIOp
>(loc
, args
[0], args
[1]);
1034 if (isa
<tosa::ReduceAllOp
>(op
) && elementTy
.isInteger(1))
1035 return rewriter
.create
<arith::AndIOp
>(loc
, args
);
1037 if (isa
<tosa::ReduceAnyOp
>(op
) && elementTy
.isInteger(1))
1038 return rewriter
.create
<arith::OrIOp
>(loc
, args
);
1043 // Performs the match and rewrite for reduction operations. This includes
1044 // declaring a correctly sized initial value, and the linalg.generic operation
1045 // that reduces across the specified axis.
1046 static LogicalResult
reduceMatchAndRewriteHelper(Operation
*op
, uint64_t axis
,
1047 PatternRewriter
&rewriter
) {
1048 auto loc
= op
->getLoc();
1049 auto inputTy
= cast
<ShapedType
>(op
->getOperand(0).getType());
1050 auto resultTy
= cast
<ShapedType
>(op
->getResult(0).getType());
1051 auto elementTy
= resultTy
.getElementType();
1052 Value input
= op
->getOperand(0);
1054 SmallVector
<int64_t> reduceShape
;
1055 SmallVector
<Value
> dynDims
;
1056 for (unsigned i
= 0; i
< inputTy
.getRank(); i
++) {
1058 reduceShape
.push_back(inputTy
.getDimSize(i
));
1059 if (inputTy
.isDynamicDim(i
))
1060 dynDims
.push_back(rewriter
.create
<tensor::DimOp
>(loc
, input
, i
));
1064 // First fill the output buffer with the init value.
1067 .create
<tensor::EmptyOp
>(loc
, reduceShape
, resultTy
.getElementType(),
1071 auto fillValueAttr
= createInitialValueForReduceOp(op
, elementTy
, rewriter
);
1073 return rewriter
.notifyMatchFailure(
1074 op
, "No initial value found for reduction operation");
1076 auto fillValue
= rewriter
.create
<arith::ConstantOp
>(loc
, fillValueAttr
);
1077 auto filledTensor
= rewriter
1078 .create
<linalg::FillOp
>(loc
, ValueRange
{fillValue
},
1079 ValueRange
{emptyTensor
})
1082 bool didEncounterError
= false;
1083 auto linalgOp
= rewriter
.create
<linalg::ReduceOp
>(
1084 loc
, input
, filledTensor
, axis
,
1085 [&](OpBuilder
&nestedBuilder
, Location nestedLoc
, ValueRange blockArgs
) {
1086 auto result
= createLinalgBodyCalculationForReduceOp(
1087 op
, blockArgs
, elementTy
, rewriter
);
1089 didEncounterError
= true;
1091 nestedBuilder
.create
<linalg::YieldOp
>(loc
, result
);
1094 if (!didEncounterError
)
1095 return rewriter
.notifyMatchFailure(
1096 op
, "unable to create linalg.generic body for reduce op");
1098 SmallVector
<ReassociationExprs
, 4> reassociationMap
;
1099 uint64_t expandInputRank
=
1100 cast
<ShapedType
>(linalgOp
.getResults()[0].getType()).getRank();
1101 reassociationMap
.resize(expandInputRank
);
1103 for (uint64_t i
= 0; i
< expandInputRank
; i
++) {
1104 int32_t dimToPush
= i
> axis
? i
+ 1 : i
;
1105 reassociationMap
[i
].push_back(rewriter
.getAffineDimExpr(dimToPush
));
1108 if (expandInputRank
!= 0) {
1109 int32_t expandedDim
= axis
< expandInputRank
? axis
: expandInputRank
- 1;
1110 reassociationMap
[expandedDim
].push_back(
1111 rewriter
.getAffineDimExpr(expandedDim
+ 1));
1114 // Lower directly to `tensor::ExpandShapeOp` instead of `tosa::ReshapeOp`,
1115 // since here we know which dimension to expand, and `tosa::ReshapeOp` would
1116 // not have access to such information. This matters when handling dynamically
1118 rewriter
.replaceOpWithNewOp
<tensor::ExpandShapeOp
>(
1119 op
, resultTy
, linalgOp
.getResults()[0], reassociationMap
);
1125 template <typename SrcOp
>
1126 class PointwiseConverter
: public OpConversionPattern
<SrcOp
> {
1128 using OpConversionPattern
<SrcOp
>::OpConversionPattern
;
1129 using typename OpConversionPattern
<SrcOp
>::OpAdaptor
;
1132 matchAndRewrite(SrcOp op
, OpAdaptor operands
,
1133 ConversionPatternRewriter
&rewriter
) const final
{
1134 return elementwiseMatchAndRewriteHelper(
1135 op
, operands
.getOperands(), rewriter
, *this->getTypeConverter());
1139 class RescaleConverter
: public OpRewritePattern
<tosa::RescaleOp
> {
1141 using OpRewritePattern
<tosa::RescaleOp
>::OpRewritePattern
;
1143 LogicalResult
matchAndRewrite(tosa::RescaleOp op
,
1144 PatternRewriter
&rewriter
) const final
{
1145 auto loc
= op
.getLoc();
1146 auto input
= op
.getInput();
1147 auto inputTy
= cast
<ShapedType
>(op
.getInput().getType());
1148 auto outputTy
= cast
<ShapedType
>(op
.getOutput().getType());
1149 unsigned rank
= inputTy
.getRank();
1151 // This is an illegal configuration. terminate and log an error
1152 if (op
.getDoubleRound() && !op
.getScale32())
1153 return rewriter
.notifyMatchFailure(
1154 op
, "tosa.rescale requires scale32 for double_round to be true");
1156 if (!isa
<IntegerType
>(inputTy
.getElementType()))
1157 return rewriter
.notifyMatchFailure(op
, "only support integer type");
1159 SmallVector
<Value
> dynDims
;
1160 for (int i
= 0; i
< outputTy
.getRank(); i
++) {
1161 if (outputTy
.isDynamicDim(i
)) {
1162 dynDims
.push_back(rewriter
.create
<tensor::DimOp
>(loc
, input
, i
));
1166 // The shift and multiplier values.
1167 SmallVector
<int32_t> multiplierValues(op
.getMultiplier());
1168 SmallVector
<int8_t> shiftValues(op
.getShift());
1170 // If we shift by more than the bitwidth, this just sets to 0.
1171 for (int i
= 0, s
= multiplierValues
.size(); i
< s
; i
++) {
1172 if (shiftValues
[i
] > 63) {
1174 multiplierValues
[i
] = 0;
1178 // Double round only occurs if shift is greater than 31, check that this
1181 op
.getDoubleRound() &&
1182 llvm::any_of(shiftValues
, [](int32_t v
) { return v
> 31; });
1184 SmallVector
<AffineMap
> indexingMaps
= {
1185 rewriter
.getMultiDimIdentityMap(rank
)};
1186 SmallVector
<Value
, 4> genericInputs
= {input
};
1188 // If we are rescaling per-channel then we need to store the multiplier
1189 // values in a buffer.
1190 Value multiplierConstant
;
1191 int64_t multiplierArg
= 0;
1192 if (multiplierValues
.size() == 1) {
1193 multiplierConstant
= rewriter
.create
<arith::ConstantOp
>(
1194 loc
, rewriter
.getI32IntegerAttr(multiplierValues
.front()));
1196 SmallVector
<AffineExpr
, 2> multiplierExprs
{
1197 rewriter
.getAffineDimExpr(rank
- 1)};
1198 auto multiplierType
=
1199 RankedTensorType::get({static_cast<int64_t>(multiplierValues
.size())},
1200 rewriter
.getI32Type());
1201 genericInputs
.push_back(rewriter
.create
<arith::ConstantOp
>(
1202 loc
, DenseIntElementsAttr::get(multiplierType
, multiplierValues
)));
1204 indexingMaps
.push_back(AffineMap::get(/*dimCount=*/rank
,
1205 /*symbolCount=*/0, multiplierExprs
,
1206 rewriter
.getContext()));
1208 multiplierArg
= indexingMaps
.size() - 1;
1211 // If we are rescaling per-channel then we need to store the shift
1212 // values in a buffer.
1213 Value shiftConstant
;
1214 int64_t shiftArg
= 0;
1215 if (shiftValues
.size() == 1) {
1216 shiftConstant
= rewriter
.create
<arith::ConstantOp
>(
1217 loc
, rewriter
.getI8IntegerAttr(shiftValues
.front()));
1219 SmallVector
<AffineExpr
, 2> shiftExprs
= {
1220 rewriter
.getAffineDimExpr(rank
- 1)};
1222 RankedTensorType::get({static_cast<int64_t>(shiftValues
.size())},
1223 rewriter
.getIntegerType(8));
1224 genericInputs
.push_back(rewriter
.create
<arith::ConstantOp
>(
1225 loc
, DenseIntElementsAttr::get(shiftType
, shiftValues
)));
1226 indexingMaps
.push_back(AffineMap::get(/*dimCount=*/rank
,
1227 /*symbolCount=*/0, shiftExprs
,
1228 rewriter
.getContext()));
1229 shiftArg
= indexingMaps
.size() - 1;
1232 // Indexing maps for output values.
1233 indexingMaps
.push_back(rewriter
.getMultiDimIdentityMap(rank
));
1235 // Construct the indexing maps needed for linalg.generic ops.
1236 Value emptyTensor
= rewriter
.create
<tensor::EmptyOp
>(
1237 loc
, outputTy
.getShape(), outputTy
.getElementType(),
1238 ArrayRef
<Value
>({dynDims
}));
1240 auto linalgOp
= rewriter
.create
<linalg::GenericOp
>(
1241 loc
, outputTy
, genericInputs
, ValueRange
{emptyTensor
}, indexingMaps
,
1242 getNParallelLoopsAttrs(rank
),
1243 [&](OpBuilder
&nestedBuilder
, Location nestedLoc
,
1244 ValueRange blockArgs
) {
1245 Value value
= blockArgs
[0];
1246 Type valueTy
= value
.getType();
1248 // For now we do all of our math in 64-bit. This is not optimal but
1249 // should be correct for now, consider computing correct bit depth
1251 int32_t inBitwidth
= valueTy
.getIntOrFloatBitWidth() > 32 ? 48 : 32;
1253 auto inputZp
= createConstFromIntAttribute
<int32_t>(
1254 op
, "input_zp", nestedBuilder
.getIntegerType(inBitwidth
),
1256 auto outputZp
= createConstFromIntAttribute
<int32_t>(
1257 op
, "output_zp", nestedBuilder
.getI32Type(), nestedBuilder
);
1259 Value multiplier
= multiplierConstant
? multiplierConstant
1260 : blockArgs
[multiplierArg
];
1261 Value shift
= shiftConstant
? shiftConstant
: blockArgs
[shiftArg
];
1263 if (valueTy
.getIntOrFloatBitWidth() < 32) {
1264 if (valueTy
.isUnsignedInteger()) {
1265 value
= nestedBuilder
1266 .create
<UnrealizedConversionCastOp
>(
1268 nestedBuilder
.getIntegerType(
1269 valueTy
.getIntOrFloatBitWidth()),
1272 value
= nestedBuilder
.create
<arith::ExtUIOp
>(
1273 nestedLoc
, nestedBuilder
.getI32Type(), value
);
1275 value
= nestedBuilder
.create
<arith::ExtSIOp
>(
1276 nestedLoc
, nestedBuilder
.getI32Type(), value
);
1281 nestedBuilder
.create
<arith::SubIOp
>(nestedLoc
, value
, inputZp
);
1283 value
= nestedBuilder
.create
<tosa::ApplyScaleOp
>(
1284 loc
, nestedBuilder
.getI32Type(), value
, multiplier
, shift
,
1285 nestedBuilder
.getBoolAttr(doubleRound
));
1287 // Move to the new zero-point.
1289 nestedBuilder
.create
<arith::AddIOp
>(nestedLoc
, value
, outputZp
);
1291 // Saturate to the output size.
1292 IntegerType outIntType
=
1293 cast
<IntegerType
>(blockArgs
.back().getType());
1294 unsigned outBitWidth
= outIntType
.getWidth();
1296 int32_t intMin
= APInt::getSignedMinValue(outBitWidth
).getSExtValue();
1297 int32_t intMax
= APInt::getSignedMaxValue(outBitWidth
).getSExtValue();
1299 // Unsigned integers have a difference output value.
1300 if (outIntType
.isUnsignedInteger()) {
1302 intMax
= APInt::getMaxValue(outBitWidth
).getZExtValue();
1305 auto intMinVal
= nestedBuilder
.create
<arith::ConstantOp
>(
1306 loc
, nestedBuilder
.getI32IntegerAttr(intMin
));
1307 auto intMaxVal
= nestedBuilder
.create
<arith::ConstantOp
>(
1308 loc
, nestedBuilder
.getI32IntegerAttr(intMax
));
1310 value
= clampIntHelper(nestedLoc
, value
, intMinVal
, intMaxVal
,
1311 nestedBuilder
, /*isUnsigned=*/false);
1313 if (outIntType
.getWidth() < 32) {
1314 value
= nestedBuilder
.create
<arith::TruncIOp
>(
1315 nestedLoc
, rewriter
.getIntegerType(outIntType
.getWidth()),
1318 if (outIntType
.isUnsignedInteger()) {
1319 value
= nestedBuilder
1320 .create
<UnrealizedConversionCastOp
>(nestedLoc
,
1326 nestedBuilder
.create
<linalg::YieldOp
>(loc
, value
);
1329 rewriter
.replaceOp(op
, linalgOp
->getResults());
1334 // Handle the resize case where the input is a 1x1 image. This case
1335 // can entirely avoiding having extract operations which target much
1336 // more difficult to optimize away.
1337 class ResizeUnaryConverter
: public OpRewritePattern
<tosa::ResizeOp
> {
1339 using OpRewritePattern
<tosa::ResizeOp
>::OpRewritePattern
;
1341 LogicalResult
matchAndRewrite(tosa::ResizeOp op
,
1342 PatternRewriter
&rewriter
) const final
{
1343 Location loc
= op
.getLoc();
1344 ImplicitLocOpBuilder
builder(loc
, rewriter
);
1345 auto input
= op
.getInput();
1346 auto inputTy
= cast
<RankedTensorType
>(input
.getType());
1347 auto resultTy
= cast
<RankedTensorType
>(op
.getType());
1348 const bool isBilinear
= op
.getMode() == "BILINEAR";
1350 auto inputH
= inputTy
.getDimSize(1);
1351 auto inputW
= inputTy
.getDimSize(2);
1352 auto outputH
= resultTy
.getDimSize(1);
1353 auto outputW
= resultTy
.getDimSize(2);
1355 if (inputH
!= 1 || inputW
!= 1 || outputH
!= 1 || outputW
!= 1)
1356 return rewriter
.notifyMatchFailure(
1357 op
, "tosa.resize is not a pure 1x1->1x1 image operation");
1359 // TODO(suderman): These string values should be declared the TOSA dialect.
1360 if (op
.getMode() != "NEAREST_NEIGHBOR" && op
.getMode() != "BILINEAR")
1361 return rewriter
.notifyMatchFailure(
1362 op
, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1364 if (inputTy
== resultTy
) {
1365 rewriter
.replaceOp(op
, input
);
1369 ArrayRef
<int64_t> scale
= op
.getScale();
1371 // Collapse the unit width and height away.
1372 SmallVector
<ReassociationExprs
, 4> reassociationMap(2);
1373 reassociationMap
[0].push_back(builder
.getAffineDimExpr(0));
1374 reassociationMap
[1].push_back(builder
.getAffineDimExpr(1));
1375 reassociationMap
[1].push_back(builder
.getAffineDimExpr(2));
1376 reassociationMap
[1].push_back(builder
.getAffineDimExpr(3));
1379 RankedTensorType::get({inputTy
.getDimSize(0), inputTy
.getDimSize(3)},
1380 inputTy
.getElementType());
1381 Value collapse
= builder
.create
<tensor::CollapseShapeOp
>(collapseTy
, input
,
1384 // Get any dynamic shapes that appear in the input format.
1385 llvm::SmallVector
<Value
> outputDynSize
;
1386 if (inputTy
.isDynamicDim(0))
1387 outputDynSize
.push_back(builder
.create
<tensor::DimOp
>(input
, 0));
1388 if (inputTy
.isDynamicDim(3))
1389 outputDynSize
.push_back(builder
.create
<tensor::DimOp
>(input
, 3));
1391 // Generate the elementwise operation for casting scaling the input value.
1392 auto genericTy
= collapseTy
.clone(resultTy
.getElementType());
1393 Value empty
= builder
.create
<tensor::EmptyOp
>(
1394 genericTy
.getShape(), resultTy
.getElementType(), outputDynSize
);
1395 auto genericMap
= rewriter
.getMultiDimIdentityMap(genericTy
.getRank());
1396 SmallVector
<utils::IteratorType
> iterators(genericTy
.getRank(),
1397 utils::IteratorType::parallel
);
1399 auto generic
= builder
.create
<linalg::GenericOp
>(
1400 genericTy
, ValueRange
{collapse
}, ValueRange
{empty
},
1401 ArrayRef
<AffineMap
>{genericMap
, genericMap
}, iterators
,
1402 [=](OpBuilder
&b
, Location loc
, ValueRange args
) {
1403 Value value
= args
[0];
1404 // This is the quantized case.
1405 if (inputTy
.getElementType() != resultTy
.getElementType()) {
1407 b
.create
<arith::ExtSIOp
>(loc
, resultTy
.getElementType(), value
);
1409 if (isBilinear
&& scale
[0] != 0) {
1410 Value scaleY
= b
.create
<arith::ConstantOp
>(
1411 loc
, b
.getI32IntegerAttr(scale
[0]));
1412 value
= b
.create
<arith::MulIOp
>(loc
, value
, scaleY
);
1415 if (isBilinear
&& scale
[2] != 0) {
1416 Value scaleX
= b
.create
<arith::ConstantOp
>(
1417 loc
, b
.getI32IntegerAttr(scale
[2]));
1418 value
= b
.create
<arith::MulIOp
>(loc
, value
, scaleX
);
1422 b
.create
<linalg::YieldOp
>(loc
, value
);
1425 rewriter
.replaceOpWithNewOp
<tensor::ExpandShapeOp
>(
1426 op
, resultTy
, generic
.getResults()[0], reassociationMap
);
1431 // TOSA resize with width or height of 1 may be broadcasted to a wider
1432 // dimension. This is done by materializing a new tosa.resize without
1433 // the broadcasting behavior, and an explicit broadcast afterwards.
1434 class MaterializeResizeBroadcast
: public OpRewritePattern
<tosa::ResizeOp
> {
1436 using OpRewritePattern
<tosa::ResizeOp
>::OpRewritePattern
;
1438 LogicalResult
matchAndRewrite(tosa::ResizeOp op
,
1439 PatternRewriter
&rewriter
) const final
{
1440 Location loc
= op
.getLoc();
1441 ImplicitLocOpBuilder
builder(loc
, rewriter
);
1442 auto input
= op
.getInput();
1443 auto inputTy
= dyn_cast
<RankedTensorType
>(input
.getType());
1444 auto resultTy
= dyn_cast
<RankedTensorType
>(op
.getType());
1446 if (!inputTy
|| !resultTy
)
1447 return rewriter
.notifyMatchFailure(op
,
1448 "requires ranked input/output types");
1450 auto batch
= inputTy
.getDimSize(0);
1451 auto channels
= inputTy
.getDimSize(3);
1452 auto inputH
= inputTy
.getDimSize(1);
1453 auto inputW
= inputTy
.getDimSize(2);
1454 auto outputH
= resultTy
.getDimSize(1);
1455 auto outputW
= resultTy
.getDimSize(2);
1457 if ((inputH
!= 1 || outputH
== 1) && (inputW
!= 1 || outputW
== 1))
1458 return rewriter
.notifyMatchFailure(
1459 op
, "tosa.resize has no broadcasting behavior");
1461 // For any dimension that is broadcastable we generate a width of 1
1463 llvm::SmallVector
<int64_t> resizeShape
;
1464 resizeShape
.push_back(batch
);
1465 resizeShape
.push_back(inputH
== 1 ? 1 : outputH
);
1466 resizeShape
.push_back(inputW
== 1 ? 1 : outputW
);
1467 resizeShape
.push_back(channels
);
1469 auto resizeTy
= resultTy
.clone(resizeShape
);
1471 builder
.create
<tosa::ResizeOp
>(resizeTy
, input
, op
->getAttrs());
1473 // Collapse an unit result dims.
1474 SmallVector
<ReassociationExprs
, 4> reassociationMap(2);
1475 reassociationMap
[0].push_back(builder
.getAffineDimExpr(0));
1476 reassociationMap
.back().push_back(builder
.getAffineDimExpr(1));
1478 reassociationMap
.push_back({});
1479 reassociationMap
.back().push_back(builder
.getAffineDimExpr(2));
1481 reassociationMap
.push_back({});
1482 reassociationMap
.back().push_back(builder
.getAffineDimExpr(3));
1484 llvm::SmallVector
<int64_t> collapseShape
{batch
};
1486 collapseShape
.push_back(outputH
);
1488 collapseShape
.push_back(outputW
);
1489 collapseShape
.push_back(channels
);
1491 auto collapseTy
= resultTy
.clone(collapseShape
);
1492 Value collapse
= builder
.create
<tensor::CollapseShapeOp
>(collapseTy
, resize
,
1495 // Broadcast the collapsed shape to the output result.
1496 llvm::SmallVector
<Value
> outputDynSize
;
1497 if (inputTy
.isDynamicDim(0))
1498 outputDynSize
.push_back(builder
.create
<tensor::DimOp
>(input
, 0));
1499 if (inputTy
.isDynamicDim(3))
1500 outputDynSize
.push_back(builder
.create
<tensor::DimOp
>(input
, 3));
1502 SmallVector
<utils::IteratorType
> iterators(resultTy
.getRank(),
1503 utils::IteratorType::parallel
);
1504 Value empty
= builder
.create
<tensor::EmptyOp
>(
1505 resultTy
.getShape(), resultTy
.getElementType(), outputDynSize
);
1507 SmallVector
<AffineExpr
, 4> inputExprs
{rewriter
.getAffineDimExpr(0)};
1509 inputExprs
.push_back(rewriter
.getAffineDimExpr(1));
1511 inputExprs
.push_back(rewriter
.getAffineDimExpr(2));
1512 inputExprs
.push_back(rewriter
.getAffineDimExpr(3));
1514 auto inputMap
= AffineMap::get(resultTy
.getRank(), /*symbolCount=*/0,
1515 inputExprs
, rewriter
.getContext());
1517 auto outputMap
= rewriter
.getMultiDimIdentityMap(resultTy
.getRank());
1518 rewriter
.replaceOpWithNewOp
<linalg::GenericOp
>(
1519 op
, resultTy
, ValueRange
{collapse
}, ValueRange
{empty
},
1520 ArrayRef
<AffineMap
>{inputMap
, outputMap
}, iterators
,
1521 [=](OpBuilder
&b
, Location loc
, ValueRange args
) {
1522 Value value
= args
[0];
1523 b
.create
<linalg::YieldOp
>(loc
, value
);
1530 class GenericResizeConverter
: public OpRewritePattern
<tosa::ResizeOp
> {
1532 using OpRewritePattern
<tosa::ResizeOp
>::OpRewritePattern
;
1534 LogicalResult
matchAndRewrite(tosa::ResizeOp op
,
1535 PatternRewriter
&rewriter
) const final
{
1536 Location loc
= op
.getLoc();
1537 ImplicitLocOpBuilder
b(loc
, rewriter
);
1538 auto input
= op
.getInput();
1539 auto inputTy
= cast
<ShapedType
>(input
.getType());
1540 auto resultTy
= cast
<ShapedType
>(op
.getType());
1541 auto resultETy
= resultTy
.getElementType();
1543 bool floatingPointMode
= resultETy
.isF16() || resultETy
.isF32();
1544 auto floatTy
= resultETy
.isF16() ? b
.getF16Type() : b
.getF32Type();
1546 auto imageH
= inputTy
.getShape()[1];
1547 auto imageW
= inputTy
.getShape()[2];
1549 auto dynamicDimsOr
=
1550 checkHasDynamicBatchDims(rewriter
, op
, {input
, op
.getOutput()});
1551 if (!dynamicDimsOr
.has_value())
1552 return rewriter
.notifyMatchFailure(
1553 op
, "unable to get dynamic dimensions of tosa.resize");
1555 if (op
.getMode() != "NEAREST_NEIGHBOR" && op
.getMode() != "BILINEAR")
1556 return rewriter
.notifyMatchFailure(
1557 op
, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1559 SmallVector
<AffineMap
, 2> affineMaps
= {
1560 rewriter
.getMultiDimIdentityMap(resultTy
.getRank())};
1561 auto emptyTensor
= b
.create
<tensor::EmptyOp
>(resultTy
.getShape(), resultETy
,
1563 auto genericOp
= b
.create
<linalg::GenericOp
>(
1564 resultTy
, ValueRange({}), ValueRange
{emptyTensor
}, affineMaps
,
1565 getNParallelLoopsAttrs(resultTy
.getRank()));
1566 Value resize
= genericOp
.getResult(0);
1569 OpBuilder::InsertionGuard
regionGuard(b
);
1570 b
.createBlock(&genericOp
.getRegion(), genericOp
.getRegion().end(),
1571 TypeRange({resultETy
}), loc
);
1572 Value batch
= b
.create
<linalg::IndexOp
>(0);
1573 Value y
= b
.create
<linalg::IndexOp
>(1);
1574 Value x
= b
.create
<linalg::IndexOp
>(2);
1575 Value channel
= b
.create
<linalg::IndexOp
>(3);
1578 b
.create
<arith::ConstantOp
>(b
.getZeroAttr(b
.getI32Type()));
1579 Value zeroFp
= b
.create
<arith::ConstantOp
>(b
.getZeroAttr(floatTy
));
1580 Value hMax
= b
.create
<arith::ConstantOp
>(b
.getI32IntegerAttr(imageH
- 1));
1581 Value wMax
= b
.create
<arith::ConstantOp
>(b
.getI32IntegerAttr(imageW
- 1));
1583 Value inY
= b
.create
<arith::IndexCastOp
>(b
.getI32Type(), y
);
1584 Value inX
= b
.create
<arith::IndexCastOp
>(b
.getI32Type(), x
);
1586 ArrayRef
<int64_t> offset
= op
.getOffset();
1587 ArrayRef
<int64_t> border
= op
.getBorder();
1588 ArrayRef
<int64_t> scale
= op
.getScale();
1590 Value yScaleN
, yScaleD
, xScaleN
, xScaleD
;
1591 yScaleN
= b
.create
<arith::ConstantOp
>(b
.getI32IntegerAttr(scale
[0]));
1592 yScaleD
= b
.create
<arith::ConstantOp
>(b
.getI32IntegerAttr(scale
[1]));
1593 xScaleN
= b
.create
<arith::ConstantOp
>(b
.getI32IntegerAttr(scale
[2]));
1594 xScaleD
= b
.create
<arith::ConstantOp
>(b
.getI32IntegerAttr(scale
[3]));
1596 Value yOffset
, xOffset
, yBorder
, xBorder
;
1597 yOffset
= b
.create
<arith::ConstantOp
>(b
.getI32IntegerAttr(offset
[0]));
1598 xOffset
= b
.create
<arith::ConstantOp
>(b
.getI32IntegerAttr(offset
[1]));
1599 yBorder
= b
.create
<arith::ConstantOp
>(b
.getI32IntegerAttr(border
[0]));
1600 xBorder
= b
.create
<arith::ConstantOp
>(b
.getI32IntegerAttr(border
[1]));
1602 // Compute the ix and dx values for both the X and Y dimensions.
1603 auto getIndexAndDeltaFp
= [&](Value
&index
, Value
&delta
, Value in
,
1604 Value scaleN
, Value scaleD
, Value offset
,
1605 int size
, ImplicitLocOpBuilder
&b
) {
1611 // x = x * scale_d + offset;
1612 // ix = floor(x / scale_n)
1613 Value val
= b
.create
<arith::MulIOp
>(in
, scaleD
);
1614 val
= b
.create
<arith::AddIOp
>(val
, offset
);
1615 index
= b
.create
<arith::FloorDivSIOp
>(val
, scaleN
);
1618 // dx = rx / scale_n
1619 Value r
= b
.create
<arith::RemSIOp
>(val
, scaleN
);
1620 Value rFp
= b
.create
<arith::SIToFPOp
>(floatTy
, r
);
1621 Value scaleNfp
= b
.create
<arith::UIToFPOp
>(floatTy
, scaleN
);
1622 delta
= b
.create
<arith::DivFOp
>(rFp
, scaleNfp
);
1625 // Compute the ix and dx values for the X and Y dimensions - int case.
1626 auto getIndexAndDeltaInt
= [&](Value
&index
, Value
&delta
, Value in
,
1627 Value scaleN
, Value scaleD
, Value offset
,
1628 int size
, ImplicitLocOpBuilder
&b
) {
1634 // x = x * scale_d + offset;
1635 // ix = floor(x / scale_n)
1636 // dx = x - ix * scale_n;
1637 Value val
= b
.create
<arith::MulIOp
>(in
, scaleD
);
1638 val
= b
.create
<arith::AddIOp
>(val
, offset
);
1639 index
= b
.create
<arith::DivSIOp
>(val
, scaleN
);
1640 delta
= b
.create
<arith::MulIOp
>(index
, scaleN
);
1641 delta
= b
.create
<arith::SubIOp
>(val
, delta
);
1644 Value ix
, iy
, dx
, dy
;
1645 if (floatingPointMode
) {
1646 getIndexAndDeltaFp(iy
, dy
, inY
, yScaleN
, yScaleD
, yOffset
, imageH
, b
);
1647 getIndexAndDeltaFp(ix
, dx
, inX
, xScaleN
, xScaleD
, xOffset
, imageW
, b
);
1649 getIndexAndDeltaInt(iy
, dy
, inY
, yScaleN
, yScaleD
, yOffset
, imageH
, b
);
1650 getIndexAndDeltaInt(ix
, dx
, inX
, xScaleN
, xScaleD
, xOffset
, imageW
, b
);
1653 if (op
.getMode() == "NEAREST_NEIGHBOR") {
1654 auto one
= b
.create
<arith::ConstantOp
>(b
.getI32IntegerAttr(1));
1656 auto getNearestIndexAndClamp
= [&](Value val
, Value dval
, Value scale
,
1657 Value max
, int size
,
1658 ImplicitLocOpBuilder
&b
) -> Value
{
1660 return b
.create
<arith::ConstantIndexOp
>(0);
1664 if (floatingPointMode
) {
1665 auto h
= b
.create
<arith::ConstantOp
>(b
.getFloatAttr(floatTy
, 0.5f
));
1666 pred
= b
.create
<arith::CmpFOp
>(arith::CmpFPredicate::OGE
, dval
, h
);
1668 Value dvalDouble
= b
.create
<arith::ShLIOp
>(dval
, one
);
1669 pred
= b
.create
<arith::CmpIOp
>(arith::CmpIPredicate::sge
,
1673 auto offset
= b
.create
<arith::SelectOp
>(pred
, one
, zeroI32
);
1674 val
= b
.create
<arith::AddIOp
>(val
, offset
);
1675 val
= clampIntHelper(loc
, val
, zeroI32
, max
, b
, /*isUnsigned=*/false);
1676 return b
.create
<arith::IndexCastOp
>(b
.getIndexType(), val
);
1679 iy
= getNearestIndexAndClamp(iy
, dy
, yScaleN
, hMax
, imageH
, b
);
1680 ix
= getNearestIndexAndClamp(ix
, dx
, xScaleN
, wMax
, imageW
, b
);
1682 Value result
= b
.create
<tensor::ExtractOp
>(
1683 input
, ValueRange
{batch
, iy
, ix
, channel
});
1685 b
.create
<linalg::YieldOp
>(result
);
1687 // The mode here must be BILINEAR.
1688 assert(op
.getMode() == "BILINEAR");
1690 auto oneVal
= b
.create
<arith::ConstantOp
>(b
.getI32IntegerAttr(1));
1692 auto getClampedIdxs
= [&](Value
&val0
, Value
&val1
, int size
, Value in
,
1693 Value max
, ImplicitLocOpBuilder
&b
) {
1695 val1
= b
.create
<arith::AddIOp
>(val0
, oneVal
);
1697 clampIntHelper(loc
, val0
, zeroI32
, max
, b
, /*isUnsigned=*/false);
1699 clampIntHelper(loc
, val1
, zeroI32
, max
, b
, /*isUnsigned=*/false);
1700 val0
= b
.create
<arith::IndexCastOp
>(b
.getIndexType(), val0
);
1701 val1
= b
.create
<arith::IndexCastOp
>(b
.getIndexType(), val1
);
1704 // Linalg equivalent to the section below:
1705 // int16_t iy0 = apply_max(iy, 0);
1706 // int16_t iy1 = apply_min(iy + 1, IH - 1);
1707 // int16_t ix0 = apply_max(ix, 0);
1708 // int16_t ix1 = apply_min(ix + 1, IW - 1);
1709 Value x0
, x1
, y0
, y1
;
1710 getClampedIdxs(y0
, y1
, imageH
, iy
, hMax
, b
);
1711 getClampedIdxs(x0
, x1
, imageW
, ix
, wMax
, b
);
1713 Value y0x0
= b
.create
<tensor::ExtractOp
>(
1714 input
, ValueRange
{batch
, y0
, x0
, channel
});
1715 Value y0x1
= b
.create
<tensor::ExtractOp
>(
1716 input
, ValueRange
{batch
, y0
, x1
, channel
});
1717 Value y1x0
= b
.create
<tensor::ExtractOp
>(
1718 input
, ValueRange
{batch
, y1
, x0
, channel
});
1719 Value y1x1
= b
.create
<tensor::ExtractOp
>(
1720 input
, ValueRange
{batch
, y1
, x1
, channel
});
1722 if (floatingPointMode
) {
1724 b
.create
<arith::ConstantOp
>(b
.getFloatAttr(floatTy
, 1.0f
));
1725 auto interpolate
= [&](Value val0
, Value val1
, Value delta
,
1727 ImplicitLocOpBuilder
&b
) -> Value
{
1730 Value oneMinusDelta
= b
.create
<arith::SubFOp
>(oneVal
, delta
);
1731 Value mul0
= b
.create
<arith::MulFOp
>(val0
, oneMinusDelta
);
1732 Value mul1
= b
.create
<arith::MulFOp
>(val1
, delta
);
1733 return b
.create
<arith::AddFOp
>(mul0
, mul1
);
1736 // Linalg equivalent to the section below:
1737 // topAcc = v00 * (unit_x - dx);
1738 // topAcc += v01 * dx;
1739 Value topAcc
= interpolate(y0x0
, y0x1
, dx
, imageW
, b
);
1741 // Linalg equivalent to the section below:
1742 // bottomAcc = v10 * (unit_x - dx);
1743 // bottomAcc += v11 * dx;
1744 Value bottomAcc
= interpolate(y1x0
, y1x1
, dx
, imageW
, b
);
1746 // Linalg equivalent to the section below:
1747 // result = topAcc * (unit_y - dy) + bottomAcc * dy
1748 Value result
= interpolate(topAcc
, bottomAcc
, dy
, imageH
, b
);
1749 b
.create
<linalg::YieldOp
>(result
);
1751 // Perform in quantized space.
1752 y0x0
= b
.create
<arith::ExtSIOp
>(resultETy
, y0x0
);
1753 y0x1
= b
.create
<arith::ExtSIOp
>(resultETy
, y0x1
);
1754 y1x0
= b
.create
<arith::ExtSIOp
>(resultETy
, y1x0
);
1755 y1x1
= b
.create
<arith::ExtSIOp
>(resultETy
, y1x1
);
1757 const int64_t deltaBitwidth
= dx
.getType().getIntOrFloatBitWidth();
1758 if (resultETy
.getIntOrFloatBitWidth() > deltaBitwidth
) {
1759 dx
= b
.create
<arith::ExtSIOp
>(resultETy
, dx
);
1760 dy
= b
.create
<arith::ExtSIOp
>(resultETy
, dy
);
1763 Value yScaleNExt
= yScaleN
;
1764 Value xScaleNExt
= xScaleN
;
1766 const int64_t scaleBitwidth
=
1767 xScaleN
.getType().getIntOrFloatBitWidth();
1768 if (resultETy
.getIntOrFloatBitWidth() > scaleBitwidth
) {
1769 yScaleNExt
= b
.create
<arith::ExtSIOp
>(resultETy
, yScaleN
);
1770 xScaleNExt
= b
.create
<arith::ExtSIOp
>(resultETy
, xScaleN
);
1773 auto interpolate
= [](Value val0
, Value val1
, Value weight1
,
1774 Value scale
, int inputSize
,
1775 ImplicitLocOpBuilder
&b
) -> Value
{
1777 return b
.create
<arith::MulIOp
>(val0
, scale
);
1778 Value weight0
= b
.create
<arith::SubIOp
>(scale
, weight1
);
1779 Value mul0
= b
.create
<arith::MulIOp
>(val0
, weight0
);
1780 Value mul1
= b
.create
<arith::MulIOp
>(val1
, weight1
);
1781 return b
.create
<arith::AddIOp
>(mul0
, mul1
);
1784 Value topAcc
= interpolate(y0x0
, y0x1
, dx
, xScaleNExt
, imageW
, b
);
1785 Value bottomAcc
= interpolate(y1x0
, y1x1
, dx
, xScaleNExt
, imageW
, b
);
1787 interpolate(topAcc
, bottomAcc
, dy
, yScaleNExt
, imageH
, b
);
1788 b
.create
<linalg::YieldOp
>(result
);
1793 rewriter
.replaceOp(op
, resize
);
1798 // At the codegen level any identity operations should be removed. Any cases
1799 // where identity is load-bearing (e.g. cross device computation) should be
1800 // handled before lowering to codegen.
1801 template <typename SrcOp
>
1802 class IdentityNConverter
: public OpRewritePattern
<SrcOp
> {
1804 using OpRewritePattern
<SrcOp
>::OpRewritePattern
;
1806 LogicalResult
matchAndRewrite(SrcOp op
,
1807 PatternRewriter
&rewriter
) const final
{
1808 rewriter
.replaceOp(op
, op
.getOperation()->getOperands());
1813 template <typename SrcOp
>
1814 class ReduceConverter
: public OpRewritePattern
<SrcOp
> {
1816 using OpRewritePattern
<SrcOp
>::OpRewritePattern
;
1818 LogicalResult
matchAndRewrite(SrcOp reduceOp
,
1819 PatternRewriter
&rewriter
) const final
{
1820 return reduceMatchAndRewriteHelper(reduceOp
, reduceOp
.getAxis(), rewriter
);
1824 class ReverseConverter
: public OpRewritePattern
<tosa::ReverseOp
> {
1826 using OpRewritePattern
<tosa::ReverseOp
>::OpRewritePattern
;
1828 LogicalResult
matchAndRewrite(tosa::ReverseOp op
,
1829 PatternRewriter
&rewriter
) const final
{
1830 auto loc
= op
.getLoc();
1831 Value input
= op
.getInput1();
1832 auto inputTy
= cast
<ShapedType
>(input
.getType());
1833 auto resultTy
= cast
<ShapedType
>(op
.getType());
1834 auto axis
= op
.getAxis();
1836 SmallVector
<Value
> dynDims
;
1837 for (int i
= 0; i
< inputTy
.getRank(); i
++) {
1838 if (inputTy
.isDynamicDim(i
)) {
1839 dynDims
.push_back(rewriter
.create
<tensor::DimOp
>(loc
, input
, i
));
1843 Value axisDimSize
= rewriter
.create
<tensor::DimOp
>(loc
, input
, axis
);
1845 // First fill the output buffer with the init value.
1846 auto emptyTensor
= rewriter
1847 .create
<tensor::EmptyOp
>(loc
, inputTy
.getShape(),
1848 inputTy
.getElementType(),
1849 ArrayRef
<Value
>({dynDims
}))
1851 SmallVector
<AffineMap
, 2> affineMaps
= {
1852 rewriter
.getMultiDimIdentityMap(resultTy
.getRank())};
1854 rewriter
.replaceOpWithNewOp
<linalg::GenericOp
>(
1855 op
, resultTy
, ArrayRef
<Value
>({}), ValueRange
{emptyTensor
}, affineMaps
,
1856 getNParallelLoopsAttrs(resultTy
.getRank()),
1857 [&](OpBuilder
&nestedBuilder
, Location nestedLoc
, ValueRange args
) {
1858 llvm::SmallVector
<Value
> indices
;
1859 for (unsigned int i
= 0; i
< inputTy
.getRank(); i
++) {
1861 rewriter
.create
<linalg::IndexOp
>(nestedLoc
, i
).getResult();
1863 auto one
= rewriter
.create
<arith::ConstantIndexOp
>(nestedLoc
, 1);
1865 rewriter
.create
<arith::SubIOp
>(nestedLoc
, axisDimSize
, one
);
1866 index
= rewriter
.create
<arith::SubIOp
>(nestedLoc
, sizeMinusOne
,
1870 indices
.push_back(index
);
1873 auto extract
= nestedBuilder
.create
<tensor::ExtractOp
>(
1874 nestedLoc
, input
, indices
);
1875 nestedBuilder
.create
<linalg::YieldOp
>(op
.getLoc(),
1876 extract
.getResult());
1882 // This converter translate a tile operation to a reshape, broadcast, reshape.
1883 // The first reshape minimally expands each tiled dimension to include a
1884 // proceding size-1 dim. This dim is then broadcasted to the appropriate
1886 struct TileConverter
: public OpConversionPattern
<tosa::TileOp
> {
1887 using OpConversionPattern
<tosa::TileOp
>::OpConversionPattern
;
1890 matchAndRewrite(tosa::TileOp op
, OpAdaptor adaptor
,
1891 ConversionPatternRewriter
&rewriter
) const override
{
1892 auto loc
= op
.getLoc();
1893 auto input
= op
.getInput1();
1894 auto inputTy
= cast
<ShapedType
>(input
.getType());
1895 auto inputShape
= inputTy
.getShape();
1896 auto resultTy
= cast
<ShapedType
>(op
.getType());
1897 auto elementTy
= inputTy
.getElementType();
1898 int64_t rank
= inputTy
.getRank();
1900 ArrayRef
<int64_t> multiples
= op
.getMultiples();
1902 // Broadcast the newly added dimensions to their appropriate multiple.
1903 SmallVector
<int64_t, 2> genericShape
;
1904 for (int i
= 0; i
< rank
; i
++) {
1905 int64_t dim
= multiples
[i
];
1906 genericShape
.push_back(dim
== -1 ? ShapedType::kDynamic
: dim
);
1907 genericShape
.push_back(inputShape
[i
]);
1910 SmallVector
<Value
> dynDims
;
1911 for (int i
= 0; i
< inputTy
.getRank(); i
++) {
1912 if (inputTy
.isDynamicDim(i
) || multiples
[i
] == -1) {
1913 dynDims
.push_back(rewriter
.create
<tensor::DimOp
>(loc
, input
, i
));
1917 auto emptyTensor
= rewriter
.create
<tensor::EmptyOp
>(
1918 op
.getLoc(), genericShape
, elementTy
, dynDims
);
1920 // We needs to map the input shape to the non-broadcasted dimensions.
1921 SmallVector
<AffineExpr
, 4> dimExprs
;
1922 dimExprs
.reserve(rank
);
1923 for (unsigned i
= 0; i
< rank
; ++i
)
1924 dimExprs
.push_back(rewriter
.getAffineDimExpr(i
* 2 + 1));
1926 auto readAffineMap
=
1927 AffineMap::get(/*dimCount=*/rank
* 2, /*symbolCount=*/0, dimExprs
,
1928 rewriter
.getContext());
1930 SmallVector
<AffineMap
, 2> affineMaps
= {
1931 readAffineMap
, rewriter
.getMultiDimIdentityMap(genericShape
.size())};
1933 auto genericOp
= rewriter
.create
<linalg::GenericOp
>(
1934 loc
, RankedTensorType::get(genericShape
, elementTy
), input
,
1935 ValueRange
{emptyTensor
}, affineMaps
,
1936 getNParallelLoopsAttrs(genericShape
.size()),
1937 [&](OpBuilder
&nestedBuilder
, Location nestedLoc
, ValueRange args
) {
1938 nestedBuilder
.create
<linalg::YieldOp
>(op
.getLoc(), *args
.begin());
1941 rewriter
.replaceOpWithNewOp
<tosa::ReshapeOp
>(
1942 op
, resultTy
, genericOp
.getResult(0),
1943 rewriter
.getDenseI64ArrayAttr(resultTy
.getShape()));
1948 // Tosa argmax lowering represents the ArgMax op as an linalg.indexed_generic
1949 // op, producing two output buffers.
1951 // The first output buffer contains the index of the found maximum value. It is
1952 // initialized to 0 and is resulting integer type.
1954 // The second output buffer contains the maximum value found. It is initialized
1955 // to the minimum representable value of the input element type. After being
1956 // populated by indexed_generic, this buffer is disgarded as only the index is
1959 // The indexed_generic op updates both the maximum value and index if the
1960 // current value exceeds the running max.
1961 class ArgMaxConverter
: public OpRewritePattern
<tosa::ArgMaxOp
> {
1963 using OpRewritePattern
<tosa::ArgMaxOp
>::OpRewritePattern
;
1965 LogicalResult
matchAndRewrite(tosa::ArgMaxOp argmaxOp
,
1966 PatternRewriter
&rewriter
) const final
{
1967 auto loc
= argmaxOp
.getLoc();
1968 Value input
= argmaxOp
.getInput();
1969 auto inputTy
= cast
<ShapedType
>(input
.getType());
1970 auto resultTy
= cast
<ShapedType
>(argmaxOp
.getOutput().getType());
1971 auto inElementTy
= inputTy
.getElementType();
1972 auto outElementTy
= resultTy
.getElementType();
1973 int axis
= argmaxOp
.getAxis();
1974 auto resultMaxTy
= RankedTensorType::get(resultTy
.getShape(), inElementTy
);
1976 if (!isa
<IntegerType
>(outElementTy
))
1977 return rewriter
.notifyMatchFailure(
1979 "tosa.arg_max to linalg.* requires integer-like result type");
1981 SmallVector
<Value
> dynDims
;
1982 for (int i
= 0; i
< inputTy
.getRank(); i
++) {
1983 if (inputTy
.isDynamicDim(i
) && i
!= axis
) {
1984 dynDims
.push_back(rewriter
.create
<tensor::DimOp
>(loc
, input
, i
));
1988 // First fill the output buffer for the index.
1989 auto emptyTensorIdx
= rewriter
1990 .create
<tensor::EmptyOp
>(loc
, resultTy
.getShape(),
1991 outElementTy
, dynDims
)
1993 auto fillValueIdx
= rewriter
.create
<arith::ConstantOp
>(
1994 loc
, rewriter
.getIntegerAttr(outElementTy
, 0));
1995 auto filledTensorIdx
=
1997 .create
<linalg::FillOp
>(loc
, ValueRange
{fillValueIdx
},
1998 ValueRange
{emptyTensorIdx
})
2001 // Second fill the output buffer for the running max.
2002 auto emptyTensorMax
= rewriter
2003 .create
<tensor::EmptyOp
>(loc
, resultTy
.getShape(),
2004 inElementTy
, dynDims
)
2006 auto fillValueMaxAttr
=
2007 createInitialValueForReduceOp(argmaxOp
, inElementTy
, rewriter
);
2009 if (!fillValueMaxAttr
)
2010 return rewriter
.notifyMatchFailure(
2011 argmaxOp
, "unsupported tosa.argmax element type");
2014 rewriter
.create
<arith::ConstantOp
>(loc
, fillValueMaxAttr
);
2015 auto filledTensorMax
=
2017 .create
<linalg::FillOp
>(loc
, ValueRange
{fillValueMax
},
2018 ValueRange
{emptyTensorMax
})
2021 // We need to reduce along the arg-max axis, with parallel operations along
2023 SmallVector
<utils::IteratorType
, 4> iteratorTypes
;
2024 iteratorTypes
.resize(inputTy
.getRank(), utils::IteratorType::parallel
);
2025 iteratorTypes
[axis
] = utils::IteratorType::reduction
;
2027 SmallVector
<AffineExpr
, 2> srcExprs
;
2028 SmallVector
<AffineExpr
, 2> dstExprs
;
2029 for (int i
= 0, rank
= inputTy
.getRank(); i
!= rank
; ++i
) {
2030 srcExprs
.push_back(mlir::getAffineDimExpr(i
, rewriter
.getContext()));
2032 dstExprs
.push_back(mlir::getAffineDimExpr(i
, rewriter
.getContext()));
2035 bool didEncounterError
= false;
2036 auto maps
= AffineMap::inferFromExprList({srcExprs
, dstExprs
, dstExprs
},
2037 rewriter
.getContext());
2038 auto linalgOp
= rewriter
.create
<linalg::GenericOp
>(
2039 loc
, ArrayRef
<Type
>({resultTy
, resultMaxTy
}), input
,
2040 ValueRange({filledTensorIdx
, filledTensorMax
}), maps
, iteratorTypes
,
2041 [&](OpBuilder
&nestedBuilder
, Location nestedLoc
,
2042 ValueRange blockArgs
) {
2043 auto newValue
= blockArgs
[0];
2044 auto oldIndex
= blockArgs
[1];
2045 auto oldValue
= blockArgs
[2];
2047 Value newIndex
= rewriter
.create
<arith::IndexCastOp
>(
2048 nestedLoc
, oldIndex
.getType(),
2049 rewriter
.create
<linalg::IndexOp
>(loc
, axis
));
2052 if (isa
<FloatType
>(inElementTy
)) {
2053 predicate
= rewriter
.create
<arith::CmpFOp
>(
2054 nestedLoc
, arith::CmpFPredicate::OGT
, newValue
, oldValue
);
2055 } else if (isa
<IntegerType
>(inElementTy
)) {
2056 predicate
= rewriter
.create
<arith::CmpIOp
>(
2057 nestedLoc
, arith::CmpIPredicate::sgt
, newValue
, oldValue
);
2059 didEncounterError
= true;
2063 auto resultMax
= rewriter
.create
<arith::SelectOp
>(
2064 nestedLoc
, predicate
, newValue
, oldValue
);
2065 auto resultIndex
= rewriter
.create
<arith::SelectOp
>(
2066 nestedLoc
, predicate
, newIndex
, oldIndex
);
2067 nestedBuilder
.create
<linalg::YieldOp
>(
2068 nestedLoc
, ValueRange({resultIndex
, resultMax
}));
2071 if (didEncounterError
)
2072 return rewriter
.notifyMatchFailure(
2073 argmaxOp
, "unsupported tosa.argmax element type");
2075 rewriter
.replaceOp(argmaxOp
, linalgOp
.getResult(0));
2080 class GatherConverter
: public OpConversionPattern
<tosa::GatherOp
> {
2082 using OpConversionPattern
<tosa::GatherOp
>::OpConversionPattern
;
2084 matchAndRewrite(tosa::GatherOp op
, OpAdaptor adaptor
,
2085 ConversionPatternRewriter
&rewriter
) const final
{
2086 auto input
= adaptor
.getOperands()[0];
2087 auto indices
= adaptor
.getOperands()[1];
2090 dyn_cast_or_null
<RankedTensorType
>(op
.getValues().getType());
2091 auto resultTy
= cast
<ShapedType
>(op
.getType());
2094 return rewriter
.notifyMatchFailure(op
, "unranked tensors not supported");
2096 auto dynamicDims
= inferDynamicDimsForGather(
2097 rewriter
, op
.getLoc(), adaptor
.getValues(), adaptor
.getIndices());
2099 auto resultElementTy
= resultTy
.getElementType();
2101 auto loc
= op
.getLoc();
2104 .create
<tensor::EmptyOp
>(loc
, resultTy
.getShape(), resultElementTy
,
2108 SmallVector
<AffineMap
, 2> affineMaps
= {
2110 /*dimCount=*/resultTy
.getRank(), /*symbolCount=*/0,
2111 {rewriter
.getAffineDimExpr(0), rewriter
.getAffineDimExpr(1)},
2112 rewriter
.getContext()),
2113 rewriter
.getMultiDimIdentityMap(resultTy
.getRank())};
2115 auto genericOp
= rewriter
.create
<linalg::GenericOp
>(
2116 loc
, ArrayRef
<Type
>({resultTy
}), ValueRange
{indices
},
2117 ValueRange
{emptyTensor
}, affineMaps
,
2118 getNParallelLoopsAttrs(resultTy
.getRank()),
2119 [&](OpBuilder
&b
, Location loc
, ValueRange args
) {
2120 auto indexValue
= args
[0];
2121 auto index0
= rewriter
.create
<linalg::IndexOp
>(loc
, 0);
2122 Value index1
= rewriter
.create
<arith::IndexCastOp
>(
2123 loc
, rewriter
.getIndexType(), indexValue
);
2124 auto index2
= rewriter
.create
<linalg::IndexOp
>(loc
, 2);
2125 Value extract
= rewriter
.create
<tensor::ExtractOp
>(
2126 loc
, input
, ValueRange
{index0
, index1
, index2
});
2127 rewriter
.create
<linalg::YieldOp
>(loc
, extract
);
2129 rewriter
.replaceOp(op
, genericOp
.getResult(0));
2133 static llvm::SmallVector
<Value
> inferDynamicDimsForGather(OpBuilder
&builder
,
2137 llvm::SmallVector
<Value
> results
;
2139 auto addDynamicDimension
= [&](Value source
, int64_t dim
) {
2140 auto sz
= tensor::getMixedSize(builder
, loc
, source
, dim
);
2141 if (auto dimValue
= llvm::dyn_cast_if_present
<Value
>(sz
))
2142 results
.push_back(dimValue
);
2145 addDynamicDimension(values
, 0);
2146 addDynamicDimension(indices
, 1);
2147 addDynamicDimension(values
, 2);
2152 // Lowerings the TableOp to a series of gathers and numerica operations. This
2153 // includes interpolation between the high/low values. For the I8 varient, this
2154 // simplifies to a single gather operation.
2155 class TableConverter
: public OpRewritePattern
<tosa::TableOp
> {
2157 using OpRewritePattern
<tosa::TableOp
>::OpRewritePattern
;
2159 LogicalResult
matchAndRewrite(tosa::TableOp op
,
2160 PatternRewriter
&rewriter
) const final
{
2161 auto loc
= op
.getLoc();
2162 Value input
= op
.getInput1();
2163 Value table
= op
.getTable();
2164 auto inputTy
= cast
<ShapedType
>(input
.getType());
2165 auto tableTy
= cast
<ShapedType
>(table
.getType());
2166 auto resultTy
= cast
<ShapedType
>(op
.getType());
2168 auto inputElementTy
= inputTy
.getElementType();
2169 auto tableElementTy
= tableTy
.getElementType();
2170 auto resultElementTy
= resultTy
.getElementType();
2172 SmallVector
<Value
> dynDims
;
2173 for (int i
= 0; i
< resultTy
.getRank(); ++i
) {
2174 if (inputTy
.isDynamicDim(i
)) {
2176 rewriter
.create
<tensor::DimOp
>(loc
, op
.getOperand(0), i
));
2180 auto emptyTensor
= rewriter
2181 .create
<tensor::EmptyOp
>(loc
, resultTy
.getShape(),
2182 resultElementTy
, dynDims
)
2185 SmallVector
<AffineMap
, 2> affineMaps
= {
2186 rewriter
.getMultiDimIdentityMap(resultTy
.getRank()),
2187 rewriter
.getMultiDimIdentityMap(resultTy
.getRank())};
2189 auto genericOp
= rewriter
.create
<linalg::GenericOp
>(
2190 loc
, resultTy
, ValueRange({input
}), ValueRange
{emptyTensor
}, affineMaps
,
2191 getNParallelLoopsAttrs(resultTy
.getRank()));
2192 rewriter
.replaceOp(op
, genericOp
.getResult(0));
2195 OpBuilder::InsertionGuard
regionGuard(rewriter
);
2196 Block
*block
= rewriter
.createBlock(
2197 &genericOp
.getRegion(), genericOp
.getRegion().end(),
2198 TypeRange({inputElementTy
, resultElementTy
}), {loc
, loc
});
2200 auto inputValue
= block
->getArgument(0);
2201 rewriter
.setInsertionPointToStart(block
);
2202 if (inputElementTy
.isInteger(8) && tableElementTy
.isInteger(8) &&
2203 resultElementTy
.isInteger(8)) {
2204 Value index
= rewriter
.create
<arith::IndexCastOp
>(
2205 loc
, rewriter
.getIndexType(), inputValue
);
2206 Value offset
= rewriter
.create
<arith::ConstantIndexOp
>(loc
, 128);
2207 index
= rewriter
.create
<arith::AddIOp
>(loc
, rewriter
.getIndexType(),
2210 rewriter
.create
<tensor::ExtractOp
>(loc
, table
, ValueRange
{index
});
2211 rewriter
.create
<linalg::YieldOp
>(loc
, extract
);
2215 if (inputElementTy
.isInteger(16) && tableElementTy
.isInteger(16) &&
2216 resultElementTy
.isInteger(32)) {
2217 Value extend
= rewriter
.create
<arith::ExtSIOp
>(
2218 loc
, rewriter
.getI32Type(), inputValue
);
2220 auto offset
= rewriter
.create
<arith::ConstantOp
>(
2221 loc
, rewriter
.getI32IntegerAttr(32768));
2222 auto seven
= rewriter
.create
<arith::ConstantOp
>(
2223 loc
, rewriter
.getI32IntegerAttr(7));
2224 auto one
= rewriter
.create
<arith::ConstantOp
>(
2225 loc
, rewriter
.getI32IntegerAttr(1));
2226 auto b1111111
= rewriter
.create
<arith::ConstantOp
>(
2227 loc
, rewriter
.getI32IntegerAttr(127));
2229 // Compute the index and fractional part from the input value:
2230 // value = value + 32768
2231 // index = value >> 7;
2232 // fraction = 0x01111111 & value
2233 auto extendAdd
= rewriter
.create
<arith::AddIOp
>(loc
, extend
, offset
);
2234 Value index
= rewriter
.create
<arith::ShRUIOp
>(loc
, extendAdd
, seven
);
2236 rewriter
.create
<arith::AndIOp
>(loc
, extendAdd
, b1111111
);
2238 // Extract the base and next values from the table.
2239 // base = (int32_t) table[index];
2240 // next = (int32_t) table[index + 1];
2241 Value indexPlusOne
= rewriter
.create
<arith::AddIOp
>(loc
, index
, one
);
2243 index
= rewriter
.create
<arith::IndexCastOp
>(
2244 loc
, rewriter
.getIndexType(), index
);
2245 indexPlusOne
= rewriter
.create
<arith::IndexCastOp
>(
2246 loc
, rewriter
.getIndexType(), indexPlusOne
);
2249 rewriter
.create
<tensor::ExtractOp
>(loc
, table
, ValueRange
{index
});
2250 Value next
= rewriter
.create
<tensor::ExtractOp
>(
2251 loc
, table
, ValueRange
{indexPlusOne
});
2254 rewriter
.create
<arith::ExtSIOp
>(loc
, rewriter
.getI32Type(), base
);
2256 rewriter
.create
<arith::ExtSIOp
>(loc
, rewriter
.getI32Type(), next
);
2258 // Use the fractional part to interpolate between the input values:
2259 // result = (base << 7) + (next - base) * fraction
2260 Value baseScaled
= rewriter
.create
<arith::ShLIOp
>(loc
, base
, seven
);
2261 Value diff
= rewriter
.create
<arith::SubIOp
>(loc
, next
, base
);
2262 Value diffScaled
= rewriter
.create
<arith::MulIOp
>(loc
, diff
, fraction
);
2264 rewriter
.create
<arith::AddIOp
>(loc
, baseScaled
, diffScaled
);
2266 rewriter
.create
<linalg::YieldOp
>(loc
, result
);
2272 return rewriter
.notifyMatchFailure(
2273 op
, "unable to create body for tosa.table op");
2277 struct RFFT2dConverter final
: public OpRewritePattern
<RFFT2dOp
> {
2278 using OpRewritePattern
<RFFT2dOp
>::OpRewritePattern
;
2280 static bool isRankedTensor(Type type
) { return isa
<RankedTensorType
>(type
); }
2282 static OpFoldResult
halfPlusOne(OpBuilder
&builder
, Location loc
,
2284 auto one
= builder
.create
<arith::ConstantIndexOp
>(loc
, 1);
2285 auto two
= builder
.create
<arith::ConstantIndexOp
>(loc
, 2);
2287 auto value
= getValueOrCreateConstantIndexOp(builder
, loc
, ofr
);
2288 auto divBy2
= builder
.createOrFold
<arith::DivUIOp
>(loc
, value
, two
);
2289 auto plusOne
= builder
.createOrFold
<arith::AddIOp
>(loc
, divBy2
, one
);
2290 return getAsOpFoldResult(plusOne
);
2293 static RankedTensorType
2294 computeOutputShape(OpBuilder
&builder
, Location loc
, Value input
,
2295 llvm::SmallVectorImpl
<Value
> &dynamicSizes
) {
2297 auto dims
= tensor::getMixedSizes(builder
, loc
, input
);
2299 // Set W = (W / 2) + 1 to account for the half-sized W dimension of the
2301 dims
[2] = halfPlusOne(builder
, loc
, dims
[2]);
2303 llvm::SmallVector
<int64_t, 3> staticSizes
;
2304 dispatchIndexOpFoldResults(dims
, dynamicSizes
, staticSizes
);
2306 auto elementType
= cast
<RankedTensorType
>(input
.getType()).getElementType();
2307 return RankedTensorType::get(staticSizes
, elementType
);
2310 static Value
createZeroTensor(PatternRewriter
&rewriter
, Location loc
,
2311 RankedTensorType type
,
2312 llvm::ArrayRef
<Value
> dynamicSizes
) {
2314 rewriter
.create
<tensor::EmptyOp
>(loc
, type
, dynamicSizes
);
2315 auto fillValueAttr
= rewriter
.getZeroAttr(type
.getElementType());
2316 auto fillValue
= rewriter
.create
<arith::ConstantOp
>(loc
, fillValueAttr
);
2317 auto filledTensor
= rewriter
2318 .create
<linalg::FillOp
>(loc
, ValueRange
{fillValue
},
2319 ValueRange
{emptyTensor
})
2321 return filledTensor
;
2324 static Value
castIndexToFloat(OpBuilder
&builder
, Location loc
,
2325 FloatType type
, Value value
) {
2326 auto integerVal
= builder
.create
<arith::IndexCastUIOp
>(
2328 type
.getIntOrFloatBitWidth() > 32 ? builder
.getI64Type()
2329 : builder
.getI32Type(),
2332 return builder
.create
<arith::UIToFPOp
>(loc
, type
, integerVal
);
2335 static Value
createLinalgIndex(OpBuilder
&builder
, Location loc
,
2336 FloatType type
, int64_t index
) {
2337 auto indexVal
= builder
.create
<linalg::IndexOp
>(loc
, index
);
2338 return castIndexToFloat(builder
, loc
, type
, indexVal
);
2341 template <typename
... Args
>
2342 static llvm::SmallVector
<AffineExpr
, 4> affineDimsExpr(OpBuilder
&builder
,
2344 return {builder
.getAffineDimExpr(args
)...};
2347 LogicalResult
matchAndRewrite(RFFT2dOp rfft2d
,
2348 PatternRewriter
&rewriter
) const override
{
2349 if (!llvm::all_of(rfft2d
->getOperandTypes(), isRankedTensor
) ||
2350 !llvm::all_of(rfft2d
->getResultTypes(), isRankedTensor
)) {
2351 return rewriter
.notifyMatchFailure(rfft2d
,
2352 "only supports ranked tensors");
2355 auto loc
= rfft2d
.getLoc();
2356 auto input
= rfft2d
.getInput();
2358 dyn_cast
<FloatType
>(cast
<ShapedType
>(input
.getType()).getElementType());
2360 return rewriter
.notifyMatchFailure(rfft2d
,
2361 "only supports float element types");
2363 // Compute the output type and set of dynamic sizes
2364 llvm::SmallVector
<Value
> dynamicSizes
;
2365 auto outputType
= computeOutputShape(rewriter
, loc
, input
, dynamicSizes
);
2367 // Iterator types for the linalg.generic implementation
2368 llvm::SmallVector
<utils::IteratorType
, 5> iteratorTypes
= {
2369 utils::IteratorType::parallel
, utils::IteratorType::parallel
,
2370 utils::IteratorType::parallel
, utils::IteratorType::reduction
,
2371 utils::IteratorType::reduction
};
2373 // Inputs/outputs to the linalg.generic implementation
2374 llvm::SmallVector
<Value
> genericOpInputs
= {input
};
2375 llvm::SmallVector
<Value
> genericOpOutputs
= {
2376 createZeroTensor(rewriter
, loc
, outputType
, dynamicSizes
),
2377 createZeroTensor(rewriter
, loc
, outputType
, dynamicSizes
)};
2379 // Indexing maps for input and output tensors
2380 auto indexingMaps
= AffineMap::inferFromExprList(
2381 llvm::ArrayRef
{affineDimsExpr(rewriter
, 0, 3, 4),
2382 affineDimsExpr(rewriter
, 0, 1, 2),
2383 affineDimsExpr(rewriter
, 0, 1, 2)},
2384 rewriter
.getContext());
2386 // Width and height dimensions of the original input.
2387 auto dimH
= rewriter
.createOrFold
<tensor::DimOp
>(loc
, input
, 1);
2388 auto dimW
= rewriter
.createOrFold
<tensor::DimOp
>(loc
, input
, 2);
2390 // Constants and dimension sizes
2391 auto twoPiAttr
= rewriter
.getFloatAttr(elementType
, 6.283185307179586);
2392 auto twoPi
= rewriter
.create
<arith::ConstantOp
>(loc
, twoPiAttr
);
2393 auto constH
= castIndexToFloat(rewriter
, loc
, elementType
, dimH
);
2394 auto constW
= castIndexToFloat(rewriter
, loc
, elementType
, dimW
);
2396 auto buildBody
= [&](OpBuilder
&builder
, Location loc
, ValueRange args
) {
2397 Value valReal
= args
[0];
2398 Value sumReal
= args
[1];
2399 Value sumImag
= args
[2];
2401 // Indices for angle computation
2402 Value oy
= builder
.create
<linalg::IndexOp
>(loc
, 1);
2403 Value ox
= builder
.create
<linalg::IndexOp
>(loc
, 2);
2404 Value iy
= builder
.create
<linalg::IndexOp
>(loc
, 3);
2405 Value ix
= builder
.create
<linalg::IndexOp
>(loc
, 4);
2407 // Calculating angle without integer parts of components as sin/cos are
2408 // periodic: angle = 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix * ox) % W )
2410 auto iyXoy
= builder
.create
<index::MulOp
>(loc
, iy
, oy
);
2411 auto ixXox
= builder
.create
<index::MulOp
>(loc
, ix
, ox
);
2413 auto iyRem
= builder
.create
<index::RemUOp
>(loc
, iyXoy
, dimH
);
2414 auto ixRem
= builder
.create
<index::RemUOp
>(loc
, ixXox
, dimW
);
2416 auto iyRemFloat
= castIndexToFloat(builder
, loc
, elementType
, iyRem
);
2417 auto ixRemFloat
= castIndexToFloat(builder
, loc
, elementType
, ixRem
);
2419 auto yComponent
= builder
.create
<arith::DivFOp
>(loc
, iyRemFloat
, constH
);
2420 auto xComponent
= builder
.create
<arith::DivFOp
>(loc
, ixRemFloat
, constW
);
2421 auto sumXY
= builder
.create
<arith::AddFOp
>(loc
, yComponent
, xComponent
);
2422 auto angle
= builder
.create
<arith::MulFOp
>(loc
, twoPi
, sumXY
);
2424 // realComponent = valReal * cos(angle)
2425 // imagComponent = valReal * sin(angle)
2426 auto cosAngle
= builder
.create
<math::CosOp
>(loc
, angle
);
2427 auto sinAngle
= builder
.create
<math::SinOp
>(loc
, angle
);
2428 auto realComponent
=
2429 builder
.create
<arith::MulFOp
>(loc
, valReal
, cosAngle
);
2430 auto imagComponent
=
2431 builder
.create
<arith::MulFOp
>(loc
, valReal
, sinAngle
);
2433 // outReal = sumReal + realComponent
2434 // outImag = sumImag - imagComponent
2435 auto outReal
= builder
.create
<arith::AddFOp
>(loc
, sumReal
, realComponent
);
2436 auto outImag
= builder
.create
<arith::SubFOp
>(loc
, sumImag
, imagComponent
);
2438 builder
.create
<linalg::YieldOp
>(loc
, ValueRange
{outReal
, outImag
});
2441 rewriter
.replaceOpWithNewOp
<linalg::GenericOp
>(
2442 rfft2d
, rfft2d
.getResultTypes(), genericOpInputs
, genericOpOutputs
,
2443 indexingMaps
, iteratorTypes
, buildBody
);
2449 struct FFT2dConverter final
: OpRewritePattern
<FFT2dOp
> {
2450 using OpRewritePattern::OpRewritePattern
;
2452 LogicalResult
matchAndRewrite(FFT2dOp fft2d
,
2453 PatternRewriter
&rewriter
) const override
{
2454 if (!llvm::all_of(fft2d
->getOperandTypes(),
2455 RFFT2dConverter::isRankedTensor
) ||
2456 !llvm::all_of(fft2d
->getResultTypes(),
2457 RFFT2dConverter::isRankedTensor
)) {
2458 return rewriter
.notifyMatchFailure(fft2d
, "only supports ranked tensors");
2461 Location loc
= fft2d
.getLoc();
2462 Value input_real
= fft2d
.getInputReal();
2463 Value input_imag
= fft2d
.getInputImag();
2464 BoolAttr inverse
= fft2d
.getInverseAttr();
2466 auto real_el_ty
= cast
<FloatType
>(
2467 cast
<ShapedType
>(input_real
.getType()).getElementType());
2468 [[maybe_unused
]] auto imag_el_ty
= cast
<FloatType
>(
2469 cast
<ShapedType
>(input_imag
.getType()).getElementType());
2471 assert(real_el_ty
== imag_el_ty
);
2473 // Compute the output type and set of dynamic sizes
2474 SmallVector
<Value
> dynamicSizes
;
2477 auto dims
= tensor::getMixedSizes(rewriter
, loc
, input_real
);
2479 SmallVector
<int64_t, 3> staticSizes
;
2480 dispatchIndexOpFoldResults(dims
, dynamicSizes
, staticSizes
);
2482 auto outputType
= RankedTensorType::get(staticSizes
, real_el_ty
);
2484 // Iterator types for the linalg.generic implementation
2485 SmallVector
<utils::IteratorType
, 5> iteratorTypes
= {
2486 utils::IteratorType::parallel
, utils::IteratorType::parallel
,
2487 utils::IteratorType::parallel
, utils::IteratorType::reduction
,
2488 utils::IteratorType::reduction
};
2490 // Inputs/outputs to the linalg.generic implementation
2491 SmallVector
<Value
> genericOpInputs
= {input_real
, input_imag
};
2492 SmallVector
<Value
> genericOpOutputs
= {
2493 RFFT2dConverter::createZeroTensor(rewriter
, loc
, outputType
,
2495 RFFT2dConverter::createZeroTensor(rewriter
, loc
, outputType
,
2498 // Indexing maps for input and output tensors
2499 auto indexingMaps
= AffineMap::inferFromExprList(
2500 ArrayRef
{RFFT2dConverter::affineDimsExpr(rewriter
, 0, 3, 4),
2501 RFFT2dConverter::affineDimsExpr(rewriter
, 0, 3, 4),
2502 RFFT2dConverter::affineDimsExpr(rewriter
, 0, 1, 2),
2503 RFFT2dConverter::affineDimsExpr(rewriter
, 0, 1, 2)},
2504 rewriter
.getContext());
2506 // Width and height dimensions of the original input.
2507 auto dimH
= rewriter
.createOrFold
<tensor::DimOp
>(loc
, input_real
, 1);
2508 auto dimW
= rewriter
.createOrFold
<tensor::DimOp
>(loc
, input_real
, 2);
2510 // Constants and dimension sizes
2511 auto twoPiAttr
= rewriter
.getFloatAttr(real_el_ty
, 6.283185307179586);
2512 auto twoPi
= rewriter
.create
<arith::ConstantOp
>(loc
, twoPiAttr
);
2514 RFFT2dConverter::castIndexToFloat(rewriter
, loc
, real_el_ty
, dimH
);
2516 RFFT2dConverter::castIndexToFloat(rewriter
, loc
, real_el_ty
, dimW
);
2518 auto buildBody
= [&](OpBuilder
&builder
, Location loc
, ValueRange args
) {
2519 Value valReal
= args
[0];
2520 Value valImag
= args
[1];
2521 Value sumReal
= args
[2];
2522 Value sumImag
= args
[3];
2524 // Indices for angle computation
2525 Value oy
= builder
.create
<linalg::IndexOp
>(loc
, 1);
2526 Value ox
= builder
.create
<linalg::IndexOp
>(loc
, 2);
2527 Value iy
= builder
.create
<linalg::IndexOp
>(loc
, 3);
2528 Value ix
= builder
.create
<linalg::IndexOp
>(loc
, 4);
2530 // float_t angle = sign_val * 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix *
2532 auto iyXoy
= builder
.create
<index::MulOp
>(loc
, iy
, oy
);
2533 auto ixXox
= builder
.create
<index::MulOp
>(loc
, ix
, ox
);
2535 auto iyRem
= builder
.create
<index::RemUOp
>(loc
, iyXoy
, dimH
);
2536 auto ixRem
= builder
.create
<index::RemUOp
>(loc
, ixXox
, dimW
);
2539 RFFT2dConverter::castIndexToFloat(builder
, loc
, real_el_ty
, iyRem
);
2541 RFFT2dConverter::castIndexToFloat(builder
, loc
, real_el_ty
, ixRem
);
2543 auto yComponent
= builder
.create
<arith::DivFOp
>(loc
, iyRemFloat
, constH
);
2544 auto xComponent
= builder
.create
<arith::DivFOp
>(loc
, ixRemFloat
, constW
);
2546 auto sumXY
= builder
.create
<arith::AddFOp
>(loc
, yComponent
, xComponent
);
2547 auto angle
= builder
.create
<arith::MulFOp
>(loc
, twoPi
, sumXY
);
2549 if (inverse
.getValue()) {
2550 angle
= builder
.create
<arith::MulFOp
>(
2552 rewriter
.create
<arith::ConstantOp
>(
2553 loc
, rewriter
.getFloatAttr(real_el_ty
, -1.0)));
2556 // realComponent = val_real * cos(a) + val_imag * sin(a);
2557 // imagComponent = -val_real * sin(a) + val_imag * cos(a);
2558 auto cosAngle
= builder
.create
<math::CosOp
>(loc
, angle
);
2559 auto sinAngle
= builder
.create
<math::SinOp
>(loc
, angle
);
2561 auto rcos
= builder
.create
<arith::MulFOp
>(loc
, valReal
, cosAngle
);
2562 auto rsin
= builder
.create
<arith::MulFOp
>(loc
, valImag
, sinAngle
);
2563 auto realComponent
= builder
.create
<arith::AddFOp
>(loc
, rcos
, rsin
);
2565 auto icos
= builder
.create
<arith::MulFOp
>(loc
, valImag
, cosAngle
);
2566 auto isin
= builder
.create
<arith::MulFOp
>(loc
, valReal
, sinAngle
);
2568 auto imagComponent
= builder
.create
<arith::SubFOp
>(loc
, icos
, isin
);
2570 // outReal = sumReal + realComponent
2571 // outImag = sumImag - imagComponent
2572 auto outReal
= builder
.create
<arith::AddFOp
>(loc
, sumReal
, realComponent
);
2573 auto outImag
= builder
.create
<arith::AddFOp
>(loc
, sumImag
, imagComponent
);
2575 builder
.create
<linalg::YieldOp
>(loc
, ValueRange
{outReal
, outImag
});
2578 rewriter
.replaceOpWithNewOp
<linalg::GenericOp
>(
2579 fft2d
, fft2d
.getResultTypes(), genericOpInputs
, genericOpOutputs
,
2580 indexingMaps
, iteratorTypes
, buildBody
);
2588 void mlir::tosa::populateTosaToLinalgConversionPatterns(
2589 const TypeConverter
&converter
, RewritePatternSet
*patterns
) {
2591 // We have multiple resize coverters to handle degenerate cases.
2592 patterns
->add
<GenericResizeConverter
>(patterns
->getContext(),
2594 patterns
->add
<ResizeUnaryConverter
>(patterns
->getContext(),
2596 patterns
->add
<MaterializeResizeBroadcast
>(patterns
->getContext(),
2601 PointwiseConverter
<tosa::AddOp
>,
2602 PointwiseConverter
<tosa::SubOp
>,
2603 PointwiseConverter
<tosa::MulOp
>,
2604 PointwiseConverter
<tosa::IntDivOp
>,
2605 PointwiseConverter
<tosa::NegateOp
>,
2606 PointwiseConverter
<tosa::PowOp
>,
2607 PointwiseConverter
<tosa::ReciprocalOp
>,
2608 PointwiseConverter
<tosa::RsqrtOp
>,
2609 PointwiseConverter
<tosa::LogOp
>,
2610 PointwiseConverter
<tosa::ExpOp
>,
2611 PointwiseConverter
<tosa::AbsOp
>,
2612 PointwiseConverter
<tosa::SinOp
>,
2613 PointwiseConverter
<tosa::CosOp
>,
2614 PointwiseConverter
<tosa::TanhOp
>,
2615 PointwiseConverter
<tosa::ErfOp
>,
2616 PointwiseConverter
<tosa::BitwiseAndOp
>,
2617 PointwiseConverter
<tosa::BitwiseOrOp
>,
2618 PointwiseConverter
<tosa::BitwiseNotOp
>,
2619 PointwiseConverter
<tosa::BitwiseXorOp
>,
2620 PointwiseConverter
<tosa::LogicalAndOp
>,
2621 PointwiseConverter
<tosa::LogicalNotOp
>,
2622 PointwiseConverter
<tosa::LogicalOrOp
>,
2623 PointwiseConverter
<tosa::LogicalXorOp
>,
2624 PointwiseConverter
<tosa::CastOp
>,
2625 PointwiseConverter
<tosa::LogicalLeftShiftOp
>,
2626 PointwiseConverter
<tosa::LogicalRightShiftOp
>,
2627 PointwiseConverter
<tosa::ArithmeticRightShiftOp
>,
2628 PointwiseConverter
<tosa::ClzOp
>,
2629 PointwiseConverter
<tosa::SelectOp
>,
2630 PointwiseConverter
<tosa::GreaterOp
>,
2631 PointwiseConverter
<tosa::GreaterEqualOp
>,
2632 PointwiseConverter
<tosa::EqualOp
>,
2633 PointwiseConverter
<tosa::MaximumOp
>,
2634 PointwiseConverter
<tosa::MinimumOp
>,
2635 PointwiseConverter
<tosa::CeilOp
>,
2636 PointwiseConverter
<tosa::FloorOp
>,
2637 PointwiseConverter
<tosa::ClampOp
>,
2638 PointwiseConverter
<tosa::SigmoidOp
>
2639 >(converter
, patterns
->getContext());
2642 IdentityNConverter
<tosa::IdentityOp
>,
2643 ReduceConverter
<tosa::ReduceAllOp
>,
2644 ReduceConverter
<tosa::ReduceAnyOp
>,
2645 ReduceConverter
<tosa::ReduceMinOp
>,
2646 ReduceConverter
<tosa::ReduceMaxOp
>,
2647 ReduceConverter
<tosa::ReduceSumOp
>,
2648 ReduceConverter
<tosa::ReduceProdOp
>,
2656 TileConverter
>(patterns
->getContext());