Revert "[HLSL] Add `Increment`/`DecrementCounter` methods to structured buffers ...
[llvm-project.git] / mlir / lib / Conversion / TosaToLinalg / TosaToLinalg.cpp
blob5291f95d371442ef5c0d4ceca899e8d6f7857d3e
1 //===- TosaToLinalg.cpp - Lowering Tosa to Linalg Dialect -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // These rewriters lower from the Tosa to the Linalg dialect.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
14 #include "mlir/Dialect/Arith/IR/Arith.h"
15 #include "mlir/Dialect/Arith/Utils/Utils.h"
16 #include "mlir/Dialect/Index/IR/IndexOps.h"
17 #include "mlir/Dialect/Linalg/IR/Linalg.h"
18 #include "mlir/Dialect/Math/IR/Math.h"
19 #include "mlir/Dialect/SCF/IR/SCF.h"
20 #include "mlir/Dialect/Tensor/IR/Tensor.h"
21 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
22 #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
23 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
24 #include "mlir/Dialect/Utils/StaticValueUtils.h"
25 #include "mlir/IR/ImplicitLocOpBuilder.h"
26 #include "mlir/IR/Matchers.h"
27 #include "mlir/IR/OpDefinition.h"
28 #include "mlir/IR/PatternMatch.h"
29 #include "mlir/Transforms/DialectConversion.h"
30 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/Sequence.h"
34 #include <numeric>
36 using namespace mlir;
37 using namespace mlir::tosa;
39 template <typename T>
40 static arith::ConstantOp
41 createConstFromIntAttribute(Operation *op, const std::string &attrName,
42 Type requiredAttrType, OpBuilder &rewriter) {
43 auto castedN = static_cast<T>(
44 cast<IntegerAttr>(op->getAttr(attrName)).getValue().getSExtValue());
45 return rewriter.create<arith::ConstantOp>(
46 op->getLoc(), IntegerAttr::get(requiredAttrType, castedN));
49 static Value createLinalgBodyCalculationForElementwiseOp(
50 Operation *op, ValueRange args, ArrayRef<Type> resultTypes,
51 ConversionPatternRewriter &rewriter) {
52 Location loc = op->getLoc();
53 auto elementTy =
54 cast<ShapedType>(op->getOperand(0).getType()).getElementType();
56 // tosa::AbsOp
57 if (isa<tosa::AbsOp>(op) && isa<FloatType>(elementTy))
58 return rewriter.create<math::AbsFOp>(loc, resultTypes, args);
60 if (isa<tosa::AbsOp>(op) && isa<IntegerType>(elementTy)) {
61 auto zero = rewriter.create<arith::ConstantOp>(
62 loc, rewriter.getZeroAttr(elementTy));
63 auto neg = rewriter.create<arith::SubIOp>(loc, zero, args[0]);
64 return rewriter.create<arith::MaxSIOp>(loc, args[0], neg);
67 // tosa::AddOp
68 if (isa<tosa::AddOp>(op) && isa<FloatType>(elementTy))
69 return rewriter.create<arith::AddFOp>(loc, resultTypes, args);
71 if (isa<tosa::AddOp>(op) && isa<IntegerType>(elementTy))
72 return rewriter.create<arith::AddIOp>(loc, resultTypes, args);
74 // tosa::SubOp
75 if (isa<tosa::SubOp>(op) && isa<FloatType>(elementTy))
76 return rewriter.create<arith::SubFOp>(loc, resultTypes, args);
78 if (isa<tosa::SubOp>(op) && isa<IntegerType>(elementTy))
79 return rewriter.create<arith::SubIOp>(loc, resultTypes, args);
81 // tosa::IntDivOp
82 if (isa<tosa::IntDivOp>(op) && isa<IntegerType>(elementTy))
83 return rewriter.create<arith::DivSIOp>(loc, resultTypes, args);
85 // tosa::ReciprocalOp
86 if (isa<tosa::ReciprocalOp>(op) && isa<FloatType>(elementTy)) {
87 auto one =
88 rewriter.create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 1));
89 return rewriter.create<arith::DivFOp>(loc, resultTypes, one, args[0]);
92 // tosa::MulOp
93 if (isa<tosa::MulOp>(op) && isa<FloatType>(elementTy))
94 return rewriter.create<arith::MulFOp>(loc, resultTypes, args);
96 if (isa<tosa::MulOp>(op) && isa<IntegerType>(elementTy)) {
97 Value a = args[0];
98 Value b = args[1];
99 auto shift =
100 cast<IntegerAttr>(op->getAttr("shift")).getValue().getSExtValue();
101 if (shift > 0) {
102 auto shiftConst =
103 rewriter.create<arith::ConstantIntOp>(loc, shift, /*bitwidth=*/8);
104 if (!a.getType().isInteger(32))
105 a = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), a);
107 if (!b.getType().isInteger(32))
108 b = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), b);
110 auto result = rewriter.create<tosa::ApplyScaleOp>(
111 loc, rewriter.getI32Type(), a, b, shiftConst,
112 rewriter.getBoolAttr(false));
114 if (elementTy.isInteger(32))
115 return result;
117 return rewriter.create<arith::TruncIOp>(loc, elementTy, result);
120 int aWidth = a.getType().getIntOrFloatBitWidth();
121 int bWidth = b.getType().getIntOrFloatBitWidth();
122 int cWidth = resultTypes[0].getIntOrFloatBitWidth();
124 if (aWidth < cWidth)
125 a = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], a);
126 if (bWidth < cWidth)
127 b = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], b);
129 return rewriter.create<arith::MulIOp>(loc, resultTypes, a, b);
132 // tosa::NegateOp
133 if (isa<tosa::NegateOp>(op) && isa<FloatType>(elementTy))
134 return rewriter.create<arith::NegFOp>(loc, resultTypes, args);
136 if (isa<tosa::NegateOp>(op) && isa<IntegerType>(elementTy)) {
137 int64_t inZp = 0, outZp = 0;
139 if (cast<tosa::NegateOp>(op).getQuantizationInfo()) {
140 auto quantizationInfo = cast<tosa::NegateOp>(op).getQuantizationInfo();
141 inZp = quantizationInfo.value().getInputZp();
142 outZp = quantizationInfo.value().getOutputZp();
145 int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
146 if (!inZp && !outZp) {
147 auto constant = rewriter.create<arith::ConstantOp>(
148 loc, IntegerAttr::get(elementTy, 0));
149 return rewriter.create<arith::SubIOp>(loc, resultTypes, constant,
150 args[0]);
153 // Compute the maximum value that can occur in the intermediate buffer.
154 int64_t zpAdd = inZp + outZp;
155 int64_t maxValue = APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
156 std::abs(zpAdd) + 1;
158 // Convert that maximum value into the maximum bitwidth needed to represent
159 // it. We assume 48-bit numbers may be supported further in the pipeline.
160 int intermediateBitWidth = 64;
161 if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
162 intermediateBitWidth = 16;
163 } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
164 intermediateBitWidth = 32;
165 } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
166 intermediateBitWidth = 48;
169 Type intermediateType = rewriter.getIntegerType(intermediateBitWidth);
170 Value zpAddValue = rewriter.create<arith::ConstantOp>(
171 loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
173 // The negation can be applied by doing:
174 // outputValue = inZp + outZp - inputValue
175 auto ext = rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[0]);
176 auto sub = rewriter.create<arith::SubIOp>(loc, zpAddValue, ext);
178 // Clamp to the negation range.
179 Value min = rewriter.create<arith::ConstantIntOp>(
180 loc, APInt::getSignedMinValue(inputBitWidth).getSExtValue(),
181 intermediateType);
182 Value max = rewriter.create<arith::ConstantIntOp>(
183 loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),
184 intermediateType);
185 auto clamp =
186 clampIntHelper(loc, sub, min, max, rewriter, /*isUnsigned=*/false);
188 // Truncate to the final value.
189 return rewriter.create<arith::TruncIOp>(loc, elementTy, clamp);
192 // tosa::BitwiseAndOp
193 if (isa<tosa::BitwiseAndOp>(op) && isa<IntegerType>(elementTy))
194 return rewriter.create<arith::AndIOp>(loc, resultTypes, args);
196 // tosa::BitwiseOrOp
197 if (isa<tosa::BitwiseOrOp>(op) && isa<IntegerType>(elementTy))
198 return rewriter.create<arith::OrIOp>(loc, resultTypes, args);
200 // tosa::BitwiseNotOp
201 if (isa<tosa::BitwiseNotOp>(op) && isa<IntegerType>(elementTy)) {
202 auto allOnesAttr = rewriter.getIntegerAttr(
203 elementTy, APInt::getAllOnes(elementTy.getIntOrFloatBitWidth()));
204 auto allOnes = rewriter.create<arith::ConstantOp>(loc, allOnesAttr);
205 return rewriter.create<arith::XOrIOp>(loc, resultTypes, args[0], allOnes);
208 // tosa::BitwiseXOrOp
209 if (isa<tosa::BitwiseXorOp>(op) && isa<IntegerType>(elementTy))
210 return rewriter.create<arith::XOrIOp>(loc, resultTypes, args);
212 // tosa::LogicalLeftShiftOp
213 if (isa<tosa::LogicalLeftShiftOp>(op) && isa<IntegerType>(elementTy))
214 return rewriter.create<arith::ShLIOp>(loc, resultTypes, args);
216 // tosa::LogicalRightShiftOp
217 if (isa<tosa::LogicalRightShiftOp>(op) && isa<IntegerType>(elementTy))
218 return rewriter.create<arith::ShRUIOp>(loc, resultTypes, args);
220 // tosa::ArithmeticRightShiftOp
221 if (isa<tosa::ArithmeticRightShiftOp>(op) && isa<IntegerType>(elementTy)) {
222 auto result = rewriter.create<arith::ShRSIOp>(loc, resultTypes, args);
223 auto round = cast<BoolAttr>(op->getAttr("round")).getValue();
224 if (!round) {
225 return result;
228 Type i1Ty = IntegerType::get(rewriter.getContext(), /*width=*/1);
229 auto one =
230 rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 1));
231 auto zero =
232 rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
233 auto i1one =
234 rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(i1Ty, 1));
236 // Checking that input2 != 0
237 auto shiftValueGreaterThanZero = rewriter.create<arith::CmpIOp>(
238 loc, arith::CmpIPredicate::sgt, args[1], zero);
240 // Checking for the last bit of input1 to be 1
241 auto subtract =
242 rewriter.create<arith::SubIOp>(loc, resultTypes, args[1], one);
243 auto shifted =
244 rewriter.create<arith::ShRSIOp>(loc, resultTypes, args[0], subtract)
245 ->getResults();
246 auto truncated =
247 rewriter.create<arith::TruncIOp>(loc, i1Ty, shifted, std::nullopt);
248 auto isInputOdd =
249 rewriter.create<arith::AndIOp>(loc, i1Ty, truncated, i1one);
251 auto shouldRound = rewriter.create<arith::AndIOp>(
252 loc, i1Ty, shiftValueGreaterThanZero, isInputOdd);
253 auto extended =
254 rewriter.create<arith::ExtUIOp>(loc, resultTypes, shouldRound);
255 return rewriter.create<arith::AddIOp>(loc, resultTypes, result, extended);
258 // tosa::ClzOp
259 if (isa<tosa::ClzOp>(op) && isa<IntegerType>(elementTy)) {
260 return rewriter.create<math::CountLeadingZerosOp>(loc, elementTy, args[0]);
263 // tosa::LogicalAnd
264 if (isa<tosa::LogicalAndOp>(op) && elementTy.isInteger(1))
265 return rewriter.create<arith::AndIOp>(loc, resultTypes, args);
267 // tosa::LogicalNot
268 if (isa<tosa::LogicalNotOp>(op) && elementTy.isInteger(1)) {
269 auto one = rewriter.create<arith::ConstantOp>(
270 loc, rewriter.getIntegerAttr(elementTy, 1));
271 return rewriter.create<arith::XOrIOp>(loc, resultTypes, args[0], one);
274 // tosa::LogicalOr
275 if (isa<tosa::LogicalOrOp>(op) && elementTy.isInteger(1))
276 return rewriter.create<arith::OrIOp>(loc, resultTypes, args);
278 // tosa::LogicalXor
279 if (isa<tosa::LogicalXorOp>(op) && elementTy.isInteger(1))
280 return rewriter.create<arith::XOrIOp>(loc, resultTypes, args);
282 // tosa::PowOp
283 if (isa<tosa::PowOp>(op) && isa<FloatType>(elementTy))
284 return rewriter.create<mlir::math::PowFOp>(loc, resultTypes, args);
286 // tosa::RsqrtOp
287 if (isa<tosa::RsqrtOp>(op) && isa<FloatType>(elementTy))
288 return rewriter.create<mlir::math::RsqrtOp>(loc, resultTypes, args);
290 // tosa::LogOp
291 if (isa<tosa::LogOp>(op) && isa<FloatType>(elementTy))
292 return rewriter.create<mlir::math::LogOp>(loc, resultTypes, args);
294 // tosa::ExpOp
295 if (isa<tosa::ExpOp>(op) && isa<FloatType>(elementTy))
296 return rewriter.create<mlir::math::ExpOp>(loc, resultTypes, args);
298 // tosa::SinOp
299 if (isa<tosa::SinOp>(op) && isa<FloatType>(elementTy))
300 return rewriter.create<mlir::math::SinOp>(loc, resultTypes, args);
302 // tosa::CosOp
303 if (isa<tosa::CosOp>(op) && isa<FloatType>(elementTy))
304 return rewriter.create<mlir::math::CosOp>(loc, resultTypes, args);
306 // tosa::TanhOp
307 if (isa<tosa::TanhOp>(op) && isa<FloatType>(elementTy))
308 return rewriter.create<mlir::math::TanhOp>(loc, resultTypes, args);
310 // tosa::ErfOp
311 if (isa<tosa::ErfOp>(op) && llvm::isa<FloatType>(elementTy))
312 return rewriter.create<mlir::math::ErfOp>(loc, resultTypes, args);
314 // tosa::GreaterOp
315 if (isa<tosa::GreaterOp>(op) && isa<FloatType>(elementTy))
316 return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT,
317 args[0], args[1]);
319 if (isa<tosa::GreaterOp>(op) && elementTy.isSignlessInteger())
320 return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
321 args[0], args[1]);
323 // tosa::GreaterEqualOp
324 if (isa<tosa::GreaterEqualOp>(op) && isa<FloatType>(elementTy))
325 return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
326 args[0], args[1]);
328 if (isa<tosa::GreaterEqualOp>(op) && elementTy.isSignlessInteger())
329 return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
330 args[0], args[1]);
332 // tosa::EqualOp
333 if (isa<tosa::EqualOp>(op) && isa<FloatType>(elementTy))
334 return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OEQ,
335 args[0], args[1]);
337 if (isa<tosa::EqualOp>(op) && elementTy.isSignlessInteger())
338 return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
339 args[0], args[1]);
341 // tosa::SelectOp
342 if (isa<tosa::SelectOp>(op)) {
343 elementTy = cast<ShapedType>(op->getOperand(1).getType()).getElementType();
344 if (isa<FloatType>(elementTy) || isa<IntegerType>(elementTy))
345 return rewriter.create<arith::SelectOp>(loc, args[0], args[1], args[2]);
348 // tosa::MaximumOp
349 if (isa<tosa::MaximumOp>(op) && isa<FloatType>(elementTy)) {
350 return rewriter.create<arith::MaximumFOp>(loc, args[0], args[1]);
353 if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
354 return rewriter.create<arith::MaxSIOp>(loc, args[0], args[1]);
357 // tosa::MinimumOp
358 if (isa<tosa::MinimumOp>(op) && isa<FloatType>(elementTy)) {
359 return rewriter.create<arith::MinimumFOp>(loc, args[0], args[1]);
362 if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
363 return rewriter.create<arith::MinSIOp>(loc, args[0], args[1]);
366 // tosa::CeilOp
367 if (isa<tosa::CeilOp>(op) && isa<FloatType>(elementTy))
368 return rewriter.create<math::CeilOp>(loc, resultTypes, args);
370 // tosa::FloorOp
371 if (isa<tosa::FloorOp>(op) && isa<FloatType>(elementTy))
372 return rewriter.create<math::FloorOp>(loc, resultTypes, args);
374 // tosa::ClampOp
375 if (isa<tosa::ClampOp>(op) && isa<FloatType>(elementTy)) {
376 bool losesInfo = false;
377 APFloat minApf = cast<FloatAttr>(op->getAttr("min_fp")).getValue();
378 APFloat maxApf = cast<FloatAttr>(op->getAttr("max_fp")).getValue();
379 minApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
380 APFloat::rmNearestTiesToEven, &losesInfo);
381 maxApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
382 APFloat::rmNearestTiesToEven, &losesInfo);
383 auto min = rewriter.create<arith::ConstantOp>(
384 loc, elementTy, rewriter.getFloatAttr(elementTy, minApf));
385 auto max = rewriter.create<arith::ConstantOp>(
386 loc, elementTy, rewriter.getFloatAttr(elementTy, maxApf));
387 return clampFloatHelper(loc, args[0], min, max, rewriter);
390 if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
391 auto intTy = cast<IntegerType>(elementTy);
392 int64_t min =
393 cast<IntegerAttr>(op->getAttr("min_int")).getValue().getSExtValue();
394 int64_t max =
395 cast<IntegerAttr>(op->getAttr("max_int")).getValue().getSExtValue();
397 int64_t minRepresentable = std::numeric_limits<int64_t>::min();
398 int64_t maxRepresentable = std::numeric_limits<int64_t>::max();
399 if (intTy.isUnsignedInteger()) {
400 minRepresentable = 0;
401 if (intTy.getIntOrFloatBitWidth() <= 63) {
402 maxRepresentable =
403 (int64_t)APInt::getMaxValue(intTy.getIntOrFloatBitWidth())
404 .getZExtValue();
406 } else if (intTy.getIntOrFloatBitWidth() <= 64) {
407 // Ensure that min & max fit into signed n-bit constants.
408 minRepresentable = APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
409 .getSExtValue();
410 maxRepresentable = APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
411 .getSExtValue();
413 // Ensure that the bounds are representable as n-bit signed/unsigned
414 // integers.
415 min = std::max(min, minRepresentable);
416 max = std::max(max, minRepresentable);
417 min = std::min(min, maxRepresentable);
418 max = std::min(max, maxRepresentable);
420 auto minVal = rewriter.create<arith::ConstantIntOp>(
421 loc, min, intTy.getIntOrFloatBitWidth());
422 auto maxVal = rewriter.create<arith::ConstantIntOp>(
423 loc, max, intTy.getIntOrFloatBitWidth());
424 return clampIntHelper(loc, args[0], minVal, maxVal, rewriter,
425 intTy.isUnsignedInteger());
428 // tosa::SigmoidOp
429 if (isa<tosa::SigmoidOp>(op) && isa<FloatType>(elementTy)) {
430 auto one =
431 rewriter.create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 1));
432 auto negate = rewriter.create<arith::NegFOp>(loc, resultTypes, args[0]);
433 auto exp = rewriter.create<mlir::math::ExpOp>(loc, resultTypes, negate);
434 auto added = rewriter.create<arith::AddFOp>(loc, resultTypes, exp, one);
435 return rewriter.create<arith::DivFOp>(loc, resultTypes, one, added);
438 // tosa::CastOp
439 if (isa<tosa::CastOp>(op)) {
440 Type srcTy = elementTy;
441 Type dstTy = resultTypes.front();
442 bool bitExtend =
443 srcTy.getIntOrFloatBitWidth() < dstTy.getIntOrFloatBitWidth();
445 if (srcTy == dstTy)
446 return args.front();
448 if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && bitExtend)
449 return rewriter.create<arith::ExtFOp>(loc, resultTypes, args,
450 std::nullopt);
452 if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && !bitExtend)
453 return rewriter.create<arith::TruncFOp>(loc, resultTypes, args,
454 std::nullopt);
456 // 1-bit integers need to be treated as signless.
457 if (srcTy.isInteger(1) && arith::UIToFPOp::areCastCompatible(srcTy, dstTy))
458 return rewriter.create<arith::UIToFPOp>(loc, resultTypes, args,
459 std::nullopt);
461 if (srcTy.isInteger(1) && isa<IntegerType>(dstTy) && bitExtend)
462 return rewriter.create<arith::ExtUIOp>(loc, resultTypes, args,
463 std::nullopt);
465 // Unsigned integers need an unrealized cast so that they can be passed
466 // to UIToFP.
467 if (srcTy.isUnsignedInteger() && isa<FloatType>(dstTy)) {
468 auto unrealizedCast =
469 rewriter
470 .create<UnrealizedConversionCastOp>(
471 loc, rewriter.getIntegerType(srcTy.getIntOrFloatBitWidth()),
472 args[0])
473 .getResult(0);
474 return rewriter.create<arith::UIToFPOp>(loc, resultTypes[0],
475 unrealizedCast);
478 // All other si-to-fp conversions should be handled by SIToFP.
479 if (arith::SIToFPOp::areCastCompatible(srcTy, dstTy))
480 return rewriter.create<arith::SIToFPOp>(loc, resultTypes, args,
481 std::nullopt);
483 // Casting to boolean, floats need to only be checked as not-equal to zero.
484 if (isa<FloatType>(srcTy) && dstTy.isInteger(1)) {
485 Value zero = rewriter.create<arith::ConstantOp>(
486 loc, rewriter.getFloatAttr(srcTy, 0.0));
487 return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE,
488 args.front(), zero);
491 if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
492 auto rounded = rewriter.create<math::RoundEvenOp>(loc, args[0]);
494 const auto &fltSemantics = cast<FloatType>(srcTy).getFloatSemantics();
495 // Check whether neither int min nor int max can be represented in the
496 // input floating-point type due to too short exponent range.
497 if (static_cast<int>(dstTy.getIntOrFloatBitWidth()) - 1 >
498 APFloat::semanticsMaxExponent(fltSemantics)) {
499 // Use cmp + select to replace infinites by int min / int max. Other
500 // integral values can be represented in the integer space.
501 auto conv = rewriter.create<arith::FPToSIOp>(loc, dstTy, rounded);
502 auto posInf = rewriter.create<arith::ConstantOp>(
503 loc, rewriter.getFloatAttr(getElementTypeOrSelf(srcTy),
504 APFloat::getInf(fltSemantics)));
505 auto negInf = rewriter.create<arith::ConstantOp>(
506 loc, rewriter.getFloatAttr(
507 getElementTypeOrSelf(srcTy),
508 APFloat::getInf(fltSemantics, /*Negative=*/true)));
509 auto overflow = rewriter.create<arith::CmpFOp>(
510 loc, arith::CmpFPredicate::UEQ, rounded, posInf);
511 auto underflow = rewriter.create<arith::CmpFOp>(
512 loc, arith::CmpFPredicate::UEQ, rounded, negInf);
513 auto intMin = rewriter.create<arith::ConstantOp>(
514 loc, rewriter.getIntegerAttr(
515 getElementTypeOrSelf(dstTy),
516 APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())));
517 auto intMax = rewriter.create<arith::ConstantOp>(
518 loc, rewriter.getIntegerAttr(
519 getElementTypeOrSelf(dstTy),
520 APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
521 auto maxClamped =
522 rewriter.create<arith::SelectOp>(loc, overflow, intMax, conv);
523 return rewriter.create<arith::SelectOp>(loc, underflow, intMin,
524 maxClamped);
527 auto intMinFP = rewriter.create<arith::ConstantOp>(
528 loc, rewriter.getFloatAttr(
529 getElementTypeOrSelf(srcTy),
530 APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
531 .getSExtValue()));
533 // Check whether the mantissa has enough bits to represent int max.
534 if (cast<FloatType>(srcTy).getFPMantissaWidth() >=
535 dstTy.getIntOrFloatBitWidth() - 1) {
536 // Int min can also be represented since it is a power of two and thus
537 // consists of a single leading bit. Therefore we can clamp the input
538 // in the floating-point domain.
540 auto intMaxFP = rewriter.create<arith::ConstantOp>(
541 loc, rewriter.getFloatAttr(
542 getElementTypeOrSelf(srcTy),
543 APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
544 .getSExtValue()));
546 Value clamped =
547 clampFloatHelper(loc, rounded, intMinFP, intMaxFP, rewriter);
548 return rewriter.create<arith::FPToSIOp>(loc, dstTy, clamped);
551 // Due to earlier check we know exponant range is big enough to represent
552 // int min. We can therefore rely on int max + 1 being representable as
553 // well because it's just int min with a positive sign. So clamp the min
554 // value and compare against that to select the max int value if needed.
555 auto intMaxPlusOneFP = rewriter.create<arith::ConstantOp>(
556 loc, rewriter.getFloatAttr(
557 getElementTypeOrSelf(srcTy),
558 static_cast<double>(
559 APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
560 .getSExtValue()) +
561 1.0f));
563 auto intMax = rewriter.create<arith::ConstantOp>(
564 loc, rewriter.getIntegerAttr(
565 getElementTypeOrSelf(dstTy),
566 APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
567 auto minClampedFP =
568 rewriter.create<arith::MaximumFOp>(loc, rounded, intMinFP);
569 auto minClamped =
570 rewriter.create<arith::FPToSIOp>(loc, dstTy, minClampedFP);
571 auto overflow = rewriter.create<arith::CmpFOp>(
572 loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP);
573 return rewriter.create<arith::SelectOp>(loc, overflow, intMax,
574 minClamped);
577 // Casting to boolean, integers need to only be checked as not-equal to
578 // zero.
579 if (isa<IntegerType>(srcTy) && dstTy.isInteger(1)) {
580 Value zero = rewriter.create<arith::ConstantIntOp>(
581 loc, 0, srcTy.getIntOrFloatBitWidth());
582 return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne,
583 args.front(), zero);
586 if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && bitExtend)
587 return rewriter.create<arith::ExtSIOp>(loc, resultTypes, args,
588 std::nullopt);
590 if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && !bitExtend) {
591 return rewriter.create<arith::TruncIOp>(loc, dstTy, args[0]);
595 (void)rewriter.notifyMatchFailure(
596 op, "unhandled op for linalg body calculation for elementwise op");
597 return nullptr;
600 static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor,
601 int64_t rank) {
602 // No need to expand if we are already at the desired rank
603 auto shapedType = dyn_cast<ShapedType>(tensor.getType());
604 assert(shapedType && shapedType.hasRank() && "expected a ranked shaped type");
605 int64_t numExtraDims = rank - shapedType.getRank();
606 assert(numExtraDims >= 0 && "cannot expand tensor to a lower rank");
607 if (!numExtraDims)
608 return tensor;
610 // Compute reassociation indices
611 SmallVector<SmallVector<int64_t, 2>> reassociationIndices(
612 shapedType.getRank());
613 int64_t index = 0;
614 for (index = 0; index <= numExtraDims; index++)
615 reassociationIndices[0].push_back(index);
616 for (size_t position = 1; position < reassociationIndices.size(); position++)
617 reassociationIndices[position].push_back(index++);
619 // Compute result type
620 SmallVector<int64_t> resultShape;
621 for (index = 0; index < numExtraDims; index++)
622 resultShape.push_back(1);
623 for (auto size : shapedType.getShape())
624 resultShape.push_back(size);
625 auto resultType =
626 RankedTensorType::get(resultShape, shapedType.getElementType());
628 // Emit 'tensor.expand_shape' op
629 return rewriter.create<tensor::ExpandShapeOp>(loc, resultType, tensor,
630 reassociationIndices);
633 static SmallVector<Value> expandInputRanks(PatternRewriter &rewriter,
634 Location loc, ValueRange operands,
635 int64_t rank) {
636 return llvm::map_to_vector(operands, [&](Value operand) {
637 return expandRank(rewriter, loc, operand, rank);
641 using IndexPool = DenseMap<int64_t, Value>;
643 // Emit an 'arith.constant' op for the given index if it has not been created
644 // yet, or return an existing constant. This will prevent an excessive creation
645 // of redundant constants, easing readability of emitted code for unit tests.
646 static Value createIndex(PatternRewriter &rewriter, Location loc,
647 IndexPool &indexPool, int64_t index) {
648 auto [it, inserted] = indexPool.try_emplace(index);
649 if (inserted)
650 it->second =
651 rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(index));
652 return it->second;
655 static Value getTensorDim(PatternRewriter &rewriter, Location loc,
656 IndexPool &indexPool, Value tensor, int64_t index) {
657 auto indexValue = createIndex(rewriter, loc, indexPool, index);
658 return rewriter.create<tensor::DimOp>(loc, tensor, indexValue).getResult();
661 static OpFoldResult getOrFoldTensorDim(PatternRewriter &rewriter, Location loc,
662 IndexPool &indexPool, Value tensor,
663 int64_t index) {
664 auto shapedType = dyn_cast<ShapedType>(tensor.getType());
665 assert(shapedType && shapedType.hasRank() && "expected a ranked shaped type");
666 assert(index >= 0 && index < shapedType.getRank() && "index out of bounds");
667 if (shapedType.isDynamicDim(index))
668 return getTensorDim(rewriter, loc, indexPool, tensor, index);
669 return rewriter.getIndexAttr(shapedType.getDimSize(index));
672 static bool operandsAndResultsRanked(Operation *operation) {
673 auto isRanked = [](Value value) {
674 return isa<RankedTensorType>(value.getType());
676 return llvm::all_of(operation->getOperands(), isRanked) &&
677 llvm::all_of(operation->getResults(), isRanked);
680 // Compute the runtime dimension size for dimension 'dim' of the output by
681 // inspecting input 'operands', all of which are expected to have the same rank.
682 // This function returns a pair {targetSize, masterOperand}.
684 // The runtime size of the output dimension is returned either as a statically
685 // computed attribute or as a runtime SSA value.
687 // If the target size was inferred directly from one dominating operand, that
688 // operand is returned in 'masterOperand'. If the target size is inferred from
689 // multiple operands, 'masterOperand' is set to nullptr.
690 static std::pair<OpFoldResult, Value>
691 computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool,
692 ValueRange operands, int64_t dim) {
693 // If any input operand contains a static size greater than 1 for this
694 // dimension, that is the target size. An occurrence of an additional static
695 // dimension greater than 1 with a different value is undefined behavior.
696 for (auto operand : operands) {
697 auto size = cast<RankedTensorType>(operand.getType()).getDimSize(dim);
698 if (!ShapedType::isDynamic(size) && size > 1)
699 return {rewriter.getIndexAttr(size), operand};
702 // Filter operands with dynamic dimension
703 auto operandsWithDynamicDim =
704 llvm::to_vector(llvm::make_filter_range(operands, [&](Value operand) {
705 return cast<RankedTensorType>(operand.getType()).isDynamicDim(dim);
706 }));
708 // If no operand has a dynamic dimension, it means all sizes were 1
709 if (operandsWithDynamicDim.empty())
710 return {rewriter.getIndexAttr(1), operands.front()};
712 // Emit code that computes the runtime size for this dimension. If there is
713 // only one operand with a dynamic dimension, it is considered the master
714 // operand that determines the runtime size of the output dimension.
715 auto targetSize =
716 getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[0], dim);
717 if (operandsWithDynamicDim.size() == 1)
718 return {targetSize, operandsWithDynamicDim[0]};
720 // Calculate maximum size among all dynamic dimensions
721 for (size_t i = 1; i < operandsWithDynamicDim.size(); i++) {
722 auto nextSize =
723 getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[i], dim);
724 targetSize = rewriter.create<arith::MaxUIOp>(loc, targetSize, nextSize);
726 return {targetSize, nullptr};
729 // Compute the runtime output size for all dimensions. This function returns
730 // a pair {targetShape, masterOperands}.
731 static std::pair<SmallVector<OpFoldResult>, SmallVector<Value>>
732 computeTargetShape(PatternRewriter &rewriter, Location loc,
733 IndexPool &indexPool, ValueRange operands) {
734 assert(!operands.empty());
735 auto rank = cast<RankedTensorType>(operands.front().getType()).getRank();
736 SmallVector<OpFoldResult> targetShape;
737 SmallVector<Value> masterOperands;
738 for (auto dim : llvm::seq<int64_t>(0, rank)) {
739 auto [targetSize, masterOperand] =
740 computeTargetSize(rewriter, loc, indexPool, operands, dim);
741 targetShape.push_back(targetSize);
742 masterOperands.push_back(masterOperand);
744 return {targetShape, masterOperands};
747 static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc,
748 IndexPool &indexPool, Value operand,
749 int64_t dim, OpFoldResult targetSize,
750 Value masterOperand) {
751 // Nothing to do if this is a static dimension
752 auto rankedTensorType = cast<RankedTensorType>(operand.getType());
753 if (!rankedTensorType.isDynamicDim(dim))
754 return operand;
756 // If the target size for this dimension was directly inferred by only taking
757 // this operand into account, there is no need to broadcast. This is an
758 // optimization that will prevent redundant control flow, and constitutes the
759 // main motivation for tracking "master operands".
760 if (operand == masterOperand)
761 return operand;
763 // Affine maps for 'linalg.generic' op
764 auto rank = rankedTensorType.getRank();
765 SmallVector<AffineExpr> affineExprs;
766 for (auto index : llvm::seq<int64_t>(0, rank)) {
767 auto affineExpr = index == dim ? rewriter.getAffineConstantExpr(0)
768 : rewriter.getAffineDimExpr(index);
769 affineExprs.push_back(affineExpr);
771 auto broadcastAffineMap =
772 AffineMap::get(rank, 0, affineExprs, rewriter.getContext());
773 auto identityAffineMap = rewriter.getMultiDimIdentityMap(rank);
774 SmallVector<AffineMap> affineMaps = {broadcastAffineMap, identityAffineMap};
776 // Check if broadcast is necessary
777 auto one = createIndex(rewriter, loc, indexPool, 1);
778 auto runtimeSize = getTensorDim(rewriter, loc, indexPool, operand, dim);
779 auto broadcastNecessary = rewriter.create<arith::CmpIOp>(
780 loc, arith::CmpIPredicate::eq, runtimeSize, one);
782 // Emit 'then' region of 'scf.if'
783 auto emitThenRegion = [&](OpBuilder &opBuilder, Location loc) {
784 // It is not safe to cache constants across regions.
785 // New constants could potentially violate dominance requirements.
786 IndexPool localPool;
788 // Emit 'tensor.empty' op
789 SmallVector<OpFoldResult> outputTensorShape;
790 for (auto index : llvm::seq<int64_t>(0, rank)) {
791 auto size = index == dim ? targetSize
792 : getOrFoldTensorDim(rewriter, loc, localPool,
793 operand, index);
794 outputTensorShape.push_back(size);
796 Value outputTensor = opBuilder.create<tensor::EmptyOp>(
797 loc, outputTensorShape, rankedTensorType.getElementType());
799 // Emit 'linalg.generic' op
800 auto resultTensor =
801 opBuilder
802 .create<linalg::GenericOp>(
803 loc, outputTensor.getType(), operand, outputTensor, affineMaps,
804 getNParallelLoopsAttrs(rank),
805 [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
806 // Emit 'linalg.yield' op
807 opBuilder.create<linalg::YieldOp>(loc, blockArgs.front());
809 .getResult(0);
811 // Cast to original operand type if necessary
812 auto castResultTensor = rewriter.createOrFold<tensor::CastOp>(
813 loc, operand.getType(), resultTensor);
815 // Emit 'scf.yield' op
816 opBuilder.create<scf::YieldOp>(loc, castResultTensor);
819 // Emit 'else' region of 'scf.if'
820 auto emitElseRegion = [&](OpBuilder &opBuilder, Location loc) {
821 opBuilder.create<scf::YieldOp>(loc, operand);
824 // Emit 'scf.if' op
825 auto ifOp = rewriter.create<scf::IfOp>(loc, broadcastNecessary,
826 emitThenRegion, emitElseRegion);
827 return ifOp.getResult(0);
830 static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
831 IndexPool &indexPool, Value operand,
832 ArrayRef<OpFoldResult> targetShape,
833 ArrayRef<Value> masterOperands) {
834 int64_t rank = cast<RankedTensorType>(operand.getType()).getRank();
835 assert((int64_t)targetShape.size() == rank);
836 assert((int64_t)masterOperands.size() == rank);
837 for (auto index : llvm::seq<int64_t>(0, rank))
838 operand =
839 broadcastDynamicDimension(rewriter, loc, indexPool, operand, index,
840 targetShape[index], masterOperands[index]);
841 return operand;
844 static SmallVector<Value>
845 broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
846 IndexPool &indexPool, ValueRange operands,
847 ArrayRef<OpFoldResult> targetShape,
848 ArrayRef<Value> masterOperands) {
849 // No need to broadcast for unary operations
850 if (operands.size() == 1)
851 return operands;
853 // Broadcast dynamic dimensions operand by operand
854 return llvm::map_to_vector(operands, [&](Value operand) {
855 return broadcastDynamicDimensions(rewriter, loc, indexPool, operand,
856 targetShape, masterOperands);
860 static LogicalResult
861 emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc,
862 Operation *operation, ValueRange operands,
863 ArrayRef<OpFoldResult> targetShape,
864 const TypeConverter &converter) {
865 // Generate output tensor
866 auto resultType = cast_or_null<RankedTensorType>(
867 converter.convertType(operation->getResultTypes().front()));
868 if (!resultType) {
869 return rewriter.notifyMatchFailure(operation, "failed to convert type");
871 Value outputTensor = rewriter.create<tensor::EmptyOp>(
872 loc, targetShape, resultType.getElementType());
874 // Create affine maps. Input affine maps broadcast static dimensions of size
875 // 1. The output affine map is an identity map.
877 auto rank = resultType.getRank();
878 auto affineMaps = llvm::map_to_vector(operands, [&](Value operand) {
879 auto shape = cast<ShapedType>(operand.getType()).getShape();
880 SmallVector<AffineExpr> affineExprs;
881 for (auto it : llvm::enumerate(shape)) {
882 auto affineExpr = it.value() == 1 ? rewriter.getAffineConstantExpr(0)
883 : rewriter.getAffineDimExpr(it.index());
884 affineExprs.push_back(affineExpr);
886 return AffineMap::get(rank, 0, affineExprs, rewriter.getContext());
888 affineMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
890 // Emit 'linalg.generic' op
891 bool encounteredError = false;
892 auto linalgOp = rewriter.create<linalg::GenericOp>(
893 loc, outputTensor.getType(), operands, outputTensor, affineMaps,
894 getNParallelLoopsAttrs(rank),
895 [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
896 Value opResult = createLinalgBodyCalculationForElementwiseOp(
897 operation, blockArgs.take_front(operation->getNumOperands()),
898 {resultType.getElementType()}, rewriter);
899 if (!opResult) {
900 encounteredError = true;
901 return;
903 opBuilder.create<linalg::YieldOp>(loc, opResult);
905 if (encounteredError)
906 return rewriter.notifyMatchFailure(
907 operation, "unable to create linalg.generic body for elementwise op");
909 // Cast 'linalg.generic' result into original result type if needed
910 auto castResult = rewriter.createOrFold<tensor::CastOp>(
911 loc, resultType, linalgOp->getResult(0));
912 rewriter.replaceOp(operation, castResult);
913 return success();
916 static LogicalResult
917 elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands,
918 ConversionPatternRewriter &rewriter,
919 const TypeConverter &converter) {
921 // Collect op properties
922 assert(operation->getNumResults() == 1 && "elementwise op expects 1 result");
923 assert(operation->getNumOperands() >= 1 &&
924 "elementwise op expects at least 1 operand");
925 if (!operandsAndResultsRanked(operation))
926 return rewriter.notifyMatchFailure(operation,
927 "Unranked tensors not supported");
929 // Lower operation
930 IndexPool indexPool;
931 auto loc = operation->getLoc();
932 auto rank =
933 cast<RankedTensorType>(operation->getResultTypes().front()).getRank();
934 auto expandedOperands = expandInputRanks(rewriter, loc, operands, rank);
935 auto [targetShape, masterOperands] =
936 computeTargetShape(rewriter, loc, indexPool, expandedOperands);
937 auto broadcastOperands = broadcastDynamicDimensions(
938 rewriter, loc, indexPool, expandedOperands, targetShape, masterOperands);
939 return emitElementwiseComputation(rewriter, loc, operation, broadcastOperands,
940 targetShape, converter);
943 // Returns the constant initial value for a given reduction operation. The
944 // attribute type varies depending on the element type required.
945 static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy,
946 PatternRewriter &rewriter) {
947 if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy))
948 return rewriter.getFloatAttr(elementTy, 0.0);
950 if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy))
951 return rewriter.getIntegerAttr(elementTy, 0);
953 if (isa<tosa::ReduceProdOp>(op) && isa<FloatType>(elementTy))
954 return rewriter.getFloatAttr(elementTy, 1.0);
956 if (isa<tosa::ReduceProdOp>(op) && isa<IntegerType>(elementTy))
957 return rewriter.getIntegerAttr(elementTy, 1);
959 if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy))
960 return rewriter.getFloatAttr(
961 elementTy, APFloat::getLargest(
962 cast<FloatType>(elementTy).getFloatSemantics(), false));
964 if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy))
965 return rewriter.getIntegerAttr(
966 elementTy, APInt::getSignedMaxValue(elementTy.getIntOrFloatBitWidth()));
968 if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy))
969 return rewriter.getFloatAttr(
970 elementTy, APFloat::getLargest(
971 cast<FloatType>(elementTy).getFloatSemantics(), true));
973 if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy))
974 return rewriter.getIntegerAttr(
975 elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
977 if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
978 return rewriter.getIntegerAttr(elementTy, APInt::getAllOnes(1));
980 if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
981 return rewriter.getIntegerAttr(elementTy, APInt::getZero(1));
983 if (isa<tosa::ArgMaxOp>(op) && isa<FloatType>(elementTy))
984 return rewriter.getFloatAttr(
985 elementTy, APFloat::getLargest(
986 cast<FloatType>(elementTy).getFloatSemantics(), true));
988 if (isa<tosa::ArgMaxOp>(op) && isa<IntegerType>(elementTy))
989 return rewriter.getIntegerAttr(
990 elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
992 return {};
995 // Creates the body calculation for a reduction. The operations vary depending
996 // on the input type.
997 static Value createLinalgBodyCalculationForReduceOp(Operation *op,
998 ValueRange args,
999 Type elementTy,
1000 PatternRewriter &rewriter) {
1001 Location loc = op->getLoc();
1002 if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy)) {
1003 return rewriter.create<arith::AddFOp>(loc, args);
1006 if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy)) {
1007 return rewriter.create<arith::AddIOp>(loc, args);
1010 if (isa<tosa::ReduceProdOp>(op) && isa<FloatType>(elementTy)) {
1011 return rewriter.create<arith::MulFOp>(loc, args);
1014 if (isa<tosa::ReduceProdOp>(op) && isa<IntegerType>(elementTy)) {
1015 return rewriter.create<arith::MulIOp>(loc, args);
1018 if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy)) {
1019 return rewriter.create<arith::MinimumFOp>(loc, args[0], args[1]);
1022 if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy)) {
1023 return rewriter.create<arith::MinSIOp>(loc, args[0], args[1]);
1026 if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy)) {
1027 return rewriter.create<arith::MaximumFOp>(loc, args[0], args[1]);
1030 if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy)) {
1031 return rewriter.create<arith::MaxSIOp>(loc, args[0], args[1]);
1034 if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
1035 return rewriter.create<arith::AndIOp>(loc, args);
1037 if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
1038 return rewriter.create<arith::OrIOp>(loc, args);
1040 return {};
1043 // Performs the match and rewrite for reduction operations. This includes
1044 // declaring a correctly sized initial value, and the linalg.generic operation
1045 // that reduces across the specified axis.
1046 static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
1047 PatternRewriter &rewriter) {
1048 auto loc = op->getLoc();
1049 auto inputTy = cast<ShapedType>(op->getOperand(0).getType());
1050 auto resultTy = cast<ShapedType>(op->getResult(0).getType());
1051 auto elementTy = resultTy.getElementType();
1052 Value input = op->getOperand(0);
1054 SmallVector<int64_t> reduceShape;
1055 SmallVector<Value> dynDims;
1056 for (unsigned i = 0; i < inputTy.getRank(); i++) {
1057 if (axis != i) {
1058 reduceShape.push_back(inputTy.getDimSize(i));
1059 if (inputTy.isDynamicDim(i))
1060 dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1064 // First fill the output buffer with the init value.
1065 auto emptyTensor =
1066 rewriter
1067 .create<tensor::EmptyOp>(loc, reduceShape, resultTy.getElementType(),
1068 dynDims)
1069 .getResult();
1071 auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter);
1072 if (!fillValueAttr)
1073 return rewriter.notifyMatchFailure(
1074 op, "No initial value found for reduction operation");
1076 auto fillValue = rewriter.create<arith::ConstantOp>(loc, fillValueAttr);
1077 auto filledTensor = rewriter
1078 .create<linalg::FillOp>(loc, ValueRange{fillValue},
1079 ValueRange{emptyTensor})
1080 .result();
1082 bool didEncounterError = false;
1083 auto linalgOp = rewriter.create<linalg::ReduceOp>(
1084 loc, input, filledTensor, axis,
1085 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
1086 auto result = createLinalgBodyCalculationForReduceOp(
1087 op, blockArgs, elementTy, rewriter);
1088 if (result)
1089 didEncounterError = true;
1091 nestedBuilder.create<linalg::YieldOp>(loc, result);
1094 if (!didEncounterError)
1095 return rewriter.notifyMatchFailure(
1096 op, "unable to create linalg.generic body for reduce op");
1098 SmallVector<ReassociationExprs, 4> reassociationMap;
1099 uint64_t expandInputRank =
1100 cast<ShapedType>(linalgOp.getResults()[0].getType()).getRank();
1101 reassociationMap.resize(expandInputRank);
1103 for (uint64_t i = 0; i < expandInputRank; i++) {
1104 int32_t dimToPush = i > axis ? i + 1 : i;
1105 reassociationMap[i].push_back(rewriter.getAffineDimExpr(dimToPush));
1108 if (expandInputRank != 0) {
1109 int32_t expandedDim = axis < expandInputRank ? axis : expandInputRank - 1;
1110 reassociationMap[expandedDim].push_back(
1111 rewriter.getAffineDimExpr(expandedDim + 1));
1114 // Lower directly to `tensor::ExpandShapeOp` instead of `tosa::ReshapeOp`,
1115 // since here we know which dimension to expand, and `tosa::ReshapeOp` would
1116 // not have access to such information. This matters when handling dynamically
1117 // sized tensors.
1118 rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1119 op, resultTy, linalgOp.getResults()[0], reassociationMap);
1120 return success();
1123 namespace {
1125 template <typename SrcOp>
1126 class PointwiseConverter : public OpConversionPattern<SrcOp> {
1127 public:
1128 using OpConversionPattern<SrcOp>::OpConversionPattern;
1129 using typename OpConversionPattern<SrcOp>::OpAdaptor;
1131 LogicalResult
1132 matchAndRewrite(SrcOp op, OpAdaptor operands,
1133 ConversionPatternRewriter &rewriter) const final {
1134 return elementwiseMatchAndRewriteHelper(
1135 op, operands.getOperands(), rewriter, *this->getTypeConverter());
1139 class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
1140 public:
1141 using OpRewritePattern<tosa::RescaleOp>::OpRewritePattern;
1143 LogicalResult matchAndRewrite(tosa::RescaleOp op,
1144 PatternRewriter &rewriter) const final {
1145 auto loc = op.getLoc();
1146 auto input = op.getInput();
1147 auto inputTy = cast<ShapedType>(op.getInput().getType());
1148 auto outputTy = cast<ShapedType>(op.getOutput().getType());
1149 unsigned rank = inputTy.getRank();
1151 // This is an illegal configuration. terminate and log an error
1152 if (op.getDoubleRound() && !op.getScale32())
1153 return rewriter.notifyMatchFailure(
1154 op, "tosa.rescale requires scale32 for double_round to be true");
1156 if (!isa<IntegerType>(inputTy.getElementType()))
1157 return rewriter.notifyMatchFailure(op, "only support integer type");
1159 SmallVector<Value> dynDims;
1160 for (int i = 0; i < outputTy.getRank(); i++) {
1161 if (outputTy.isDynamicDim(i)) {
1162 dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1166 // The shift and multiplier values.
1167 SmallVector<int32_t> multiplierValues(op.getMultiplier());
1168 SmallVector<int8_t> shiftValues(op.getShift());
1170 // If we shift by more than the bitwidth, this just sets to 0.
1171 for (int i = 0, s = multiplierValues.size(); i < s; i++) {
1172 if (shiftValues[i] > 63) {
1173 shiftValues[i] = 0;
1174 multiplierValues[i] = 0;
1178 // Double round only occurs if shift is greater than 31, check that this
1179 // is ever true.
1180 bool doubleRound =
1181 op.getDoubleRound() &&
1182 llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
1184 SmallVector<AffineMap> indexingMaps = {
1185 rewriter.getMultiDimIdentityMap(rank)};
1186 SmallVector<Value, 4> genericInputs = {input};
1188 // If we are rescaling per-channel then we need to store the multiplier
1189 // values in a buffer.
1190 Value multiplierConstant;
1191 int64_t multiplierArg = 0;
1192 if (multiplierValues.size() == 1) {
1193 multiplierConstant = rewriter.create<arith::ConstantOp>(
1194 loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
1195 } else {
1196 SmallVector<AffineExpr, 2> multiplierExprs{
1197 rewriter.getAffineDimExpr(rank - 1)};
1198 auto multiplierType =
1199 RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())},
1200 rewriter.getI32Type());
1201 genericInputs.push_back(rewriter.create<arith::ConstantOp>(
1202 loc, DenseIntElementsAttr::get(multiplierType, multiplierValues)));
1204 indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1205 /*symbolCount=*/0, multiplierExprs,
1206 rewriter.getContext()));
1208 multiplierArg = indexingMaps.size() - 1;
1211 // If we are rescaling per-channel then we need to store the shift
1212 // values in a buffer.
1213 Value shiftConstant;
1214 int64_t shiftArg = 0;
1215 if (shiftValues.size() == 1) {
1216 shiftConstant = rewriter.create<arith::ConstantOp>(
1217 loc, rewriter.getI8IntegerAttr(shiftValues.front()));
1218 } else {
1219 SmallVector<AffineExpr, 2> shiftExprs = {
1220 rewriter.getAffineDimExpr(rank - 1)};
1221 auto shiftType =
1222 RankedTensorType::get({static_cast<int64_t>(shiftValues.size())},
1223 rewriter.getIntegerType(8));
1224 genericInputs.push_back(rewriter.create<arith::ConstantOp>(
1225 loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
1226 indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1227 /*symbolCount=*/0, shiftExprs,
1228 rewriter.getContext()));
1229 shiftArg = indexingMaps.size() - 1;
1232 // Indexing maps for output values.
1233 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
1235 // Construct the indexing maps needed for linalg.generic ops.
1236 Value emptyTensor = rewriter.create<tensor::EmptyOp>(
1237 loc, outputTy.getShape(), outputTy.getElementType(),
1238 ArrayRef<Value>({dynDims}));
1240 auto linalgOp = rewriter.create<linalg::GenericOp>(
1241 loc, outputTy, genericInputs, ValueRange{emptyTensor}, indexingMaps,
1242 getNParallelLoopsAttrs(rank),
1243 [&](OpBuilder &nestedBuilder, Location nestedLoc,
1244 ValueRange blockArgs) {
1245 Value value = blockArgs[0];
1246 Type valueTy = value.getType();
1248 // For now we do all of our math in 64-bit. This is not optimal but
1249 // should be correct for now, consider computing correct bit depth
1250 // later.
1251 int32_t inBitwidth = valueTy.getIntOrFloatBitWidth() > 32 ? 48 : 32;
1253 auto inputZp = createConstFromIntAttribute<int32_t>(
1254 op, "input_zp", nestedBuilder.getIntegerType(inBitwidth),
1255 nestedBuilder);
1256 auto outputZp = createConstFromIntAttribute<int32_t>(
1257 op, "output_zp", nestedBuilder.getI32Type(), nestedBuilder);
1259 Value multiplier = multiplierConstant ? multiplierConstant
1260 : blockArgs[multiplierArg];
1261 Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
1263 if (valueTy.getIntOrFloatBitWidth() < 32) {
1264 if (valueTy.isUnsignedInteger()) {
1265 value = nestedBuilder
1266 .create<UnrealizedConversionCastOp>(
1267 nestedLoc,
1268 nestedBuilder.getIntegerType(
1269 valueTy.getIntOrFloatBitWidth()),
1270 value)
1271 .getResult(0);
1272 value = nestedBuilder.create<arith::ExtUIOp>(
1273 nestedLoc, nestedBuilder.getI32Type(), value);
1274 } else {
1275 value = nestedBuilder.create<arith::ExtSIOp>(
1276 nestedLoc, nestedBuilder.getI32Type(), value);
1280 value =
1281 nestedBuilder.create<arith::SubIOp>(nestedLoc, value, inputZp);
1283 value = nestedBuilder.create<tosa::ApplyScaleOp>(
1284 loc, nestedBuilder.getI32Type(), value, multiplier, shift,
1285 nestedBuilder.getBoolAttr(doubleRound));
1287 // Move to the new zero-point.
1288 value =
1289 nestedBuilder.create<arith::AddIOp>(nestedLoc, value, outputZp);
1291 // Saturate to the output size.
1292 IntegerType outIntType =
1293 cast<IntegerType>(blockArgs.back().getType());
1294 unsigned outBitWidth = outIntType.getWidth();
1296 int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue();
1297 int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue();
1299 // Unsigned integers have a difference output value.
1300 if (outIntType.isUnsignedInteger()) {
1301 intMin = 0;
1302 intMax = APInt::getMaxValue(outBitWidth).getZExtValue();
1305 auto intMinVal = nestedBuilder.create<arith::ConstantOp>(
1306 loc, nestedBuilder.getI32IntegerAttr(intMin));
1307 auto intMaxVal = nestedBuilder.create<arith::ConstantOp>(
1308 loc, nestedBuilder.getI32IntegerAttr(intMax));
1310 value = clampIntHelper(nestedLoc, value, intMinVal, intMaxVal,
1311 nestedBuilder, /*isUnsigned=*/false);
1313 if (outIntType.getWidth() < 32) {
1314 value = nestedBuilder.create<arith::TruncIOp>(
1315 nestedLoc, rewriter.getIntegerType(outIntType.getWidth()),
1316 value);
1318 if (outIntType.isUnsignedInteger()) {
1319 value = nestedBuilder
1320 .create<UnrealizedConversionCastOp>(nestedLoc,
1321 outIntType, value)
1322 .getResult(0);
1326 nestedBuilder.create<linalg::YieldOp>(loc, value);
1329 rewriter.replaceOp(op, linalgOp->getResults());
1330 return success();
1334 // Handle the resize case where the input is a 1x1 image. This case
1335 // can entirely avoiding having extract operations which target much
1336 // more difficult to optimize away.
1337 class ResizeUnaryConverter : public OpRewritePattern<tosa::ResizeOp> {
1338 public:
1339 using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
1341 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1342 PatternRewriter &rewriter) const final {
1343 Location loc = op.getLoc();
1344 ImplicitLocOpBuilder builder(loc, rewriter);
1345 auto input = op.getInput();
1346 auto inputTy = cast<RankedTensorType>(input.getType());
1347 auto resultTy = cast<RankedTensorType>(op.getType());
1348 const bool isBilinear = op.getMode() == "BILINEAR";
1350 auto inputH = inputTy.getDimSize(1);
1351 auto inputW = inputTy.getDimSize(2);
1352 auto outputH = resultTy.getDimSize(1);
1353 auto outputW = resultTy.getDimSize(2);
1355 if (inputH != 1 || inputW != 1 || outputH != 1 || outputW != 1)
1356 return rewriter.notifyMatchFailure(
1357 op, "tosa.resize is not a pure 1x1->1x1 image operation");
1359 // TODO(suderman): These string values should be declared the TOSA dialect.
1360 if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR")
1361 return rewriter.notifyMatchFailure(
1362 op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1364 if (inputTy == resultTy) {
1365 rewriter.replaceOp(op, input);
1366 return success();
1369 ArrayRef<int64_t> scale = op.getScale();
1371 // Collapse the unit width and height away.
1372 SmallVector<ReassociationExprs, 4> reassociationMap(2);
1373 reassociationMap[0].push_back(builder.getAffineDimExpr(0));
1374 reassociationMap[1].push_back(builder.getAffineDimExpr(1));
1375 reassociationMap[1].push_back(builder.getAffineDimExpr(2));
1376 reassociationMap[1].push_back(builder.getAffineDimExpr(3));
1378 auto collapseTy =
1379 RankedTensorType::get({inputTy.getDimSize(0), inputTy.getDimSize(3)},
1380 inputTy.getElementType());
1381 Value collapse = builder.create<tensor::CollapseShapeOp>(collapseTy, input,
1382 reassociationMap);
1384 // Get any dynamic shapes that appear in the input format.
1385 llvm::SmallVector<Value> outputDynSize;
1386 if (inputTy.isDynamicDim(0))
1387 outputDynSize.push_back(builder.create<tensor::DimOp>(input, 0));
1388 if (inputTy.isDynamicDim(3))
1389 outputDynSize.push_back(builder.create<tensor::DimOp>(input, 3));
1391 // Generate the elementwise operation for casting scaling the input value.
1392 auto genericTy = collapseTy.clone(resultTy.getElementType());
1393 Value empty = builder.create<tensor::EmptyOp>(
1394 genericTy.getShape(), resultTy.getElementType(), outputDynSize);
1395 auto genericMap = rewriter.getMultiDimIdentityMap(genericTy.getRank());
1396 SmallVector<utils::IteratorType> iterators(genericTy.getRank(),
1397 utils::IteratorType::parallel);
1399 auto generic = builder.create<linalg::GenericOp>(
1400 genericTy, ValueRange{collapse}, ValueRange{empty},
1401 ArrayRef<AffineMap>{genericMap, genericMap}, iterators,
1402 [=](OpBuilder &b, Location loc, ValueRange args) {
1403 Value value = args[0];
1404 // This is the quantized case.
1405 if (inputTy.getElementType() != resultTy.getElementType()) {
1406 value =
1407 b.create<arith::ExtSIOp>(loc, resultTy.getElementType(), value);
1409 if (isBilinear && scale[0] != 0) {
1410 Value scaleY = b.create<arith::ConstantOp>(
1411 loc, b.getI32IntegerAttr(scale[0]));
1412 value = b.create<arith::MulIOp>(loc, value, scaleY);
1415 if (isBilinear && scale[2] != 0) {
1416 Value scaleX = b.create<arith::ConstantOp>(
1417 loc, b.getI32IntegerAttr(scale[2]));
1418 value = b.create<arith::MulIOp>(loc, value, scaleX);
1422 b.create<linalg::YieldOp>(loc, value);
1425 rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1426 op, resultTy, generic.getResults()[0], reassociationMap);
1427 return success();
1431 // TOSA resize with width or height of 1 may be broadcasted to a wider
1432 // dimension. This is done by materializing a new tosa.resize without
1433 // the broadcasting behavior, and an explicit broadcast afterwards.
1434 class MaterializeResizeBroadcast : public OpRewritePattern<tosa::ResizeOp> {
1435 public:
1436 using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
1438 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1439 PatternRewriter &rewriter) const final {
1440 Location loc = op.getLoc();
1441 ImplicitLocOpBuilder builder(loc, rewriter);
1442 auto input = op.getInput();
1443 auto inputTy = dyn_cast<RankedTensorType>(input.getType());
1444 auto resultTy = dyn_cast<RankedTensorType>(op.getType());
1446 if (!inputTy || !resultTy)
1447 return rewriter.notifyMatchFailure(op,
1448 "requires ranked input/output types");
1450 auto batch = inputTy.getDimSize(0);
1451 auto channels = inputTy.getDimSize(3);
1452 auto inputH = inputTy.getDimSize(1);
1453 auto inputW = inputTy.getDimSize(2);
1454 auto outputH = resultTy.getDimSize(1);
1455 auto outputW = resultTy.getDimSize(2);
1457 if ((inputH != 1 || outputH == 1) && (inputW != 1 || outputW == 1))
1458 return rewriter.notifyMatchFailure(
1459 op, "tosa.resize has no broadcasting behavior");
1461 // For any dimension that is broadcastable we generate a width of 1
1462 // on the output.
1463 llvm::SmallVector<int64_t> resizeShape;
1464 resizeShape.push_back(batch);
1465 resizeShape.push_back(inputH == 1 ? 1 : outputH);
1466 resizeShape.push_back(inputW == 1 ? 1 : outputW);
1467 resizeShape.push_back(channels);
1469 auto resizeTy = resultTy.clone(resizeShape);
1470 auto resize =
1471 builder.create<tosa::ResizeOp>(resizeTy, input, op->getAttrs());
1473 // Collapse an unit result dims.
1474 SmallVector<ReassociationExprs, 4> reassociationMap(2);
1475 reassociationMap[0].push_back(builder.getAffineDimExpr(0));
1476 reassociationMap.back().push_back(builder.getAffineDimExpr(1));
1477 if (inputH != 1)
1478 reassociationMap.push_back({});
1479 reassociationMap.back().push_back(builder.getAffineDimExpr(2));
1480 if (inputW != 1)
1481 reassociationMap.push_back({});
1482 reassociationMap.back().push_back(builder.getAffineDimExpr(3));
1484 llvm::SmallVector<int64_t> collapseShape{batch};
1485 if (inputH != 1)
1486 collapseShape.push_back(outputH);
1487 if (inputW != 1)
1488 collapseShape.push_back(outputW);
1489 collapseShape.push_back(channels);
1491 auto collapseTy = resultTy.clone(collapseShape);
1492 Value collapse = builder.create<tensor::CollapseShapeOp>(collapseTy, resize,
1493 reassociationMap);
1495 // Broadcast the collapsed shape to the output result.
1496 llvm::SmallVector<Value> outputDynSize;
1497 if (inputTy.isDynamicDim(0))
1498 outputDynSize.push_back(builder.create<tensor::DimOp>(input, 0));
1499 if (inputTy.isDynamicDim(3))
1500 outputDynSize.push_back(builder.create<tensor::DimOp>(input, 3));
1502 SmallVector<utils::IteratorType> iterators(resultTy.getRank(),
1503 utils::IteratorType::parallel);
1504 Value empty = builder.create<tensor::EmptyOp>(
1505 resultTy.getShape(), resultTy.getElementType(), outputDynSize);
1507 SmallVector<AffineExpr, 4> inputExprs{rewriter.getAffineDimExpr(0)};
1508 if (inputH != 1)
1509 inputExprs.push_back(rewriter.getAffineDimExpr(1));
1510 if (inputW != 1)
1511 inputExprs.push_back(rewriter.getAffineDimExpr(2));
1512 inputExprs.push_back(rewriter.getAffineDimExpr(3));
1514 auto inputMap = AffineMap::get(resultTy.getRank(), /*symbolCount=*/0,
1515 inputExprs, rewriter.getContext());
1517 auto outputMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
1518 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1519 op, resultTy, ValueRange{collapse}, ValueRange{empty},
1520 ArrayRef<AffineMap>{inputMap, outputMap}, iterators,
1521 [=](OpBuilder &b, Location loc, ValueRange args) {
1522 Value value = args[0];
1523 b.create<linalg::YieldOp>(loc, value);
1526 return success();
1530 class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
1531 public:
1532 using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
1534 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1535 PatternRewriter &rewriter) const final {
1536 Location loc = op.getLoc();
1537 ImplicitLocOpBuilder b(loc, rewriter);
1538 auto input = op.getInput();
1539 auto inputTy = cast<ShapedType>(input.getType());
1540 auto resultTy = cast<ShapedType>(op.getType());
1541 auto resultETy = resultTy.getElementType();
1543 bool floatingPointMode = resultETy.isF16() || resultETy.isF32();
1544 auto floatTy = resultETy.isF16() ? b.getF16Type() : b.getF32Type();
1546 auto imageH = inputTy.getShape()[1];
1547 auto imageW = inputTy.getShape()[2];
1549 auto dynamicDimsOr =
1550 checkHasDynamicBatchDims(rewriter, op, {input, op.getOutput()});
1551 if (!dynamicDimsOr.has_value())
1552 return rewriter.notifyMatchFailure(
1553 op, "unable to get dynamic dimensions of tosa.resize");
1555 if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR")
1556 return rewriter.notifyMatchFailure(
1557 op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1559 SmallVector<AffineMap, 2> affineMaps = {
1560 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1561 auto emptyTensor = b.create<tensor::EmptyOp>(resultTy.getShape(), resultETy,
1562 *dynamicDimsOr);
1563 auto genericOp = b.create<linalg::GenericOp>(
1564 resultTy, ValueRange({}), ValueRange{emptyTensor}, affineMaps,
1565 getNParallelLoopsAttrs(resultTy.getRank()));
1566 Value resize = genericOp.getResult(0);
1569 OpBuilder::InsertionGuard regionGuard(b);
1570 b.createBlock(&genericOp.getRegion(), genericOp.getRegion().end(),
1571 TypeRange({resultETy}), loc);
1572 Value batch = b.create<linalg::IndexOp>(0);
1573 Value y = b.create<linalg::IndexOp>(1);
1574 Value x = b.create<linalg::IndexOp>(2);
1575 Value channel = b.create<linalg::IndexOp>(3);
1577 Value zeroI32 =
1578 b.create<arith::ConstantOp>(b.getZeroAttr(b.getI32Type()));
1579 Value zeroFp = b.create<arith::ConstantOp>(b.getZeroAttr(floatTy));
1580 Value hMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageH - 1));
1581 Value wMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageW - 1));
1583 Value inY = b.create<arith::IndexCastOp>(b.getI32Type(), y);
1584 Value inX = b.create<arith::IndexCastOp>(b.getI32Type(), x);
1586 ArrayRef<int64_t> offset = op.getOffset();
1587 ArrayRef<int64_t> border = op.getBorder();
1588 ArrayRef<int64_t> scale = op.getScale();
1590 Value yScaleN, yScaleD, xScaleN, xScaleD;
1591 yScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[0]));
1592 yScaleD = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[1]));
1593 xScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[2]));
1594 xScaleD = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[3]));
1596 Value yOffset, xOffset, yBorder, xBorder;
1597 yOffset = b.create<arith::ConstantOp>(b.getI32IntegerAttr(offset[0]));
1598 xOffset = b.create<arith::ConstantOp>(b.getI32IntegerAttr(offset[1]));
1599 yBorder = b.create<arith::ConstantOp>(b.getI32IntegerAttr(border[0]));
1600 xBorder = b.create<arith::ConstantOp>(b.getI32IntegerAttr(border[1]));
1602 // Compute the ix and dx values for both the X and Y dimensions.
1603 auto getIndexAndDeltaFp = [&](Value &index, Value &delta, Value in,
1604 Value scaleN, Value scaleD, Value offset,
1605 int size, ImplicitLocOpBuilder &b) {
1606 if (size == 1) {
1607 index = zeroI32;
1608 delta = zeroFp;
1609 return;
1611 // x = x * scale_d + offset;
1612 // ix = floor(x / scale_n)
1613 Value val = b.create<arith::MulIOp>(in, scaleD);
1614 val = b.create<arith::AddIOp>(val, offset);
1615 index = b.create<arith::FloorDivSIOp>(val, scaleN);
1617 // rx = x % scale_n
1618 // dx = rx / scale_n
1619 Value r = b.create<arith::RemSIOp>(val, scaleN);
1620 Value rFp = b.create<arith::SIToFPOp>(floatTy, r);
1621 Value scaleNfp = b.create<arith::UIToFPOp>(floatTy, scaleN);
1622 delta = b.create<arith::DivFOp>(rFp, scaleNfp);
1625 // Compute the ix and dx values for the X and Y dimensions - int case.
1626 auto getIndexAndDeltaInt = [&](Value &index, Value &delta, Value in,
1627 Value scaleN, Value scaleD, Value offset,
1628 int size, ImplicitLocOpBuilder &b) {
1629 if (size == 1) {
1630 index = zeroI32;
1631 delta = zeroI32;
1632 return;
1634 // x = x * scale_d + offset;
1635 // ix = floor(x / scale_n)
1636 // dx = x - ix * scale_n;
1637 Value val = b.create<arith::MulIOp>(in, scaleD);
1638 val = b.create<arith::AddIOp>(val, offset);
1639 index = b.create<arith::DivSIOp>(val, scaleN);
1640 delta = b.create<arith::MulIOp>(index, scaleN);
1641 delta = b.create<arith::SubIOp>(val, delta);
1644 Value ix, iy, dx, dy;
1645 if (floatingPointMode) {
1646 getIndexAndDeltaFp(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
1647 getIndexAndDeltaFp(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
1648 } else {
1649 getIndexAndDeltaInt(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
1650 getIndexAndDeltaInt(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
1653 if (op.getMode() == "NEAREST_NEIGHBOR") {
1654 auto one = b.create<arith::ConstantOp>(b.getI32IntegerAttr(1));
1656 auto getNearestIndexAndClamp = [&](Value val, Value dval, Value scale,
1657 Value max, int size,
1658 ImplicitLocOpBuilder &b) -> Value {
1659 if (size == 1) {
1660 return b.create<arith::ConstantIndexOp>(0);
1663 Value pred;
1664 if (floatingPointMode) {
1665 auto h = b.create<arith::ConstantOp>(b.getFloatAttr(floatTy, 0.5f));
1666 pred = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, dval, h);
1667 } else {
1668 Value dvalDouble = b.create<arith::ShLIOp>(dval, one);
1669 pred = b.create<arith::CmpIOp>(arith::CmpIPredicate::sge,
1670 dvalDouble, scale);
1673 auto offset = b.create<arith::SelectOp>(pred, one, zeroI32);
1674 val = b.create<arith::AddIOp>(val, offset);
1675 val = clampIntHelper(loc, val, zeroI32, max, b, /*isUnsigned=*/false);
1676 return b.create<arith::IndexCastOp>(b.getIndexType(), val);
1679 iy = getNearestIndexAndClamp(iy, dy, yScaleN, hMax, imageH, b);
1680 ix = getNearestIndexAndClamp(ix, dx, xScaleN, wMax, imageW, b);
1682 Value result = b.create<tensor::ExtractOp>(
1683 input, ValueRange{batch, iy, ix, channel});
1685 b.create<linalg::YieldOp>(result);
1686 } else {
1687 // The mode here must be BILINEAR.
1688 assert(op.getMode() == "BILINEAR");
1690 auto oneVal = b.create<arith::ConstantOp>(b.getI32IntegerAttr(1));
1692 auto getClampedIdxs = [&](Value &val0, Value &val1, int size, Value in,
1693 Value max, ImplicitLocOpBuilder &b) {
1694 val0 = in;
1695 val1 = b.create<arith::AddIOp>(val0, oneVal);
1696 val0 =
1697 clampIntHelper(loc, val0, zeroI32, max, b, /*isUnsigned=*/false);
1698 val1 =
1699 clampIntHelper(loc, val1, zeroI32, max, b, /*isUnsigned=*/false);
1700 val0 = b.create<arith::IndexCastOp>(b.getIndexType(), val0);
1701 val1 = b.create<arith::IndexCastOp>(b.getIndexType(), val1);
1704 // Linalg equivalent to the section below:
1705 // int16_t iy0 = apply_max(iy, 0);
1706 // int16_t iy1 = apply_min(iy + 1, IH - 1);
1707 // int16_t ix0 = apply_max(ix, 0);
1708 // int16_t ix1 = apply_min(ix + 1, IW - 1);
1709 Value x0, x1, y0, y1;
1710 getClampedIdxs(y0, y1, imageH, iy, hMax, b);
1711 getClampedIdxs(x0, x1, imageW, ix, wMax, b);
1713 Value y0x0 = b.create<tensor::ExtractOp>(
1714 input, ValueRange{batch, y0, x0, channel});
1715 Value y0x1 = b.create<tensor::ExtractOp>(
1716 input, ValueRange{batch, y0, x1, channel});
1717 Value y1x0 = b.create<tensor::ExtractOp>(
1718 input, ValueRange{batch, y1, x0, channel});
1719 Value y1x1 = b.create<tensor::ExtractOp>(
1720 input, ValueRange{batch, y1, x1, channel});
1722 if (floatingPointMode) {
1723 auto oneVal =
1724 b.create<arith::ConstantOp>(b.getFloatAttr(floatTy, 1.0f));
1725 auto interpolate = [&](Value val0, Value val1, Value delta,
1726 int inputSize,
1727 ImplicitLocOpBuilder &b) -> Value {
1728 if (inputSize == 1)
1729 return val0;
1730 Value oneMinusDelta = b.create<arith::SubFOp>(oneVal, delta);
1731 Value mul0 = b.create<arith::MulFOp>(val0, oneMinusDelta);
1732 Value mul1 = b.create<arith::MulFOp>(val1, delta);
1733 return b.create<arith::AddFOp>(mul0, mul1);
1736 // Linalg equivalent to the section below:
1737 // topAcc = v00 * (unit_x - dx);
1738 // topAcc += v01 * dx;
1739 Value topAcc = interpolate(y0x0, y0x1, dx, imageW, b);
1741 // Linalg equivalent to the section below:
1742 // bottomAcc = v10 * (unit_x - dx);
1743 // bottomAcc += v11 * dx;
1744 Value bottomAcc = interpolate(y1x0, y1x1, dx, imageW, b);
1746 // Linalg equivalent to the section below:
1747 // result = topAcc * (unit_y - dy) + bottomAcc * dy
1748 Value result = interpolate(topAcc, bottomAcc, dy, imageH, b);
1749 b.create<linalg::YieldOp>(result);
1750 } else {
1751 // Perform in quantized space.
1752 y0x0 = b.create<arith::ExtSIOp>(resultETy, y0x0);
1753 y0x1 = b.create<arith::ExtSIOp>(resultETy, y0x1);
1754 y1x0 = b.create<arith::ExtSIOp>(resultETy, y1x0);
1755 y1x1 = b.create<arith::ExtSIOp>(resultETy, y1x1);
1757 const int64_t deltaBitwidth = dx.getType().getIntOrFloatBitWidth();
1758 if (resultETy.getIntOrFloatBitWidth() > deltaBitwidth) {
1759 dx = b.create<arith::ExtSIOp>(resultETy, dx);
1760 dy = b.create<arith::ExtSIOp>(resultETy, dy);
1763 Value yScaleNExt = yScaleN;
1764 Value xScaleNExt = xScaleN;
1766 const int64_t scaleBitwidth =
1767 xScaleN.getType().getIntOrFloatBitWidth();
1768 if (resultETy.getIntOrFloatBitWidth() > scaleBitwidth) {
1769 yScaleNExt = b.create<arith::ExtSIOp>(resultETy, yScaleN);
1770 xScaleNExt = b.create<arith::ExtSIOp>(resultETy, xScaleN);
1773 auto interpolate = [](Value val0, Value val1, Value weight1,
1774 Value scale, int inputSize,
1775 ImplicitLocOpBuilder &b) -> Value {
1776 if (inputSize == 1)
1777 return b.create<arith::MulIOp>(val0, scale);
1778 Value weight0 = b.create<arith::SubIOp>(scale, weight1);
1779 Value mul0 = b.create<arith::MulIOp>(val0, weight0);
1780 Value mul1 = b.create<arith::MulIOp>(val1, weight1);
1781 return b.create<arith::AddIOp>(mul0, mul1);
1784 Value topAcc = interpolate(y0x0, y0x1, dx, xScaleNExt, imageW, b);
1785 Value bottomAcc = interpolate(y1x0, y1x1, dx, xScaleNExt, imageW, b);
1786 Value result =
1787 interpolate(topAcc, bottomAcc, dy, yScaleNExt, imageH, b);
1788 b.create<linalg::YieldOp>(result);
1793 rewriter.replaceOp(op, resize);
1794 return success();
1798 // At the codegen level any identity operations should be removed. Any cases
1799 // where identity is load-bearing (e.g. cross device computation) should be
1800 // handled before lowering to codegen.
1801 template <typename SrcOp>
1802 class IdentityNConverter : public OpRewritePattern<SrcOp> {
1803 public:
1804 using OpRewritePattern<SrcOp>::OpRewritePattern;
1806 LogicalResult matchAndRewrite(SrcOp op,
1807 PatternRewriter &rewriter) const final {
1808 rewriter.replaceOp(op, op.getOperation()->getOperands());
1809 return success();
1813 template <typename SrcOp>
1814 class ReduceConverter : public OpRewritePattern<SrcOp> {
1815 public:
1816 using OpRewritePattern<SrcOp>::OpRewritePattern;
1818 LogicalResult matchAndRewrite(SrcOp reduceOp,
1819 PatternRewriter &rewriter) const final {
1820 return reduceMatchAndRewriteHelper(reduceOp, reduceOp.getAxis(), rewriter);
1824 class ReverseConverter : public OpRewritePattern<tosa::ReverseOp> {
1825 public:
1826 using OpRewritePattern<tosa::ReverseOp>::OpRewritePattern;
1828 LogicalResult matchAndRewrite(tosa::ReverseOp op,
1829 PatternRewriter &rewriter) const final {
1830 auto loc = op.getLoc();
1831 Value input = op.getInput1();
1832 auto inputTy = cast<ShapedType>(input.getType());
1833 auto resultTy = cast<ShapedType>(op.getType());
1834 auto axis = op.getAxis();
1836 SmallVector<Value> dynDims;
1837 for (int i = 0; i < inputTy.getRank(); i++) {
1838 if (inputTy.isDynamicDim(i)) {
1839 dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1843 Value axisDimSize = rewriter.create<tensor::DimOp>(loc, input, axis);
1845 // First fill the output buffer with the init value.
1846 auto emptyTensor = rewriter
1847 .create<tensor::EmptyOp>(loc, inputTy.getShape(),
1848 inputTy.getElementType(),
1849 ArrayRef<Value>({dynDims}))
1850 .getResult();
1851 SmallVector<AffineMap, 2> affineMaps = {
1852 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1854 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1855 op, resultTy, ArrayRef<Value>({}), ValueRange{emptyTensor}, affineMaps,
1856 getNParallelLoopsAttrs(resultTy.getRank()),
1857 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
1858 llvm::SmallVector<Value> indices;
1859 for (unsigned int i = 0; i < inputTy.getRank(); i++) {
1860 Value index =
1861 rewriter.create<linalg::IndexOp>(nestedLoc, i).getResult();
1862 if (i == axis) {
1863 auto one = rewriter.create<arith::ConstantIndexOp>(nestedLoc, 1);
1864 auto sizeMinusOne =
1865 rewriter.create<arith::SubIOp>(nestedLoc, axisDimSize, one);
1866 index = rewriter.create<arith::SubIOp>(nestedLoc, sizeMinusOne,
1867 index);
1870 indices.push_back(index);
1873 auto extract = nestedBuilder.create<tensor::ExtractOp>(
1874 nestedLoc, input, indices);
1875 nestedBuilder.create<linalg::YieldOp>(op.getLoc(),
1876 extract.getResult());
1878 return success();
1882 // This converter translate a tile operation to a reshape, broadcast, reshape.
1883 // The first reshape minimally expands each tiled dimension to include a
1884 // proceding size-1 dim. This dim is then broadcasted to the appropriate
1885 // multiple.
1886 struct TileConverter : public OpConversionPattern<tosa::TileOp> {
1887 using OpConversionPattern<tosa::TileOp>::OpConversionPattern;
1889 LogicalResult
1890 matchAndRewrite(tosa::TileOp op, OpAdaptor adaptor,
1891 ConversionPatternRewriter &rewriter) const override {
1892 auto loc = op.getLoc();
1893 auto input = op.getInput1();
1894 auto inputTy = cast<ShapedType>(input.getType());
1895 auto inputShape = inputTy.getShape();
1896 auto resultTy = cast<ShapedType>(op.getType());
1897 auto elementTy = inputTy.getElementType();
1898 int64_t rank = inputTy.getRank();
1900 ArrayRef<int64_t> multiples = op.getMultiples();
1902 // Broadcast the newly added dimensions to their appropriate multiple.
1903 SmallVector<int64_t, 2> genericShape;
1904 for (int i = 0; i < rank; i++) {
1905 int64_t dim = multiples[i];
1906 genericShape.push_back(dim == -1 ? ShapedType::kDynamic : dim);
1907 genericShape.push_back(inputShape[i]);
1910 SmallVector<Value> dynDims;
1911 for (int i = 0; i < inputTy.getRank(); i++) {
1912 if (inputTy.isDynamicDim(i) || multiples[i] == -1) {
1913 dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1917 auto emptyTensor = rewriter.create<tensor::EmptyOp>(
1918 op.getLoc(), genericShape, elementTy, dynDims);
1920 // We needs to map the input shape to the non-broadcasted dimensions.
1921 SmallVector<AffineExpr, 4> dimExprs;
1922 dimExprs.reserve(rank);
1923 for (unsigned i = 0; i < rank; ++i)
1924 dimExprs.push_back(rewriter.getAffineDimExpr(i * 2 + 1));
1926 auto readAffineMap =
1927 AffineMap::get(/*dimCount=*/rank * 2, /*symbolCount=*/0, dimExprs,
1928 rewriter.getContext());
1930 SmallVector<AffineMap, 2> affineMaps = {
1931 readAffineMap, rewriter.getMultiDimIdentityMap(genericShape.size())};
1933 auto genericOp = rewriter.create<linalg::GenericOp>(
1934 loc, RankedTensorType::get(genericShape, elementTy), input,
1935 ValueRange{emptyTensor}, affineMaps,
1936 getNParallelLoopsAttrs(genericShape.size()),
1937 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
1938 nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin());
1941 rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
1942 op, resultTy, genericOp.getResult(0),
1943 rewriter.getDenseI64ArrayAttr(resultTy.getShape()));
1944 return success();
1948 // Tosa argmax lowering represents the ArgMax op as an linalg.indexed_generic
1949 // op, producing two output buffers.
1951 // The first output buffer contains the index of the found maximum value. It is
1952 // initialized to 0 and is resulting integer type.
1954 // The second output buffer contains the maximum value found. It is initialized
1955 // to the minimum representable value of the input element type. After being
1956 // populated by indexed_generic, this buffer is disgarded as only the index is
1957 // requested.
1959 // The indexed_generic op updates both the maximum value and index if the
1960 // current value exceeds the running max.
1961 class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
1962 public:
1963 using OpRewritePattern<tosa::ArgMaxOp>::OpRewritePattern;
1965 LogicalResult matchAndRewrite(tosa::ArgMaxOp argmaxOp,
1966 PatternRewriter &rewriter) const final {
1967 auto loc = argmaxOp.getLoc();
1968 Value input = argmaxOp.getInput();
1969 auto inputTy = cast<ShapedType>(input.getType());
1970 auto resultTy = cast<ShapedType>(argmaxOp.getOutput().getType());
1971 auto inElementTy = inputTy.getElementType();
1972 auto outElementTy = resultTy.getElementType();
1973 int axis = argmaxOp.getAxis();
1974 auto resultMaxTy = RankedTensorType::get(resultTy.getShape(), inElementTy);
1976 if (!isa<IntegerType>(outElementTy))
1977 return rewriter.notifyMatchFailure(
1978 argmaxOp,
1979 "tosa.arg_max to linalg.* requires integer-like result type");
1981 SmallVector<Value> dynDims;
1982 for (int i = 0; i < inputTy.getRank(); i++) {
1983 if (inputTy.isDynamicDim(i) && i != axis) {
1984 dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1988 // First fill the output buffer for the index.
1989 auto emptyTensorIdx = rewriter
1990 .create<tensor::EmptyOp>(loc, resultTy.getShape(),
1991 outElementTy, dynDims)
1992 .getResult();
1993 auto fillValueIdx = rewriter.create<arith::ConstantOp>(
1994 loc, rewriter.getIntegerAttr(outElementTy, 0));
1995 auto filledTensorIdx =
1996 rewriter
1997 .create<linalg::FillOp>(loc, ValueRange{fillValueIdx},
1998 ValueRange{emptyTensorIdx})
1999 .result();
2001 // Second fill the output buffer for the running max.
2002 auto emptyTensorMax = rewriter
2003 .create<tensor::EmptyOp>(loc, resultTy.getShape(),
2004 inElementTy, dynDims)
2005 .getResult();
2006 auto fillValueMaxAttr =
2007 createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter);
2009 if (!fillValueMaxAttr)
2010 return rewriter.notifyMatchFailure(
2011 argmaxOp, "unsupported tosa.argmax element type");
2013 auto fillValueMax =
2014 rewriter.create<arith::ConstantOp>(loc, fillValueMaxAttr);
2015 auto filledTensorMax =
2016 rewriter
2017 .create<linalg::FillOp>(loc, ValueRange{fillValueMax},
2018 ValueRange{emptyTensorMax})
2019 .result();
2021 // We need to reduce along the arg-max axis, with parallel operations along
2022 // the rest.
2023 SmallVector<utils::IteratorType, 4> iteratorTypes;
2024 iteratorTypes.resize(inputTy.getRank(), utils::IteratorType::parallel);
2025 iteratorTypes[axis] = utils::IteratorType::reduction;
2027 SmallVector<AffineExpr, 2> srcExprs;
2028 SmallVector<AffineExpr, 2> dstExprs;
2029 for (int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
2030 srcExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
2031 if (axis != i)
2032 dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
2035 bool didEncounterError = false;
2036 auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs, dstExprs},
2037 rewriter.getContext());
2038 auto linalgOp = rewriter.create<linalg::GenericOp>(
2039 loc, ArrayRef<Type>({resultTy, resultMaxTy}), input,
2040 ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes,
2041 [&](OpBuilder &nestedBuilder, Location nestedLoc,
2042 ValueRange blockArgs) {
2043 auto newValue = blockArgs[0];
2044 auto oldIndex = blockArgs[1];
2045 auto oldValue = blockArgs[2];
2047 Value newIndex = rewriter.create<arith::IndexCastOp>(
2048 nestedLoc, oldIndex.getType(),
2049 rewriter.create<linalg::IndexOp>(loc, axis));
2051 Value predicate;
2052 if (isa<FloatType>(inElementTy)) {
2053 predicate = rewriter.create<arith::CmpFOp>(
2054 nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);
2055 } else if (isa<IntegerType>(inElementTy)) {
2056 predicate = rewriter.create<arith::CmpIOp>(
2057 nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue);
2058 } else {
2059 didEncounterError = true;
2060 return;
2063 auto resultMax = rewriter.create<arith::SelectOp>(
2064 nestedLoc, predicate, newValue, oldValue);
2065 auto resultIndex = rewriter.create<arith::SelectOp>(
2066 nestedLoc, predicate, newIndex, oldIndex);
2067 nestedBuilder.create<linalg::YieldOp>(
2068 nestedLoc, ValueRange({resultIndex, resultMax}));
2071 if (didEncounterError)
2072 return rewriter.notifyMatchFailure(
2073 argmaxOp, "unsupported tosa.argmax element type");
2075 rewriter.replaceOp(argmaxOp, linalgOp.getResult(0));
2076 return success();
2080 class GatherConverter : public OpConversionPattern<tosa::GatherOp> {
2081 public:
2082 using OpConversionPattern<tosa::GatherOp>::OpConversionPattern;
2083 LogicalResult
2084 matchAndRewrite(tosa::GatherOp op, OpAdaptor adaptor,
2085 ConversionPatternRewriter &rewriter) const final {
2086 auto input = adaptor.getOperands()[0];
2087 auto indices = adaptor.getOperands()[1];
2089 auto valuesTy =
2090 dyn_cast_or_null<RankedTensorType>(op.getValues().getType());
2091 auto resultTy = cast<ShapedType>(op.getType());
2093 if (!valuesTy)
2094 return rewriter.notifyMatchFailure(op, "unranked tensors not supported");
2096 auto dynamicDims = inferDynamicDimsForGather(
2097 rewriter, op.getLoc(), adaptor.getValues(), adaptor.getIndices());
2099 auto resultElementTy = resultTy.getElementType();
2101 auto loc = op.getLoc();
2102 auto emptyTensor =
2103 rewriter
2104 .create<tensor::EmptyOp>(loc, resultTy.getShape(), resultElementTy,
2105 dynamicDims)
2106 .getResult();
2108 SmallVector<AffineMap, 2> affineMaps = {
2109 AffineMap::get(
2110 /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0,
2111 {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)},
2112 rewriter.getContext()),
2113 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2115 auto genericOp = rewriter.create<linalg::GenericOp>(
2116 loc, ArrayRef<Type>({resultTy}), ValueRange{indices},
2117 ValueRange{emptyTensor}, affineMaps,
2118 getNParallelLoopsAttrs(resultTy.getRank()),
2119 [&](OpBuilder &b, Location loc, ValueRange args) {
2120 auto indexValue = args[0];
2121 auto index0 = rewriter.create<linalg::IndexOp>(loc, 0);
2122 Value index1 = rewriter.create<arith::IndexCastOp>(
2123 loc, rewriter.getIndexType(), indexValue);
2124 auto index2 = rewriter.create<linalg::IndexOp>(loc, 2);
2125 Value extract = rewriter.create<tensor::ExtractOp>(
2126 loc, input, ValueRange{index0, index1, index2});
2127 rewriter.create<linalg::YieldOp>(loc, extract);
2129 rewriter.replaceOp(op, genericOp.getResult(0));
2130 return success();
2133 static llvm::SmallVector<Value> inferDynamicDimsForGather(OpBuilder &builder,
2134 Location loc,
2135 Value values,
2136 Value indices) {
2137 llvm::SmallVector<Value> results;
2139 auto addDynamicDimension = [&](Value source, int64_t dim) {
2140 auto sz = tensor::getMixedSize(builder, loc, source, dim);
2141 if (auto dimValue = llvm::dyn_cast_if_present<Value>(sz))
2142 results.push_back(dimValue);
2145 addDynamicDimension(values, 0);
2146 addDynamicDimension(indices, 1);
2147 addDynamicDimension(values, 2);
2148 return results;
2152 // Lowerings the TableOp to a series of gathers and numerica operations. This
2153 // includes interpolation between the high/low values. For the I8 varient, this
2154 // simplifies to a single gather operation.
2155 class TableConverter : public OpRewritePattern<tosa::TableOp> {
2156 public:
2157 using OpRewritePattern<tosa::TableOp>::OpRewritePattern;
2159 LogicalResult matchAndRewrite(tosa::TableOp op,
2160 PatternRewriter &rewriter) const final {
2161 auto loc = op.getLoc();
2162 Value input = op.getInput1();
2163 Value table = op.getTable();
2164 auto inputTy = cast<ShapedType>(input.getType());
2165 auto tableTy = cast<ShapedType>(table.getType());
2166 auto resultTy = cast<ShapedType>(op.getType());
2168 auto inputElementTy = inputTy.getElementType();
2169 auto tableElementTy = tableTy.getElementType();
2170 auto resultElementTy = resultTy.getElementType();
2172 SmallVector<Value> dynDims;
2173 for (int i = 0; i < resultTy.getRank(); ++i) {
2174 if (inputTy.isDynamicDim(i)) {
2175 dynDims.push_back(
2176 rewriter.create<tensor::DimOp>(loc, op.getOperand(0), i));
2180 auto emptyTensor = rewriter
2181 .create<tensor::EmptyOp>(loc, resultTy.getShape(),
2182 resultElementTy, dynDims)
2183 .getResult();
2185 SmallVector<AffineMap, 2> affineMaps = {
2186 rewriter.getMultiDimIdentityMap(resultTy.getRank()),
2187 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2189 auto genericOp = rewriter.create<linalg::GenericOp>(
2190 loc, resultTy, ValueRange({input}), ValueRange{emptyTensor}, affineMaps,
2191 getNParallelLoopsAttrs(resultTy.getRank()));
2192 rewriter.replaceOp(op, genericOp.getResult(0));
2195 OpBuilder::InsertionGuard regionGuard(rewriter);
2196 Block *block = rewriter.createBlock(
2197 &genericOp.getRegion(), genericOp.getRegion().end(),
2198 TypeRange({inputElementTy, resultElementTy}), {loc, loc});
2200 auto inputValue = block->getArgument(0);
2201 rewriter.setInsertionPointToStart(block);
2202 if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) &&
2203 resultElementTy.isInteger(8)) {
2204 Value index = rewriter.create<arith::IndexCastOp>(
2205 loc, rewriter.getIndexType(), inputValue);
2206 Value offset = rewriter.create<arith::ConstantIndexOp>(loc, 128);
2207 index = rewriter.create<arith::AddIOp>(loc, rewriter.getIndexType(),
2208 index, offset);
2209 Value extract =
2210 rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index});
2211 rewriter.create<linalg::YieldOp>(loc, extract);
2212 return success();
2215 if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) &&
2216 resultElementTy.isInteger(32)) {
2217 Value extend = rewriter.create<arith::ExtSIOp>(
2218 loc, rewriter.getI32Type(), inputValue);
2220 auto offset = rewriter.create<arith::ConstantOp>(
2221 loc, rewriter.getI32IntegerAttr(32768));
2222 auto seven = rewriter.create<arith::ConstantOp>(
2223 loc, rewriter.getI32IntegerAttr(7));
2224 auto one = rewriter.create<arith::ConstantOp>(
2225 loc, rewriter.getI32IntegerAttr(1));
2226 auto b1111111 = rewriter.create<arith::ConstantOp>(
2227 loc, rewriter.getI32IntegerAttr(127));
2229 // Compute the index and fractional part from the input value:
2230 // value = value + 32768
2231 // index = value >> 7;
2232 // fraction = 0x01111111 & value
2233 auto extendAdd = rewriter.create<arith::AddIOp>(loc, extend, offset);
2234 Value index = rewriter.create<arith::ShRUIOp>(loc, extendAdd, seven);
2235 Value fraction =
2236 rewriter.create<arith::AndIOp>(loc, extendAdd, b1111111);
2238 // Extract the base and next values from the table.
2239 // base = (int32_t) table[index];
2240 // next = (int32_t) table[index + 1];
2241 Value indexPlusOne = rewriter.create<arith::AddIOp>(loc, index, one);
2243 index = rewriter.create<arith::IndexCastOp>(
2244 loc, rewriter.getIndexType(), index);
2245 indexPlusOne = rewriter.create<arith::IndexCastOp>(
2246 loc, rewriter.getIndexType(), indexPlusOne);
2248 Value base =
2249 rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index});
2250 Value next = rewriter.create<tensor::ExtractOp>(
2251 loc, table, ValueRange{indexPlusOne});
2253 base =
2254 rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), base);
2255 next =
2256 rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), next);
2258 // Use the fractional part to interpolate between the input values:
2259 // result = (base << 7) + (next - base) * fraction
2260 Value baseScaled = rewriter.create<arith::ShLIOp>(loc, base, seven);
2261 Value diff = rewriter.create<arith::SubIOp>(loc, next, base);
2262 Value diffScaled = rewriter.create<arith::MulIOp>(loc, diff, fraction);
2263 Value result =
2264 rewriter.create<arith::AddIOp>(loc, baseScaled, diffScaled);
2266 rewriter.create<linalg::YieldOp>(loc, result);
2268 return success();
2272 return rewriter.notifyMatchFailure(
2273 op, "unable to create body for tosa.table op");
2277 struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
2278 using OpRewritePattern<RFFT2dOp>::OpRewritePattern;
2280 static bool isRankedTensor(Type type) { return isa<RankedTensorType>(type); }
2282 static OpFoldResult halfPlusOne(OpBuilder &builder, Location loc,
2283 OpFoldResult ofr) {
2284 auto one = builder.create<arith::ConstantIndexOp>(loc, 1);
2285 auto two = builder.create<arith::ConstantIndexOp>(loc, 2);
2287 auto value = getValueOrCreateConstantIndexOp(builder, loc, ofr);
2288 auto divBy2 = builder.createOrFold<arith::DivUIOp>(loc, value, two);
2289 auto plusOne = builder.createOrFold<arith::AddIOp>(loc, divBy2, one);
2290 return getAsOpFoldResult(plusOne);
2293 static RankedTensorType
2294 computeOutputShape(OpBuilder &builder, Location loc, Value input,
2295 llvm::SmallVectorImpl<Value> &dynamicSizes) {
2296 // Get [N, H, W]
2297 auto dims = tensor::getMixedSizes(builder, loc, input);
2299 // Set W = (W / 2) + 1 to account for the half-sized W dimension of the
2300 // output tensors.
2301 dims[2] = halfPlusOne(builder, loc, dims[2]);
2303 llvm::SmallVector<int64_t, 3> staticSizes;
2304 dispatchIndexOpFoldResults(dims, dynamicSizes, staticSizes);
2306 auto elementType = cast<RankedTensorType>(input.getType()).getElementType();
2307 return RankedTensorType::get(staticSizes, elementType);
2310 static Value createZeroTensor(PatternRewriter &rewriter, Location loc,
2311 RankedTensorType type,
2312 llvm::ArrayRef<Value> dynamicSizes) {
2313 auto emptyTensor =
2314 rewriter.create<tensor::EmptyOp>(loc, type, dynamicSizes);
2315 auto fillValueAttr = rewriter.getZeroAttr(type.getElementType());
2316 auto fillValue = rewriter.create<arith::ConstantOp>(loc, fillValueAttr);
2317 auto filledTensor = rewriter
2318 .create<linalg::FillOp>(loc, ValueRange{fillValue},
2319 ValueRange{emptyTensor})
2320 .result();
2321 return filledTensor;
2324 static Value castIndexToFloat(OpBuilder &builder, Location loc,
2325 FloatType type, Value value) {
2326 auto integerVal = builder.create<arith::IndexCastUIOp>(
2327 loc,
2328 type.getIntOrFloatBitWidth() > 32 ? builder.getI64Type()
2329 : builder.getI32Type(),
2330 value);
2332 return builder.create<arith::UIToFPOp>(loc, type, integerVal);
2335 static Value createLinalgIndex(OpBuilder &builder, Location loc,
2336 FloatType type, int64_t index) {
2337 auto indexVal = builder.create<linalg::IndexOp>(loc, index);
2338 return castIndexToFloat(builder, loc, type, indexVal);
2341 template <typename... Args>
2342 static llvm::SmallVector<AffineExpr, 4> affineDimsExpr(OpBuilder &builder,
2343 Args... args) {
2344 return {builder.getAffineDimExpr(args)...};
2347 LogicalResult matchAndRewrite(RFFT2dOp rfft2d,
2348 PatternRewriter &rewriter) const override {
2349 if (!llvm::all_of(rfft2d->getOperandTypes(), isRankedTensor) ||
2350 !llvm::all_of(rfft2d->getResultTypes(), isRankedTensor)) {
2351 return rewriter.notifyMatchFailure(rfft2d,
2352 "only supports ranked tensors");
2355 auto loc = rfft2d.getLoc();
2356 auto input = rfft2d.getInput();
2357 auto elementType =
2358 dyn_cast<FloatType>(cast<ShapedType>(input.getType()).getElementType());
2359 if (!elementType)
2360 return rewriter.notifyMatchFailure(rfft2d,
2361 "only supports float element types");
2363 // Compute the output type and set of dynamic sizes
2364 llvm::SmallVector<Value> dynamicSizes;
2365 auto outputType = computeOutputShape(rewriter, loc, input, dynamicSizes);
2367 // Iterator types for the linalg.generic implementation
2368 llvm::SmallVector<utils::IteratorType, 5> iteratorTypes = {
2369 utils::IteratorType::parallel, utils::IteratorType::parallel,
2370 utils::IteratorType::parallel, utils::IteratorType::reduction,
2371 utils::IteratorType::reduction};
2373 // Inputs/outputs to the linalg.generic implementation
2374 llvm::SmallVector<Value> genericOpInputs = {input};
2375 llvm::SmallVector<Value> genericOpOutputs = {
2376 createZeroTensor(rewriter, loc, outputType, dynamicSizes),
2377 createZeroTensor(rewriter, loc, outputType, dynamicSizes)};
2379 // Indexing maps for input and output tensors
2380 auto indexingMaps = AffineMap::inferFromExprList(
2381 llvm::ArrayRef{affineDimsExpr(rewriter, 0, 3, 4),
2382 affineDimsExpr(rewriter, 0, 1, 2),
2383 affineDimsExpr(rewriter, 0, 1, 2)},
2384 rewriter.getContext());
2386 // Width and height dimensions of the original input.
2387 auto dimH = rewriter.createOrFold<tensor::DimOp>(loc, input, 1);
2388 auto dimW = rewriter.createOrFold<tensor::DimOp>(loc, input, 2);
2390 // Constants and dimension sizes
2391 auto twoPiAttr = rewriter.getFloatAttr(elementType, 6.283185307179586);
2392 auto twoPi = rewriter.create<arith::ConstantOp>(loc, twoPiAttr);
2393 auto constH = castIndexToFloat(rewriter, loc, elementType, dimH);
2394 auto constW = castIndexToFloat(rewriter, loc, elementType, dimW);
2396 auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange args) {
2397 Value valReal = args[0];
2398 Value sumReal = args[1];
2399 Value sumImag = args[2];
2401 // Indices for angle computation
2402 Value oy = builder.create<linalg::IndexOp>(loc, 1);
2403 Value ox = builder.create<linalg::IndexOp>(loc, 2);
2404 Value iy = builder.create<linalg::IndexOp>(loc, 3);
2405 Value ix = builder.create<linalg::IndexOp>(loc, 4);
2407 // Calculating angle without integer parts of components as sin/cos are
2408 // periodic: angle = 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix * ox) % W )
2409 // / W);
2410 auto iyXoy = builder.create<index::MulOp>(loc, iy, oy);
2411 auto ixXox = builder.create<index::MulOp>(loc, ix, ox);
2413 auto iyRem = builder.create<index::RemUOp>(loc, iyXoy, dimH);
2414 auto ixRem = builder.create<index::RemUOp>(loc, ixXox, dimW);
2416 auto iyRemFloat = castIndexToFloat(builder, loc, elementType, iyRem);
2417 auto ixRemFloat = castIndexToFloat(builder, loc, elementType, ixRem);
2419 auto yComponent = builder.create<arith::DivFOp>(loc, iyRemFloat, constH);
2420 auto xComponent = builder.create<arith::DivFOp>(loc, ixRemFloat, constW);
2421 auto sumXY = builder.create<arith::AddFOp>(loc, yComponent, xComponent);
2422 auto angle = builder.create<arith::MulFOp>(loc, twoPi, sumXY);
2424 // realComponent = valReal * cos(angle)
2425 // imagComponent = valReal * sin(angle)
2426 auto cosAngle = builder.create<math::CosOp>(loc, angle);
2427 auto sinAngle = builder.create<math::SinOp>(loc, angle);
2428 auto realComponent =
2429 builder.create<arith::MulFOp>(loc, valReal, cosAngle);
2430 auto imagComponent =
2431 builder.create<arith::MulFOp>(loc, valReal, sinAngle);
2433 // outReal = sumReal + realComponent
2434 // outImag = sumImag - imagComponent
2435 auto outReal = builder.create<arith::AddFOp>(loc, sumReal, realComponent);
2436 auto outImag = builder.create<arith::SubFOp>(loc, sumImag, imagComponent);
2438 builder.create<linalg::YieldOp>(loc, ValueRange{outReal, outImag});
2441 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
2442 rfft2d, rfft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
2443 indexingMaps, iteratorTypes, buildBody);
2445 return success();
2449 struct FFT2dConverter final : OpRewritePattern<FFT2dOp> {
2450 using OpRewritePattern::OpRewritePattern;
2452 LogicalResult matchAndRewrite(FFT2dOp fft2d,
2453 PatternRewriter &rewriter) const override {
2454 if (!llvm::all_of(fft2d->getOperandTypes(),
2455 RFFT2dConverter::isRankedTensor) ||
2456 !llvm::all_of(fft2d->getResultTypes(),
2457 RFFT2dConverter::isRankedTensor)) {
2458 return rewriter.notifyMatchFailure(fft2d, "only supports ranked tensors");
2461 Location loc = fft2d.getLoc();
2462 Value input_real = fft2d.getInputReal();
2463 Value input_imag = fft2d.getInputImag();
2464 BoolAttr inverse = fft2d.getInverseAttr();
2466 auto real_el_ty = cast<FloatType>(
2467 cast<ShapedType>(input_real.getType()).getElementType());
2468 [[maybe_unused]] auto imag_el_ty = cast<FloatType>(
2469 cast<ShapedType>(input_imag.getType()).getElementType());
2471 assert(real_el_ty == imag_el_ty);
2473 // Compute the output type and set of dynamic sizes
2474 SmallVector<Value> dynamicSizes;
2476 // Get [N, H, W]
2477 auto dims = tensor::getMixedSizes(rewriter, loc, input_real);
2479 SmallVector<int64_t, 3> staticSizes;
2480 dispatchIndexOpFoldResults(dims, dynamicSizes, staticSizes);
2482 auto outputType = RankedTensorType::get(staticSizes, real_el_ty);
2484 // Iterator types for the linalg.generic implementation
2485 SmallVector<utils::IteratorType, 5> iteratorTypes = {
2486 utils::IteratorType::parallel, utils::IteratorType::parallel,
2487 utils::IteratorType::parallel, utils::IteratorType::reduction,
2488 utils::IteratorType::reduction};
2490 // Inputs/outputs to the linalg.generic implementation
2491 SmallVector<Value> genericOpInputs = {input_real, input_imag};
2492 SmallVector<Value> genericOpOutputs = {
2493 RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2494 dynamicSizes),
2495 RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2496 dynamicSizes)};
2498 // Indexing maps for input and output tensors
2499 auto indexingMaps = AffineMap::inferFromExprList(
2500 ArrayRef{RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2501 RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2502 RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2),
2503 RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2)},
2504 rewriter.getContext());
2506 // Width and height dimensions of the original input.
2507 auto dimH = rewriter.createOrFold<tensor::DimOp>(loc, input_real, 1);
2508 auto dimW = rewriter.createOrFold<tensor::DimOp>(loc, input_real, 2);
2510 // Constants and dimension sizes
2511 auto twoPiAttr = rewriter.getFloatAttr(real_el_ty, 6.283185307179586);
2512 auto twoPi = rewriter.create<arith::ConstantOp>(loc, twoPiAttr);
2513 Value constH =
2514 RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimH);
2515 Value constW =
2516 RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimW);
2518 auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange args) {
2519 Value valReal = args[0];
2520 Value valImag = args[1];
2521 Value sumReal = args[2];
2522 Value sumImag = args[3];
2524 // Indices for angle computation
2525 Value oy = builder.create<linalg::IndexOp>(loc, 1);
2526 Value ox = builder.create<linalg::IndexOp>(loc, 2);
2527 Value iy = builder.create<linalg::IndexOp>(loc, 3);
2528 Value ix = builder.create<linalg::IndexOp>(loc, 4);
2530 // float_t angle = sign_val * 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix *
2531 // ox) % W ) / W);
2532 auto iyXoy = builder.create<index::MulOp>(loc, iy, oy);
2533 auto ixXox = builder.create<index::MulOp>(loc, ix, ox);
2535 auto iyRem = builder.create<index::RemUOp>(loc, iyXoy, dimH);
2536 auto ixRem = builder.create<index::RemUOp>(loc, ixXox, dimW);
2538 auto iyRemFloat =
2539 RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, iyRem);
2540 auto ixRemFloat =
2541 RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, ixRem);
2543 auto yComponent = builder.create<arith::DivFOp>(loc, iyRemFloat, constH);
2544 auto xComponent = builder.create<arith::DivFOp>(loc, ixRemFloat, constW);
2546 auto sumXY = builder.create<arith::AddFOp>(loc, yComponent, xComponent);
2547 auto angle = builder.create<arith::MulFOp>(loc, twoPi, sumXY);
2549 if (inverse.getValue()) {
2550 angle = builder.create<arith::MulFOp>(
2551 loc, angle,
2552 rewriter.create<arith::ConstantOp>(
2553 loc, rewriter.getFloatAttr(real_el_ty, -1.0)));
2556 // realComponent = val_real * cos(a) + val_imag * sin(a);
2557 // imagComponent = -val_real * sin(a) + val_imag * cos(a);
2558 auto cosAngle = builder.create<math::CosOp>(loc, angle);
2559 auto sinAngle = builder.create<math::SinOp>(loc, angle);
2561 auto rcos = builder.create<arith::MulFOp>(loc, valReal, cosAngle);
2562 auto rsin = builder.create<arith::MulFOp>(loc, valImag, sinAngle);
2563 auto realComponent = builder.create<arith::AddFOp>(loc, rcos, rsin);
2565 auto icos = builder.create<arith::MulFOp>(loc, valImag, cosAngle);
2566 auto isin = builder.create<arith::MulFOp>(loc, valReal, sinAngle);
2568 auto imagComponent = builder.create<arith::SubFOp>(loc, icos, isin);
2570 // outReal = sumReal + realComponent
2571 // outImag = sumImag - imagComponent
2572 auto outReal = builder.create<arith::AddFOp>(loc, sumReal, realComponent);
2573 auto outImag = builder.create<arith::AddFOp>(loc, sumImag, imagComponent);
2575 builder.create<linalg::YieldOp>(loc, ValueRange{outReal, outImag});
2578 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
2579 fft2d, fft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
2580 indexingMaps, iteratorTypes, buildBody);
2582 return success();
2586 } // namespace
2588 void mlir::tosa::populateTosaToLinalgConversionPatterns(
2589 const TypeConverter &converter, RewritePatternSet *patterns) {
2591 // We have multiple resize coverters to handle degenerate cases.
2592 patterns->add<GenericResizeConverter>(patterns->getContext(),
2593 /*benefit=*/100);
2594 patterns->add<ResizeUnaryConverter>(patterns->getContext(),
2595 /*benefit=*/200);
2596 patterns->add<MaterializeResizeBroadcast>(patterns->getContext(),
2597 /*benefit=*/300);
2599 patterns->add<
2600 // clang-format off
2601 PointwiseConverter<tosa::AddOp>,
2602 PointwiseConverter<tosa::SubOp>,
2603 PointwiseConverter<tosa::MulOp>,
2604 PointwiseConverter<tosa::IntDivOp>,
2605 PointwiseConverter<tosa::NegateOp>,
2606 PointwiseConverter<tosa::PowOp>,
2607 PointwiseConverter<tosa::ReciprocalOp>,
2608 PointwiseConverter<tosa::RsqrtOp>,
2609 PointwiseConverter<tosa::LogOp>,
2610 PointwiseConverter<tosa::ExpOp>,
2611 PointwiseConverter<tosa::AbsOp>,
2612 PointwiseConverter<tosa::SinOp>,
2613 PointwiseConverter<tosa::CosOp>,
2614 PointwiseConverter<tosa::TanhOp>,
2615 PointwiseConverter<tosa::ErfOp>,
2616 PointwiseConverter<tosa::BitwiseAndOp>,
2617 PointwiseConverter<tosa::BitwiseOrOp>,
2618 PointwiseConverter<tosa::BitwiseNotOp>,
2619 PointwiseConverter<tosa::BitwiseXorOp>,
2620 PointwiseConverter<tosa::LogicalAndOp>,
2621 PointwiseConverter<tosa::LogicalNotOp>,
2622 PointwiseConverter<tosa::LogicalOrOp>,
2623 PointwiseConverter<tosa::LogicalXorOp>,
2624 PointwiseConverter<tosa::CastOp>,
2625 PointwiseConverter<tosa::LogicalLeftShiftOp>,
2626 PointwiseConverter<tosa::LogicalRightShiftOp>,
2627 PointwiseConverter<tosa::ArithmeticRightShiftOp>,
2628 PointwiseConverter<tosa::ClzOp>,
2629 PointwiseConverter<tosa::SelectOp>,
2630 PointwiseConverter<tosa::GreaterOp>,
2631 PointwiseConverter<tosa::GreaterEqualOp>,
2632 PointwiseConverter<tosa::EqualOp>,
2633 PointwiseConverter<tosa::MaximumOp>,
2634 PointwiseConverter<tosa::MinimumOp>,
2635 PointwiseConverter<tosa::CeilOp>,
2636 PointwiseConverter<tosa::FloorOp>,
2637 PointwiseConverter<tosa::ClampOp>,
2638 PointwiseConverter<tosa::SigmoidOp>
2639 >(converter, patterns->getContext());
2641 patterns->add<
2642 IdentityNConverter<tosa::IdentityOp>,
2643 ReduceConverter<tosa::ReduceAllOp>,
2644 ReduceConverter<tosa::ReduceAnyOp>,
2645 ReduceConverter<tosa::ReduceMinOp>,
2646 ReduceConverter<tosa::ReduceMaxOp>,
2647 ReduceConverter<tosa::ReduceSumOp>,
2648 ReduceConverter<tosa::ReduceProdOp>,
2649 ArgMaxConverter,
2650 GatherConverter,
2651 RescaleConverter,
2652 ReverseConverter,
2653 RFFT2dConverter,
2654 FFT2dConverter,
2655 TableConverter,
2656 TileConverter>(patterns->getContext());
2657 // clang-format on