1 //===- AffineExpr.cpp - MLIR Affine Expr Classes --------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
14 #include "AffineExprDetail.h"
15 #include "mlir/IR/AffineExpr.h"
16 #include "mlir/IR/AffineExprVisitor.h"
17 #include "mlir/IR/AffineMap.h"
18 #include "mlir/IR/IntegerSet.h"
19 #include "mlir/Support/TypeID.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/Support/MathExtras.h"
26 using namespace mlir::detail
;
28 using llvm::divideCeilSigned
;
29 using llvm::divideFloorSigned
;
30 using llvm::divideSignedWouldOverflow
;
33 MLIRContext
*AffineExpr::getContext() const { return expr
->context
; }
35 AffineExprKind
AffineExpr::getKind() const { return expr
->kind
; }
37 /// Walk all of the AffineExprs in `e` in postorder. This is a private factory
38 /// method to help handle lambda walk functions. Users should use the regular
39 /// (non-static) `walk` method.
40 template <typename WalkRetTy
>
41 WalkRetTy
mlir::AffineExpr::walk(AffineExpr e
,
42 function_ref
<WalkRetTy(AffineExpr
)> callback
) {
43 struct AffineExprWalker
44 : public AffineExprVisitor
<AffineExprWalker
, WalkRetTy
> {
45 function_ref
<WalkRetTy(AffineExpr
)> callback
;
47 AffineExprWalker(function_ref
<WalkRetTy(AffineExpr
)> callback
)
48 : callback(callback
) {}
50 WalkRetTy
visitAffineBinaryOpExpr(AffineBinaryOpExpr expr
) {
51 return callback(expr
);
53 WalkRetTy
visitConstantExpr(AffineConstantExpr expr
) {
54 return callback(expr
);
56 WalkRetTy
visitDimExpr(AffineDimExpr expr
) { return callback(expr
); }
57 WalkRetTy
visitSymbolExpr(AffineSymbolExpr expr
) { return callback(expr
); }
60 return AffineExprWalker(callback
).walkPostOrder(e
);
62 // Explicitly instantiate for the two supported return types.
63 template void mlir::AffineExpr::walk(AffineExpr e
,
64 function_ref
<void(AffineExpr
)> callback
);
66 mlir::AffineExpr::walk(AffineExpr e
,
67 function_ref
<WalkResult(AffineExpr
)> callback
);
69 // Dispatch affine expression construction based on kind.
70 AffineExpr
mlir::getAffineBinaryOpExpr(AffineExprKind kind
, AffineExpr lhs
,
72 if (kind
== AffineExprKind::Add
)
74 if (kind
== AffineExprKind::Mul
)
76 if (kind
== AffineExprKind::FloorDiv
)
77 return lhs
.floorDiv(rhs
);
78 if (kind
== AffineExprKind::CeilDiv
)
79 return lhs
.ceilDiv(rhs
);
80 if (kind
== AffineExprKind::Mod
)
83 llvm_unreachable("unknown binary operation on affine expressions");
86 /// This method substitutes any uses of dimensions and symbols (e.g.
87 /// dim#0 with dimReplacements[0]) and returns the modified expression tree.
89 AffineExpr::replaceDimsAndSymbols(ArrayRef
<AffineExpr
> dimReplacements
,
90 ArrayRef
<AffineExpr
> symReplacements
) const {
92 case AffineExprKind::Constant
:
94 case AffineExprKind::DimId
: {
95 unsigned dimId
= llvm::cast
<AffineDimExpr
>(*this).getPosition();
96 if (dimId
>= dimReplacements
.size())
98 return dimReplacements
[dimId
];
100 case AffineExprKind::SymbolId
: {
101 unsigned symId
= llvm::cast
<AffineSymbolExpr
>(*this).getPosition();
102 if (symId
>= symReplacements
.size())
104 return symReplacements
[symId
];
106 case AffineExprKind::Add
:
107 case AffineExprKind::Mul
:
108 case AffineExprKind::FloorDiv
:
109 case AffineExprKind::CeilDiv
:
110 case AffineExprKind::Mod
:
111 auto binOp
= llvm::cast
<AffineBinaryOpExpr
>(*this);
112 auto lhs
= binOp
.getLHS(), rhs
= binOp
.getRHS();
113 auto newLHS
= lhs
.replaceDimsAndSymbols(dimReplacements
, symReplacements
);
114 auto newRHS
= rhs
.replaceDimsAndSymbols(dimReplacements
, symReplacements
);
115 if (newLHS
== lhs
&& newRHS
== rhs
)
117 return getAffineBinaryOpExpr(getKind(), newLHS
, newRHS
);
119 llvm_unreachable("Unknown AffineExpr");
122 AffineExpr
AffineExpr::replaceDims(ArrayRef
<AffineExpr
> dimReplacements
) const {
123 return replaceDimsAndSymbols(dimReplacements
, {});
127 AffineExpr::replaceSymbols(ArrayRef
<AffineExpr
> symReplacements
) const {
128 return replaceDimsAndSymbols({}, symReplacements
);
131 /// Replace dims[offset ... numDims)
132 /// by dims[offset + shift ... shift + numDims).
133 AffineExpr
AffineExpr::shiftDims(unsigned numDims
, unsigned shift
,
134 unsigned offset
) const {
135 SmallVector
<AffineExpr
, 4> dims
;
136 for (unsigned idx
= 0; idx
< offset
; ++idx
)
137 dims
.push_back(getAffineDimExpr(idx
, getContext()));
138 for (unsigned idx
= offset
; idx
< numDims
; ++idx
)
139 dims
.push_back(getAffineDimExpr(idx
+ shift
, getContext()));
140 return replaceDimsAndSymbols(dims
, {});
143 /// Replace symbols[offset ... numSymbols)
144 /// by symbols[offset + shift ... shift + numSymbols).
145 AffineExpr
AffineExpr::shiftSymbols(unsigned numSymbols
, unsigned shift
,
146 unsigned offset
) const {
147 SmallVector
<AffineExpr
, 4> symbols
;
148 for (unsigned idx
= 0; idx
< offset
; ++idx
)
149 symbols
.push_back(getAffineSymbolExpr(idx
, getContext()));
150 for (unsigned idx
= offset
; idx
< numSymbols
; ++idx
)
151 symbols
.push_back(getAffineSymbolExpr(idx
+ shift
, getContext()));
152 return replaceDimsAndSymbols({}, symbols
);
155 /// Sparse replace method. Return the modified expression tree.
157 AffineExpr::replace(const DenseMap
<AffineExpr
, AffineExpr
> &map
) const {
158 auto it
= map
.find(*this);
164 case AffineExprKind::Add
:
165 case AffineExprKind::Mul
:
166 case AffineExprKind::FloorDiv
:
167 case AffineExprKind::CeilDiv
:
168 case AffineExprKind::Mod
:
169 auto binOp
= llvm::cast
<AffineBinaryOpExpr
>(*this);
170 auto lhs
= binOp
.getLHS(), rhs
= binOp
.getRHS();
171 auto newLHS
= lhs
.replace(map
);
172 auto newRHS
= rhs
.replace(map
);
173 if (newLHS
== lhs
&& newRHS
== rhs
)
175 return getAffineBinaryOpExpr(getKind(), newLHS
, newRHS
);
177 llvm_unreachable("Unknown AffineExpr");
180 /// Sparse replace method. Return the modified expression tree.
181 AffineExpr
AffineExpr::replace(AffineExpr expr
, AffineExpr replacement
) const {
182 DenseMap
<AffineExpr
, AffineExpr
> map
;
183 map
.insert(std::make_pair(expr
, replacement
));
186 /// Returns true if this expression is made out of only symbols and
187 /// constants (no dimensional identifiers).
188 bool AffineExpr::isSymbolicOrConstant() const {
190 case AffineExprKind::Constant
:
192 case AffineExprKind::DimId
:
194 case AffineExprKind::SymbolId
:
197 case AffineExprKind::Add
:
198 case AffineExprKind::Mul
:
199 case AffineExprKind::FloorDiv
:
200 case AffineExprKind::CeilDiv
:
201 case AffineExprKind::Mod
: {
202 auto expr
= llvm::cast
<AffineBinaryOpExpr
>(*this);
203 return expr
.getLHS().isSymbolicOrConstant() &&
204 expr
.getRHS().isSymbolicOrConstant();
207 llvm_unreachable("Unknown AffineExpr");
210 /// Returns true if this is a pure affine expression, i.e., multiplication,
211 /// floordiv, ceildiv, and mod is only allowed w.r.t constants.
212 bool AffineExpr::isPureAffine() const {
214 case AffineExprKind::SymbolId
:
215 case AffineExprKind::DimId
:
216 case AffineExprKind::Constant
:
218 case AffineExprKind::Add
: {
219 auto op
= llvm::cast
<AffineBinaryOpExpr
>(*this);
220 return op
.getLHS().isPureAffine() && op
.getRHS().isPureAffine();
223 case AffineExprKind::Mul
: {
224 // TODO: Canonicalize the constants in binary operators to the RHS when
225 // possible, allowing this to merge into the next case.
226 auto op
= llvm::cast
<AffineBinaryOpExpr
>(*this);
227 return op
.getLHS().isPureAffine() && op
.getRHS().isPureAffine() &&
228 (llvm::isa
<AffineConstantExpr
>(op
.getLHS()) ||
229 llvm::isa
<AffineConstantExpr
>(op
.getRHS()));
231 case AffineExprKind::FloorDiv
:
232 case AffineExprKind::CeilDiv
:
233 case AffineExprKind::Mod
: {
234 auto op
= llvm::cast
<AffineBinaryOpExpr
>(*this);
235 return op
.getLHS().isPureAffine() &&
236 llvm::isa
<AffineConstantExpr
>(op
.getRHS());
239 llvm_unreachable("Unknown AffineExpr");
242 // Returns the greatest known integral divisor of this affine expression.
243 int64_t AffineExpr::getLargestKnownDivisor() const {
244 AffineBinaryOpExpr
binExpr(nullptr);
246 case AffineExprKind::DimId
:
248 case AffineExprKind::SymbolId
:
250 case AffineExprKind::CeilDiv
:
252 case AffineExprKind::FloorDiv
: {
253 // If the RHS is a constant and divides the known divisor on the LHS, the
254 // quotient is a known divisor of the expression.
255 binExpr
= llvm::cast
<AffineBinaryOpExpr
>(*this);
256 auto rhs
= llvm::dyn_cast
<AffineConstantExpr
>(binExpr
.getRHS());
257 // Leave alone undefined expressions.
258 if (rhs
&& rhs
.getValue() != 0) {
259 int64_t lhsDiv
= binExpr
.getLHS().getLargestKnownDivisor();
260 if (lhsDiv
% rhs
.getValue() == 0)
261 return std::abs(lhsDiv
/ rhs
.getValue());
265 case AffineExprKind::Constant
:
266 return std::abs(llvm::cast
<AffineConstantExpr
>(*this).getValue());
267 case AffineExprKind::Mul
: {
268 binExpr
= llvm::cast
<AffineBinaryOpExpr
>(*this);
269 return binExpr
.getLHS().getLargestKnownDivisor() *
270 binExpr
.getRHS().getLargestKnownDivisor();
272 case AffineExprKind::Add
:
274 case AffineExprKind::Mod
: {
275 binExpr
= llvm::cast
<AffineBinaryOpExpr
>(*this);
276 return std::gcd((uint64_t)binExpr
.getLHS().getLargestKnownDivisor(),
277 (uint64_t)binExpr
.getRHS().getLargestKnownDivisor());
280 llvm_unreachable("Unknown AffineExpr");
283 bool AffineExpr::isMultipleOf(int64_t factor
) const {
284 AffineBinaryOpExpr
binExpr(nullptr);
287 case AffineExprKind::SymbolId
:
289 case AffineExprKind::DimId
:
290 return factor
* factor
== 1;
291 case AffineExprKind::Constant
:
292 return llvm::cast
<AffineConstantExpr
>(*this).getValue() % factor
== 0;
293 case AffineExprKind::Mul
: {
294 binExpr
= llvm::cast
<AffineBinaryOpExpr
>(*this);
295 // It's probably not worth optimizing this further (to not traverse the
296 // whole sub-tree under - it that would require a version of isMultipleOf
297 // that on a 'false' return also returns the largest known divisor).
298 return (l
= binExpr
.getLHS().getLargestKnownDivisor()) % factor
== 0 ||
299 (u
= binExpr
.getRHS().getLargestKnownDivisor()) % factor
== 0 ||
300 (l
* u
) % factor
== 0;
302 case AffineExprKind::Add
:
303 case AffineExprKind::FloorDiv
:
304 case AffineExprKind::CeilDiv
:
305 case AffineExprKind::Mod
: {
306 binExpr
= llvm::cast
<AffineBinaryOpExpr
>(*this);
307 return std::gcd((uint64_t)binExpr
.getLHS().getLargestKnownDivisor(),
308 (uint64_t)binExpr
.getRHS().getLargestKnownDivisor()) %
313 llvm_unreachable("Unknown AffineExpr");
316 bool AffineExpr::isFunctionOfDim(unsigned position
) const {
317 if (getKind() == AffineExprKind::DimId
) {
318 return *this == mlir::getAffineDimExpr(position
, getContext());
320 if (auto expr
= llvm::dyn_cast
<AffineBinaryOpExpr
>(*this)) {
321 return expr
.getLHS().isFunctionOfDim(position
) ||
322 expr
.getRHS().isFunctionOfDim(position
);
327 bool AffineExpr::isFunctionOfSymbol(unsigned position
) const {
328 if (getKind() == AffineExprKind::SymbolId
) {
329 return *this == mlir::getAffineSymbolExpr(position
, getContext());
331 if (auto expr
= llvm::dyn_cast
<AffineBinaryOpExpr
>(*this)) {
332 return expr
.getLHS().isFunctionOfSymbol(position
) ||
333 expr
.getRHS().isFunctionOfSymbol(position
);
338 AffineBinaryOpExpr::AffineBinaryOpExpr(AffineExpr::ImplType
*ptr
)
340 AffineExpr
AffineBinaryOpExpr::getLHS() const {
341 return static_cast<ImplType
*>(expr
)->lhs
;
343 AffineExpr
AffineBinaryOpExpr::getRHS() const {
344 return static_cast<ImplType
*>(expr
)->rhs
;
347 AffineDimExpr::AffineDimExpr(AffineExpr::ImplType
*ptr
) : AffineExpr(ptr
) {}
348 unsigned AffineDimExpr::getPosition() const {
349 return static_cast<ImplType
*>(expr
)->position
;
352 /// Returns true if the expression is divisible by the given symbol with
353 /// position `symbolPos`. The argument `opKind` specifies here what kind of
354 /// division or mod operation called this division. It helps in implementing the
355 /// commutative property of the floordiv and ceildiv operations. If the argument
356 ///`exprKind` is floordiv and `expr` is also a binary expression of a floordiv
357 /// operation, then the commutative property can be used otherwise, the floordiv
358 /// operation is not divisible. The same argument holds for ceildiv operation.
359 static bool canSimplifyDivisionBySymbol(AffineExpr expr
, unsigned symbolPos
,
360 AffineExprKind opKind
,
361 bool fromMul
= false) {
362 // The argument `opKind` can either be Modulo, Floordiv or Ceildiv only.
363 assert((opKind
== AffineExprKind::Mod
|| opKind
== AffineExprKind::FloorDiv
||
364 opKind
== AffineExprKind::CeilDiv
) &&
365 "unexpected opKind");
366 switch (expr
.getKind()) {
367 case AffineExprKind::Constant
:
368 return cast
<AffineConstantExpr
>(expr
).getValue() == 0;
369 case AffineExprKind::DimId
:
371 case AffineExprKind::SymbolId
:
372 return (cast
<AffineSymbolExpr
>(expr
).getPosition() == symbolPos
);
373 // Checks divisibility by the given symbol for both operands.
374 case AffineExprKind::Add
: {
375 AffineBinaryOpExpr binaryExpr
= cast
<AffineBinaryOpExpr
>(expr
);
376 return canSimplifyDivisionBySymbol(binaryExpr
.getLHS(), symbolPos
,
378 canSimplifyDivisionBySymbol(binaryExpr
.getRHS(), symbolPos
, opKind
);
380 // Checks divisibility by the given symbol for both operands. Consider the
381 // expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv s1`,
382 // this is a division by s1 and both the operands of modulo are divisible by
383 // s1 but it is not divisible by s1 always. The third argument is
384 // `AffineExprKind::Mod` for this reason.
385 case AffineExprKind::Mod
: {
386 AffineBinaryOpExpr binaryExpr
= cast
<AffineBinaryOpExpr
>(expr
);
387 return canSimplifyDivisionBySymbol(binaryExpr
.getLHS(), symbolPos
,
388 AffineExprKind::Mod
) &&
389 canSimplifyDivisionBySymbol(binaryExpr
.getRHS(), symbolPos
,
390 AffineExprKind::Mod
);
392 // Checks if any of the operand divisible by the given symbol.
393 case AffineExprKind::Mul
: {
394 AffineBinaryOpExpr binaryExpr
= cast
<AffineBinaryOpExpr
>(expr
);
395 return canSimplifyDivisionBySymbol(binaryExpr
.getLHS(), symbolPos
, opKind
,
397 canSimplifyDivisionBySymbol(binaryExpr
.getRHS(), symbolPos
, opKind
,
400 // Floordiv and ceildiv are divisible by the given symbol when the first
401 // operand is divisible, and the affine expression kind of the argument expr
402 // is same as the argument `opKind`. This can be inferred from commutative
403 // property of floordiv and ceildiv operations and are as follow:
404 // (exp1 floordiv exp2) floordiv exp3 = (exp1 floordiv exp3) floordiv exp2
405 // (exp1 ceildiv exp2) ceildiv exp3 = (exp1 ceildiv exp3) ceildiv expr2
406 // It will fail 1.if operations are not same. For example:
407 // (exps1 ceildiv exp2) floordiv exp3 can not be simplified. 2.if there is a
408 // multiplication operation in the expression. For example:
409 // (exps1 ceildiv exp2) mul exp3 ceildiv exp4 can not be simplified.
410 case AffineExprKind::FloorDiv
:
411 case AffineExprKind::CeilDiv
: {
412 AffineBinaryOpExpr binaryExpr
= cast
<AffineBinaryOpExpr
>(expr
);
413 if (opKind
!= expr
.getKind())
417 return canSimplifyDivisionBySymbol(binaryExpr
.getLHS(), symbolPos
,
421 llvm_unreachable("Unknown AffineExpr");
424 /// Divides the given expression by the given symbol at position `symbolPos`. It
425 /// considers the divisibility condition is checked before calling itself. A
426 /// null expression is returned whenever the divisibility condition fails.
427 static AffineExpr
symbolicDivide(AffineExpr expr
, unsigned symbolPos
,
428 AffineExprKind opKind
) {
429 // THe argument `opKind` can either be Modulo, Floordiv or Ceildiv only.
430 assert((opKind
== AffineExprKind::Mod
|| opKind
== AffineExprKind::FloorDiv
||
431 opKind
== AffineExprKind::CeilDiv
) &&
432 "unexpected opKind");
433 switch (expr
.getKind()) {
434 case AffineExprKind::Constant
:
435 if (cast
<AffineConstantExpr
>(expr
).getValue() != 0)
437 return getAffineConstantExpr(0, expr
.getContext());
438 case AffineExprKind::DimId
:
440 case AffineExprKind::SymbolId
:
441 return getAffineConstantExpr(1, expr
.getContext());
442 // Dividing both operands by the given symbol.
443 case AffineExprKind::Add
: {
444 AffineBinaryOpExpr binaryExpr
= cast
<AffineBinaryOpExpr
>(expr
);
445 return getAffineBinaryOpExpr(
446 expr
.getKind(), symbolicDivide(binaryExpr
.getLHS(), symbolPos
, opKind
),
447 symbolicDivide(binaryExpr
.getRHS(), symbolPos
, opKind
));
449 // Dividing both operands by the given symbol.
450 case AffineExprKind::Mod
: {
451 AffineBinaryOpExpr binaryExpr
= cast
<AffineBinaryOpExpr
>(expr
);
452 return getAffineBinaryOpExpr(
454 symbolicDivide(binaryExpr
.getLHS(), symbolPos
, expr
.getKind()),
455 symbolicDivide(binaryExpr
.getRHS(), symbolPos
, expr
.getKind()));
457 // Dividing any of the operand by the given symbol.
458 case AffineExprKind::Mul
: {
459 AffineBinaryOpExpr binaryExpr
= cast
<AffineBinaryOpExpr
>(expr
);
460 if (!canSimplifyDivisionBySymbol(binaryExpr
.getLHS(), symbolPos
, opKind
))
461 return binaryExpr
.getLHS() *
462 symbolicDivide(binaryExpr
.getRHS(), symbolPos
, opKind
);
463 return symbolicDivide(binaryExpr
.getLHS(), symbolPos
, opKind
) *
466 // Dividing first operand only by the given symbol.
467 case AffineExprKind::FloorDiv
:
468 case AffineExprKind::CeilDiv
: {
469 AffineBinaryOpExpr binaryExpr
= cast
<AffineBinaryOpExpr
>(expr
);
470 return getAffineBinaryOpExpr(
472 symbolicDivide(binaryExpr
.getLHS(), symbolPos
, expr
.getKind()),
473 binaryExpr
.getRHS());
476 llvm_unreachable("Unknown AffineExpr");
479 /// Populate `result` with all summand operands of given (potentially nested)
480 /// addition. If the given expression is not an addition, just populate the
481 /// expression itself.
482 /// Example: Add(Add(7, 8), Mul(9, 10)) will return [7, 8, Mul(9, 10)].
483 static void getSummandExprs(AffineExpr expr
, SmallVector
<AffineExpr
> &result
) {
484 auto addExpr
= dyn_cast
<AffineBinaryOpExpr
>(expr
);
485 if (!addExpr
|| addExpr
.getKind() != AffineExprKind::Add
) {
486 result
.push_back(expr
);
489 getSummandExprs(addExpr
.getLHS(), result
);
490 getSummandExprs(addExpr
.getRHS(), result
);
493 /// Return "true" if `candidate` is a negated expression, i.e., Mul(-1, expr).
494 /// If so, also return the non-negated expression via `expr`.
495 static bool isNegatedAffineExpr(AffineExpr candidate
, AffineExpr
&expr
) {
496 auto mulExpr
= dyn_cast
<AffineBinaryOpExpr
>(candidate
);
497 if (!mulExpr
|| mulExpr
.getKind() != AffineExprKind::Mul
)
499 if (auto lhs
= dyn_cast
<AffineConstantExpr
>(mulExpr
.getLHS())) {
500 if (lhs
.getValue() == -1) {
501 expr
= mulExpr
.getRHS();
505 if (auto rhs
= dyn_cast
<AffineConstantExpr
>(mulExpr
.getRHS())) {
506 if (rhs
.getValue() == -1) {
507 expr
= mulExpr
.getLHS();
514 /// Return "true" if `lhs` % `rhs` is guaranteed to evaluate to zero based on
515 /// the fact that `lhs` contains another modulo expression that ensures that
516 /// `lhs` is divisible by `rhs`. This is a common pattern in the resulting IR
517 /// after loop peeling.
519 /// Example: lhs = ub - ub % step
521 /// => (ub - ub % step) % step is guaranteed to evaluate to 0.
522 static bool isModOfModSubtraction(AffineExpr lhs
, AffineExpr rhs
,
523 unsigned numDims
, unsigned numSymbols
) {
524 // TODO: Try to unify this function with `getBoundForAffineExpr`.
525 // Collect all summands in lhs.
526 SmallVector
<AffineExpr
> summands
;
527 getSummandExprs(lhs
, summands
);
528 // Look for Mul(-1, Mod(x, rhs)) among the summands. If x matches the
529 // remaining summands, then lhs % rhs is guaranteed to evaluate to 0.
530 for (int64_t i
= 0, e
= summands
.size(); i
< e
; ++i
) {
531 AffineExpr current
= summands
[i
];
532 AffineExpr beforeNegation
;
533 if (!isNegatedAffineExpr(current
, beforeNegation
))
535 AffineBinaryOpExpr innerMod
= dyn_cast
<AffineBinaryOpExpr
>(beforeNegation
);
536 if (!innerMod
|| innerMod
.getKind() != AffineExprKind::Mod
)
538 if (innerMod
.getRHS() != rhs
)
540 // Sum all remaining summands and subtract x. If that expression can be
541 // simplified to zero, then the remaining summands and x are equal.
542 AffineExpr diff
= getAffineConstantExpr(0, lhs
.getContext());
543 for (int64_t j
= 0; j
< e
; ++j
)
545 diff
= diff
+ summands
[j
];
546 diff
= diff
- innerMod
.getLHS();
547 diff
= simplifyAffineExpr(diff
, numDims
, numSymbols
);
548 auto constExpr
= dyn_cast
<AffineConstantExpr
>(diff
);
549 if (constExpr
&& constExpr
.getValue() == 0)
555 /// Simplify a semi-affine expression by handling modulo, floordiv, or ceildiv
556 /// operations when the second operand simplifies to a symbol and the first
557 /// operand is divisible by that symbol. It can be applied to any semi-affine
558 /// expression. Returned expression can either be a semi-affine or pure affine
560 static AffineExpr
simplifySemiAffine(AffineExpr expr
, unsigned numDims
,
561 unsigned numSymbols
) {
562 switch (expr
.getKind()) {
563 case AffineExprKind::Constant
:
564 case AffineExprKind::DimId
:
565 case AffineExprKind::SymbolId
:
567 case AffineExprKind::Add
:
568 case AffineExprKind::Mul
: {
569 AffineBinaryOpExpr binaryExpr
= cast
<AffineBinaryOpExpr
>(expr
);
570 return getAffineBinaryOpExpr(
572 simplifySemiAffine(binaryExpr
.getLHS(), numDims
, numSymbols
),
573 simplifySemiAffine(binaryExpr
.getRHS(), numDims
, numSymbols
));
575 // Check if the simplification of the second operand is a symbol, and the
576 // first operand is divisible by it. If the operation is a modulo, a constant
577 // zero expression is returned. In the case of floordiv and ceildiv, the
578 // symbol from the simplification of the second operand divides the first
579 // operand. Otherwise, simplification is not possible.
580 case AffineExprKind::FloorDiv
:
581 case AffineExprKind::CeilDiv
:
582 case AffineExprKind::Mod
: {
583 AffineBinaryOpExpr binaryExpr
= cast
<AffineBinaryOpExpr
>(expr
);
585 simplifySemiAffine(binaryExpr
.getLHS(), numDims
, numSymbols
);
587 simplifySemiAffine(binaryExpr
.getRHS(), numDims
, numSymbols
);
588 if (isModOfModSubtraction(sLHS
, sRHS
, numDims
, numSymbols
))
589 return getAffineConstantExpr(0, expr
.getContext());
590 AffineSymbolExpr symbolExpr
= dyn_cast
<AffineSymbolExpr
>(
591 simplifySemiAffine(binaryExpr
.getRHS(), numDims
, numSymbols
));
593 return getAffineBinaryOpExpr(expr
.getKind(), sLHS
, sRHS
);
594 unsigned symbolPos
= symbolExpr
.getPosition();
595 if (!canSimplifyDivisionBySymbol(binaryExpr
.getLHS(), symbolPos
,
597 return getAffineBinaryOpExpr(expr
.getKind(), sLHS
, sRHS
);
598 if (expr
.getKind() == AffineExprKind::Mod
)
599 return getAffineConstantExpr(0, expr
.getContext());
600 return symbolicDivide(sLHS
, symbolPos
, expr
.getKind());
603 llvm_unreachable("Unknown AffineExpr");
606 static AffineExpr
getAffineDimOrSymbol(AffineExprKind kind
, unsigned position
,
607 MLIRContext
*context
) {
608 auto assignCtx
= [context
](AffineDimExprStorage
*storage
) {
609 storage
->context
= context
;
612 StorageUniquer
&uniquer
= context
->getAffineUniquer();
613 return uniquer
.get
<AffineDimExprStorage
>(
614 assignCtx
, static_cast<unsigned>(kind
), position
);
617 AffineExpr
mlir::getAffineDimExpr(unsigned position
, MLIRContext
*context
) {
618 return getAffineDimOrSymbol(AffineExprKind::DimId
, position
, context
);
621 AffineSymbolExpr::AffineSymbolExpr(AffineExpr::ImplType
*ptr
)
623 unsigned AffineSymbolExpr::getPosition() const {
624 return static_cast<ImplType
*>(expr
)->position
;
627 AffineExpr
mlir::getAffineSymbolExpr(unsigned position
, MLIRContext
*context
) {
628 return getAffineDimOrSymbol(AffineExprKind::SymbolId
, position
, context
);
631 AffineConstantExpr::AffineConstantExpr(AffineExpr::ImplType
*ptr
)
633 int64_t AffineConstantExpr::getValue() const {
634 return static_cast<ImplType
*>(expr
)->constant
;
637 bool AffineExpr::operator==(int64_t v
) const {
638 return *this == getAffineConstantExpr(v
, getContext());
641 AffineExpr
mlir::getAffineConstantExpr(int64_t constant
, MLIRContext
*context
) {
642 auto assignCtx
= [context
](AffineConstantExprStorage
*storage
) {
643 storage
->context
= context
;
646 StorageUniquer
&uniquer
= context
->getAffineUniquer();
647 return uniquer
.get
<AffineConstantExprStorage
>(assignCtx
, constant
);
650 SmallVector
<AffineExpr
>
651 mlir::getAffineConstantExprs(ArrayRef
<int64_t> constants
,
652 MLIRContext
*context
) {
653 return llvm::to_vector(llvm::map_range(constants
, [&](int64_t constant
) {
654 return getAffineConstantExpr(constant
, context
);
658 /// Simplify add expression. Return nullptr if it can't be simplified.
659 static AffineExpr
simplifyAdd(AffineExpr lhs
, AffineExpr rhs
) {
660 auto lhsConst
= dyn_cast
<AffineConstantExpr
>(lhs
);
661 auto rhsConst
= dyn_cast
<AffineConstantExpr
>(rhs
);
662 // Fold if both LHS, RHS are a constant and the sum does not overflow.
663 if (lhsConst
&& rhsConst
) {
665 if (llvm::AddOverflow(lhsConst
.getValue(), rhsConst
.getValue(), sum
)) {
668 return getAffineConstantExpr(sum
, lhs
.getContext());
671 // Canonicalize so that only the RHS is a constant. (4 + d0 becomes d0 + 4).
672 // If only one of them is a symbolic expressions, make it the RHS.
673 if (isa
<AffineConstantExpr
>(lhs
) ||
674 (lhs
.isSymbolicOrConstant() && !rhs
.isSymbolicOrConstant())) {
678 // At this point, if there was a constant, it would be on the right.
680 // Addition with a zero is a noop, return the other input.
682 if (rhsConst
.getValue() == 0)
685 // Fold successive additions like (d0 + 2) + 3 into d0 + 5.
686 auto lBin
= dyn_cast
<AffineBinaryOpExpr
>(lhs
);
687 if (lBin
&& rhsConst
&& lBin
.getKind() == AffineExprKind::Add
) {
688 if (auto lrhs
= dyn_cast
<AffineConstantExpr
>(lBin
.getRHS()))
689 return lBin
.getLHS() + (lrhs
.getValue() + rhsConst
.getValue());
692 // Detect "c1 * expr + c_2 * expr" as "(c1 + c2) * expr".
693 // c1 is rRhsConst, c2 is rLhsConst; firstExpr, secondExpr are their
694 // respective multiplicands.
695 std::optional
<int64_t> rLhsConst
, rRhsConst
;
696 AffineExpr firstExpr
, secondExpr
;
697 AffineConstantExpr rLhsConstExpr
;
698 auto lBinOpExpr
= dyn_cast
<AffineBinaryOpExpr
>(lhs
);
699 if (lBinOpExpr
&& lBinOpExpr
.getKind() == AffineExprKind::Mul
&&
700 (rLhsConstExpr
= dyn_cast
<AffineConstantExpr
>(lBinOpExpr
.getRHS()))) {
701 rLhsConst
= rLhsConstExpr
.getValue();
702 firstExpr
= lBinOpExpr
.getLHS();
708 auto rBinOpExpr
= dyn_cast
<AffineBinaryOpExpr
>(rhs
);
709 AffineConstantExpr rRhsConstExpr
;
710 if (rBinOpExpr
&& rBinOpExpr
.getKind() == AffineExprKind::Mul
&&
711 (rRhsConstExpr
= dyn_cast
<AffineConstantExpr
>(rBinOpExpr
.getRHS()))) {
712 rRhsConst
= rRhsConstExpr
.getValue();
713 secondExpr
= rBinOpExpr
.getLHS();
719 if (rLhsConst
&& rRhsConst
&& firstExpr
== secondExpr
)
720 return getAffineBinaryOpExpr(
721 AffineExprKind::Mul
, firstExpr
,
722 getAffineConstantExpr(*rLhsConst
+ *rRhsConst
, lhs
.getContext()));
724 // When doing successive additions, bring constant to the right: turn (d0 + 2)
725 // + d1 into (d0 + d1) + 2.
726 if (lBin
&& lBin
.getKind() == AffineExprKind::Add
) {
727 if (auto lrhs
= dyn_cast
<AffineConstantExpr
>(lBin
.getRHS())) {
728 return lBin
.getLHS() + rhs
+ lrhs
;
732 // Detect and transform "expr - q * (expr floordiv q)" to "expr mod q", where
733 // q may be a constant or symbolic expression. This leads to a much more
734 // efficient form when 'c' is a power of two, and in general a more compact
735 // and readable form.
737 // Process '(expr floordiv c) * (-c)'.
741 auto lrhs
= rBinOpExpr
.getLHS();
742 auto rrhs
= rBinOpExpr
.getRHS();
744 AffineExpr llrhs
, rlrhs
;
746 // Check if lrhsBinOpExpr is of the form (expr floordiv q) * q, where q is a
747 // symbolic expression.
748 auto lrhsBinOpExpr
= dyn_cast
<AffineBinaryOpExpr
>(lrhs
);
749 // Check rrhsConstOpExpr = -1.
750 auto rrhsConstOpExpr
= dyn_cast
<AffineConstantExpr
>(rrhs
);
751 if (rrhsConstOpExpr
&& rrhsConstOpExpr
.getValue() == -1 && lrhsBinOpExpr
&&
752 lrhsBinOpExpr
.getKind() == AffineExprKind::Mul
) {
753 // Check llrhs = expr floordiv q.
754 llrhs
= lrhsBinOpExpr
.getLHS();
756 rlrhs
= lrhsBinOpExpr
.getRHS();
757 auto llrhsBinOpExpr
= dyn_cast
<AffineBinaryOpExpr
>(llrhs
);
758 if (!llrhsBinOpExpr
|| llrhsBinOpExpr
.getKind() != AffineExprKind::FloorDiv
)
760 if (llrhsBinOpExpr
.getRHS() == rlrhs
&& lhs
== llrhsBinOpExpr
.getLHS())
764 // Process lrhs, which is 'expr floordiv c'.
765 // expr + (expr // c * -c) = expr % c
766 AffineBinaryOpExpr lrBinOpExpr
= dyn_cast
<AffineBinaryOpExpr
>(lrhs
);
767 if (!lrBinOpExpr
|| rhs
.getKind() != AffineExprKind::Mul
||
768 lrBinOpExpr
.getKind() != AffineExprKind::FloorDiv
)
771 llrhs
= lrBinOpExpr
.getLHS();
772 rlrhs
= lrBinOpExpr
.getRHS();
773 auto rlrhsConstOpExpr
= dyn_cast
<AffineConstantExpr
>(rlrhs
);
774 // We don't support modulo with a negative RHS.
775 bool isPositiveRhs
= rlrhsConstOpExpr
&& rlrhsConstOpExpr
.getValue() > 0;
777 if (isPositiveRhs
&& lhs
== llrhs
&& rlrhs
== -rrhs
) {
783 AffineExpr
AffineExpr::operator+(int64_t v
) const {
784 return *this + getAffineConstantExpr(v
, getContext());
786 AffineExpr
AffineExpr::operator+(AffineExpr other
) const {
787 if (auto simplified
= simplifyAdd(*this, other
))
790 StorageUniquer
&uniquer
= getContext()->getAffineUniquer();
791 return uniquer
.get
<AffineBinaryOpExprStorage
>(
792 /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Add
), *this, other
);
795 /// Simplify a multiply expression. Return nullptr if it can't be simplified.
796 static AffineExpr
simplifyMul(AffineExpr lhs
, AffineExpr rhs
) {
797 auto lhsConst
= dyn_cast
<AffineConstantExpr
>(lhs
);
798 auto rhsConst
= dyn_cast
<AffineConstantExpr
>(rhs
);
800 if (lhsConst
&& rhsConst
) {
802 if (llvm::MulOverflow(lhsConst
.getValue(), rhsConst
.getValue(), product
)) {
805 return getAffineConstantExpr(product
, lhs
.getContext());
808 if (!lhs
.isSymbolicOrConstant() && !rhs
.isSymbolicOrConstant())
811 // Canonicalize the mul expression so that the constant/symbolic term is the
812 // RHS. If both the lhs and rhs are symbolic, swap them if the lhs is a
813 // constant. (Note that a constant is trivially symbolic).
814 if (!rhs
.isSymbolicOrConstant() || isa
<AffineConstantExpr
>(lhs
)) {
815 // At least one of them has to be symbolic.
819 // At this point, if there was a constant, it would be on the right.
821 // Multiplication with a one is a noop, return the other input.
823 if (rhsConst
.getValue() == 1)
825 // Multiplication with zero.
826 if (rhsConst
.getValue() == 0)
830 // Fold successive multiplications: eg: (d0 * 2) * 3 into d0 * 6.
831 auto lBin
= dyn_cast
<AffineBinaryOpExpr
>(lhs
);
832 if (lBin
&& rhsConst
&& lBin
.getKind() == AffineExprKind::Mul
) {
833 if (auto lrhs
= dyn_cast
<AffineConstantExpr
>(lBin
.getRHS()))
834 return lBin
.getLHS() * (lrhs
.getValue() * rhsConst
.getValue());
837 // When doing successive multiplication, bring constant to the right: turn (d0
838 // * 2) * d1 into (d0 * d1) * 2.
839 if (lBin
&& lBin
.getKind() == AffineExprKind::Mul
) {
840 if (auto lrhs
= dyn_cast
<AffineConstantExpr
>(lBin
.getRHS())) {
841 return (lBin
.getLHS() * rhs
) * lrhs
;
848 AffineExpr
AffineExpr::operator*(int64_t v
) const {
849 return *this * getAffineConstantExpr(v
, getContext());
851 AffineExpr
AffineExpr::operator*(AffineExpr other
) const {
852 if (auto simplified
= simplifyMul(*this, other
))
855 StorageUniquer
&uniquer
= getContext()->getAffineUniquer();
856 return uniquer
.get
<AffineBinaryOpExprStorage
>(
857 /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mul
), *this, other
);
860 // Unary minus, delegate to operator*.
861 AffineExpr
AffineExpr::operator-() const {
862 return *this * getAffineConstantExpr(-1, getContext());
865 // Delegate to operator+.
866 AffineExpr
AffineExpr::operator-(int64_t v
) const { return *this + (-v
); }
867 AffineExpr
AffineExpr::operator-(AffineExpr other
) const {
868 return *this + (-other
);
871 static AffineExpr
simplifyFloorDiv(AffineExpr lhs
, AffineExpr rhs
) {
872 auto lhsConst
= dyn_cast
<AffineConstantExpr
>(lhs
);
873 auto rhsConst
= dyn_cast
<AffineConstantExpr
>(rhs
);
875 if (!rhsConst
|| rhsConst
.getValue() == 0)
879 if (divideSignedWouldOverflow(lhsConst
.getValue(), rhsConst
.getValue()))
881 return getAffineConstantExpr(
882 divideFloorSigned(lhsConst
.getValue(), rhsConst
.getValue()),
886 // Fold floordiv of a multiply with a constant that is a multiple of the
887 // divisor. Eg: (i * 128) floordiv 64 = i * 2.
891 // Simplify `(expr * lrhs) floordiv rhsConst` when `lrhs` is known to be a
892 // multiple of `rhsConst`.
893 auto lBin
= dyn_cast
<AffineBinaryOpExpr
>(lhs
);
894 if (lBin
&& lBin
.getKind() == AffineExprKind::Mul
) {
895 if (auto lrhs
= dyn_cast
<AffineConstantExpr
>(lBin
.getRHS())) {
896 // `rhsConst` is known to be a nonzero constant.
897 if (lrhs
.getValue() % rhsConst
.getValue() == 0)
898 return lBin
.getLHS() * (lrhs
.getValue() / rhsConst
.getValue());
902 // Simplify (expr1 + expr2) floordiv divConst when either expr1 or expr2 is
903 // known to be a multiple of divConst.
904 if (lBin
&& lBin
.getKind() == AffineExprKind::Add
) {
905 int64_t llhsDiv
= lBin
.getLHS().getLargestKnownDivisor();
906 int64_t lrhsDiv
= lBin
.getRHS().getLargestKnownDivisor();
907 // rhsConst is known to be a nonzero constant.
908 if (llhsDiv
% rhsConst
.getValue() == 0 ||
909 lrhsDiv
% rhsConst
.getValue() == 0)
910 return lBin
.getLHS().floorDiv(rhsConst
.getValue()) +
911 lBin
.getRHS().floorDiv(rhsConst
.getValue());
917 AffineExpr
AffineExpr::floorDiv(uint64_t v
) const {
918 return floorDiv(getAffineConstantExpr(v
, getContext()));
920 AffineExpr
AffineExpr::floorDiv(AffineExpr other
) const {
921 if (auto simplified
= simplifyFloorDiv(*this, other
))
924 StorageUniquer
&uniquer
= getContext()->getAffineUniquer();
925 return uniquer
.get
<AffineBinaryOpExprStorage
>(
926 /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::FloorDiv
), *this,
930 static AffineExpr
simplifyCeilDiv(AffineExpr lhs
, AffineExpr rhs
) {
931 auto lhsConst
= dyn_cast
<AffineConstantExpr
>(lhs
);
932 auto rhsConst
= dyn_cast
<AffineConstantExpr
>(rhs
);
934 if (!rhsConst
|| rhsConst
.getValue() == 0)
938 if (divideSignedWouldOverflow(lhsConst
.getValue(), rhsConst
.getValue()))
940 return getAffineConstantExpr(
941 divideCeilSigned(lhsConst
.getValue(), rhsConst
.getValue()),
945 // Fold ceildiv of a multiply with a constant that is a multiple of the
946 // divisor. Eg: (i * 128) ceildiv 64 = i * 2.
947 if (rhsConst
.getValue() == 1)
950 // Simplify `(expr * lrhs) ceildiv rhsConst` when `lrhs` is known to be a
951 // multiple of `rhsConst`.
952 auto lBin
= dyn_cast
<AffineBinaryOpExpr
>(lhs
);
953 if (lBin
&& lBin
.getKind() == AffineExprKind::Mul
) {
954 if (auto lrhs
= dyn_cast
<AffineConstantExpr
>(lBin
.getRHS())) {
955 // `rhsConst` is known to be a nonzero constant.
956 if (lrhs
.getValue() % rhsConst
.getValue() == 0)
957 return lBin
.getLHS() * (lrhs
.getValue() / rhsConst
.getValue());
964 AffineExpr
AffineExpr::ceilDiv(uint64_t v
) const {
965 return ceilDiv(getAffineConstantExpr(v
, getContext()));
967 AffineExpr
AffineExpr::ceilDiv(AffineExpr other
) const {
968 if (auto simplified
= simplifyCeilDiv(*this, other
))
971 StorageUniquer
&uniquer
= getContext()->getAffineUniquer();
972 return uniquer
.get
<AffineBinaryOpExprStorage
>(
973 /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::CeilDiv
), *this,
977 static AffineExpr
simplifyMod(AffineExpr lhs
, AffineExpr rhs
) {
978 auto lhsConst
= dyn_cast
<AffineConstantExpr
>(lhs
);
979 auto rhsConst
= dyn_cast
<AffineConstantExpr
>(rhs
);
981 // mod w.r.t zero or negative numbers is undefined and preserved as is.
982 if (!rhsConst
|| rhsConst
.getValue() < 1)
986 // mod never overflows.
987 return getAffineConstantExpr(mod(lhsConst
.getValue(), rhsConst
.getValue()),
991 // Fold modulo of an expression that is known to be a multiple of a constant
992 // to zero if that constant is a multiple of the modulo factor. Eg: (i * 128)
993 // mod 64 is folded to 0, and less trivially, (i*(j*4*(k*32))) mod 128 = 0.
994 if (lhs
.getLargestKnownDivisor() % rhsConst
.getValue() == 0)
995 return getAffineConstantExpr(0, lhs
.getContext());
997 // Simplify (expr1 + expr2) mod divConst when either expr1 or expr2 is
998 // known to be a multiple of divConst.
999 auto lBin
= dyn_cast
<AffineBinaryOpExpr
>(lhs
);
1000 if (lBin
&& lBin
.getKind() == AffineExprKind::Add
) {
1001 int64_t llhsDiv
= lBin
.getLHS().getLargestKnownDivisor();
1002 int64_t lrhsDiv
= lBin
.getRHS().getLargestKnownDivisor();
1003 // rhsConst is known to be a positive constant.
1004 if (llhsDiv
% rhsConst
.getValue() == 0)
1005 return lBin
.getRHS() % rhsConst
.getValue();
1006 if (lrhsDiv
% rhsConst
.getValue() == 0)
1007 return lBin
.getLHS() % rhsConst
.getValue();
1010 // Simplify (e % a) % b to e % b when b evenly divides a
1011 if (lBin
&& lBin
.getKind() == AffineExprKind::Mod
) {
1012 auto intermediate
= dyn_cast
<AffineConstantExpr
>(lBin
.getRHS());
1013 if (intermediate
&& intermediate
.getValue() >= 1 &&
1014 mod(intermediate
.getValue(), rhsConst
.getValue()) == 0) {
1015 return lBin
.getLHS() % rhsConst
.getValue();
1022 AffineExpr
AffineExpr::operator%(uint64_t v
) const {
1023 return *this % getAffineConstantExpr(v
, getContext());
1025 AffineExpr
AffineExpr::operator%(AffineExpr other
) const {
1026 if (auto simplified
= simplifyMod(*this, other
))
1029 StorageUniquer
&uniquer
= getContext()->getAffineUniquer();
1030 return uniquer
.get
<AffineBinaryOpExprStorage
>(
1031 /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mod
), *this, other
);
1034 AffineExpr
AffineExpr::compose(AffineMap map
) const {
1035 SmallVector
<AffineExpr
, 8> dimReplacements(map
.getResults());
1036 return replaceDimsAndSymbols(dimReplacements
, {});
1038 raw_ostream
&mlir::operator<<(raw_ostream
&os
, AffineExpr expr
) {
1043 /// Constructs an affine expression from a flat ArrayRef. If there are local
1044 /// identifiers (neither dimensional nor symbolic) that appear in the sum of
1045 /// products expression, `localExprs` is expected to have the AffineExpr
1046 /// for it, and is substituted into. The ArrayRef `flatExprs` is expected to be
1047 /// in the format [dims, symbols, locals, constant term].
1048 AffineExpr
mlir::getAffineExprFromFlatForm(ArrayRef
<int64_t> flatExprs
,
1050 unsigned numSymbols
,
1051 ArrayRef
<AffineExpr
> localExprs
,
1052 MLIRContext
*context
) {
1053 // Assert expected numLocals = flatExprs.size() - numDims - numSymbols - 1.
1054 assert(flatExprs
.size() - numDims
- numSymbols
- 1 == localExprs
.size() &&
1055 "unexpected number of local expressions");
1057 auto expr
= getAffineConstantExpr(0, context
);
1058 // Dimensions and symbols.
1059 for (unsigned j
= 0; j
< numDims
+ numSymbols
; j
++) {
1060 if (flatExprs
[j
] == 0)
1062 auto id
= j
< numDims
? getAffineDimExpr(j
, context
)
1063 : getAffineSymbolExpr(j
- numDims
, context
);
1064 expr
= expr
+ id
* flatExprs
[j
];
1067 // Local identifiers.
1068 for (unsigned j
= numDims
+ numSymbols
, e
= flatExprs
.size() - 1; j
< e
;
1070 if (flatExprs
[j
] == 0)
1072 auto term
= localExprs
[j
- numDims
- numSymbols
] * flatExprs
[j
];
1077 int64_t constTerm
= flatExprs
[flatExprs
.size() - 1];
1079 expr
= expr
+ constTerm
;
1083 /// Constructs a semi-affine expression from a flat ArrayRef. If there are
1084 /// local identifiers (neither dimensional nor symbolic) that appear in the sum
1085 /// of products expression, `localExprs` is expected to have the AffineExprs for
1086 /// it, and is substituted into. The ArrayRef `flatExprs` is expected to be in
1087 /// the format [dims, symbols, locals, constant term]. The semi-affine
1088 /// expression is constructed in the sorted order of dimension and symbol
1089 /// position numbers. Note: local expressions/ids are used for mod, div as well
1090 /// as symbolic RHS terms for terms that are not pure affine.
1091 static AffineExpr
getSemiAffineExprFromFlatForm(ArrayRef
<int64_t> flatExprs
,
1093 unsigned numSymbols
,
1094 ArrayRef
<AffineExpr
> localExprs
,
1095 MLIRContext
*context
) {
1096 assert(!flatExprs
.empty() && "flatExprs cannot be empty");
1098 // Assert expected numLocals = flatExprs.size() - numDims - numSymbols - 1.
1099 assert(flatExprs
.size() - numDims
- numSymbols
- 1 == localExprs
.size() &&
1100 "unexpected number of local expressions");
1102 AffineExpr expr
= getAffineConstantExpr(0, context
);
1104 // We design indices as a pair which help us present the semi-affine map as
1105 // sum of product where terms are sorted based on dimension or symbol
1106 // position: <keyA, keyB> for expressions of the form dimension * symbol,
1107 // where keyA is the position number of the dimension and keyB is the
1108 // position number of the symbol. For dimensional expressions we set the index
1109 // as (position number of the dimension, -1), as we want dimensional
1110 // expressions to appear before symbolic and product of dimensional and
1111 // symbolic expressions having the dimension with the same position number.
1112 // For symbolic expression set the index as (position number of the symbol,
1113 // maximum of last dimension and symbol position) number. For example, we want
1114 // the expression we are constructing to look something like: d0 + d0 * s0 +
1117 // Stores the affine expression corresponding to a given index.
1118 DenseMap
<std::pair
<unsigned, signed>, AffineExpr
> indexToExprMap
;
1119 // Stores the constant coefficient value corresponding to a given
1120 // dimension, symbol or a non-pure affine expression stored in `localExprs`.
1121 DenseMap
<std::pair
<unsigned, signed>, int64_t> coefficients
;
1122 // Stores the indices as defined above, and later sorted to produce
1123 // the semi-affine expression in the desired form.
1124 SmallVector
<std::pair
<unsigned, signed>, 8> indices
;
1126 // Example: expression = d0 + d0 * s0 + 2 * s0.
1127 // indices = [{0,-1}, {0, 0}, {0, 1}]
1128 // coefficients = [{{0, -1}, 1}, {{0, 0}, 1}, {{0, 1}, 2}]
1129 // indexToExprMap = [{{0, -1}, d0}, {{0, 0}, d0 * s0}, {{0, 1}, s0}]
1131 // Adds entries to `indexToExprMap`, `coefficients` and `indices`.
1132 auto addEntry
= [&](std::pair
<unsigned, signed> index
, int64_t coefficient
,
1134 assert(!llvm::is_contained(indices
, index
) &&
1135 "Key is already present in indices vector and overwriting will "
1136 "happen in `indexToExprMap` and `coefficients`!");
1138 indices
.push_back(index
);
1139 coefficients
.insert({index
, coefficient
});
1140 indexToExprMap
.insert({index
, expr
});
1143 // Design indices for dimensional or symbolic terms, and store the indices,
1144 // constant coefficient corresponding to the indices in `coefficients` map,
1145 // and affine expression corresponding to indices in `indexToExprMap` map.
1147 // Ensure we do not have duplicate keys in `indexToExpr` map.
1148 unsigned offsetSym
= 0;
1149 signed offsetDim
= -1;
1150 for (unsigned j
= numDims
; j
< numDims
+ numSymbols
; ++j
) {
1151 if (flatExprs
[j
] == 0)
1153 // For symbolic expression set the index as <position number
1154 // of the symbol, max(dimCount, symCount)> number,
1155 // as we want symbolic expressions with the same positional number to
1156 // appear after dimensional expressions having the same positional number.
1157 std::pair
<unsigned, signed> indexEntry(
1158 j
- numDims
, std::max(numDims
, numSymbols
) + offsetSym
++);
1159 addEntry(indexEntry
, flatExprs
[j
],
1160 getAffineSymbolExpr(j
- numDims
, context
));
1163 // Denotes semi-affine product, modulo or division terms, which has been added
1164 // to the `indexToExpr` map.
1165 SmallVector
<bool, 4> addedToMap(flatExprs
.size() - numDims
- numSymbols
- 1,
1167 unsigned lhsPos
, rhsPos
;
1168 // Construct indices for product terms involving dimension, symbol or constant
1169 // as lhs/rhs, and store the indices, constant coefficient corresponding to
1170 // the indices in `coefficients` map, and affine expression corresponding to
1171 // in indices in `indexToExprMap` map.
1172 for (const auto &it
: llvm::enumerate(localExprs
)) {
1173 AffineExpr expr
= it
.value();
1174 if (flatExprs
[numDims
+ numSymbols
+ it
.index()] == 0)
1176 AffineExpr lhs
= cast
<AffineBinaryOpExpr
>(expr
).getLHS();
1177 AffineExpr rhs
= cast
<AffineBinaryOpExpr
>(expr
).getRHS();
1178 if (!((isa
<AffineDimExpr
>(lhs
) || isa
<AffineSymbolExpr
>(lhs
)) &&
1179 (isa
<AffineDimExpr
>(rhs
) || isa
<AffineSymbolExpr
>(rhs
) ||
1180 isa
<AffineConstantExpr
>(rhs
)))) {
1183 if (isa
<AffineConstantExpr
>(rhs
)) {
1184 // For product/modulo/division expressions, when rhs of modulo/division
1185 // expression is constant, we put 0 in place of keyB, because we want
1186 // them to appear earlier in the semi-affine expression we are
1187 // constructing. When rhs is constant, we place 0 in place of keyB.
1188 if (isa
<AffineDimExpr
>(lhs
)) {
1189 lhsPos
= cast
<AffineDimExpr
>(lhs
).getPosition();
1190 std::pair
<unsigned, signed> indexEntry(lhsPos
, offsetDim
--);
1191 addEntry(indexEntry
, flatExprs
[numDims
+ numSymbols
+ it
.index()],
1194 lhsPos
= cast
<AffineSymbolExpr
>(lhs
).getPosition();
1195 std::pair
<unsigned, signed> indexEntry(
1196 lhsPos
, std::max(numDims
, numSymbols
) + offsetSym
++);
1197 addEntry(indexEntry
, flatExprs
[numDims
+ numSymbols
+ it
.index()],
1200 } else if (isa
<AffineDimExpr
>(lhs
)) {
1201 // For product/modulo/division expressions having lhs as dimension and rhs
1202 // as symbol, we order the terms in the semi-affine expression based on
1203 // the pair: <keyA, keyB> for expressions of the form dimension * symbol,
1204 // where keyA is the position number of the dimension and keyB is the
1205 // position number of the symbol.
1206 lhsPos
= cast
<AffineDimExpr
>(lhs
).getPosition();
1207 rhsPos
= cast
<AffineSymbolExpr
>(rhs
).getPosition();
1208 std::pair
<unsigned, signed> indexEntry(lhsPos
, rhsPos
);
1209 addEntry(indexEntry
, flatExprs
[numDims
+ numSymbols
+ it
.index()], expr
);
1211 // For product/modulo/division expressions having both lhs and rhs as
1212 // symbol, we design indices as a pair: <keyA, keyB> for expressions
1213 // of the form dimension * symbol, where keyA is the position number of
1214 // the dimension and keyB is the position number of the symbol.
1215 lhsPos
= cast
<AffineSymbolExpr
>(lhs
).getPosition();
1216 rhsPos
= cast
<AffineSymbolExpr
>(rhs
).getPosition();
1217 std::pair
<unsigned, signed> indexEntry(
1218 lhsPos
, std::max(numDims
, numSymbols
) + offsetSym
++);
1219 addEntry(indexEntry
, flatExprs
[numDims
+ numSymbols
+ it
.index()], expr
);
1221 addedToMap
[it
.index()] = true;
1224 for (unsigned j
= 0; j
< numDims
; ++j
) {
1225 if (flatExprs
[j
] == 0)
1227 // For dimensional expressions we set the index as <position number of the
1228 // dimension, 0>, as we want dimensional expressions to appear before
1229 // symbolic ones and products of dimensional and symbolic expressions
1230 // having the dimension with the same position number.
1231 std::pair
<unsigned, signed> indexEntry(j
, offsetDim
--);
1232 addEntry(indexEntry
, flatExprs
[j
], getAffineDimExpr(j
, context
));
1235 // Constructing the simplified semi-affine sum of product/division/mod
1236 // expression from the flattened form in the desired sorted order of indices
1237 // of the various individual product/division/mod expressions.
1238 llvm::sort(indices
);
1239 for (const std::pair
<unsigned, unsigned> index
: indices
) {
1240 assert(indexToExprMap
.lookup(index
) &&
1241 "cannot find key in `indexToExprMap` map");
1242 expr
= expr
+ indexToExprMap
.lookup(index
) * coefficients
.lookup(index
);
1245 // Local identifiers.
1246 for (unsigned j
= numDims
+ numSymbols
, e
= flatExprs
.size() - 1; j
< e
;
1248 // If the coefficient of the local expression is 0, continue as we need not
1249 // add it in out final expression.
1250 if (flatExprs
[j
] == 0 || addedToMap
[j
- numDims
- numSymbols
])
1252 auto term
= localExprs
[j
- numDims
- numSymbols
] * flatExprs
[j
];
1257 int64_t constTerm
= flatExprs
.back();
1259 expr
= expr
+ constTerm
;
1263 SimpleAffineExprFlattener::SimpleAffineExprFlattener(unsigned numDims
,
1264 unsigned numSymbols
)
1265 : numDims(numDims
), numSymbols(numSymbols
), numLocals(0) {
1266 operandExprStack
.reserve(8);
1269 // In pure affine t = expr * c, we multiply each coefficient of lhs with c.
1271 // In case of semi affine multiplication expressions, t = expr * symbolic_expr,
1272 // introduce a local variable p (= expr * symbolic_expr), and the affine
1273 // expression expr * symbolic_expr is added to `localExprs`.
1274 LogicalResult
SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr
) {
1275 assert(operandExprStack
.size() >= 2);
1276 SmallVector
<int64_t, 8> rhs
= operandExprStack
.back();
1277 operandExprStack
.pop_back();
1278 SmallVector
<int64_t, 8> &lhs
= operandExprStack
.back();
1280 // Flatten semi-affine multiplication expressions by introducing a local
1281 // variable in place of the product; the affine expression
1282 // corresponding to the quantifier is added to `localExprs`.
1283 if (!isa
<AffineConstantExpr
>(expr
.getRHS())) {
1284 SmallVector
<int64_t, 8> mulLhs(lhs
);
1285 MLIRContext
*context
= expr
.getContext();
1286 AffineExpr a
= getAffineExprFromFlatForm(lhs
, numDims
, numSymbols
,
1287 localExprs
, context
);
1288 AffineExpr b
= getAffineExprFromFlatForm(rhs
, numDims
, numSymbols
,
1289 localExprs
, context
);
1290 return addLocalVariableSemiAffine(mulLhs
, rhs
, a
* b
, lhs
, lhs
.size());
1293 // Get the RHS constant.
1294 int64_t rhsConst
= rhs
[getConstantIndex()];
1295 for (int64_t &lhsElt
: lhs
)
1301 LogicalResult
SimpleAffineExprFlattener::visitAddExpr(AffineBinaryOpExpr expr
) {
1302 assert(operandExprStack
.size() >= 2);
1303 const auto &rhs
= operandExprStack
.back();
1304 auto &lhs
= operandExprStack
[operandExprStack
.size() - 2];
1305 assert(lhs
.size() == rhs
.size());
1306 // Update the LHS in place.
1307 for (unsigned i
= 0, e
= rhs
.size(); i
< e
; i
++) {
1311 operandExprStack
.pop_back();
1316 // t = expr mod c <=> t = expr - c*q and c*q <= expr <= c*q + c - 1
1318 // A mod expression "expr mod c" is thus flattened by introducing a new local
1319 // variable q (= expr floordiv c), such that expr mod c is replaced with
1320 // 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst.
1322 // In case of semi-affine modulo expressions, t = expr mod symbolic_expr,
1323 // introduce a local variable m (= expr mod symbolic_expr), and the affine
1324 // expression expr mod symbolic_expr is added to `localExprs`.
1325 LogicalResult
SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr
) {
1326 assert(operandExprStack
.size() >= 2);
1328 SmallVector
<int64_t, 8> rhs
= operandExprStack
.back();
1329 operandExprStack
.pop_back();
1330 SmallVector
<int64_t, 8> &lhs
= operandExprStack
.back();
1331 MLIRContext
*context
= expr
.getContext();
1333 // Flatten semi affine modulo expressions by introducing a local
1334 // variable in place of the modulo value, and the affine expression
1335 // corresponding to the quantifier is added to `localExprs`.
1336 if (!isa
<AffineConstantExpr
>(expr
.getRHS())) {
1337 SmallVector
<int64_t, 8> modLhs(lhs
);
1338 AffineExpr dividendExpr
= getAffineExprFromFlatForm(
1339 lhs
, numDims
, numSymbols
, localExprs
, context
);
1340 AffineExpr divisorExpr
= getAffineExprFromFlatForm(rhs
, numDims
, numSymbols
,
1341 localExprs
, context
);
1342 AffineExpr modExpr
= dividendExpr
% divisorExpr
;
1343 return addLocalVariableSemiAffine(modLhs
, rhs
, modExpr
, lhs
, lhs
.size());
1346 int64_t rhsConst
= rhs
[getConstantIndex()];
1350 // Check if the LHS expression is a multiple of modulo factor.
1352 for (i
= 0, e
= lhs
.size(); i
< e
; i
++)
1353 if (lhs
[i
] % rhsConst
!= 0)
1355 // If yes, modulo expression here simplifies to zero.
1356 if (i
== lhs
.size()) {
1357 std::fill(lhs
.begin(), lhs
.end(), 0);
1361 // Add a local variable for the quotient, i.e., expr % c is replaced by
1362 // (expr - q * c) where q = expr floordiv c. Do this while canceling out
1363 // the GCD of expr and c.
1364 SmallVector
<int64_t, 8> floorDividend(lhs
);
1365 uint64_t gcd
= rhsConst
;
1366 for (int64_t lhsElt
: lhs
)
1367 gcd
= std::gcd(gcd
, (uint64_t)std::abs(lhsElt
));
1368 // Simplify the numerator and the denominator.
1370 for (int64_t &floorDividendElt
: floorDividend
)
1371 floorDividendElt
= floorDividendElt
/ static_cast<int64_t>(gcd
);
1373 int64_t floorDivisor
= rhsConst
/ static_cast<int64_t>(gcd
);
1375 // Construct the AffineExpr form of the floordiv to store in localExprs.
1377 AffineExpr dividendExpr
= getAffineExprFromFlatForm(
1378 floorDividend
, numDims
, numSymbols
, localExprs
, context
);
1379 AffineExpr divisorExpr
= getAffineConstantExpr(floorDivisor
, context
);
1380 AffineExpr floorDivExpr
= dividendExpr
.floorDiv(divisorExpr
);
1382 if ((loc
= findLocalId(floorDivExpr
)) == -1) {
1383 addLocalFloorDivId(floorDividend
, floorDivisor
, floorDivExpr
);
1384 // Set result at top of stack to "lhs - rhsConst * q".
1385 lhs
[getLocalVarStartIndex() + numLocals
- 1] = -rhsConst
;
1387 // Reuse the existing local id.
1388 lhs
[getLocalVarStartIndex() + loc
] = -rhsConst
;
1394 SimpleAffineExprFlattener::visitCeilDivExpr(AffineBinaryOpExpr expr
) {
1395 return visitDivExpr(expr
, /*isCeil=*/true);
1398 SimpleAffineExprFlattener::visitFloorDivExpr(AffineBinaryOpExpr expr
) {
1399 return visitDivExpr(expr
, /*isCeil=*/false);
1402 LogicalResult
SimpleAffineExprFlattener::visitDimExpr(AffineDimExpr expr
) {
1403 operandExprStack
.emplace_back(SmallVector
<int64_t, 32>(getNumCols(), 0));
1404 auto &eq
= operandExprStack
.back();
1405 assert(expr
.getPosition() < numDims
&& "Inconsistent number of dims");
1406 eq
[getDimStartIndex() + expr
.getPosition()] = 1;
1411 SimpleAffineExprFlattener::visitSymbolExpr(AffineSymbolExpr expr
) {
1412 operandExprStack
.emplace_back(SmallVector
<int64_t, 32>(getNumCols(), 0));
1413 auto &eq
= operandExprStack
.back();
1414 assert(expr
.getPosition() < numSymbols
&& "inconsistent number of symbols");
1415 eq
[getSymbolStartIndex() + expr
.getPosition()] = 1;
1420 SimpleAffineExprFlattener::visitConstantExpr(AffineConstantExpr expr
) {
1421 operandExprStack
.emplace_back(SmallVector
<int64_t, 32>(getNumCols(), 0));
1422 auto &eq
= operandExprStack
.back();
1423 eq
[getConstantIndex()] = expr
.getValue();
1427 LogicalResult
SimpleAffineExprFlattener::addLocalVariableSemiAffine(
1428 ArrayRef
<int64_t> lhs
, ArrayRef
<int64_t> rhs
, AffineExpr localExpr
,
1429 SmallVectorImpl
<int64_t> &result
, unsigned long resultSize
) {
1430 assert(result
.size() == resultSize
&&
1431 "`result` vector passed is not of correct size");
1433 if ((loc
= findLocalId(localExpr
)) == -1) {
1434 if (failed(addLocalIdSemiAffine(lhs
, rhs
, localExpr
)))
1437 std::fill(result
.begin(), result
.end(), 0);
1439 result
[getLocalVarStartIndex() + numLocals
- 1] = 1;
1441 result
[getLocalVarStartIndex() + loc
] = 1;
1445 // t = expr floordiv c <=> t = q, c * q <= expr <= c * q + c - 1
1446 // A floordiv is thus flattened by introducing a new local variable q, and
1447 // replacing that expression with 'q' while adding the constraints
1448 // c * q <= expr <= c * q + c - 1 to localVarCst (done by
1449 // IntegerRelation::addLocalFloorDiv).
1451 // A ceildiv is similarly flattened:
1452 // t = expr ceildiv c <=> t = (expr + c - 1) floordiv c
1454 // In case of semi affine division expressions, t = expr floordiv symbolic_expr
1455 // or t = expr ceildiv symbolic_expr, introduce a local variable q (= expr
1456 // floordiv/ceildiv symbolic_expr), and the affine floordiv/ceildiv is added to
1458 LogicalResult
SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr
,
1460 assert(operandExprStack
.size() >= 2);
1462 MLIRContext
*context
= expr
.getContext();
1463 SmallVector
<int64_t, 8> rhs
= operandExprStack
.back();
1464 operandExprStack
.pop_back();
1465 SmallVector
<int64_t, 8> &lhs
= operandExprStack
.back();
1467 // Flatten semi affine division expressions by introducing a local
1468 // variable in place of the quotient, and the affine expression corresponding
1469 // to the quantifier is added to `localExprs`.
1470 if (!isa
<AffineConstantExpr
>(expr
.getRHS())) {
1471 SmallVector
<int64_t, 8> divLhs(lhs
);
1472 AffineExpr a
= getAffineExprFromFlatForm(lhs
, numDims
, numSymbols
,
1473 localExprs
, context
);
1474 AffineExpr b
= getAffineExprFromFlatForm(rhs
, numDims
, numSymbols
,
1475 localExprs
, context
);
1476 AffineExpr divExpr
= isCeil
? a
.ceilDiv(b
) : a
.floorDiv(b
);
1477 return addLocalVariableSemiAffine(divLhs
, rhs
, divExpr
, lhs
, lhs
.size());
1480 // This is a pure affine expr; the RHS is a positive constant.
1481 int64_t rhsConst
= rhs
[getConstantIndex()];
1485 // Simplify the floordiv, ceildiv if possible by canceling out the greatest
1486 // common divisors of the numerator and denominator.
1487 uint64_t gcd
= std::abs(rhsConst
);
1488 for (int64_t lhsElt
: lhs
)
1489 gcd
= std::gcd(gcd
, (uint64_t)std::abs(lhsElt
));
1490 // Simplify the numerator and the denominator.
1492 for (int64_t &lhsElt
: lhs
)
1493 lhsElt
= lhsElt
/ static_cast<int64_t>(gcd
);
1495 int64_t divisor
= rhsConst
/ static_cast<int64_t>(gcd
);
1496 // If the divisor becomes 1, the updated LHS is the result. (The
1497 // divisor can't be negative since rhsConst is positive).
1501 // If the divisor cannot be simplified to one, we will have to retain
1502 // the ceil/floor expr (simplified up until here). Add an existential
1503 // quantifier to express its result, i.e., expr1 div expr2 is replaced
1504 // by a new identifier, q.
1506 getAffineExprFromFlatForm(lhs
, numDims
, numSymbols
, localExprs
, context
);
1507 AffineExpr b
= getAffineConstantExpr(divisor
, context
);
1510 AffineExpr divExpr
= isCeil
? a
.ceilDiv(b
) : a
.floorDiv(b
);
1511 if ((loc
= findLocalId(divExpr
)) == -1) {
1513 SmallVector
<int64_t, 8> dividend(lhs
);
1514 addLocalFloorDivId(dividend
, divisor
, divExpr
);
1516 // lhs ceildiv c <=> (lhs + c - 1) floordiv c
1517 SmallVector
<int64_t, 8> dividend(lhs
);
1518 dividend
.back() += divisor
- 1;
1519 addLocalFloorDivId(dividend
, divisor
, divExpr
);
1522 // Set the expression on stack to the local var introduced to capture the
1523 // result of the division (floor or ceil).
1524 std::fill(lhs
.begin(), lhs
.end(), 0);
1526 lhs
[getLocalVarStartIndex() + numLocals
- 1] = 1;
1528 lhs
[getLocalVarStartIndex() + loc
] = 1;
1532 // Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr).
1533 // The local identifier added is always a floordiv of a pure add/mul affine
1534 // function of other identifiers, coefficients of which are specified in
1535 // dividend and with respect to a positive constant divisor. localExpr is the
1536 // simplified tree expression (AffineExpr) corresponding to the quantifier.
1537 void SimpleAffineExprFlattener::addLocalFloorDivId(ArrayRef
<int64_t> dividend
,
1539 AffineExpr localExpr
) {
1540 assert(divisor
> 0 && "positive constant divisor expected");
1541 for (SmallVector
<int64_t, 8> &subExpr
: operandExprStack
)
1542 subExpr
.insert(subExpr
.begin() + getLocalVarStartIndex() + numLocals
, 0);
1543 localExprs
.push_back(localExpr
);
1545 // dividend and divisor are not used here; an override of this method uses it.
1548 LogicalResult
SimpleAffineExprFlattener::addLocalIdSemiAffine(
1549 ArrayRef
<int64_t> lhs
, ArrayRef
<int64_t> rhs
, AffineExpr localExpr
) {
1550 for (SmallVector
<int64_t, 8> &subExpr
: operandExprStack
)
1551 subExpr
.insert(subExpr
.begin() + getLocalVarStartIndex() + numLocals
, 0);
1552 localExprs
.push_back(localExpr
);
1554 // lhs and rhs are not used here; an override of this method uses them.
1558 int SimpleAffineExprFlattener::findLocalId(AffineExpr localExpr
) {
1559 SmallVectorImpl
<AffineExpr
>::iterator it
;
1560 if ((it
= llvm::find(localExprs
, localExpr
)) == localExprs
.end())
1562 return it
- localExprs
.begin();
1565 /// Simplify the affine expression by flattening it and reconstructing it.
1566 AffineExpr
mlir::simplifyAffineExpr(AffineExpr expr
, unsigned numDims
,
1567 unsigned numSymbols
) {
1568 // Simplify semi-affine expressions separately.
1569 if (!expr
.isPureAffine())
1570 expr
= simplifySemiAffine(expr
, numDims
, numSymbols
);
1572 SimpleAffineExprFlattener
flattener(numDims
, numSymbols
);
1573 // has poison expression
1574 if (failed(flattener
.walkPostOrder(expr
)))
1576 ArrayRef
<int64_t> flattenedExpr
= flattener
.operandExprStack
.back();
1577 if (!expr
.isPureAffine() &&
1578 expr
== getAffineExprFromFlatForm(flattenedExpr
, numDims
, numSymbols
,
1579 flattener
.localExprs
,
1582 AffineExpr simplifiedExpr
=
1584 ? getAffineExprFromFlatForm(flattenedExpr
, numDims
, numSymbols
,
1585 flattener
.localExprs
, expr
.getContext())
1586 : getSemiAffineExprFromFlatForm(flattenedExpr
, numDims
, numSymbols
,
1587 flattener
.localExprs
,
1590 flattener
.operandExprStack
.pop_back();
1591 assert(flattener
.operandExprStack
.empty());
1592 return simplifiedExpr
;
1595 std::optional
<int64_t> mlir::getBoundForAffineExpr(
1596 AffineExpr expr
, unsigned numDims
, unsigned numSymbols
,
1597 ArrayRef
<std::optional
<int64_t>> constLowerBounds
,
1598 ArrayRef
<std::optional
<int64_t>> constUpperBounds
, bool isUpper
) {
1599 // Handle divs and mods.
1600 if (auto binOpExpr
= dyn_cast
<AffineBinaryOpExpr
>(expr
)) {
1601 // If the LHS of a floor or ceil is bounded and the RHS is a constant, we
1602 // can compute an upper bound.
1603 if (binOpExpr
.getKind() == AffineExprKind::FloorDiv
) {
1604 auto rhsConst
= dyn_cast
<AffineConstantExpr
>(binOpExpr
.getRHS());
1605 if (!rhsConst
|| rhsConst
.getValue() < 1)
1606 return std::nullopt
;
1608 getBoundForAffineExpr(binOpExpr
.getLHS(), numDims
, numSymbols
,
1609 constLowerBounds
, constUpperBounds
, isUpper
);
1611 return std::nullopt
;
1612 return divideFloorSigned(*bound
, rhsConst
.getValue());
1614 if (binOpExpr
.getKind() == AffineExprKind::CeilDiv
) {
1615 auto rhsConst
= dyn_cast
<AffineConstantExpr
>(binOpExpr
.getRHS());
1616 if (rhsConst
&& rhsConst
.getValue() >= 1) {
1618 getBoundForAffineExpr(binOpExpr
.getLHS(), numDims
, numSymbols
,
1619 constLowerBounds
, constUpperBounds
, isUpper
);
1621 return std::nullopt
;
1622 return divideCeilSigned(*bound
, rhsConst
.getValue());
1624 return std::nullopt
;
1626 if (binOpExpr
.getKind() == AffineExprKind::Mod
) {
1627 // lhs mod c is always <= c - 1 and non-negative. In addition, if `lhs` is
1628 // bounded such that lb <= lhs <= ub and lb floordiv c == ub floordiv c
1629 // (same "interval"), then lb mod c <= lhs mod c <= ub mod c.
1630 auto rhsConst
= dyn_cast
<AffineConstantExpr
>(binOpExpr
.getRHS());
1631 if (rhsConst
&& rhsConst
.getValue() >= 1) {
1632 int64_t rhsConstVal
= rhsConst
.getValue();
1633 auto lb
= getBoundForAffineExpr(binOpExpr
.getLHS(), numDims
, numSymbols
,
1634 constLowerBounds
, constUpperBounds
,
1637 getBoundForAffineExpr(binOpExpr
.getLHS(), numDims
, numSymbols
,
1638 constLowerBounds
, constUpperBounds
, isUpper
);
1640 divideFloorSigned(*lb
, rhsConstVal
) ==
1641 divideFloorSigned(*ub
, rhsConstVal
))
1642 return isUpper
? mod(*ub
, rhsConstVal
) : mod(*lb
, rhsConstVal
);
1643 return isUpper
? rhsConstVal
- 1 : 0;
1647 // Flatten the expression.
1648 SimpleAffineExprFlattener
flattener(numDims
, numSymbols
);
1649 auto simpleResult
= flattener
.walkPostOrder(expr
);
1650 // has poison expression
1651 if (failed(simpleResult
))
1652 return std::nullopt
;
1653 ArrayRef
<int64_t> flattenedExpr
= flattener
.operandExprStack
.back();
1654 // TODO: Handle local variables. We can get hold of flattener.localExprs and
1655 // get bound on the local expr recursively.
1656 if (flattener
.numLocals
> 0)
1657 return std::nullopt
;
1659 // Substitute the constant lower or upper bound for the dimensional or
1660 // symbolic input depending on `isUpper` to determine the bound.
1661 for (unsigned i
= 0, e
= numDims
+ numSymbols
; i
< e
; ++i
) {
1662 if (flattenedExpr
[i
] > 0) {
1663 auto &constBound
= isUpper
? constUpperBounds
[i
] : constLowerBounds
[i
];
1665 return std::nullopt
;
1666 bound
+= *constBound
* flattenedExpr
[i
];
1667 } else if (flattenedExpr
[i
] < 0) {
1668 auto &constBound
= isUpper
? constLowerBounds
[i
] : constUpperBounds
[i
];
1670 return std::nullopt
;
1671 bound
+= *constBound
* flattenedExpr
[i
];
1675 bound
+= flattenedExpr
.back();