1 //===- IntegerRangeAnalysis.cpp - Integer range analysis --------*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file defines the dataflow analysis class for integer range inference
10 // which is used in transformations over the `arith` dialect such as
11 // branch elimination or signed->unsigned rewriting
13 //===----------------------------------------------------------------------===//
15 #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
16 #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
17 #include "mlir/Analysis/DataFlow/SparseAnalysis.h"
18 #include "mlir/Analysis/DataFlowFramework.h"
19 #include "mlir/IR/BuiltinAttributes.h"
20 #include "mlir/IR/Dialect.h"
21 #include "mlir/IR/OpDefinition.h"
22 #include "mlir/IR/Value.h"
23 #include "mlir/Interfaces/ControlFlowInterfaces.h"
24 #include "mlir/Interfaces/InferIntRangeInterface.h"
25 #include "mlir/Interfaces/LoopLikeInterface.h"
26 #include "mlir/Support/LLVM.h"
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/Support/Casting.h"
29 #include "llvm/Support/Debug.h"
34 #define DEBUG_TYPE "int-range-analysis"
37 using namespace mlir::dataflow
;
39 IntegerValueRange
IntegerValueRange::getMaxRange(Value value
) {
40 unsigned width
= ConstantIntRanges::getStorageBitwidth(value
.getType());
43 APInt umin
= APInt::getMinValue(width
);
44 APInt umax
= APInt::getMaxValue(width
);
45 APInt smin
= width
!= 0 ? APInt::getSignedMinValue(width
) : umin
;
46 APInt smax
= width
!= 0 ? APInt::getSignedMaxValue(width
) : umax
;
47 return IntegerValueRange
{ConstantIntRanges
{umin
, umax
, smin
, smax
}};
50 void IntegerValueRangeLattice::onUpdate(DataFlowSolver
*solver
) const {
51 Lattice::onUpdate(solver
);
53 // If the integer range can be narrowed to a constant, update the constant
54 // value of the SSA value.
55 std::optional
<APInt
> constant
= getValue().getValue().getConstantValue();
56 auto value
= point
.get
<Value
>();
57 auto *cv
= solver
->getOrCreateState
<Lattice
<ConstantValue
>>(value
);
59 return solver
->propagateIfChanged(
60 cv
, cv
->join(ConstantValue::getUnknownConstant()));
63 if (auto *parent
= value
.getDefiningOp())
64 dialect
= parent
->getDialect();
66 dialect
= value
.getParentBlock()->getParentOp()->getDialect();
67 solver
->propagateIfChanged(
68 cv
, cv
->join(ConstantValue(IntegerAttr::get(value
.getType(), *constant
),
72 void IntegerRangeAnalysis::visitOperation(
73 Operation
*op
, ArrayRef
<const IntegerValueRangeLattice
*> operands
,
74 ArrayRef
<IntegerValueRangeLattice
*> results
) {
75 // If the lattice on any operand is unitialized, bail out.
76 if (llvm::any_of(operands
, [](const IntegerValueRangeLattice
*lattice
) {
77 return lattice
->getValue().isUninitialized();
82 // Ignore non-integer outputs - return early if the op has no scalar
84 bool hasIntegerResult
= false;
85 for (auto it
: llvm::zip(results
, op
->getResults())) {
86 Value value
= std::get
<1>(it
);
87 if (value
.getType().isIntOrIndex()) {
88 hasIntegerResult
= true;
90 IntegerValueRangeLattice
*lattice
= std::get
<0>(it
);
91 propagateIfChanged(lattice
,
92 lattice
->join(IntegerValueRange::getMaxRange(value
)));
95 if (!hasIntegerResult
)
98 auto inferrable
= dyn_cast
<InferIntRangeInterface
>(op
);
100 return setAllToEntryStates(results
);
102 LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op
<< "\n");
103 SmallVector
<ConstantIntRanges
> argRanges(
104 llvm::map_range(operands
, [](const IntegerValueRangeLattice
*val
) {
105 return val
->getValue().getValue();
108 auto joinCallback
= [&](Value v
, const ConstantIntRanges
&attrs
) {
109 auto result
= dyn_cast
<OpResult
>(v
);
112 assert(llvm::is_contained(op
->getResults(), result
));
114 LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs
<< "\n");
115 IntegerValueRangeLattice
*lattice
= results
[result
.getResultNumber()];
116 IntegerValueRange oldRange
= lattice
->getValue();
118 ChangeResult changed
= lattice
->join(IntegerValueRange
{attrs
});
120 // Catch loop results with loop variant bounds and conservatively make
121 // them [-inf, inf] so we don't circle around infinitely often (because
122 // the dataflow analysis in MLIR doesn't attempt to work out trip counts
124 bool isYieldedResult
= llvm::any_of(v
.getUsers(), [](Operation
*op
) {
125 return op
->hasTrait
<OpTrait::IsTerminator
>();
127 if (isYieldedResult
&& !oldRange
.isUninitialized() &&
128 !(lattice
->getValue() == oldRange
)) {
129 LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
130 changed
|= lattice
->join(IntegerValueRange::getMaxRange(v
));
132 propagateIfChanged(lattice
, changed
);
135 inferrable
.inferResultRanges(argRanges
, joinCallback
);
138 void IntegerRangeAnalysis::visitNonControlFlowArguments(
139 Operation
*op
, const RegionSuccessor
&successor
,
140 ArrayRef
<IntegerValueRangeLattice
*> argLattices
, unsigned firstIndex
) {
141 if (auto inferrable
= dyn_cast
<InferIntRangeInterface
>(op
)) {
142 LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op
<< "\n");
143 // If the lattice on any operand is unitialized, bail out.
144 if (llvm::any_of(op
->getOperands(), [&](Value value
) {
145 return getLatticeElementFor(op
, value
)->getValue().isUninitialized();
148 SmallVector
<ConstantIntRanges
> argRanges(
149 llvm::map_range(op
->getOperands(), [&](Value value
) {
150 return getLatticeElementFor(op
, value
)->getValue().getValue();
153 auto joinCallback
= [&](Value v
, const ConstantIntRanges
&attrs
) {
154 auto arg
= dyn_cast
<BlockArgument
>(v
);
157 if (!llvm::is_contained(successor
.getSuccessor()->getArguments(), arg
))
160 LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs
<< "\n");
161 IntegerValueRangeLattice
*lattice
= argLattices
[arg
.getArgNumber()];
162 IntegerValueRange oldRange
= lattice
->getValue();
164 ChangeResult changed
= lattice
->join(IntegerValueRange
{attrs
});
166 // Catch loop results with loop variant bounds and conservatively make
167 // them [-inf, inf] so we don't circle around infinitely often (because
168 // the dataflow analysis in MLIR doesn't attempt to work out trip counts
170 bool isYieldedValue
= llvm::any_of(v
.getUsers(), [](Operation
*op
) {
171 return op
->hasTrait
<OpTrait::IsTerminator
>();
173 if (isYieldedValue
&& !oldRange
.isUninitialized() &&
174 !(lattice
->getValue() == oldRange
)) {
175 LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
176 changed
|= lattice
->join(IntegerValueRange::getMaxRange(v
));
178 propagateIfChanged(lattice
, changed
);
181 inferrable
.inferResultRanges(argRanges
, joinCallback
);
185 /// Given the results of getConstant{Lower,Upper}Bound() or getConstantStep()
186 /// on a LoopLikeInterface return the lower/upper bound for that result if
188 auto getLoopBoundFromFold
= [&](std::optional
<OpFoldResult
> loopBound
,
189 Type boundType
, bool getUpper
) {
190 unsigned int width
= ConstantIntRanges::getStorageBitwidth(boundType
);
191 if (loopBound
.has_value()) {
192 if (loopBound
->is
<Attribute
>()) {
194 dyn_cast_or_null
<IntegerAttr
>(loopBound
->get
<Attribute
>()))
195 return bound
.getValue();
196 } else if (auto value
= llvm::dyn_cast_if_present
<Value
>(*loopBound
)) {
197 const IntegerValueRangeLattice
*lattice
=
198 getLatticeElementFor(op
, value
);
199 if (lattice
!= nullptr)
200 return getUpper
? lattice
->getValue().getValue().smax()
201 : lattice
->getValue().getValue().smin();
204 // Given the results of getConstant{Lower,Upper}Bound()
205 // or getConstantStep() on a LoopLikeInterface return the lower/upper
207 return getUpper
? APInt::getSignedMaxValue(width
)
208 : APInt::getSignedMinValue(width
);
211 // Infer bounds for loop arguments that have static bounds
212 if (auto loop
= dyn_cast
<LoopLikeOpInterface
>(op
)) {
213 std::optional
<Value
> iv
= loop
.getSingleInductionVar();
215 return SparseForwardDataFlowAnalysis ::visitNonControlFlowArguments(
216 op
, successor
, argLattices
, firstIndex
);
218 std::optional
<OpFoldResult
> lowerBound
= loop
.getSingleLowerBound();
219 std::optional
<OpFoldResult
> upperBound
= loop
.getSingleUpperBound();
220 std::optional
<OpFoldResult
> step
= loop
.getSingleStep();
221 APInt min
= getLoopBoundFromFold(lowerBound
, iv
->getType(),
223 APInt max
= getLoopBoundFromFold(upperBound
, iv
->getType(),
225 // Assume positivity for uniscoverable steps by way of getUpper = true.
227 getLoopBoundFromFold(step
, iv
->getType(), /*getUpper=*/true);
229 if (stepVal
.isNegative()) {
232 // Correct the upper bound by subtracting 1 so that it becomes a <=
233 // bound, because loops do not generally include their upper bound.
237 IntegerValueRangeLattice
*ivEntry
= getLatticeElement(*iv
);
238 auto ivRange
= ConstantIntRanges::fromSigned(min
, max
);
239 propagateIfChanged(ivEntry
, ivEntry
->join(IntegerValueRange
{ivRange
}));
243 return SparseForwardDataFlowAnalysis::visitNonControlFlowArguments(
244 op
, successor
, argLattices
, firstIndex
);