[mlir] Allow fallback from file line col range to loc (#124321)
[llvm-project.git] / mlir / lib / Analysis / DataFlow / IntegerRangeAnalysis.cpp
blob9e9411e5ede12c8874ab1a601436803604d9b067
1 //===- IntegerRangeAnalysis.cpp - Integer range analysis --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the dataflow analysis class for integer range inference
10 // which is used in transformations over the `arith` dialect such as
11 // branch elimination or signed->unsigned rewriting
13 //===----------------------------------------------------------------------===//
15 #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
16 #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
17 #include "mlir/Analysis/DataFlow/SparseAnalysis.h"
18 #include "mlir/Analysis/DataFlowFramework.h"
19 #include "mlir/IR/BuiltinAttributes.h"
20 #include "mlir/IR/Dialect.h"
21 #include "mlir/IR/OpDefinition.h"
22 #include "mlir/IR/TypeUtilities.h"
23 #include "mlir/IR/Value.h"
24 #include "mlir/Interfaces/ControlFlowInterfaces.h"
25 #include "mlir/Interfaces/InferIntRangeInterface.h"
26 #include "mlir/Interfaces/LoopLikeInterface.h"
27 #include "mlir/Support/LLVM.h"
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/Support/Casting.h"
30 #include "llvm/Support/Debug.h"
31 #include <cassert>
32 #include <optional>
33 #include <utility>
35 #define DEBUG_TYPE "int-range-analysis"
37 using namespace mlir;
38 using namespace mlir::dataflow;
40 void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const {
41 Lattice::onUpdate(solver);
43 // If the integer range can be narrowed to a constant, update the constant
44 // value of the SSA value.
45 std::optional<APInt> constant = getValue().getValue().getConstantValue();
46 auto value = cast<Value>(anchor);
47 auto *cv = solver->getOrCreateState<Lattice<ConstantValue>>(value);
48 if (!constant)
49 return solver->propagateIfChanged(
50 cv, cv->join(ConstantValue::getUnknownConstant()));
52 Dialect *dialect;
53 if (auto *parent = value.getDefiningOp())
54 dialect = parent->getDialect();
55 else
56 dialect = value.getParentBlock()->getParentOp()->getDialect();
58 Type type = getElementTypeOrSelf(value);
59 solver->propagateIfChanged(
60 cv, cv->join(ConstantValue(IntegerAttr::get(type, *constant), dialect)));
63 LogicalResult IntegerRangeAnalysis::visitOperation(
64 Operation *op, ArrayRef<const IntegerValueRangeLattice *> operands,
65 ArrayRef<IntegerValueRangeLattice *> results) {
66 auto inferrable = dyn_cast<InferIntRangeInterface>(op);
67 if (!inferrable) {
68 setAllToEntryStates(results);
69 return success();
72 LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n");
73 auto argRanges = llvm::map_to_vector(
74 operands, [](const IntegerValueRangeLattice *lattice) {
75 return lattice->getValue();
76 });
78 auto joinCallback = [&](Value v, const IntegerValueRange &attrs) {
79 auto result = dyn_cast<OpResult>(v);
80 if (!result)
81 return;
82 assert(llvm::is_contained(op->getResults(), result));
84 LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n");
85 IntegerValueRangeLattice *lattice = results[result.getResultNumber()];
86 IntegerValueRange oldRange = lattice->getValue();
88 ChangeResult changed = lattice->join(attrs);
90 // Catch loop results with loop variant bounds and conservatively make
91 // them [-inf, inf] so we don't circle around infinitely often (because
92 // the dataflow analysis in MLIR doesn't attempt to work out trip counts
93 // and often can't).
94 bool isYieldedResult = llvm::any_of(v.getUsers(), [](Operation *op) {
95 return op->hasTrait<OpTrait::IsTerminator>();
96 });
97 if (isYieldedResult && !oldRange.isUninitialized() &&
98 !(lattice->getValue() == oldRange)) {
99 LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
100 changed |= lattice->join(IntegerValueRange::getMaxRange(v));
102 propagateIfChanged(lattice, changed);
105 inferrable.inferResultRangesFromOptional(argRanges, joinCallback);
106 return success();
109 void IntegerRangeAnalysis::visitNonControlFlowArguments(
110 Operation *op, const RegionSuccessor &successor,
111 ArrayRef<IntegerValueRangeLattice *> argLattices, unsigned firstIndex) {
112 if (auto inferrable = dyn_cast<InferIntRangeInterface>(op)) {
113 LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n");
115 auto argRanges = llvm::map_to_vector(op->getOperands(), [&](Value value) {
116 return getLatticeElementFor(getProgramPointAfter(op), value)->getValue();
119 auto joinCallback = [&](Value v, const IntegerValueRange &attrs) {
120 auto arg = dyn_cast<BlockArgument>(v);
121 if (!arg)
122 return;
123 if (!llvm::is_contained(successor.getSuccessor()->getArguments(), arg))
124 return;
126 LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n");
127 IntegerValueRangeLattice *lattice = argLattices[arg.getArgNumber()];
128 IntegerValueRange oldRange = lattice->getValue();
130 ChangeResult changed = lattice->join(attrs);
132 // Catch loop results with loop variant bounds and conservatively make
133 // them [-inf, inf] so we don't circle around infinitely often (because
134 // the dataflow analysis in MLIR doesn't attempt to work out trip counts
135 // and often can't).
136 bool isYieldedValue = llvm::any_of(v.getUsers(), [](Operation *op) {
137 return op->hasTrait<OpTrait::IsTerminator>();
139 if (isYieldedValue && !oldRange.isUninitialized() &&
140 !(lattice->getValue() == oldRange)) {
141 LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
142 changed |= lattice->join(IntegerValueRange::getMaxRange(v));
144 propagateIfChanged(lattice, changed);
147 inferrable.inferResultRangesFromOptional(argRanges, joinCallback);
148 return;
151 /// Given the results of getConstant{Lower,Upper}Bound() or getConstantStep()
152 /// on a LoopLikeInterface return the lower/upper bound for that result if
153 /// possible.
154 auto getLoopBoundFromFold = [&](std::optional<OpFoldResult> loopBound,
155 Type boundType, bool getUpper) {
156 unsigned int width = ConstantIntRanges::getStorageBitwidth(boundType);
157 if (loopBound.has_value()) {
158 if (auto attr = dyn_cast<Attribute>(*loopBound)) {
159 if (auto bound = dyn_cast_or_null<IntegerAttr>(attr))
160 return bound.getValue();
161 } else if (auto value = llvm::dyn_cast_if_present<Value>(*loopBound)) {
162 const IntegerValueRangeLattice *lattice =
163 getLatticeElementFor(getProgramPointAfter(op), value);
164 if (lattice != nullptr && !lattice->getValue().isUninitialized())
165 return getUpper ? lattice->getValue().getValue().smax()
166 : lattice->getValue().getValue().smin();
169 // Given the results of getConstant{Lower,Upper}Bound()
170 // or getConstantStep() on a LoopLikeInterface return the lower/upper
171 // bound
172 return getUpper ? APInt::getSignedMaxValue(width)
173 : APInt::getSignedMinValue(width);
176 // Infer bounds for loop arguments that have static bounds
177 if (auto loop = dyn_cast<LoopLikeOpInterface>(op)) {
178 std::optional<Value> iv = loop.getSingleInductionVar();
179 if (!iv) {
180 return SparseForwardDataFlowAnalysis ::visitNonControlFlowArguments(
181 op, successor, argLattices, firstIndex);
183 std::optional<OpFoldResult> lowerBound = loop.getSingleLowerBound();
184 std::optional<OpFoldResult> upperBound = loop.getSingleUpperBound();
185 std::optional<OpFoldResult> step = loop.getSingleStep();
186 APInt min = getLoopBoundFromFold(lowerBound, iv->getType(),
187 /*getUpper=*/false);
188 APInt max = getLoopBoundFromFold(upperBound, iv->getType(),
189 /*getUpper=*/true);
190 // Assume positivity for uniscoverable steps by way of getUpper = true.
191 APInt stepVal =
192 getLoopBoundFromFold(step, iv->getType(), /*getUpper=*/true);
194 if (stepVal.isNegative()) {
195 std::swap(min, max);
196 } else {
197 // Correct the upper bound by subtracting 1 so that it becomes a <=
198 // bound, because loops do not generally include their upper bound.
199 max -= 1;
202 // If we infer the lower bound to be larger than the upper bound, the
203 // resulting range is meaningless and should not be used in further
204 // inferences.
205 if (max.sge(min)) {
206 IntegerValueRangeLattice *ivEntry = getLatticeElement(*iv);
207 auto ivRange = ConstantIntRanges::fromSigned(min, max);
208 propagateIfChanged(ivEntry, ivEntry->join(IntegerValueRange{ivRange}));
210 return;
213 return SparseForwardDataFlowAnalysis::visitNonControlFlowArguments(
214 op, successor, argLattices, firstIndex);