1 //===- IntegerRangeAnalysis.cpp - Integer range analysis --------*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file defines the dataflow analysis class for integer range inference
10 // which is used in transformations over the `arith` dialect such as
11 // branch elimination or signed->unsigned rewriting
13 //===----------------------------------------------------------------------===//
15 #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
16 #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
17 #include "mlir/Analysis/DataFlow/SparseAnalysis.h"
18 #include "mlir/Analysis/DataFlowFramework.h"
19 #include "mlir/IR/BuiltinAttributes.h"
20 #include "mlir/IR/Dialect.h"
21 #include "mlir/IR/OpDefinition.h"
22 #include "mlir/IR/TypeUtilities.h"
23 #include "mlir/IR/Value.h"
24 #include "mlir/Interfaces/ControlFlowInterfaces.h"
25 #include "mlir/Interfaces/InferIntRangeInterface.h"
26 #include "mlir/Interfaces/LoopLikeInterface.h"
27 #include "mlir/Support/LLVM.h"
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/Support/Casting.h"
30 #include "llvm/Support/Debug.h"
35 #define DEBUG_TYPE "int-range-analysis"
38 using namespace mlir::dataflow
;
40 void IntegerValueRangeLattice::onUpdate(DataFlowSolver
*solver
) const {
41 Lattice::onUpdate(solver
);
43 // If the integer range can be narrowed to a constant, update the constant
44 // value of the SSA value.
45 std::optional
<APInt
> constant
= getValue().getValue().getConstantValue();
46 auto value
= anchor
.get
<Value
>();
47 auto *cv
= solver
->getOrCreateState
<Lattice
<ConstantValue
>>(value
);
49 return solver
->propagateIfChanged(
50 cv
, cv
->join(ConstantValue::getUnknownConstant()));
53 if (auto *parent
= value
.getDefiningOp())
54 dialect
= parent
->getDialect();
56 dialect
= value
.getParentBlock()->getParentOp()->getDialect();
58 Type type
= getElementTypeOrSelf(value
);
59 solver
->propagateIfChanged(
60 cv
, cv
->join(ConstantValue(IntegerAttr::get(type
, *constant
), dialect
)));
63 LogicalResult
IntegerRangeAnalysis::visitOperation(
64 Operation
*op
, ArrayRef
<const IntegerValueRangeLattice
*> operands
,
65 ArrayRef
<IntegerValueRangeLattice
*> results
) {
66 auto inferrable
= dyn_cast
<InferIntRangeInterface
>(op
);
68 setAllToEntryStates(results
);
72 LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op
<< "\n");
73 auto argRanges
= llvm::map_to_vector(
74 operands
, [](const IntegerValueRangeLattice
*lattice
) {
75 return lattice
->getValue();
78 auto joinCallback
= [&](Value v
, const IntegerValueRange
&attrs
) {
79 auto result
= dyn_cast
<OpResult
>(v
);
82 assert(llvm::is_contained(op
->getResults(), result
));
84 LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs
<< "\n");
85 IntegerValueRangeLattice
*lattice
= results
[result
.getResultNumber()];
86 IntegerValueRange oldRange
= lattice
->getValue();
88 ChangeResult changed
= lattice
->join(attrs
);
90 // Catch loop results with loop variant bounds and conservatively make
91 // them [-inf, inf] so we don't circle around infinitely often (because
92 // the dataflow analysis in MLIR doesn't attempt to work out trip counts
94 bool isYieldedResult
= llvm::any_of(v
.getUsers(), [](Operation
*op
) {
95 return op
->hasTrait
<OpTrait::IsTerminator
>();
97 if (isYieldedResult
&& !oldRange
.isUninitialized() &&
98 !(lattice
->getValue() == oldRange
)) {
99 LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
100 changed
|= lattice
->join(IntegerValueRange::getMaxRange(v
));
102 propagateIfChanged(lattice
, changed
);
105 inferrable
.inferResultRangesFromOptional(argRanges
, joinCallback
);
109 void IntegerRangeAnalysis::visitNonControlFlowArguments(
110 Operation
*op
, const RegionSuccessor
&successor
,
111 ArrayRef
<IntegerValueRangeLattice
*> argLattices
, unsigned firstIndex
) {
112 if (auto inferrable
= dyn_cast
<InferIntRangeInterface
>(op
)) {
113 LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op
<< "\n");
115 auto argRanges
= llvm::map_to_vector(op
->getOperands(), [&](Value value
) {
116 return getLatticeElementFor(getProgramPointAfter(op
), value
)->getValue();
119 auto joinCallback
= [&](Value v
, const IntegerValueRange
&attrs
) {
120 auto arg
= dyn_cast
<BlockArgument
>(v
);
123 if (!llvm::is_contained(successor
.getSuccessor()->getArguments(), arg
))
126 LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs
<< "\n");
127 IntegerValueRangeLattice
*lattice
= argLattices
[arg
.getArgNumber()];
128 IntegerValueRange oldRange
= lattice
->getValue();
130 ChangeResult changed
= lattice
->join(attrs
);
132 // Catch loop results with loop variant bounds and conservatively make
133 // them [-inf, inf] so we don't circle around infinitely often (because
134 // the dataflow analysis in MLIR doesn't attempt to work out trip counts
136 bool isYieldedValue
= llvm::any_of(v
.getUsers(), [](Operation
*op
) {
137 return op
->hasTrait
<OpTrait::IsTerminator
>();
139 if (isYieldedValue
&& !oldRange
.isUninitialized() &&
140 !(lattice
->getValue() == oldRange
)) {
141 LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
142 changed
|= lattice
->join(IntegerValueRange::getMaxRange(v
));
144 propagateIfChanged(lattice
, changed
);
147 inferrable
.inferResultRangesFromOptional(argRanges
, joinCallback
);
151 /// Given the results of getConstant{Lower,Upper}Bound() or getConstantStep()
152 /// on a LoopLikeInterface return the lower/upper bound for that result if
154 auto getLoopBoundFromFold
= [&](std::optional
<OpFoldResult
> loopBound
,
155 Type boundType
, bool getUpper
) {
156 unsigned int width
= ConstantIntRanges::getStorageBitwidth(boundType
);
157 if (loopBound
.has_value()) {
158 if (loopBound
->is
<Attribute
>()) {
160 dyn_cast_or_null
<IntegerAttr
>(loopBound
->get
<Attribute
>()))
161 return bound
.getValue();
162 } else if (auto value
= llvm::dyn_cast_if_present
<Value
>(*loopBound
)) {
163 const IntegerValueRangeLattice
*lattice
=
164 getLatticeElementFor(getProgramPointAfter(op
), value
);
165 if (lattice
!= nullptr && !lattice
->getValue().isUninitialized())
166 return getUpper
? lattice
->getValue().getValue().smax()
167 : lattice
->getValue().getValue().smin();
170 // Given the results of getConstant{Lower,Upper}Bound()
171 // or getConstantStep() on a LoopLikeInterface return the lower/upper
173 return getUpper
? APInt::getSignedMaxValue(width
)
174 : APInt::getSignedMinValue(width
);
177 // Infer bounds for loop arguments that have static bounds
178 if (auto loop
= dyn_cast
<LoopLikeOpInterface
>(op
)) {
179 std::optional
<Value
> iv
= loop
.getSingleInductionVar();
181 return SparseForwardDataFlowAnalysis ::visitNonControlFlowArguments(
182 op
, successor
, argLattices
, firstIndex
);
184 std::optional
<OpFoldResult
> lowerBound
= loop
.getSingleLowerBound();
185 std::optional
<OpFoldResult
> upperBound
= loop
.getSingleUpperBound();
186 std::optional
<OpFoldResult
> step
= loop
.getSingleStep();
187 APInt min
= getLoopBoundFromFold(lowerBound
, iv
->getType(),
189 APInt max
= getLoopBoundFromFold(upperBound
, iv
->getType(),
191 // Assume positivity for uniscoverable steps by way of getUpper = true.
193 getLoopBoundFromFold(step
, iv
->getType(), /*getUpper=*/true);
195 if (stepVal
.isNegative()) {
198 // Correct the upper bound by subtracting 1 so that it becomes a <=
199 // bound, because loops do not generally include their upper bound.
203 // If we infer the lower bound to be larger than the upper bound, the
204 // resulting range is meaningless and should not be used in further
207 IntegerValueRangeLattice
*ivEntry
= getLatticeElement(*iv
);
208 auto ivRange
= ConstantIntRanges::fromSigned(min
, max
);
209 propagateIfChanged(ivEntry
, ivEntry
->join(IntegerValueRange
{ivRange
}));
214 return SparseForwardDataFlowAnalysis::visitNonControlFlowArguments(
215 op
, successor
, argLattices
, firstIndex
);