1 //===- ValueBoundsOpInterface.cpp - Value Bounds -------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
11 #include "mlir/IR/BuiltinTypes.h"
12 #include "mlir/IR/Matchers.h"
13 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
14 #include "mlir/Interfaces/ViewLikeInterface.h"
15 #include "llvm/ADT/APSInt.h"
16 #include "llvm/Support/Debug.h"
18 #define DEBUG_TYPE "value-bounds-op-interface"
21 using presburger::BoundType
;
22 using presburger::VarKind
;
25 #include "mlir/Interfaces/ValueBoundsOpInterface.cpp.inc"
28 static Operation
*getOwnerOfValue(Value value
) {
29 if (auto bbArg
= dyn_cast
<BlockArgument
>(value
))
30 return bbArg
.getOwner()->getParentOp();
31 return value
.getDefiningOp();
34 HyperrectangularSlice::HyperrectangularSlice(ArrayRef
<OpFoldResult
> offsets
,
35 ArrayRef
<OpFoldResult
> sizes
,
36 ArrayRef
<OpFoldResult
> strides
)
37 : mixedOffsets(offsets
), mixedSizes(sizes
), mixedStrides(strides
) {
38 assert(offsets
.size() == sizes
.size() &&
39 "expected same number of offsets, sizes, strides");
40 assert(offsets
.size() == strides
.size() &&
41 "expected same number of offsets, sizes, strides");
44 HyperrectangularSlice::HyperrectangularSlice(ArrayRef
<OpFoldResult
> offsets
,
45 ArrayRef
<OpFoldResult
> sizes
)
46 : mixedOffsets(offsets
), mixedSizes(sizes
) {
47 assert(offsets
.size() == sizes
.size() &&
48 "expected same number of offsets and sizes");
49 // Assume that all strides are 1.
52 MLIRContext
*ctx
= offsets
.front().getContext();
53 mixedStrides
.append(offsets
.size(), Builder(ctx
).getIndexAttr(1));
56 HyperrectangularSlice::HyperrectangularSlice(OffsetSizeAndStrideOpInterface op
)
57 : HyperrectangularSlice(op
.getMixedOffsets(), op
.getMixedSizes(),
58 op
.getMixedStrides()) {}
60 /// If ofr is a constant integer or an IntegerAttr, return the integer.
61 static std::optional
<int64_t> getConstantIntValue(OpFoldResult ofr
) {
62 // Case 1: Check for Constant integer.
63 if (auto val
= llvm::dyn_cast_if_present
<Value
>(ofr
)) {
65 if (matchPattern(val
, m_ConstantInt(&intVal
)))
66 return intVal
.getSExtValue();
69 // Case 2: Check for IntegerAttr.
70 Attribute attr
= llvm::dyn_cast_if_present
<Attribute
>(ofr
);
71 if (auto intAttr
= dyn_cast_or_null
<IntegerAttr
>(attr
))
72 return intAttr
.getValue().getSExtValue();
76 ValueBoundsConstraintSet::Variable::Variable(OpFoldResult ofr
)
77 : Variable(ofr
, std::nullopt
) {}
79 ValueBoundsConstraintSet::Variable::Variable(Value indexValue
)
80 : Variable(static_cast<OpFoldResult
>(indexValue
)) {}
82 ValueBoundsConstraintSet::Variable::Variable(Value shapedValue
, int64_t dim
)
83 : Variable(static_cast<OpFoldResult
>(shapedValue
), std::optional(dim
)) {}
85 ValueBoundsConstraintSet::Variable::Variable(OpFoldResult ofr
,
86 std::optional
<int64_t> dim
) {
87 Builder
b(ofr
.getContext());
88 if (auto constInt
= ::getConstantIntValue(ofr
)) {
89 assert(!dim
&& "expected no dim for index-typed values");
90 map
= AffineMap::get(/*dimCount=*/0, /*symbolCount=*/0,
91 b
.getAffineConstantExpr(*constInt
));
94 Value value
= cast
<Value
>(ofr
);
97 assert(isa
<ShapedType
>(value
.getType()) && "expected shaped type");
99 assert(value
.getType().isIndex() && "expected index type");
102 map
= AffineMap::get(/*dimCount=*/0, /*symbolCount=*/1,
103 b
.getAffineSymbolExpr(0));
104 mapOperands
.emplace_back(value
, dim
);
107 ValueBoundsConstraintSet::Variable::Variable(AffineMap map
,
108 ArrayRef
<Variable
> mapOperands
) {
109 assert(map
.getNumResults() == 1 && "expected single result");
111 // Turn all dims into symbols.
112 Builder
b(map
.getContext());
113 SmallVector
<AffineExpr
> dimReplacements
, symReplacements
;
114 for (int64_t i
= 0, e
= map
.getNumDims(); i
< e
; ++i
)
115 dimReplacements
.push_back(b
.getAffineSymbolExpr(i
));
116 for (int64_t i
= 0, e
= map
.getNumSymbols(); i
< e
; ++i
)
117 symReplacements
.push_back(b
.getAffineSymbolExpr(i
+ map
.getNumDims()));
118 AffineMap tmpMap
= map
.replaceDimsAndSymbols(
119 dimReplacements
, symReplacements
, /*numResultDims=*/0,
120 /*numResultSyms=*/map
.getNumSymbols() + map
.getNumDims());
123 DenseMap
<AffineExpr
, AffineExpr
> replacements
;
124 for (auto [index
, var
] : llvm::enumerate(mapOperands
)) {
125 assert(var
.map
.getNumResults() == 1 && "expected single result");
126 assert(var
.map
.getNumDims() == 0 && "expected only symbols");
127 SmallVector
<AffineExpr
> symReplacements
;
128 for (auto valueDim
: var
.mapOperands
) {
129 auto it
= llvm::find(this->mapOperands
, valueDim
);
130 if (it
!= this->mapOperands
.end()) {
131 // There is already a symbol for this operand.
132 symReplacements
.push_back(b
.getAffineSymbolExpr(
133 std::distance(this->mapOperands
.begin(), it
)));
135 // This is a new operand: add a new symbol.
136 symReplacements
.push_back(
137 b
.getAffineSymbolExpr(this->mapOperands
.size()));
138 this->mapOperands
.push_back(valueDim
);
141 replacements
[b
.getAffineSymbolExpr(index
)] =
142 var
.map
.getResult(0).replaceSymbols(symReplacements
);
144 this->map
= tmpMap
.replace(replacements
, /*numResultDims=*/0,
145 /*numResultSyms=*/this->mapOperands
.size());
148 ValueBoundsConstraintSet::Variable::Variable(AffineMap map
,
149 ArrayRef
<Value
> mapOperands
)
150 : Variable(map
, llvm::map_to_vector(mapOperands
,
151 [](Value v
) { return Variable(v
); })) {}
153 ValueBoundsConstraintSet::ValueBoundsConstraintSet(
154 MLIRContext
*ctx
, StopConditionFn stopCondition
,
155 bool addConservativeSemiAffineBounds
)
156 : builder(ctx
), stopCondition(stopCondition
),
157 addConservativeSemiAffineBounds(addConservativeSemiAffineBounds
) {
158 assert(stopCondition
&& "expected non-null stop condition");
161 char ValueBoundsConstraintSet::ID
= 0;
164 static void assertValidValueDim(Value value
, std::optional
<int64_t> dim
) {
165 if (value
.getType().isIndex()) {
166 assert(!dim
.has_value() && "invalid dim value");
167 } else if (auto shapedType
= dyn_cast
<ShapedType
>(value
.getType())) {
168 assert(*dim
>= 0 && "invalid dim value");
169 if (shapedType
.hasRank())
170 assert(*dim
< shapedType
.getRank() && "invalid dim value");
172 llvm_unreachable("unsupported type");
177 void ValueBoundsConstraintSet::addBound(BoundType type
, int64_t pos
,
179 // Note: If `addConservativeSemiAffineBounds` is true then the bound
180 // computation function needs to handle the case that the constraints set
181 // could become empty. This is because the conservative bounds add assumptions
182 // (e.g. for `mod` it assumes `rhs > 0`). If these constraints are later found
183 // not to hold, then the bound is invalid.
184 LogicalResult status
= cstr
.addBound(
186 AffineMap::get(cstr
.getNumDimVars(), cstr
.getNumSymbolVars(), expr
),
187 addConservativeSemiAffineBounds
188 ? FlatLinearConstraints::AddConservativeSemiAffineBounds::Yes
189 : FlatLinearConstraints::AddConservativeSemiAffineBounds::No
);
190 if (failed(status
)) {
191 // Not all semi-affine expressions are not yet supported by
192 // FlatLinearConstraints. However, we can just ignore such failures here.
193 // Even without this bound, there may be enough information in the
194 // constraint system to compute the requested bound. In case this bound is
195 // actually needed, `computeBound` will return `failure`.
196 LLVM_DEBUG(llvm::dbgs() << "Failed to add bound: " << expr
<< "\n");
200 AffineExpr
ValueBoundsConstraintSet::getExpr(Value value
,
201 std::optional
<int64_t> dim
) {
203 assertValidValueDim(value
, dim
);
206 // Check if the value/dim is statically known. In that case, an affine
207 // constant expression should be returned. This allows us to support
208 // multiplications with constants. (Multiplications of two columns in the
209 // constraint set is not supported.)
210 std::optional
<int64_t> constSize
= std::nullopt
;
211 auto shapedType
= dyn_cast
<ShapedType
>(value
.getType());
213 if (shapedType
.hasRank() && !shapedType
.isDynamicDim(*dim
))
214 constSize
= shapedType
.getDimSize(*dim
);
215 } else if (auto constInt
= ::getConstantIntValue(value
)) {
216 constSize
= *constInt
;
219 // If the value/dim is already mapped, return the corresponding expression
221 ValueDim valueDim
= std::make_pair(value
, dim
.value_or(kIndexValue
));
222 if (valueDimToPosition
.contains(valueDim
)) {
223 // If it is a constant, return an affine constant expression. Otherwise,
224 // return an affine expression that represents the respective column in the
227 return builder
.getAffineConstantExpr(*constSize
);
228 return getPosExpr(getPos(value
, dim
));
232 // Constant index value/dim: add column to the constraint set, add EQ bound
233 // and return an affine constant expression without pushing the newly added
234 // column to the worklist.
235 (void)insert(value
, dim
, /*isSymbol=*/true, /*addToWorklist=*/false);
237 bound(value
)[*dim
] == *constSize
;
239 bound(value
) == *constSize
;
240 return builder
.getAffineConstantExpr(*constSize
);
243 // Dynamic value/dim: insert column to the constraint set and put it on the
244 // worklist. Return an affine expression that represents the newly inserted
245 // column in the constraint set.
246 return getPosExpr(insert(value
, dim
, /*isSymbol=*/true));
249 AffineExpr
ValueBoundsConstraintSet::getExpr(OpFoldResult ofr
) {
250 if (Value value
= llvm::dyn_cast_if_present
<Value
>(ofr
))
251 return getExpr(value
, /*dim=*/std::nullopt
);
252 auto constInt
= ::getConstantIntValue(ofr
);
253 assert(constInt
.has_value() && "expected Integer constant");
254 return builder
.getAffineConstantExpr(*constInt
);
257 AffineExpr
ValueBoundsConstraintSet::getExpr(int64_t constant
) {
258 return builder
.getAffineConstantExpr(constant
);
261 int64_t ValueBoundsConstraintSet::insert(Value value
,
262 std::optional
<int64_t> dim
,
263 bool isSymbol
, bool addToWorklist
) {
265 assertValidValueDim(value
, dim
);
268 ValueDim valueDim
= std::make_pair(value
, dim
.value_or(kIndexValue
));
269 assert(!valueDimToPosition
.contains(valueDim
) && "already mapped");
270 int64_t pos
= isSymbol
? cstr
.appendVar(VarKind::Symbol
)
271 : cstr
.appendVar(VarKind::SetDim
);
272 LLVM_DEBUG(llvm::dbgs() << "Inserting constraint set column " << pos
274 << " (dim: " << dim
.value_or(kIndexValue
)
275 << ", owner: " << getOwnerOfValue(value
)->getName()
277 positionToValueDim
.insert(positionToValueDim
.begin() + pos
, valueDim
);
278 // Update reverse mapping.
279 for (int64_t i
= pos
, e
= positionToValueDim
.size(); i
< e
; ++i
)
280 if (positionToValueDim
[i
].has_value())
281 valueDimToPosition
[*positionToValueDim
[i
]] = i
;
284 LLVM_DEBUG(llvm::dbgs() << "Push to worklist: " << value
285 << " (dim: " << dim
.value_or(kIndexValue
) << ")\n");
292 int64_t ValueBoundsConstraintSet::insert(bool isSymbol
) {
293 int64_t pos
= isSymbol
? cstr
.appendVar(VarKind::Symbol
)
294 : cstr
.appendVar(VarKind::SetDim
);
295 LLVM_DEBUG(llvm::dbgs() << "Inserting anonymous constraint set column " << pos
297 positionToValueDim
.insert(positionToValueDim
.begin() + pos
, std::nullopt
);
298 // Update reverse mapping.
299 for (int64_t i
= pos
, e
= positionToValueDim
.size(); i
< e
; ++i
)
300 if (positionToValueDim
[i
].has_value())
301 valueDimToPosition
[*positionToValueDim
[i
]] = i
;
305 int64_t ValueBoundsConstraintSet::insert(AffineMap map
, ValueDimList operands
,
307 assert(map
.getNumResults() == 1 && "expected affine map with one result");
308 int64_t pos
= insert(isSymbol
);
310 // Add map and operands to the constraint set. Dimensions are converted to
311 // symbols. All operands are added to the worklist (unless they were already
313 auto mapper
= [&](std::pair
<Value
, std::optional
<int64_t>> v
) {
314 return getExpr(v
.first
, v
.second
);
316 SmallVector
<AffineExpr
> dimReplacements
= llvm::to_vector(
317 llvm::map_range(ArrayRef(operands
).take_front(map
.getNumDims()), mapper
));
318 SmallVector
<AffineExpr
> symReplacements
= llvm::to_vector(
319 llvm::map_range(ArrayRef(operands
).drop_front(map
.getNumDims()), mapper
));
321 presburger::BoundType::EQ
, pos
,
322 map
.getResult(0).replaceDimsAndSymbols(dimReplacements
, symReplacements
));
327 int64_t ValueBoundsConstraintSet::insert(const Variable
&var
, bool isSymbol
) {
328 return insert(var
.map
, var
.mapOperands
, isSymbol
);
331 int64_t ValueBoundsConstraintSet::getPos(Value value
,
332 std::optional
<int64_t> dim
) const {
334 assertValidValueDim(value
, dim
);
335 assert((isa
<OpResult
>(value
) ||
336 cast
<BlockArgument
>(value
).getOwner()->isEntryBlock()) &&
337 "unstructured control flow is not supported");
339 LLVM_DEBUG(llvm::dbgs() << "Getting pos for: " << value
340 << " (dim: " << dim
.value_or(kIndexValue
)
341 << ", owner: " << getOwnerOfValue(value
)->getName()
344 valueDimToPosition
.find(std::make_pair(value
, dim
.value_or(kIndexValue
)));
345 assert(it
!= valueDimToPosition
.end() && "expected mapped entry");
349 AffineExpr
ValueBoundsConstraintSet::getPosExpr(int64_t pos
) {
350 assert(pos
>= 0 && pos
< cstr
.getNumDimAndSymbolVars() && "invalid position");
351 return pos
< cstr
.getNumDimVars()
352 ? builder
.getAffineDimExpr(pos
)
353 : builder
.getAffineSymbolExpr(pos
- cstr
.getNumDimVars());
356 bool ValueBoundsConstraintSet::isMapped(Value value
,
357 std::optional
<int64_t> dim
) const {
359 valueDimToPosition
.find(std::make_pair(value
, dim
.value_or(kIndexValue
)));
360 return it
!= valueDimToPosition
.end();
363 void ValueBoundsConstraintSet::processWorklist() {
364 LLVM_DEBUG(llvm::dbgs() << "Processing value bounds worklist...\n");
365 while (!worklist
.empty()) {
366 int64_t pos
= worklist
.front();
368 assert(positionToValueDim
[pos
].has_value() &&
369 "did not expect std::nullopt on worklist");
370 ValueDim valueDim
= *positionToValueDim
[pos
];
371 Value value
= valueDim
.first
;
372 int64_t dim
= valueDim
.second
;
374 // Check for static dim size.
375 if (dim
!= kIndexValue
) {
376 auto shapedType
= cast
<ShapedType
>(value
.getType());
377 if (shapedType
.hasRank() && !shapedType
.isDynamicDim(dim
)) {
378 bound(value
)[dim
] == getExpr(shapedType
.getDimSize(dim
));
383 // Do not process any further if the stop condition is met.
384 auto maybeDim
= dim
== kIndexValue
? std::nullopt
: std::make_optional(dim
);
385 if (stopCondition(value
, maybeDim
, *this)) {
386 LLVM_DEBUG(llvm::dbgs() << "Stop condition met for: " << value
387 << " (dim: " << maybeDim
<< ")\n");
391 // Query `ValueBoundsOpInterface` for constraints. New items may be added to
394 dyn_cast
<ValueBoundsOpInterface
>(getOwnerOfValue(value
));
395 LLVM_DEBUG(llvm::dbgs()
396 << "Query value bounds for: " << value
397 << " (owner: " << getOwnerOfValue(value
)->getName() << ")\n");
399 if (dim
== kIndexValue
) {
400 valueBoundsOp
.populateBoundsForIndexValue(value
, *this);
402 valueBoundsOp
.populateBoundsForShapedValueDim(value
, dim
, *this);
406 LLVM_DEBUG(llvm::dbgs() << "--> ValueBoundsOpInterface not implemented\n");
408 // If the op does not implement `ValueBoundsOpInterface`, check if it
409 // implements the `DestinationStyleOpInterface`. OpResults of such ops are
410 // tied to OpOperands. Tied values have the same shape.
411 auto dstOp
= value
.getDefiningOp
<DestinationStyleOpInterface
>();
412 if (!dstOp
|| dim
== kIndexValue
)
414 Value tiedOperand
= dstOp
.getTiedOpOperand(cast
<OpResult
>(value
))->get();
415 bound(value
)[dim
] == getExpr(tiedOperand
, dim
);
419 void ValueBoundsConstraintSet::projectOut(int64_t pos
) {
420 assert(pos
>= 0 && pos
< static_cast<int64_t>(positionToValueDim
.size()) &&
422 cstr
.projectOut(pos
);
423 if (positionToValueDim
[pos
].has_value()) {
424 bool erased
= valueDimToPosition
.erase(*positionToValueDim
[pos
]);
426 assert(erased
&& "inconsistent reverse mapping");
428 positionToValueDim
.erase(positionToValueDim
.begin() + pos
);
429 // Update reverse mapping.
430 for (int64_t i
= pos
, e
= positionToValueDim
.size(); i
< e
; ++i
)
431 if (positionToValueDim
[i
].has_value())
432 valueDimToPosition
[*positionToValueDim
[i
]] = i
;
435 void ValueBoundsConstraintSet::projectOut(
436 function_ref
<bool(ValueDim
)> condition
) {
438 while (nextPos
< static_cast<int64_t>(positionToValueDim
.size())) {
439 if (positionToValueDim
[nextPos
].has_value() &&
440 condition(*positionToValueDim
[nextPos
])) {
442 // The column was projected out so another column is now at that position.
443 // Do not increase the counter.
450 void ValueBoundsConstraintSet::projectOutAnonymous(
451 std::optional
<int64_t> except
) {
453 while (nextPos
< static_cast<int64_t>(positionToValueDim
.size())) {
454 if (positionToValueDim
[nextPos
].has_value() || except
== nextPos
) {
458 // The column was projected out so another column is now at that position.
459 // Do not increase the counter.
464 LogicalResult
ValueBoundsConstraintSet::computeBound(
465 AffineMap
&resultMap
, ValueDimList
&mapOperands
, presburger::BoundType type
,
466 const Variable
&var
, StopConditionFn stopCondition
, bool closedUB
) {
467 MLIRContext
*ctx
= var
.getContext();
468 int64_t ubAdjustment
= closedUB
? 0 : 1;
472 // Process the backward slice of `value` (i.e., reverse use-def chain) until
473 // `stopCondition` is met.
474 ValueBoundsConstraintSet
cstr(ctx
, stopCondition
);
475 int64_t pos
= cstr
.insert(var
, /*isSymbol=*/false);
476 assert(pos
== 0 && "expected first column");
477 cstr
.processWorklist();
479 // Project out all variables (apart from `valueDim`) that do not match the
481 cstr
.projectOut([&](ValueDim p
) {
483 p
.second
== kIndexValue
? std::nullopt
: std::make_optional(p
.second
);
484 return !stopCondition(p
.first
, maybeDim
, cstr
);
486 cstr
.projectOutAnonymous(/*except=*/pos
);
488 // Compute lower and upper bounds for `valueDim`.
489 SmallVector
<AffineMap
> lb(1), ub(1);
490 cstr
.cstr
.getSliceBounds(pos
, 1, ctx
, &lb
, &ub
,
493 // Note: There are TODOs in the implementation of `getSliceBounds`. In such a
494 // case, no lower/upper bound can be computed at the moment.
495 // EQ, UB bounds: upper bound is needed.
496 if ((type
!= BoundType::LB
) &&
497 (ub
.empty() || !ub
[0] || ub
[0].getNumResults() == 0))
499 // EQ, LB bounds: lower bound is needed.
500 if ((type
!= BoundType::UB
) &&
501 (lb
.empty() || !lb
[0] || lb
[0].getNumResults() == 0))
504 // TODO: Generate an affine map with multiple results.
505 if (type
!= BoundType::LB
)
506 assert(ub
.size() == 1 && ub
[0].getNumResults() == 1 &&
507 "multiple bounds not supported");
508 if (type
!= BoundType::UB
)
509 assert(lb
.size() == 1 && lb
[0].getNumResults() == 1 &&
510 "multiple bounds not supported");
512 // EQ bound: lower and upper bound must match.
513 if (type
== BoundType::EQ
&& ub
[0] != lb
[0])
517 if (type
== BoundType::EQ
|| type
== BoundType::LB
) {
520 // Computed UB is a closed bound.
521 bound
= AffineMap::get(ub
[0].getNumDims(), ub
[0].getNumSymbols(),
522 ub
[0].getResult(0) + ubAdjustment
);
525 // Gather all SSA values that are used in the computed bound.
526 assert(cstr
.cstr
.getNumDimAndSymbolVars() == cstr
.positionToValueDim
.size() &&
527 "inconsistent mapping state");
528 SmallVector
<AffineExpr
> replacementDims
, replacementSymbols
;
529 int64_t numDims
= 0, numSymbols
= 0;
530 for (int64_t i
= 0; i
< cstr
.cstr
.getNumDimAndSymbolVars(); ++i
) {
534 // Check if the position `i` is used in the generated bound. If so, it must
535 // be included in the generated affine.apply op.
537 bool isDim
= i
< cstr
.cstr
.getNumDimVars();
539 if (bound
.isFunctionOfDim(i
))
542 if (bound
.isFunctionOfSymbol(i
- cstr
.cstr
.getNumDimVars()))
547 // Not used: Remove dim/symbol from the result.
549 replacementDims
.push_back(b
.getAffineConstantExpr(0));
551 replacementSymbols
.push_back(b
.getAffineConstantExpr(0));
557 replacementDims
.push_back(b
.getAffineDimExpr(numDims
++));
559 replacementSymbols
.push_back(b
.getAffineSymbolExpr(numSymbols
++));
562 assert(cstr
.positionToValueDim
[i
].has_value() &&
563 "cannot build affine map in terms of anonymous column");
564 ValueBoundsConstraintSet::ValueDim valueDim
= *cstr
.positionToValueDim
[i
];
565 Value value
= valueDim
.first
;
566 int64_t dim
= valueDim
.second
;
567 if (dim
== ValueBoundsConstraintSet::kIndexValue
) {
568 // An index-type value is used: can be used directly in the affine.apply
570 assert(value
.getType().isIndex() && "expected index type");
571 mapOperands
.push_back(std::make_pair(value
, std::nullopt
));
575 assert(cast
<ShapedType
>(value
.getType()).isDynamicDim(dim
) &&
576 "expected dynamic dim");
577 mapOperands
.push_back(std::make_pair(value
, dim
));
580 resultMap
= bound
.replaceDimsAndSymbols(replacementDims
, replacementSymbols
,
581 numDims
, numSymbols
);
585 LogicalResult
ValueBoundsConstraintSet::computeDependentBound(
586 AffineMap
&resultMap
, ValueDimList
&mapOperands
, presburger::BoundType type
,
587 const Variable
&var
, ValueDimList dependencies
, bool closedUB
) {
589 resultMap
, mapOperands
, type
, var
,
590 [&](Value v
, std::optional
<int64_t> d
, ValueBoundsConstraintSet
&cstr
) {
591 return llvm::is_contained(dependencies
, std::make_pair(v
, d
));
596 LogicalResult
ValueBoundsConstraintSet::computeIndependentBound(
597 AffineMap
&resultMap
, ValueDimList
&mapOperands
, presburger::BoundType type
,
598 const Variable
&var
, ValueRange independencies
, bool closedUB
) {
599 // Return "true" if the given value is independent of all values in
600 // `independencies`. I.e., neither the value itself nor any value in the
601 // backward slice (reverse use-def chain) is contained in `independencies`.
602 auto isIndependent
= [&](Value v
) {
603 SmallVector
<Value
> worklist
;
604 DenseSet
<Value
> visited
;
605 worklist
.push_back(v
);
606 while (!worklist
.empty()) {
607 Value next
= worklist
.pop_back_val();
608 if (!visited
.insert(next
).second
)
610 if (llvm::is_contained(independencies
, next
))
612 // TODO: DominanceInfo could be used to stop the traversal early.
613 Operation
*op
= next
.getDefiningOp();
616 worklist
.append(op
->getOperands().begin(), op
->getOperands().end());
621 // Reify bounds in terms of any independent values.
623 resultMap
, mapOperands
, type
, var
,
624 [&](Value v
, std::optional
<int64_t> d
, ValueBoundsConstraintSet
&cstr
) {
625 return isIndependent(v
);
630 FailureOr
<int64_t> ValueBoundsConstraintSet::computeConstantBound(
631 presburger::BoundType type
, const Variable
&var
,
632 StopConditionFn stopCondition
, bool closedUB
) {
633 // Default stop condition if none was specified: Keep adding constraints until
634 // a bound could be computed.
636 auto defaultStopCondition
= [&](Value v
, std::optional
<int64_t> dim
,
637 ValueBoundsConstraintSet
&cstr
) {
638 return cstr
.cstr
.getConstantBound64(type
, pos
).has_value();
641 ValueBoundsConstraintSet
cstr(
642 var
.getContext(), stopCondition
? stopCondition
: defaultStopCondition
);
643 pos
= cstr
.populateConstraints(var
.map
, var
.mapOperands
);
644 assert(pos
== 0 && "expected `map` is the first column");
646 // Compute constant bound for `valueDim`.
647 int64_t ubAdjustment
= closedUB
? 0 : 1;
648 if (auto bound
= cstr
.cstr
.getConstantBound64(type
, pos
))
649 return type
== BoundType::UB
? *bound
+ ubAdjustment
: *bound
;
653 void ValueBoundsConstraintSet::populateConstraints(Value value
,
654 std::optional
<int64_t> dim
) {
656 assertValidValueDim(value
, dim
);
659 // `getExpr` pushes the value/dim onto the worklist (unless it was already
661 (void)getExpr(value
, dim
);
662 // Process all values/dims on the worklist. This may traverse and analyze
663 // additional IR, depending the current stop function.
667 int64_t ValueBoundsConstraintSet::populateConstraints(AffineMap map
,
668 ValueDimList operands
) {
669 int64_t pos
= insert(map
, operands
, /*isSymbol=*/false);
670 // Process the backward slice of `operands` (i.e., reverse use-def chain)
671 // until `stopCondition` is met.
677 ValueBoundsConstraintSet::computeConstantDelta(Value value1
, Value value2
,
678 std::optional
<int64_t> dim1
,
679 std::optional
<int64_t> dim2
) {
681 assertValidValueDim(value1
, dim1
);
682 assertValidValueDim(value2
, dim2
);
685 Builder
b(value1
.getContext());
686 AffineMap map
= AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0,
687 b
.getAffineDimExpr(0) - b
.getAffineDimExpr(1));
688 return computeConstantBound(presburger::BoundType::EQ
,
689 Variable(map
, {{value1
, dim1
}, {value2
, dim2
}}));
692 bool ValueBoundsConstraintSet::comparePos(int64_t lhsPos
,
693 ComparisonOperator cmp
,
695 // This function returns "true" if "lhs CMP rhs" is proven to hold.
697 // Example for ComparisonOperator::LE and index-typed values: We would like to
698 // prove that lhs <= rhs. Proof by contradiction: add the inverse
699 // relation (lhs > rhs) to the constraint set and check if the resulting
700 // constraint set is "empty" (i.e. has no solution). In that case,
701 // lhs > rhs must be incorrect and we can deduce that lhs <= rhs holds.
703 // We cannot prove anything if the constraint set is already empty.
704 if (cstr
.isEmpty()) {
707 << "cannot compare value/dims: constraint system is already empty");
711 // EQ can be expressed as LE and GE.
713 return comparePos(lhsPos
, ComparisonOperator::LE
, rhsPos
) &&
714 comparePos(lhsPos
, ComparisonOperator::GE
, rhsPos
);
716 // Construct inequality.
717 SmallVector
<int64_t> eq(cstr
.getNumCols(), 0);
718 if (cmp
== LT
|| cmp
== LE
) {
721 } else if (cmp
== GT
|| cmp
== GE
) {
725 llvm_unreachable("unsupported comparison operator");
727 if (cmp
== LE
|| cmp
== GE
)
728 eq
[cstr
.getNumCols() - 1] -= 1;
730 // Add inequality to the constraint set and check if it made the constraint
732 int64_t ineqPos
= cstr
.getNumInequalities();
733 cstr
.addInequality(eq
);
734 bool isEmpty
= cstr
.isEmpty();
735 cstr
.removeInequality(ineqPos
);
739 bool ValueBoundsConstraintSet::populateAndCompare(const Variable
&lhs
,
740 ComparisonOperator cmp
,
741 const Variable
&rhs
) {
742 int64_t lhsPos
= populateConstraints(lhs
.map
, lhs
.mapOperands
);
743 int64_t rhsPos
= populateConstraints(rhs
.map
, rhs
.mapOperands
);
744 return comparePos(lhsPos
, cmp
, rhsPos
);
747 bool ValueBoundsConstraintSet::compare(const Variable
&lhs
,
748 ComparisonOperator cmp
,
749 const Variable
&rhs
) {
750 int64_t lhsPos
= -1, rhsPos
= -1;
751 auto stopCondition
= [&](Value v
, std::optional
<int64_t> dim
,
752 ValueBoundsConstraintSet
&cstr
) {
753 // Keep processing as long as lhs/rhs were not processed.
754 if (size_t(lhsPos
) >= cstr
.positionToValueDim
.size() ||
755 size_t(rhsPos
) >= cstr
.positionToValueDim
.size())
757 // Keep processing as long as the relation cannot be proven.
758 return cstr
.comparePos(lhsPos
, cmp
, rhsPos
);
760 ValueBoundsConstraintSet
cstr(lhs
.getContext(), stopCondition
);
761 lhsPos
= cstr
.populateConstraints(lhs
.map
, lhs
.mapOperands
);
762 rhsPos
= cstr
.populateConstraints(rhs
.map
, rhs
.mapOperands
);
763 return cstr
.comparePos(lhsPos
, cmp
, rhsPos
);
766 FailureOr
<bool> ValueBoundsConstraintSet::areEqual(const Variable
&var1
,
767 const Variable
&var2
) {
768 if (ValueBoundsConstraintSet::compare(var1
, ComparisonOperator::EQ
, var2
))
770 if (ValueBoundsConstraintSet::compare(var1
, ComparisonOperator::LT
, var2
) ||
771 ValueBoundsConstraintSet::compare(var1
, ComparisonOperator::GT
, var2
))
777 ValueBoundsConstraintSet::areOverlappingSlices(MLIRContext
*ctx
,
778 HyperrectangularSlice slice1
,
779 HyperrectangularSlice slice2
) {
780 assert(slice1
.getMixedOffsets().size() == slice1
.getMixedOffsets().size() &&
781 "expected slices of same rank");
782 assert(slice1
.getMixedSizes().size() == slice1
.getMixedSizes().size() &&
783 "expected slices of same rank");
784 assert(slice1
.getMixedStrides().size() == slice1
.getMixedStrides().size() &&
785 "expected slices of same rank");
788 bool foundUnknownBound
= false;
789 for (int64_t i
= 0, e
= slice1
.getMixedOffsets().size(); i
< e
; ++i
) {
791 AffineMap::get(/*dimCount=*/0, /*symbolCount=*/4,
792 b
.getAffineSymbolExpr(0) +
793 b
.getAffineSymbolExpr(1) * b
.getAffineSymbolExpr(2) -
794 b
.getAffineSymbolExpr(3));
796 // Case 1: Slices are guaranteed to be non-overlapping if
797 // offset1 + size1 * stride1 <= offset2 (for at least one dimension).
798 SmallVector
<OpFoldResult
> ofrOperands
;
799 ofrOperands
.push_back(slice1
.getMixedOffsets()[i
]);
800 ofrOperands
.push_back(slice1
.getMixedSizes()[i
]);
801 ofrOperands
.push_back(slice1
.getMixedStrides()[i
]);
802 ofrOperands
.push_back(slice2
.getMixedOffsets()[i
]);
803 SmallVector
<Value
> valueOperands
;
804 AffineMap foldedMap
=
805 foldAttributesIntoMap(b
, map
, ofrOperands
, valueOperands
);
806 FailureOr
<int64_t> constBound
= computeConstantBound(
807 presburger::BoundType::EQ
, Variable(foldedMap
, valueOperands
));
808 foundUnknownBound
|= failed(constBound
);
809 if (succeeded(constBound
) && *constBound
<= 0)
813 // Case 2: Slices are guaranteed to be non-overlapping if
814 // offset2 + size2 * stride2 <= offset1 (for at least one dimension).
815 SmallVector
<OpFoldResult
> ofrOperands
;
816 ofrOperands
.push_back(slice2
.getMixedOffsets()[i
]);
817 ofrOperands
.push_back(slice2
.getMixedSizes()[i
]);
818 ofrOperands
.push_back(slice2
.getMixedStrides()[i
]);
819 ofrOperands
.push_back(slice1
.getMixedOffsets()[i
]);
820 SmallVector
<Value
> valueOperands
;
821 AffineMap foldedMap
=
822 foldAttributesIntoMap(b
, map
, ofrOperands
, valueOperands
);
823 FailureOr
<int64_t> constBound
= computeConstantBound(
824 presburger::BoundType::EQ
, Variable(foldedMap
, valueOperands
));
825 foundUnknownBound
|= failed(constBound
);
826 if (succeeded(constBound
) && *constBound
<= 0)
831 // If at least one bound could not be computed, we cannot be certain that the
832 // slices are really overlapping.
833 if (foundUnknownBound
)
836 // All bounds could be computed and none of the above cases applied.
837 // Therefore, the slices are guaranteed to overlap.
842 ValueBoundsConstraintSet::areEquivalentSlices(MLIRContext
*ctx
,
843 HyperrectangularSlice slice1
,
844 HyperrectangularSlice slice2
) {
845 assert(slice1
.getMixedOffsets().size() == slice1
.getMixedOffsets().size() &&
846 "expected slices of same rank");
847 assert(slice1
.getMixedSizes().size() == slice1
.getMixedSizes().size() &&
848 "expected slices of same rank");
849 assert(slice1
.getMixedStrides().size() == slice1
.getMixedStrides().size() &&
850 "expected slices of same rank");
852 // The two slices are equivalent if all of their offsets, sizes and strides
853 // are equal. If equality cannot be determined for at least one of those
854 // values, equivalence cannot be determined and this function returns
856 for (auto [offset1
, offset2
] :
857 llvm::zip_equal(slice1
.getMixedOffsets(), slice2
.getMixedOffsets())) {
858 FailureOr
<bool> equal
= areEqual(offset1
, offset2
);
864 for (auto [size1
, size2
] :
865 llvm::zip_equal(slice1
.getMixedSizes(), slice2
.getMixedSizes())) {
866 FailureOr
<bool> equal
= areEqual(size1
, size2
);
872 for (auto [stride1
, stride2
] :
873 llvm::zip_equal(slice1
.getMixedStrides(), slice2
.getMixedStrides())) {
874 FailureOr
<bool> equal
= areEqual(stride1
, stride2
);
883 void ValueBoundsConstraintSet::dump() const {
884 llvm::errs() << "==========\nColumns:\n";
885 llvm::errs() << "(column\tdim\tvalue)\n";
886 for (auto [index
, valueDim
] : llvm::enumerate(positionToValueDim
)) {
887 llvm::errs() << " " << index
<< "\t";
889 if (valueDim
->second
== kIndexValue
) {
890 llvm::errs() << "n/a\t";
892 llvm::errs() << valueDim
->second
<< "\t";
894 llvm::errs() << getOwnerOfValue(valueDim
->first
)->getName() << " ";
895 if (OpResult result
= dyn_cast
<OpResult
>(valueDim
->first
)) {
896 llvm::errs() << "(result " << result
.getResultNumber() << ")";
898 llvm::errs() << "(bbarg "
899 << cast
<BlockArgument
>(valueDim
->first
).getArgNumber()
902 llvm::errs() << "\n";
904 llvm::errs() << "n/a\tn/a\n";
907 llvm::errs() << "\nConstraint set:\n";
909 llvm::errs() << "==========\n";
912 ValueBoundsConstraintSet::BoundBuilder
&
913 ValueBoundsConstraintSet::BoundBuilder::operator[](int64_t dim
) {
914 assert(!this->dim
.has_value() && "dim was already set");
917 assertValidValueDim(value
, this->dim
);
922 void ValueBoundsConstraintSet::BoundBuilder::operator<(AffineExpr expr
) {
924 assertValidValueDim(value
, this->dim
);
926 cstr
.addBound(BoundType::UB
, cstr
.getPos(value
, this->dim
), expr
);
929 void ValueBoundsConstraintSet::BoundBuilder::operator<=(AffineExpr expr
) {
933 void ValueBoundsConstraintSet::BoundBuilder::operator>(AffineExpr expr
) {
934 operator>=(expr
+ 1);
937 void ValueBoundsConstraintSet::BoundBuilder::operator>=(AffineExpr expr
) {
939 assertValidValueDim(value
, this->dim
);
941 cstr
.addBound(BoundType::LB
, cstr
.getPos(value
, this->dim
), expr
);
944 void ValueBoundsConstraintSet::BoundBuilder::operator==(AffineExpr expr
) {
946 assertValidValueDim(value
, this->dim
);
948 cstr
.addBound(BoundType::EQ
, cstr
.getPos(value
, this->dim
), expr
);
951 void ValueBoundsConstraintSet::BoundBuilder::operator<(OpFoldResult ofr
) {
952 operator<(cstr
.getExpr(ofr
));
955 void ValueBoundsConstraintSet::BoundBuilder::operator<=(OpFoldResult ofr
) {
956 operator<=(cstr
.getExpr(ofr
));
959 void ValueBoundsConstraintSet::BoundBuilder::operator>(OpFoldResult ofr
) {
960 operator>(cstr
.getExpr(ofr
));
963 void ValueBoundsConstraintSet::BoundBuilder::operator>=(OpFoldResult ofr
) {
964 operator>=(cstr
.getExpr(ofr
));
967 void ValueBoundsConstraintSet::BoundBuilder::operator==(OpFoldResult ofr
) {
968 operator==(cstr
.getExpr(ofr
));
971 void ValueBoundsConstraintSet::BoundBuilder::operator<(int64_t i
) {
972 operator<(cstr
.getExpr(i
));
975 void ValueBoundsConstraintSet::BoundBuilder::operator<=(int64_t i
) {
976 operator<=(cstr
.getExpr(i
));
979 void ValueBoundsConstraintSet::BoundBuilder::operator>(int64_t i
) {
980 operator>(cstr
.getExpr(i
));
983 void ValueBoundsConstraintSet::BoundBuilder::operator>=(int64_t i
) {
984 operator>=(cstr
.getExpr(i
));
987 void ValueBoundsConstraintSet::BoundBuilder::operator==(int64_t i
) {
988 operator==(cstr
.getExpr(i
));