1 //===- InferIntRangeInterface.cpp - Integer range inference interface ---===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Interfaces/InferIntRangeInterface.h"
10 #include "mlir/IR/BuiltinTypes.h"
11 #include "mlir/IR/TypeUtilities.h"
12 #include "mlir/Interfaces/InferIntRangeInterface.cpp.inc"
17 bool ConstantIntRanges::operator==(const ConstantIntRanges
&other
) const {
18 return umin().getBitWidth() == other
.umin().getBitWidth() &&
19 umin() == other
.umin() && umax() == other
.umax() &&
20 smin() == other
.smin() && smax() == other
.smax();
23 const APInt
&ConstantIntRanges::umin() const { return uminVal
; }
25 const APInt
&ConstantIntRanges::umax() const { return umaxVal
; }
27 const APInt
&ConstantIntRanges::smin() const { return sminVal
; }
29 const APInt
&ConstantIntRanges::smax() const { return smaxVal
; }
31 unsigned ConstantIntRanges::getStorageBitwidth(Type type
) {
32 type
= getElementTypeOrSelf(type
);
34 return IndexType::kInternalStorageBitWidth
;
35 if (auto integerType
= dyn_cast
<IntegerType
>(type
))
36 return integerType
.getWidth();
37 // Non-integer types have their bounds stored in width 0 `APInt`s.
41 ConstantIntRanges
ConstantIntRanges::maxRange(unsigned bitwidth
) {
42 return fromUnsigned(APInt::getZero(bitwidth
), APInt::getMaxValue(bitwidth
));
45 ConstantIntRanges
ConstantIntRanges::constant(const APInt
&value
) {
46 return {value
, value
, value
, value
};
49 ConstantIntRanges
ConstantIntRanges::range(const APInt
&min
, const APInt
&max
,
52 return fromSigned(min
, max
);
53 return fromUnsigned(min
, max
);
56 ConstantIntRanges
ConstantIntRanges::fromSigned(const APInt
&smin
,
58 unsigned int width
= smin
.getBitWidth();
60 if (smin
.isNonNegative() == smax
.isNonNegative()) {
61 umin
= smin
.ult(smax
) ? smin
: smax
;
62 umax
= smin
.ugt(smax
) ? smin
: smax
;
64 umin
= APInt::getMinValue(width
);
65 umax
= APInt::getMaxValue(width
);
67 return {umin
, umax
, smin
, smax
};
70 ConstantIntRanges
ConstantIntRanges::fromUnsigned(const APInt
&umin
,
72 unsigned int width
= umin
.getBitWidth();
74 if (umin
.isNonNegative() == umax
.isNonNegative()) {
75 smin
= umin
.slt(umax
) ? umin
: umax
;
76 smax
= umin
.sgt(umax
) ? umin
: umax
;
78 smin
= APInt::getSignedMinValue(width
);
79 smax
= APInt::getSignedMaxValue(width
);
81 return {umin
, umax
, smin
, smax
};
85 ConstantIntRanges::rangeUnion(const ConstantIntRanges
&other
) const {
86 // "Not an integer" poisons everything and also cannot be fed to comparison
88 if (umin().getBitWidth() == 0)
90 if (other
.umin().getBitWidth() == 0)
93 const APInt
&uminUnion
= umin().ult(other
.umin()) ? umin() : other
.umin();
94 const APInt
&umaxUnion
= umax().ugt(other
.umax()) ? umax() : other
.umax();
95 const APInt
&sminUnion
= smin().slt(other
.smin()) ? smin() : other
.smin();
96 const APInt
&smaxUnion
= smax().sgt(other
.smax()) ? smax() : other
.smax();
98 return {uminUnion
, umaxUnion
, sminUnion
, smaxUnion
};
102 ConstantIntRanges::intersection(const ConstantIntRanges
&other
) const {
103 // "Not an integer" poisons everything and also cannot be fed to comparison
105 if (umin().getBitWidth() == 0)
107 if (other
.umin().getBitWidth() == 0)
110 const APInt
&uminIntersect
= umin().ugt(other
.umin()) ? umin() : other
.umin();
111 const APInt
&umaxIntersect
= umax().ult(other
.umax()) ? umax() : other
.umax();
112 const APInt
&sminIntersect
= smin().sgt(other
.smin()) ? smin() : other
.smin();
113 const APInt
&smaxIntersect
= smax().slt(other
.smax()) ? smax() : other
.smax();
115 return {uminIntersect
, umaxIntersect
, sminIntersect
, smaxIntersect
};
118 std::optional
<APInt
> ConstantIntRanges::getConstantValue() const {
119 // Note: we need to exclude the trivially-equal width 0 values here.
120 if (umin() == umax() && umin().getBitWidth() != 0)
122 if (smin() == smax() && smin().getBitWidth() != 0)
127 raw_ostream
&mlir::operator<<(raw_ostream
&os
, const ConstantIntRanges
&range
) {
128 return os
<< "unsigned : [" << range
.umin() << ", " << range
.umax()
129 << "] signed : [" << range
.smin() << ", " << range
.smax() << "]";
132 IntegerValueRange
IntegerValueRange::getMaxRange(Value value
) {
133 unsigned width
= ConstantIntRanges::getStorageBitwidth(value
.getType());
137 APInt umin
= APInt::getMinValue(width
);
138 APInt umax
= APInt::getMaxValue(width
);
139 APInt smin
= width
!= 0 ? APInt::getSignedMinValue(width
) : umin
;
140 APInt smax
= width
!= 0 ? APInt::getSignedMaxValue(width
) : umax
;
141 return IntegerValueRange
{ConstantIntRanges
{umin
, umax
, smin
, smax
}};
144 raw_ostream
&mlir::operator<<(raw_ostream
&os
, const IntegerValueRange
&range
) {
149 void mlir::intrange::detail::defaultInferResultRanges(
150 InferIntRangeInterface interface
, ArrayRef
<IntegerValueRange
> argRanges
,
151 SetIntLatticeFn setResultRanges
) {
152 llvm::SmallVector
<ConstantIntRanges
> unpacked
;
153 unpacked
.reserve(argRanges
.size());
155 for (const IntegerValueRange
&range
: argRanges
) {
156 if (range
.isUninitialized())
158 unpacked
.push_back(range
.getValue());
161 interface
.inferResultRanges(
163 [&setResultRanges
](Value value
, const ConstantIntRanges
&argRanges
) {
164 setResultRanges(value
, IntegerValueRange
{argRanges
});
168 void mlir::intrange::detail::defaultInferResultRangesFromOptional(
169 InferIntRangeInterface interface
, ArrayRef
<ConstantIntRanges
> argRanges
,
170 SetIntRangeFn setResultRanges
) {
171 auto ranges
= llvm::to_vector_of
<IntegerValueRange
>(argRanges
);
172 interface
.inferResultRangesFromOptional(
174 [&setResultRanges
](Value value
, const IntegerValueRange
&argRanges
) {
175 if (!argRanges
.isUninitialized())
176 setResultRanges(value
, argRanges
.getValue());