[LLD][COFF] Emit tail merge pdata for delay load thunks on ARM64EC (#116810)
[llvm-project.git] / mlir / lib / Analysis / DataFlow / IntegerRangeAnalysis.cpp
bloba97e43708d9a37d8c6d2fd005fb0a88c565a62e6
1 //===- IntegerRangeAnalysis.cpp - Integer range analysis --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the dataflow analysis class for integer range inference
10 // which is used in transformations over the `arith` dialect such as
11 // branch elimination or signed->unsigned rewriting
13 //===----------------------------------------------------------------------===//
15 #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
16 #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
17 #include "mlir/Analysis/DataFlow/SparseAnalysis.h"
18 #include "mlir/Analysis/DataFlowFramework.h"
19 #include "mlir/IR/BuiltinAttributes.h"
20 #include "mlir/IR/Dialect.h"
21 #include "mlir/IR/OpDefinition.h"
22 #include "mlir/IR/TypeUtilities.h"
23 #include "mlir/IR/Value.h"
24 #include "mlir/Interfaces/ControlFlowInterfaces.h"
25 #include "mlir/Interfaces/InferIntRangeInterface.h"
26 #include "mlir/Interfaces/LoopLikeInterface.h"
27 #include "mlir/Support/LLVM.h"
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/Support/Casting.h"
30 #include "llvm/Support/Debug.h"
31 #include <cassert>
32 #include <optional>
33 #include <utility>
35 #define DEBUG_TYPE "int-range-analysis"
37 using namespace mlir;
38 using namespace mlir::dataflow;
40 void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const {
41 Lattice::onUpdate(solver);
43 // If the integer range can be narrowed to a constant, update the constant
44 // value of the SSA value.
45 std::optional<APInt> constant = getValue().getValue().getConstantValue();
46 auto value = anchor.get<Value>();
47 auto *cv = solver->getOrCreateState<Lattice<ConstantValue>>(value);
48 if (!constant)
49 return solver->propagateIfChanged(
50 cv, cv->join(ConstantValue::getUnknownConstant()));
52 Dialect *dialect;
53 if (auto *parent = value.getDefiningOp())
54 dialect = parent->getDialect();
55 else
56 dialect = value.getParentBlock()->getParentOp()->getDialect();
58 Type type = getElementTypeOrSelf(value);
59 solver->propagateIfChanged(
60 cv, cv->join(ConstantValue(IntegerAttr::get(type, *constant), dialect)));
63 LogicalResult IntegerRangeAnalysis::visitOperation(
64 Operation *op, ArrayRef<const IntegerValueRangeLattice *> operands,
65 ArrayRef<IntegerValueRangeLattice *> results) {
66 auto inferrable = dyn_cast<InferIntRangeInterface>(op);
67 if (!inferrable) {
68 setAllToEntryStates(results);
69 return success();
72 LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n");
73 auto argRanges = llvm::map_to_vector(
74 operands, [](const IntegerValueRangeLattice *lattice) {
75 return lattice->getValue();
76 });
78 auto joinCallback = [&](Value v, const IntegerValueRange &attrs) {
79 auto result = dyn_cast<OpResult>(v);
80 if (!result)
81 return;
82 assert(llvm::is_contained(op->getResults(), result));
84 LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n");
85 IntegerValueRangeLattice *lattice = results[result.getResultNumber()];
86 IntegerValueRange oldRange = lattice->getValue();
88 ChangeResult changed = lattice->join(attrs);
90 // Catch loop results with loop variant bounds and conservatively make
91 // them [-inf, inf] so we don't circle around infinitely often (because
92 // the dataflow analysis in MLIR doesn't attempt to work out trip counts
93 // and often can't).
94 bool isYieldedResult = llvm::any_of(v.getUsers(), [](Operation *op) {
95 return op->hasTrait<OpTrait::IsTerminator>();
96 });
97 if (isYieldedResult && !oldRange.isUninitialized() &&
98 !(lattice->getValue() == oldRange)) {
99 LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
100 changed |= lattice->join(IntegerValueRange::getMaxRange(v));
102 propagateIfChanged(lattice, changed);
105 inferrable.inferResultRangesFromOptional(argRanges, joinCallback);
106 return success();
109 void IntegerRangeAnalysis::visitNonControlFlowArguments(
110 Operation *op, const RegionSuccessor &successor,
111 ArrayRef<IntegerValueRangeLattice *> argLattices, unsigned firstIndex) {
112 if (auto inferrable = dyn_cast<InferIntRangeInterface>(op)) {
113 LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n");
115 auto argRanges = llvm::map_to_vector(op->getOperands(), [&](Value value) {
116 return getLatticeElementFor(getProgramPointAfter(op), value)->getValue();
119 auto joinCallback = [&](Value v, const IntegerValueRange &attrs) {
120 auto arg = dyn_cast<BlockArgument>(v);
121 if (!arg)
122 return;
123 if (!llvm::is_contained(successor.getSuccessor()->getArguments(), arg))
124 return;
126 LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n");
127 IntegerValueRangeLattice *lattice = argLattices[arg.getArgNumber()];
128 IntegerValueRange oldRange = lattice->getValue();
130 ChangeResult changed = lattice->join(attrs);
132 // Catch loop results with loop variant bounds and conservatively make
133 // them [-inf, inf] so we don't circle around infinitely often (because
134 // the dataflow analysis in MLIR doesn't attempt to work out trip counts
135 // and often can't).
136 bool isYieldedValue = llvm::any_of(v.getUsers(), [](Operation *op) {
137 return op->hasTrait<OpTrait::IsTerminator>();
139 if (isYieldedValue && !oldRange.isUninitialized() &&
140 !(lattice->getValue() == oldRange)) {
141 LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
142 changed |= lattice->join(IntegerValueRange::getMaxRange(v));
144 propagateIfChanged(lattice, changed);
147 inferrable.inferResultRangesFromOptional(argRanges, joinCallback);
148 return;
151 /// Given the results of getConstant{Lower,Upper}Bound() or getConstantStep()
152 /// on a LoopLikeInterface return the lower/upper bound for that result if
153 /// possible.
154 auto getLoopBoundFromFold = [&](std::optional<OpFoldResult> loopBound,
155 Type boundType, bool getUpper) {
156 unsigned int width = ConstantIntRanges::getStorageBitwidth(boundType);
157 if (loopBound.has_value()) {
158 if (loopBound->is<Attribute>()) {
159 if (auto bound =
160 dyn_cast_or_null<IntegerAttr>(loopBound->get<Attribute>()))
161 return bound.getValue();
162 } else if (auto value = llvm::dyn_cast_if_present<Value>(*loopBound)) {
163 const IntegerValueRangeLattice *lattice =
164 getLatticeElementFor(getProgramPointAfter(op), value);
165 if (lattice != nullptr && !lattice->getValue().isUninitialized())
166 return getUpper ? lattice->getValue().getValue().smax()
167 : lattice->getValue().getValue().smin();
170 // Given the results of getConstant{Lower,Upper}Bound()
171 // or getConstantStep() on a LoopLikeInterface return the lower/upper
172 // bound
173 return getUpper ? APInt::getSignedMaxValue(width)
174 : APInt::getSignedMinValue(width);
177 // Infer bounds for loop arguments that have static bounds
178 if (auto loop = dyn_cast<LoopLikeOpInterface>(op)) {
179 std::optional<Value> iv = loop.getSingleInductionVar();
180 if (!iv) {
181 return SparseForwardDataFlowAnalysis ::visitNonControlFlowArguments(
182 op, successor, argLattices, firstIndex);
184 std::optional<OpFoldResult> lowerBound = loop.getSingleLowerBound();
185 std::optional<OpFoldResult> upperBound = loop.getSingleUpperBound();
186 std::optional<OpFoldResult> step = loop.getSingleStep();
187 APInt min = getLoopBoundFromFold(lowerBound, iv->getType(),
188 /*getUpper=*/false);
189 APInt max = getLoopBoundFromFold(upperBound, iv->getType(),
190 /*getUpper=*/true);
191 // Assume positivity for uniscoverable steps by way of getUpper = true.
192 APInt stepVal =
193 getLoopBoundFromFold(step, iv->getType(), /*getUpper=*/true);
195 if (stepVal.isNegative()) {
196 std::swap(min, max);
197 } else {
198 // Correct the upper bound by subtracting 1 so that it becomes a <=
199 // bound, because loops do not generally include their upper bound.
200 max -= 1;
203 // If we infer the lower bound to be larger than the upper bound, the
204 // resulting range is meaningless and should not be used in further
205 // inferences.
206 if (max.sge(min)) {
207 IntegerValueRangeLattice *ivEntry = getLatticeElement(*iv);
208 auto ivRange = ConstantIntRanges::fromSigned(min, max);
209 propagateIfChanged(ivEntry, ivEntry->join(IntegerValueRange{ivRange}));
211 return;
214 return SparseForwardDataFlowAnalysis::visitNonControlFlowArguments(
215 op, successor, argLattices, firstIndex);