[clang][modules] Don't prevent translation of FW_Private includes when explicitly...
[llvm-project.git] / mlir / lib / Analysis / DataFlow / IntegerRangeAnalysis.cpp
bloba43263bc11113b844885db204026650156fc46ee
1 //===- IntegerRangeAnalysis.cpp - Integer range analysis --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the dataflow analysis class for integer range inference
10 // which is used in transformations over the `arith` dialect such as
11 // branch elimination or signed->unsigned rewriting
13 //===----------------------------------------------------------------------===//
15 #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
16 #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
17 #include "mlir/Analysis/DataFlow/SparseAnalysis.h"
18 #include "mlir/Analysis/DataFlowFramework.h"
19 #include "mlir/IR/BuiltinAttributes.h"
20 #include "mlir/IR/Dialect.h"
21 #include "mlir/IR/OpDefinition.h"
22 #include "mlir/IR/Value.h"
23 #include "mlir/Interfaces/ControlFlowInterfaces.h"
24 #include "mlir/Interfaces/InferIntRangeInterface.h"
25 #include "mlir/Interfaces/LoopLikeInterface.h"
26 #include "mlir/Support/LLVM.h"
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/Support/Casting.h"
29 #include "llvm/Support/Debug.h"
30 #include <cassert>
31 #include <optional>
32 #include <utility>
34 #define DEBUG_TYPE "int-range-analysis"
36 using namespace mlir;
37 using namespace mlir::dataflow;
39 IntegerValueRange IntegerValueRange::getMaxRange(Value value) {
40 unsigned width = ConstantIntRanges::getStorageBitwidth(value.getType());
41 if (width == 0)
42 return {};
43 APInt umin = APInt::getMinValue(width);
44 APInt umax = APInt::getMaxValue(width);
45 APInt smin = width != 0 ? APInt::getSignedMinValue(width) : umin;
46 APInt smax = width != 0 ? APInt::getSignedMaxValue(width) : umax;
47 return IntegerValueRange{ConstantIntRanges{umin, umax, smin, smax}};
50 void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const {
51 Lattice::onUpdate(solver);
53 // If the integer range can be narrowed to a constant, update the constant
54 // value of the SSA value.
55 std::optional<APInt> constant = getValue().getValue().getConstantValue();
56 auto value = point.get<Value>();
57 auto *cv = solver->getOrCreateState<Lattice<ConstantValue>>(value);
58 if (!constant)
59 return solver->propagateIfChanged(
60 cv, cv->join(ConstantValue::getUnknownConstant()));
62 Dialect *dialect;
63 if (auto *parent = value.getDefiningOp())
64 dialect = parent->getDialect();
65 else
66 dialect = value.getParentBlock()->getParentOp()->getDialect();
67 solver->propagateIfChanged(
68 cv, cv->join(ConstantValue(IntegerAttr::get(value.getType(), *constant),
69 dialect)));
72 void IntegerRangeAnalysis::visitOperation(
73 Operation *op, ArrayRef<const IntegerValueRangeLattice *> operands,
74 ArrayRef<IntegerValueRangeLattice *> results) {
75 // If the lattice on any operand is unitialized, bail out.
76 if (llvm::any_of(operands, [](const IntegerValueRangeLattice *lattice) {
77 return lattice->getValue().isUninitialized();
78 })) {
79 return;
82 // Ignore non-integer outputs - return early if the op has no scalar
83 // integer results
84 bool hasIntegerResult = false;
85 for (auto it : llvm::zip(results, op->getResults())) {
86 Value value = std::get<1>(it);
87 if (value.getType().isIntOrIndex()) {
88 hasIntegerResult = true;
89 } else {
90 IntegerValueRangeLattice *lattice = std::get<0>(it);
91 propagateIfChanged(lattice,
92 lattice->join(IntegerValueRange::getMaxRange(value)));
95 if (!hasIntegerResult)
96 return;
98 auto inferrable = dyn_cast<InferIntRangeInterface>(op);
99 if (!inferrable)
100 return setAllToEntryStates(results);
102 LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n");
103 SmallVector<ConstantIntRanges> argRanges(
104 llvm::map_range(operands, [](const IntegerValueRangeLattice *val) {
105 return val->getValue().getValue();
106 }));
108 auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) {
109 auto result = dyn_cast<OpResult>(v);
110 if (!result)
111 return;
112 assert(llvm::is_contained(op->getResults(), result));
114 LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n");
115 IntegerValueRangeLattice *lattice = results[result.getResultNumber()];
116 IntegerValueRange oldRange = lattice->getValue();
118 ChangeResult changed = lattice->join(IntegerValueRange{attrs});
120 // Catch loop results with loop variant bounds and conservatively make
121 // them [-inf, inf] so we don't circle around infinitely often (because
122 // the dataflow analysis in MLIR doesn't attempt to work out trip counts
123 // and often can't).
124 bool isYieldedResult = llvm::any_of(v.getUsers(), [](Operation *op) {
125 return op->hasTrait<OpTrait::IsTerminator>();
127 if (isYieldedResult && !oldRange.isUninitialized() &&
128 !(lattice->getValue() == oldRange)) {
129 LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
130 changed |= lattice->join(IntegerValueRange::getMaxRange(v));
132 propagateIfChanged(lattice, changed);
135 inferrable.inferResultRanges(argRanges, joinCallback);
138 void IntegerRangeAnalysis::visitNonControlFlowArguments(
139 Operation *op, const RegionSuccessor &successor,
140 ArrayRef<IntegerValueRangeLattice *> argLattices, unsigned firstIndex) {
141 if (auto inferrable = dyn_cast<InferIntRangeInterface>(op)) {
142 LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n");
143 // If the lattice on any operand is unitialized, bail out.
144 if (llvm::any_of(op->getOperands(), [&](Value value) {
145 return getLatticeElementFor(op, value)->getValue().isUninitialized();
147 return;
148 SmallVector<ConstantIntRanges> argRanges(
149 llvm::map_range(op->getOperands(), [&](Value value) {
150 return getLatticeElementFor(op, value)->getValue().getValue();
151 }));
153 auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) {
154 auto arg = dyn_cast<BlockArgument>(v);
155 if (!arg)
156 return;
157 if (!llvm::is_contained(successor.getSuccessor()->getArguments(), arg))
158 return;
160 LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n");
161 IntegerValueRangeLattice *lattice = argLattices[arg.getArgNumber()];
162 IntegerValueRange oldRange = lattice->getValue();
164 ChangeResult changed = lattice->join(IntegerValueRange{attrs});
166 // Catch loop results with loop variant bounds and conservatively make
167 // them [-inf, inf] so we don't circle around infinitely often (because
168 // the dataflow analysis in MLIR doesn't attempt to work out trip counts
169 // and often can't).
170 bool isYieldedValue = llvm::any_of(v.getUsers(), [](Operation *op) {
171 return op->hasTrait<OpTrait::IsTerminator>();
173 if (isYieldedValue && !oldRange.isUninitialized() &&
174 !(lattice->getValue() == oldRange)) {
175 LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
176 changed |= lattice->join(IntegerValueRange::getMaxRange(v));
178 propagateIfChanged(lattice, changed);
181 inferrable.inferResultRanges(argRanges, joinCallback);
182 return;
185 /// Given the results of getConstant{Lower,Upper}Bound() or getConstantStep()
186 /// on a LoopLikeInterface return the lower/upper bound for that result if
187 /// possible.
188 auto getLoopBoundFromFold = [&](std::optional<OpFoldResult> loopBound,
189 Type boundType, bool getUpper) {
190 unsigned int width = ConstantIntRanges::getStorageBitwidth(boundType);
191 if (loopBound.has_value()) {
192 if (loopBound->is<Attribute>()) {
193 if (auto bound =
194 dyn_cast_or_null<IntegerAttr>(loopBound->get<Attribute>()))
195 return bound.getValue();
196 } else if (auto value = llvm::dyn_cast_if_present<Value>(*loopBound)) {
197 const IntegerValueRangeLattice *lattice =
198 getLatticeElementFor(op, value);
199 if (lattice != nullptr)
200 return getUpper ? lattice->getValue().getValue().smax()
201 : lattice->getValue().getValue().smin();
204 // Given the results of getConstant{Lower,Upper}Bound()
205 // or getConstantStep() on a LoopLikeInterface return the lower/upper
206 // bound
207 return getUpper ? APInt::getSignedMaxValue(width)
208 : APInt::getSignedMinValue(width);
211 // Infer bounds for loop arguments that have static bounds
212 if (auto loop = dyn_cast<LoopLikeOpInterface>(op)) {
213 std::optional<Value> iv = loop.getSingleInductionVar();
214 if (!iv) {
215 return SparseForwardDataFlowAnalysis ::visitNonControlFlowArguments(
216 op, successor, argLattices, firstIndex);
218 std::optional<OpFoldResult> lowerBound = loop.getSingleLowerBound();
219 std::optional<OpFoldResult> upperBound = loop.getSingleUpperBound();
220 std::optional<OpFoldResult> step = loop.getSingleStep();
221 APInt min = getLoopBoundFromFold(lowerBound, iv->getType(),
222 /*getUpper=*/false);
223 APInt max = getLoopBoundFromFold(upperBound, iv->getType(),
224 /*getUpper=*/true);
225 // Assume positivity for uniscoverable steps by way of getUpper = true.
226 APInt stepVal =
227 getLoopBoundFromFold(step, iv->getType(), /*getUpper=*/true);
229 if (stepVal.isNegative()) {
230 std::swap(min, max);
231 } else {
232 // Correct the upper bound by subtracting 1 so that it becomes a <=
233 // bound, because loops do not generally include their upper bound.
234 max -= 1;
237 IntegerValueRangeLattice *ivEntry = getLatticeElement(*iv);
238 auto ivRange = ConstantIntRanges::fromSigned(min, max);
239 propagateIfChanged(ivEntry, ivEntry->join(IntegerValueRange{ivRange}));
240 return;
243 return SparseForwardDataFlowAnalysis::visitNonControlFlowArguments(
244 op, successor, argLattices, firstIndex);