Fix GCC build problem with 288f05f related to SmallVector. (#116958)
[llvm-project.git] / mlir / lib / Interfaces / ValueBoundsOpInterface.cpp
blob505e84e3ca0cf3b52bdaaa23edf8984d6ff00462
1 //===- ValueBoundsOpInterface.cpp - Value Bounds -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
11 #include "mlir/IR/BuiltinTypes.h"
12 #include "mlir/IR/Matchers.h"
13 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
14 #include "mlir/Interfaces/ViewLikeInterface.h"
15 #include "llvm/ADT/APSInt.h"
16 #include "llvm/Support/Debug.h"
18 #define DEBUG_TYPE "value-bounds-op-interface"
20 using namespace mlir;
21 using presburger::BoundType;
22 using presburger::VarKind;
24 namespace mlir {
25 #include "mlir/Interfaces/ValueBoundsOpInterface.cpp.inc"
26 } // namespace mlir
28 static Operation *getOwnerOfValue(Value value) {
29 if (auto bbArg = dyn_cast<BlockArgument>(value))
30 return bbArg.getOwner()->getParentOp();
31 return value.getDefiningOp();
34 HyperrectangularSlice::HyperrectangularSlice(ArrayRef<OpFoldResult> offsets,
35 ArrayRef<OpFoldResult> sizes,
36 ArrayRef<OpFoldResult> strides)
37 : mixedOffsets(offsets), mixedSizes(sizes), mixedStrides(strides) {
38 assert(offsets.size() == sizes.size() &&
39 "expected same number of offsets, sizes, strides");
40 assert(offsets.size() == strides.size() &&
41 "expected same number of offsets, sizes, strides");
44 HyperrectangularSlice::HyperrectangularSlice(ArrayRef<OpFoldResult> offsets,
45 ArrayRef<OpFoldResult> sizes)
46 : mixedOffsets(offsets), mixedSizes(sizes) {
47 assert(offsets.size() == sizes.size() &&
48 "expected same number of offsets and sizes");
49 // Assume that all strides are 1.
50 if (offsets.empty())
51 return;
52 MLIRContext *ctx = offsets.front().getContext();
53 mixedStrides.append(offsets.size(), Builder(ctx).getIndexAttr(1));
56 HyperrectangularSlice::HyperrectangularSlice(OffsetSizeAndStrideOpInterface op)
57 : HyperrectangularSlice(op.getMixedOffsets(), op.getMixedSizes(),
58 op.getMixedStrides()) {}
60 /// If ofr is a constant integer or an IntegerAttr, return the integer.
61 static std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
62 // Case 1: Check for Constant integer.
63 if (auto val = llvm::dyn_cast_if_present<Value>(ofr)) {
64 APSInt intVal;
65 if (matchPattern(val, m_ConstantInt(&intVal)))
66 return intVal.getSExtValue();
67 return std::nullopt;
69 // Case 2: Check for IntegerAttr.
70 Attribute attr = llvm::dyn_cast_if_present<Attribute>(ofr);
71 if (auto intAttr = dyn_cast_or_null<IntegerAttr>(attr))
72 return intAttr.getValue().getSExtValue();
73 return std::nullopt;
76 ValueBoundsConstraintSet::Variable::Variable(OpFoldResult ofr)
77 : Variable(ofr, std::nullopt) {}
79 ValueBoundsConstraintSet::Variable::Variable(Value indexValue)
80 : Variable(static_cast<OpFoldResult>(indexValue)) {}
82 ValueBoundsConstraintSet::Variable::Variable(Value shapedValue, int64_t dim)
83 : Variable(static_cast<OpFoldResult>(shapedValue), std::optional(dim)) {}
85 ValueBoundsConstraintSet::Variable::Variable(OpFoldResult ofr,
86 std::optional<int64_t> dim) {
87 Builder b(ofr.getContext());
88 if (auto constInt = ::getConstantIntValue(ofr)) {
89 assert(!dim && "expected no dim for index-typed values");
90 map = AffineMap::get(/*dimCount=*/0, /*symbolCount=*/0,
91 b.getAffineConstantExpr(*constInt));
92 return;
94 Value value = cast<Value>(ofr);
95 #ifndef NDEBUG
96 if (dim) {
97 assert(isa<ShapedType>(value.getType()) && "expected shaped type");
98 } else {
99 assert(value.getType().isIndex() && "expected index type");
101 #endif // NDEBUG
102 map = AffineMap::get(/*dimCount=*/0, /*symbolCount=*/1,
103 b.getAffineSymbolExpr(0));
104 mapOperands.emplace_back(value, dim);
107 ValueBoundsConstraintSet::Variable::Variable(AffineMap map,
108 ArrayRef<Variable> mapOperands) {
109 assert(map.getNumResults() == 1 && "expected single result");
111 // Turn all dims into symbols.
112 Builder b(map.getContext());
113 SmallVector<AffineExpr> dimReplacements, symReplacements;
114 for (int64_t i = 0, e = map.getNumDims(); i < e; ++i)
115 dimReplacements.push_back(b.getAffineSymbolExpr(i));
116 for (int64_t i = 0, e = map.getNumSymbols(); i < e; ++i)
117 symReplacements.push_back(b.getAffineSymbolExpr(i + map.getNumDims()));
118 AffineMap tmpMap = map.replaceDimsAndSymbols(
119 dimReplacements, symReplacements, /*numResultDims=*/0,
120 /*numResultSyms=*/map.getNumSymbols() + map.getNumDims());
122 // Inline operands.
123 DenseMap<AffineExpr, AffineExpr> replacements;
124 for (auto [index, var] : llvm::enumerate(mapOperands)) {
125 assert(var.map.getNumResults() == 1 && "expected single result");
126 assert(var.map.getNumDims() == 0 && "expected only symbols");
127 SmallVector<AffineExpr> symReplacements;
128 for (auto valueDim : var.mapOperands) {
129 auto it = llvm::find(this->mapOperands, valueDim);
130 if (it != this->mapOperands.end()) {
131 // There is already a symbol for this operand.
132 symReplacements.push_back(b.getAffineSymbolExpr(
133 std::distance(this->mapOperands.begin(), it)));
134 } else {
135 // This is a new operand: add a new symbol.
136 symReplacements.push_back(
137 b.getAffineSymbolExpr(this->mapOperands.size()));
138 this->mapOperands.push_back(valueDim);
141 replacements[b.getAffineSymbolExpr(index)] =
142 var.map.getResult(0).replaceSymbols(symReplacements);
144 this->map = tmpMap.replace(replacements, /*numResultDims=*/0,
145 /*numResultSyms=*/this->mapOperands.size());
148 ValueBoundsConstraintSet::Variable::Variable(AffineMap map,
149 ArrayRef<Value> mapOperands)
150 : Variable(map, llvm::map_to_vector(mapOperands,
151 [](Value v) { return Variable(v); })) {}
153 ValueBoundsConstraintSet::ValueBoundsConstraintSet(
154 MLIRContext *ctx, StopConditionFn stopCondition,
155 bool addConservativeSemiAffineBounds)
156 : builder(ctx), stopCondition(stopCondition),
157 addConservativeSemiAffineBounds(addConservativeSemiAffineBounds) {
158 assert(stopCondition && "expected non-null stop condition");
161 char ValueBoundsConstraintSet::ID = 0;
163 #ifndef NDEBUG
164 static void assertValidValueDim(Value value, std::optional<int64_t> dim) {
165 if (value.getType().isIndex()) {
166 assert(!dim.has_value() && "invalid dim value");
167 } else if (auto shapedType = dyn_cast<ShapedType>(value.getType())) {
168 assert(*dim >= 0 && "invalid dim value");
169 if (shapedType.hasRank())
170 assert(*dim < shapedType.getRank() && "invalid dim value");
171 } else {
172 llvm_unreachable("unsupported type");
175 #endif // NDEBUG
177 void ValueBoundsConstraintSet::addBound(BoundType type, int64_t pos,
178 AffineExpr expr) {
179 // Note: If `addConservativeSemiAffineBounds` is true then the bound
180 // computation function needs to handle the case that the constraints set
181 // could become empty. This is because the conservative bounds add assumptions
182 // (e.g. for `mod` it assumes `rhs > 0`). If these constraints are later found
183 // not to hold, then the bound is invalid.
184 LogicalResult status = cstr.addBound(
185 type, pos,
186 AffineMap::get(cstr.getNumDimVars(), cstr.getNumSymbolVars(), expr),
187 addConservativeSemiAffineBounds
188 ? FlatLinearConstraints::AddConservativeSemiAffineBounds::Yes
189 : FlatLinearConstraints::AddConservativeSemiAffineBounds::No);
190 if (failed(status)) {
191 // Not all semi-affine expressions are not yet supported by
192 // FlatLinearConstraints. However, we can just ignore such failures here.
193 // Even without this bound, there may be enough information in the
194 // constraint system to compute the requested bound. In case this bound is
195 // actually needed, `computeBound` will return `failure`.
196 LLVM_DEBUG(llvm::dbgs() << "Failed to add bound: " << expr << "\n");
200 AffineExpr ValueBoundsConstraintSet::getExpr(Value value,
201 std::optional<int64_t> dim) {
202 #ifndef NDEBUG
203 assertValidValueDim(value, dim);
204 #endif // NDEBUG
206 // Check if the value/dim is statically known. In that case, an affine
207 // constant expression should be returned. This allows us to support
208 // multiplications with constants. (Multiplications of two columns in the
209 // constraint set is not supported.)
210 std::optional<int64_t> constSize = std::nullopt;
211 auto shapedType = dyn_cast<ShapedType>(value.getType());
212 if (shapedType) {
213 if (shapedType.hasRank() && !shapedType.isDynamicDim(*dim))
214 constSize = shapedType.getDimSize(*dim);
215 } else if (auto constInt = ::getConstantIntValue(value)) {
216 constSize = *constInt;
219 // If the value/dim is already mapped, return the corresponding expression
220 // directly.
221 ValueDim valueDim = std::make_pair(value, dim.value_or(kIndexValue));
222 if (valueDimToPosition.contains(valueDim)) {
223 // If it is a constant, return an affine constant expression. Otherwise,
224 // return an affine expression that represents the respective column in the
225 // constraint set.
226 if (constSize)
227 return builder.getAffineConstantExpr(*constSize);
228 return getPosExpr(getPos(value, dim));
231 if (constSize) {
232 // Constant index value/dim: add column to the constraint set, add EQ bound
233 // and return an affine constant expression without pushing the newly added
234 // column to the worklist.
235 (void)insert(value, dim, /*isSymbol=*/true, /*addToWorklist=*/false);
236 if (shapedType)
237 bound(value)[*dim] == *constSize;
238 else
239 bound(value) == *constSize;
240 return builder.getAffineConstantExpr(*constSize);
243 // Dynamic value/dim: insert column to the constraint set and put it on the
244 // worklist. Return an affine expression that represents the newly inserted
245 // column in the constraint set.
246 return getPosExpr(insert(value, dim, /*isSymbol=*/true));
249 AffineExpr ValueBoundsConstraintSet::getExpr(OpFoldResult ofr) {
250 if (Value value = llvm::dyn_cast_if_present<Value>(ofr))
251 return getExpr(value, /*dim=*/std::nullopt);
252 auto constInt = ::getConstantIntValue(ofr);
253 assert(constInt.has_value() && "expected Integer constant");
254 return builder.getAffineConstantExpr(*constInt);
257 AffineExpr ValueBoundsConstraintSet::getExpr(int64_t constant) {
258 return builder.getAffineConstantExpr(constant);
261 int64_t ValueBoundsConstraintSet::insert(Value value,
262 std::optional<int64_t> dim,
263 bool isSymbol, bool addToWorklist) {
264 #ifndef NDEBUG
265 assertValidValueDim(value, dim);
266 #endif // NDEBUG
268 ValueDim valueDim = std::make_pair(value, dim.value_or(kIndexValue));
269 assert(!valueDimToPosition.contains(valueDim) && "already mapped");
270 int64_t pos = isSymbol ? cstr.appendVar(VarKind::Symbol)
271 : cstr.appendVar(VarKind::SetDim);
272 LLVM_DEBUG(llvm::dbgs() << "Inserting constraint set column " << pos
273 << " for: " << value
274 << " (dim: " << dim.value_or(kIndexValue)
275 << ", owner: " << getOwnerOfValue(value)->getName()
276 << ")\n");
277 positionToValueDim.insert(positionToValueDim.begin() + pos, valueDim);
278 // Update reverse mapping.
279 for (int64_t i = pos, e = positionToValueDim.size(); i < e; ++i)
280 if (positionToValueDim[i].has_value())
281 valueDimToPosition[*positionToValueDim[i]] = i;
283 if (addToWorklist) {
284 LLVM_DEBUG(llvm::dbgs() << "Push to worklist: " << value
285 << " (dim: " << dim.value_or(kIndexValue) << ")\n");
286 worklist.push(pos);
289 return pos;
292 int64_t ValueBoundsConstraintSet::insert(bool isSymbol) {
293 int64_t pos = isSymbol ? cstr.appendVar(VarKind::Symbol)
294 : cstr.appendVar(VarKind::SetDim);
295 LLVM_DEBUG(llvm::dbgs() << "Inserting anonymous constraint set column " << pos
296 << "\n");
297 positionToValueDim.insert(positionToValueDim.begin() + pos, std::nullopt);
298 // Update reverse mapping.
299 for (int64_t i = pos, e = positionToValueDim.size(); i < e; ++i)
300 if (positionToValueDim[i].has_value())
301 valueDimToPosition[*positionToValueDim[i]] = i;
302 return pos;
305 int64_t ValueBoundsConstraintSet::insert(AffineMap map, ValueDimList operands,
306 bool isSymbol) {
307 assert(map.getNumResults() == 1 && "expected affine map with one result");
308 int64_t pos = insert(isSymbol);
310 // Add map and operands to the constraint set. Dimensions are converted to
311 // symbols. All operands are added to the worklist (unless they were already
312 // processed).
313 auto mapper = [&](std::pair<Value, std::optional<int64_t>> v) {
314 return getExpr(v.first, v.second);
316 SmallVector<AffineExpr> dimReplacements = llvm::to_vector(
317 llvm::map_range(ArrayRef(operands).take_front(map.getNumDims()), mapper));
318 SmallVector<AffineExpr> symReplacements = llvm::to_vector(
319 llvm::map_range(ArrayRef(operands).drop_front(map.getNumDims()), mapper));
320 addBound(
321 presburger::BoundType::EQ, pos,
322 map.getResult(0).replaceDimsAndSymbols(dimReplacements, symReplacements));
324 return pos;
327 int64_t ValueBoundsConstraintSet::insert(const Variable &var, bool isSymbol) {
328 return insert(var.map, var.mapOperands, isSymbol);
331 int64_t ValueBoundsConstraintSet::getPos(Value value,
332 std::optional<int64_t> dim) const {
333 #ifndef NDEBUG
334 assertValidValueDim(value, dim);
335 assert((isa<OpResult>(value) ||
336 cast<BlockArgument>(value).getOwner()->isEntryBlock()) &&
337 "unstructured control flow is not supported");
338 #endif // NDEBUG
339 LLVM_DEBUG(llvm::dbgs() << "Getting pos for: " << value
340 << " (dim: " << dim.value_or(kIndexValue)
341 << ", owner: " << getOwnerOfValue(value)->getName()
342 << ")\n");
343 auto it =
344 valueDimToPosition.find(std::make_pair(value, dim.value_or(kIndexValue)));
345 assert(it != valueDimToPosition.end() && "expected mapped entry");
346 return it->second;
349 AffineExpr ValueBoundsConstraintSet::getPosExpr(int64_t pos) {
350 assert(pos >= 0 && pos < cstr.getNumDimAndSymbolVars() && "invalid position");
351 return pos < cstr.getNumDimVars()
352 ? builder.getAffineDimExpr(pos)
353 : builder.getAffineSymbolExpr(pos - cstr.getNumDimVars());
356 bool ValueBoundsConstraintSet::isMapped(Value value,
357 std::optional<int64_t> dim) const {
358 auto it =
359 valueDimToPosition.find(std::make_pair(value, dim.value_or(kIndexValue)));
360 return it != valueDimToPosition.end();
363 void ValueBoundsConstraintSet::processWorklist() {
364 LLVM_DEBUG(llvm::dbgs() << "Processing value bounds worklist...\n");
365 while (!worklist.empty()) {
366 int64_t pos = worklist.front();
367 worklist.pop();
368 assert(positionToValueDim[pos].has_value() &&
369 "did not expect std::nullopt on worklist");
370 ValueDim valueDim = *positionToValueDim[pos];
371 Value value = valueDim.first;
372 int64_t dim = valueDim.second;
374 // Check for static dim size.
375 if (dim != kIndexValue) {
376 auto shapedType = cast<ShapedType>(value.getType());
377 if (shapedType.hasRank() && !shapedType.isDynamicDim(dim)) {
378 bound(value)[dim] == getExpr(shapedType.getDimSize(dim));
379 continue;
383 // Do not process any further if the stop condition is met.
384 auto maybeDim = dim == kIndexValue ? std::nullopt : std::make_optional(dim);
385 if (stopCondition(value, maybeDim, *this)) {
386 LLVM_DEBUG(llvm::dbgs() << "Stop condition met for: " << value
387 << " (dim: " << maybeDim << ")\n");
388 continue;
391 // Query `ValueBoundsOpInterface` for constraints. New items may be added to
392 // the worklist.
393 auto valueBoundsOp =
394 dyn_cast<ValueBoundsOpInterface>(getOwnerOfValue(value));
395 LLVM_DEBUG(llvm::dbgs()
396 << "Query value bounds for: " << value
397 << " (owner: " << getOwnerOfValue(value)->getName() << ")\n");
398 if (valueBoundsOp) {
399 if (dim == kIndexValue) {
400 valueBoundsOp.populateBoundsForIndexValue(value, *this);
401 } else {
402 valueBoundsOp.populateBoundsForShapedValueDim(value, dim, *this);
404 continue;
406 LLVM_DEBUG(llvm::dbgs() << "--> ValueBoundsOpInterface not implemented\n");
408 // If the op does not implement `ValueBoundsOpInterface`, check if it
409 // implements the `DestinationStyleOpInterface`. OpResults of such ops are
410 // tied to OpOperands. Tied values have the same shape.
411 auto dstOp = value.getDefiningOp<DestinationStyleOpInterface>();
412 if (!dstOp || dim == kIndexValue)
413 continue;
414 Value tiedOperand = dstOp.getTiedOpOperand(cast<OpResult>(value))->get();
415 bound(value)[dim] == getExpr(tiedOperand, dim);
419 void ValueBoundsConstraintSet::projectOut(int64_t pos) {
420 assert(pos >= 0 && pos < static_cast<int64_t>(positionToValueDim.size()) &&
421 "invalid position");
422 cstr.projectOut(pos);
423 if (positionToValueDim[pos].has_value()) {
424 bool erased = valueDimToPosition.erase(*positionToValueDim[pos]);
425 (void)erased;
426 assert(erased && "inconsistent reverse mapping");
428 positionToValueDim.erase(positionToValueDim.begin() + pos);
429 // Update reverse mapping.
430 for (int64_t i = pos, e = positionToValueDim.size(); i < e; ++i)
431 if (positionToValueDim[i].has_value())
432 valueDimToPosition[*positionToValueDim[i]] = i;
435 void ValueBoundsConstraintSet::projectOut(
436 function_ref<bool(ValueDim)> condition) {
437 int64_t nextPos = 0;
438 while (nextPos < static_cast<int64_t>(positionToValueDim.size())) {
439 if (positionToValueDim[nextPos].has_value() &&
440 condition(*positionToValueDim[nextPos])) {
441 projectOut(nextPos);
442 // The column was projected out so another column is now at that position.
443 // Do not increase the counter.
444 } else {
445 ++nextPos;
450 void ValueBoundsConstraintSet::projectOutAnonymous(
451 std::optional<int64_t> except) {
452 int64_t nextPos = 0;
453 while (nextPos < static_cast<int64_t>(positionToValueDim.size())) {
454 if (positionToValueDim[nextPos].has_value() || except == nextPos) {
455 ++nextPos;
456 } else {
457 projectOut(nextPos);
458 // The column was projected out so another column is now at that position.
459 // Do not increase the counter.
464 LogicalResult ValueBoundsConstraintSet::computeBound(
465 AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type,
466 const Variable &var, StopConditionFn stopCondition, bool closedUB) {
467 MLIRContext *ctx = var.getContext();
468 int64_t ubAdjustment = closedUB ? 0 : 1;
469 Builder b(ctx);
470 mapOperands.clear();
472 // Process the backward slice of `value` (i.e., reverse use-def chain) until
473 // `stopCondition` is met.
474 ValueBoundsConstraintSet cstr(ctx, stopCondition);
475 int64_t pos = cstr.insert(var, /*isSymbol=*/false);
476 assert(pos == 0 && "expected first column");
477 cstr.processWorklist();
479 // Project out all variables (apart from `valueDim`) that do not match the
480 // stop condition.
481 cstr.projectOut([&](ValueDim p) {
482 auto maybeDim =
483 p.second == kIndexValue ? std::nullopt : std::make_optional(p.second);
484 return !stopCondition(p.first, maybeDim, cstr);
486 cstr.projectOutAnonymous(/*except=*/pos);
488 // Compute lower and upper bounds for `valueDim`.
489 SmallVector<AffineMap> lb(1), ub(1);
490 cstr.cstr.getSliceBounds(pos, 1, ctx, &lb, &ub,
491 /*closedUB=*/true);
493 // Note: There are TODOs in the implementation of `getSliceBounds`. In such a
494 // case, no lower/upper bound can be computed at the moment.
495 // EQ, UB bounds: upper bound is needed.
496 if ((type != BoundType::LB) &&
497 (ub.empty() || !ub[0] || ub[0].getNumResults() == 0))
498 return failure();
499 // EQ, LB bounds: lower bound is needed.
500 if ((type != BoundType::UB) &&
501 (lb.empty() || !lb[0] || lb[0].getNumResults() == 0))
502 return failure();
504 // TODO: Generate an affine map with multiple results.
505 if (type != BoundType::LB)
506 assert(ub.size() == 1 && ub[0].getNumResults() == 1 &&
507 "multiple bounds not supported");
508 if (type != BoundType::UB)
509 assert(lb.size() == 1 && lb[0].getNumResults() == 1 &&
510 "multiple bounds not supported");
512 // EQ bound: lower and upper bound must match.
513 if (type == BoundType::EQ && ub[0] != lb[0])
514 return failure();
516 AffineMap bound;
517 if (type == BoundType::EQ || type == BoundType::LB) {
518 bound = lb[0];
519 } else {
520 // Computed UB is a closed bound.
521 bound = AffineMap::get(ub[0].getNumDims(), ub[0].getNumSymbols(),
522 ub[0].getResult(0) + ubAdjustment);
525 // Gather all SSA values that are used in the computed bound.
526 assert(cstr.cstr.getNumDimAndSymbolVars() == cstr.positionToValueDim.size() &&
527 "inconsistent mapping state");
528 SmallVector<AffineExpr> replacementDims, replacementSymbols;
529 int64_t numDims = 0, numSymbols = 0;
530 for (int64_t i = 0; i < cstr.cstr.getNumDimAndSymbolVars(); ++i) {
531 // Skip `value`.
532 if (i == pos)
533 continue;
534 // Check if the position `i` is used in the generated bound. If so, it must
535 // be included in the generated affine.apply op.
536 bool used = false;
537 bool isDim = i < cstr.cstr.getNumDimVars();
538 if (isDim) {
539 if (bound.isFunctionOfDim(i))
540 used = true;
541 } else {
542 if (bound.isFunctionOfSymbol(i - cstr.cstr.getNumDimVars()))
543 used = true;
546 if (!used) {
547 // Not used: Remove dim/symbol from the result.
548 if (isDim) {
549 replacementDims.push_back(b.getAffineConstantExpr(0));
550 } else {
551 replacementSymbols.push_back(b.getAffineConstantExpr(0));
553 continue;
556 if (isDim) {
557 replacementDims.push_back(b.getAffineDimExpr(numDims++));
558 } else {
559 replacementSymbols.push_back(b.getAffineSymbolExpr(numSymbols++));
562 assert(cstr.positionToValueDim[i].has_value() &&
563 "cannot build affine map in terms of anonymous column");
564 ValueBoundsConstraintSet::ValueDim valueDim = *cstr.positionToValueDim[i];
565 Value value = valueDim.first;
566 int64_t dim = valueDim.second;
567 if (dim == ValueBoundsConstraintSet::kIndexValue) {
568 // An index-type value is used: can be used directly in the affine.apply
569 // op.
570 assert(value.getType().isIndex() && "expected index type");
571 mapOperands.push_back(std::make_pair(value, std::nullopt));
572 continue;
575 assert(cast<ShapedType>(value.getType()).isDynamicDim(dim) &&
576 "expected dynamic dim");
577 mapOperands.push_back(std::make_pair(value, dim));
580 resultMap = bound.replaceDimsAndSymbols(replacementDims, replacementSymbols,
581 numDims, numSymbols);
582 return success();
585 LogicalResult ValueBoundsConstraintSet::computeDependentBound(
586 AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type,
587 const Variable &var, ValueDimList dependencies, bool closedUB) {
588 return computeBound(
589 resultMap, mapOperands, type, var,
590 [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
591 return llvm::is_contained(dependencies, std::make_pair(v, d));
593 closedUB);
596 LogicalResult ValueBoundsConstraintSet::computeIndependentBound(
597 AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type,
598 const Variable &var, ValueRange independencies, bool closedUB) {
599 // Return "true" if the given value is independent of all values in
600 // `independencies`. I.e., neither the value itself nor any value in the
601 // backward slice (reverse use-def chain) is contained in `independencies`.
602 auto isIndependent = [&](Value v) {
603 SmallVector<Value> worklist;
604 DenseSet<Value> visited;
605 worklist.push_back(v);
606 while (!worklist.empty()) {
607 Value next = worklist.pop_back_val();
608 if (!visited.insert(next).second)
609 continue;
610 if (llvm::is_contained(independencies, next))
611 return false;
612 // TODO: DominanceInfo could be used to stop the traversal early.
613 Operation *op = next.getDefiningOp();
614 if (!op)
615 continue;
616 worklist.append(op->getOperands().begin(), op->getOperands().end());
618 return true;
621 // Reify bounds in terms of any independent values.
622 return computeBound(
623 resultMap, mapOperands, type, var,
624 [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
625 return isIndependent(v);
627 closedUB);
630 FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
631 presburger::BoundType type, const Variable &var,
632 StopConditionFn stopCondition, bool closedUB) {
633 // Default stop condition if none was specified: Keep adding constraints until
634 // a bound could be computed.
635 int64_t pos = 0;
636 auto defaultStopCondition = [&](Value v, std::optional<int64_t> dim,
637 ValueBoundsConstraintSet &cstr) {
638 return cstr.cstr.getConstantBound64(type, pos).has_value();
641 ValueBoundsConstraintSet cstr(
642 var.getContext(), stopCondition ? stopCondition : defaultStopCondition);
643 pos = cstr.populateConstraints(var.map, var.mapOperands);
644 assert(pos == 0 && "expected `map` is the first column");
646 // Compute constant bound for `valueDim`.
647 int64_t ubAdjustment = closedUB ? 0 : 1;
648 if (auto bound = cstr.cstr.getConstantBound64(type, pos))
649 return type == BoundType::UB ? *bound + ubAdjustment : *bound;
650 return failure();
653 void ValueBoundsConstraintSet::populateConstraints(Value value,
654 std::optional<int64_t> dim) {
655 #ifndef NDEBUG
656 assertValidValueDim(value, dim);
657 #endif // NDEBUG
659 // `getExpr` pushes the value/dim onto the worklist (unless it was already
660 // analyzed).
661 (void)getExpr(value, dim);
662 // Process all values/dims on the worklist. This may traverse and analyze
663 // additional IR, depending the current stop function.
664 processWorklist();
667 int64_t ValueBoundsConstraintSet::populateConstraints(AffineMap map,
668 ValueDimList operands) {
669 int64_t pos = insert(map, operands, /*isSymbol=*/false);
670 // Process the backward slice of `operands` (i.e., reverse use-def chain)
671 // until `stopCondition` is met.
672 processWorklist();
673 return pos;
676 FailureOr<int64_t>
677 ValueBoundsConstraintSet::computeConstantDelta(Value value1, Value value2,
678 std::optional<int64_t> dim1,
679 std::optional<int64_t> dim2) {
680 #ifndef NDEBUG
681 assertValidValueDim(value1, dim1);
682 assertValidValueDim(value2, dim2);
683 #endif // NDEBUG
685 Builder b(value1.getContext());
686 AffineMap map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0,
687 b.getAffineDimExpr(0) - b.getAffineDimExpr(1));
688 return computeConstantBound(presburger::BoundType::EQ,
689 Variable(map, {{value1, dim1}, {value2, dim2}}));
692 bool ValueBoundsConstraintSet::comparePos(int64_t lhsPos,
693 ComparisonOperator cmp,
694 int64_t rhsPos) {
695 // This function returns "true" if "lhs CMP rhs" is proven to hold.
697 // Example for ComparisonOperator::LE and index-typed values: We would like to
698 // prove that lhs <= rhs. Proof by contradiction: add the inverse
699 // relation (lhs > rhs) to the constraint set and check if the resulting
700 // constraint set is "empty" (i.e. has no solution). In that case,
701 // lhs > rhs must be incorrect and we can deduce that lhs <= rhs holds.
703 // We cannot prove anything if the constraint set is already empty.
704 if (cstr.isEmpty()) {
705 LLVM_DEBUG(
706 llvm::dbgs()
707 << "cannot compare value/dims: constraint system is already empty");
708 return false;
711 // EQ can be expressed as LE and GE.
712 if (cmp == EQ)
713 return comparePos(lhsPos, ComparisonOperator::LE, rhsPos) &&
714 comparePos(lhsPos, ComparisonOperator::GE, rhsPos);
716 // Construct inequality.
717 SmallVector<int64_t> eq(cstr.getNumCols(), 0);
718 if (cmp == LT || cmp == LE) {
719 ++eq[lhsPos];
720 --eq[rhsPos];
721 } else if (cmp == GT || cmp == GE) {
722 --eq[lhsPos];
723 ++eq[rhsPos];
724 } else {
725 llvm_unreachable("unsupported comparison operator");
727 if (cmp == LE || cmp == GE)
728 eq[cstr.getNumCols() - 1] -= 1;
730 // Add inequality to the constraint set and check if it made the constraint
731 // set empty.
732 int64_t ineqPos = cstr.getNumInequalities();
733 cstr.addInequality(eq);
734 bool isEmpty = cstr.isEmpty();
735 cstr.removeInequality(ineqPos);
736 return isEmpty;
739 bool ValueBoundsConstraintSet::populateAndCompare(const Variable &lhs,
740 ComparisonOperator cmp,
741 const Variable &rhs) {
742 int64_t lhsPos = populateConstraints(lhs.map, lhs.mapOperands);
743 int64_t rhsPos = populateConstraints(rhs.map, rhs.mapOperands);
744 return comparePos(lhsPos, cmp, rhsPos);
747 bool ValueBoundsConstraintSet::compare(const Variable &lhs,
748 ComparisonOperator cmp,
749 const Variable &rhs) {
750 int64_t lhsPos = -1, rhsPos = -1;
751 auto stopCondition = [&](Value v, std::optional<int64_t> dim,
752 ValueBoundsConstraintSet &cstr) {
753 // Keep processing as long as lhs/rhs were not processed.
754 if (size_t(lhsPos) >= cstr.positionToValueDim.size() ||
755 size_t(rhsPos) >= cstr.positionToValueDim.size())
756 return false;
757 // Keep processing as long as the relation cannot be proven.
758 return cstr.comparePos(lhsPos, cmp, rhsPos);
760 ValueBoundsConstraintSet cstr(lhs.getContext(), stopCondition);
761 lhsPos = cstr.populateConstraints(lhs.map, lhs.mapOperands);
762 rhsPos = cstr.populateConstraints(rhs.map, rhs.mapOperands);
763 return cstr.comparePos(lhsPos, cmp, rhsPos);
766 FailureOr<bool> ValueBoundsConstraintSet::areEqual(const Variable &var1,
767 const Variable &var2) {
768 if (ValueBoundsConstraintSet::compare(var1, ComparisonOperator::EQ, var2))
769 return true;
770 if (ValueBoundsConstraintSet::compare(var1, ComparisonOperator::LT, var2) ||
771 ValueBoundsConstraintSet::compare(var1, ComparisonOperator::GT, var2))
772 return false;
773 return failure();
776 FailureOr<bool>
777 ValueBoundsConstraintSet::areOverlappingSlices(MLIRContext *ctx,
778 HyperrectangularSlice slice1,
779 HyperrectangularSlice slice2) {
780 assert(slice1.getMixedOffsets().size() == slice1.getMixedOffsets().size() &&
781 "expected slices of same rank");
782 assert(slice1.getMixedSizes().size() == slice1.getMixedSizes().size() &&
783 "expected slices of same rank");
784 assert(slice1.getMixedStrides().size() == slice1.getMixedStrides().size() &&
785 "expected slices of same rank");
787 Builder b(ctx);
788 bool foundUnknownBound = false;
789 for (int64_t i = 0, e = slice1.getMixedOffsets().size(); i < e; ++i) {
790 AffineMap map =
791 AffineMap::get(/*dimCount=*/0, /*symbolCount=*/4,
792 b.getAffineSymbolExpr(0) +
793 b.getAffineSymbolExpr(1) * b.getAffineSymbolExpr(2) -
794 b.getAffineSymbolExpr(3));
796 // Case 1: Slices are guaranteed to be non-overlapping if
797 // offset1 + size1 * stride1 <= offset2 (for at least one dimension).
798 SmallVector<OpFoldResult> ofrOperands;
799 ofrOperands.push_back(slice1.getMixedOffsets()[i]);
800 ofrOperands.push_back(slice1.getMixedSizes()[i]);
801 ofrOperands.push_back(slice1.getMixedStrides()[i]);
802 ofrOperands.push_back(slice2.getMixedOffsets()[i]);
803 SmallVector<Value> valueOperands;
804 AffineMap foldedMap =
805 foldAttributesIntoMap(b, map, ofrOperands, valueOperands);
806 FailureOr<int64_t> constBound = computeConstantBound(
807 presburger::BoundType::EQ, Variable(foldedMap, valueOperands));
808 foundUnknownBound |= failed(constBound);
809 if (succeeded(constBound) && *constBound <= 0)
810 return false;
813 // Case 2: Slices are guaranteed to be non-overlapping if
814 // offset2 + size2 * stride2 <= offset1 (for at least one dimension).
815 SmallVector<OpFoldResult> ofrOperands;
816 ofrOperands.push_back(slice2.getMixedOffsets()[i]);
817 ofrOperands.push_back(slice2.getMixedSizes()[i]);
818 ofrOperands.push_back(slice2.getMixedStrides()[i]);
819 ofrOperands.push_back(slice1.getMixedOffsets()[i]);
820 SmallVector<Value> valueOperands;
821 AffineMap foldedMap =
822 foldAttributesIntoMap(b, map, ofrOperands, valueOperands);
823 FailureOr<int64_t> constBound = computeConstantBound(
824 presburger::BoundType::EQ, Variable(foldedMap, valueOperands));
825 foundUnknownBound |= failed(constBound);
826 if (succeeded(constBound) && *constBound <= 0)
827 return false;
831 // If at least one bound could not be computed, we cannot be certain that the
832 // slices are really overlapping.
833 if (foundUnknownBound)
834 return failure();
836 // All bounds could be computed and none of the above cases applied.
837 // Therefore, the slices are guaranteed to overlap.
838 return true;
841 FailureOr<bool>
842 ValueBoundsConstraintSet::areEquivalentSlices(MLIRContext *ctx,
843 HyperrectangularSlice slice1,
844 HyperrectangularSlice slice2) {
845 assert(slice1.getMixedOffsets().size() == slice1.getMixedOffsets().size() &&
846 "expected slices of same rank");
847 assert(slice1.getMixedSizes().size() == slice1.getMixedSizes().size() &&
848 "expected slices of same rank");
849 assert(slice1.getMixedStrides().size() == slice1.getMixedStrides().size() &&
850 "expected slices of same rank");
852 // The two slices are equivalent if all of their offsets, sizes and strides
853 // are equal. If equality cannot be determined for at least one of those
854 // values, equivalence cannot be determined and this function returns
855 // "failure".
856 for (auto [offset1, offset2] :
857 llvm::zip_equal(slice1.getMixedOffsets(), slice2.getMixedOffsets())) {
858 FailureOr<bool> equal = areEqual(offset1, offset2);
859 if (failed(equal))
860 return failure();
861 if (!equal.value())
862 return false;
864 for (auto [size1, size2] :
865 llvm::zip_equal(slice1.getMixedSizes(), slice2.getMixedSizes())) {
866 FailureOr<bool> equal = areEqual(size1, size2);
867 if (failed(equal))
868 return failure();
869 if (!equal.value())
870 return false;
872 for (auto [stride1, stride2] :
873 llvm::zip_equal(slice1.getMixedStrides(), slice2.getMixedStrides())) {
874 FailureOr<bool> equal = areEqual(stride1, stride2);
875 if (failed(equal))
876 return failure();
877 if (!equal.value())
878 return false;
880 return true;
883 void ValueBoundsConstraintSet::dump() const {
884 llvm::errs() << "==========\nColumns:\n";
885 llvm::errs() << "(column\tdim\tvalue)\n";
886 for (auto [index, valueDim] : llvm::enumerate(positionToValueDim)) {
887 llvm::errs() << " " << index << "\t";
888 if (valueDim) {
889 if (valueDim->second == kIndexValue) {
890 llvm::errs() << "n/a\t";
891 } else {
892 llvm::errs() << valueDim->second << "\t";
894 llvm::errs() << getOwnerOfValue(valueDim->first)->getName() << " ";
895 if (OpResult result = dyn_cast<OpResult>(valueDim->first)) {
896 llvm::errs() << "(result " << result.getResultNumber() << ")";
897 } else {
898 llvm::errs() << "(bbarg "
899 << cast<BlockArgument>(valueDim->first).getArgNumber()
900 << ")";
902 llvm::errs() << "\n";
903 } else {
904 llvm::errs() << "n/a\tn/a\n";
907 llvm::errs() << "\nConstraint set:\n";
908 cstr.dump();
909 llvm::errs() << "==========\n";
912 ValueBoundsConstraintSet::BoundBuilder &
913 ValueBoundsConstraintSet::BoundBuilder::operator[](int64_t dim) {
914 assert(!this->dim.has_value() && "dim was already set");
915 this->dim = dim;
916 #ifndef NDEBUG
917 assertValidValueDim(value, this->dim);
918 #endif // NDEBUG
919 return *this;
922 void ValueBoundsConstraintSet::BoundBuilder::operator<(AffineExpr expr) {
923 #ifndef NDEBUG
924 assertValidValueDim(value, this->dim);
925 #endif // NDEBUG
926 cstr.addBound(BoundType::UB, cstr.getPos(value, this->dim), expr);
929 void ValueBoundsConstraintSet::BoundBuilder::operator<=(AffineExpr expr) {
930 operator<(expr + 1);
933 void ValueBoundsConstraintSet::BoundBuilder::operator>(AffineExpr expr) {
934 operator>=(expr + 1);
937 void ValueBoundsConstraintSet::BoundBuilder::operator>=(AffineExpr expr) {
938 #ifndef NDEBUG
939 assertValidValueDim(value, this->dim);
940 #endif // NDEBUG
941 cstr.addBound(BoundType::LB, cstr.getPos(value, this->dim), expr);
944 void ValueBoundsConstraintSet::BoundBuilder::operator==(AffineExpr expr) {
945 #ifndef NDEBUG
946 assertValidValueDim(value, this->dim);
947 #endif // NDEBUG
948 cstr.addBound(BoundType::EQ, cstr.getPos(value, this->dim), expr);
951 void ValueBoundsConstraintSet::BoundBuilder::operator<(OpFoldResult ofr) {
952 operator<(cstr.getExpr(ofr));
955 void ValueBoundsConstraintSet::BoundBuilder::operator<=(OpFoldResult ofr) {
956 operator<=(cstr.getExpr(ofr));
959 void ValueBoundsConstraintSet::BoundBuilder::operator>(OpFoldResult ofr) {
960 operator>(cstr.getExpr(ofr));
963 void ValueBoundsConstraintSet::BoundBuilder::operator>=(OpFoldResult ofr) {
964 operator>=(cstr.getExpr(ofr));
967 void ValueBoundsConstraintSet::BoundBuilder::operator==(OpFoldResult ofr) {
968 operator==(cstr.getExpr(ofr));
971 void ValueBoundsConstraintSet::BoundBuilder::operator<(int64_t i) {
972 operator<(cstr.getExpr(i));
975 void ValueBoundsConstraintSet::BoundBuilder::operator<=(int64_t i) {
976 operator<=(cstr.getExpr(i));
979 void ValueBoundsConstraintSet::BoundBuilder::operator>(int64_t i) {
980 operator>(cstr.getExpr(i));
983 void ValueBoundsConstraintSet::BoundBuilder::operator>=(int64_t i) {
984 operator>=(cstr.getExpr(i));
987 void ValueBoundsConstraintSet::BoundBuilder::operator==(int64_t i) {
988 operator==(cstr.getExpr(i));