1 //===-- KnownBits.cpp - Stores known zeros/ones ---------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file contains a class for representing known zeros and ones used by
12 //===----------------------------------------------------------------------===//
14 #include "llvm/Support/KnownBits.h"
15 #include "llvm/Support/Debug.h"
16 #include "llvm/Support/raw_ostream.h"
21 static KnownBits
computeForAddCarry(
22 const KnownBits
&LHS
, const KnownBits
&RHS
,
23 bool CarryZero
, bool CarryOne
) {
24 assert(!(CarryZero
&& CarryOne
) &&
25 "Carry can't be zero and one at the same time");
27 APInt PossibleSumZero
= LHS
.getMaxValue() + RHS
.getMaxValue() + !CarryZero
;
28 APInt PossibleSumOne
= LHS
.getMinValue() + RHS
.getMinValue() + CarryOne
;
30 // Compute known bits of the carry.
31 APInt CarryKnownZero
= ~(PossibleSumZero
^ LHS
.Zero
^ RHS
.Zero
);
32 APInt CarryKnownOne
= PossibleSumOne
^ LHS
.One
^ RHS
.One
;
34 // Compute set of known bits (where all three relevant bits are known).
35 APInt LHSKnownUnion
= LHS
.Zero
| LHS
.One
;
36 APInt RHSKnownUnion
= RHS
.Zero
| RHS
.One
;
37 APInt CarryKnownUnion
= std::move(CarryKnownZero
) | CarryKnownOne
;
38 APInt Known
= std::move(LHSKnownUnion
) & RHSKnownUnion
& CarryKnownUnion
;
40 assert((PossibleSumZero
& Known
) == (PossibleSumOne
& Known
) &&
41 "known bits of sum differ");
43 // Compute known bits of the result.
45 KnownOut
.Zero
= ~std::move(PossibleSumZero
) & Known
;
46 KnownOut
.One
= std::move(PossibleSumOne
) & Known
;
50 KnownBits
KnownBits::computeForAddCarry(
51 const KnownBits
&LHS
, const KnownBits
&RHS
, const KnownBits
&Carry
) {
52 assert(Carry
.getBitWidth() == 1 && "Carry must be 1-bit");
53 return ::computeForAddCarry(
54 LHS
, RHS
, Carry
.Zero
.getBoolValue(), Carry
.One
.getBoolValue());
57 KnownBits
KnownBits::computeForAddSub(bool Add
, bool NSW
,
58 const KnownBits
&LHS
, KnownBits RHS
) {
61 // Sum = LHS + RHS + 0
62 KnownOut
= ::computeForAddCarry(
63 LHS
, RHS
, /*CarryZero*/true, /*CarryOne*/false);
65 // Sum = LHS + ~RHS + 1
66 std::swap(RHS
.Zero
, RHS
.One
);
67 KnownOut
= ::computeForAddCarry(
68 LHS
, RHS
, /*CarryZero*/false, /*CarryOne*/true);
71 // Are we still trying to solve for the sign bit?
72 if (!KnownOut
.isNegative() && !KnownOut
.isNonNegative()) {
74 // Adding two non-negative numbers, or subtracting a negative number from
75 // a non-negative one, can't wrap into negative.
76 if (LHS
.isNonNegative() && RHS
.isNonNegative())
77 KnownOut
.makeNonNegative();
78 // Adding two negative numbers, or subtracting a non-negative number from
79 // a negative one, can't wrap into non-negative.
80 else if (LHS
.isNegative() && RHS
.isNegative())
81 KnownOut
.makeNegative();
88 KnownBits
KnownBits::computeForSubBorrow(const KnownBits
&LHS
, KnownBits RHS
,
89 const KnownBits
&Borrow
) {
90 assert(Borrow
.getBitWidth() == 1 && "Borrow must be 1-bit");
92 // LHS - RHS = LHS + ~RHS + 1
93 // Carry 1 - Borrow in ::computeForAddCarry
94 std::swap(RHS
.Zero
, RHS
.One
);
95 return ::computeForAddCarry(LHS
, RHS
,
96 /*CarryZero=*/Borrow
.One
.getBoolValue(),
97 /*CarryOne=*/Borrow
.Zero
.getBoolValue());
100 KnownBits
KnownBits::sextInReg(unsigned SrcBitWidth
) const {
101 unsigned BitWidth
= getBitWidth();
102 assert(0 < SrcBitWidth
&& SrcBitWidth
<= BitWidth
&&
103 "Illegal sext-in-register");
105 if (SrcBitWidth
== BitWidth
)
108 unsigned ExtBits
= BitWidth
- SrcBitWidth
;
110 Result
.One
= One
<< ExtBits
;
111 Result
.Zero
= Zero
<< ExtBits
;
112 Result
.One
.ashrInPlace(ExtBits
);
113 Result
.Zero
.ashrInPlace(ExtBits
);
117 KnownBits
KnownBits::makeGE(const APInt
&Val
) const {
118 // Count the number of leading bit positions where our underlying value is
119 // known to be less than or equal to Val.
120 unsigned N
= (Zero
| Val
).countl_one();
122 // For each of those bit positions, if Val has a 1 in that bit then our
123 // underlying value must also have a 1.
124 APInt
MaskedVal(Val
);
125 MaskedVal
.clearLowBits(getBitWidth() - N
);
126 return KnownBits(Zero
, One
| MaskedVal
);
129 KnownBits
KnownBits::umax(const KnownBits
&LHS
, const KnownBits
&RHS
) {
130 // If we can prove that LHS >= RHS then use LHS as the result. Likewise for
131 // RHS. Ideally our caller would already have spotted these cases and
132 // optimized away the umax operation, but we handle them here for
134 if (LHS
.getMinValue().uge(RHS
.getMaxValue()))
136 if (RHS
.getMinValue().uge(LHS
.getMaxValue()))
139 // If the result of the umax is LHS then it must be greater than or equal to
140 // the minimum possible value of RHS. Likewise for RHS. Any known bits that
141 // are common to these two values are also known in the result.
142 KnownBits L
= LHS
.makeGE(RHS
.getMinValue());
143 KnownBits R
= RHS
.makeGE(LHS
.getMinValue());
144 return L
.intersectWith(R
);
147 KnownBits
KnownBits::umin(const KnownBits
&LHS
, const KnownBits
&RHS
) {
148 // Flip the range of values: [0, 0xFFFFFFFF] <-> [0xFFFFFFFF, 0]
149 auto Flip
= [](const KnownBits
&Val
) { return KnownBits(Val
.One
, Val
.Zero
); };
150 return Flip(umax(Flip(LHS
), Flip(RHS
)));
153 KnownBits
KnownBits::smax(const KnownBits
&LHS
, const KnownBits
&RHS
) {
154 // Flip the range of values: [-0x80000000, 0x7FFFFFFF] <-> [0, 0xFFFFFFFF]
155 auto Flip
= [](const KnownBits
&Val
) {
156 unsigned SignBitPosition
= Val
.getBitWidth() - 1;
157 APInt Zero
= Val
.Zero
;
159 Zero
.setBitVal(SignBitPosition
, Val
.One
[SignBitPosition
]);
160 One
.setBitVal(SignBitPosition
, Val
.Zero
[SignBitPosition
]);
161 return KnownBits(Zero
, One
);
163 return Flip(umax(Flip(LHS
), Flip(RHS
)));
166 KnownBits
KnownBits::smin(const KnownBits
&LHS
, const KnownBits
&RHS
) {
167 // Flip the range of values: [-0x80000000, 0x7FFFFFFF] <-> [0xFFFFFFFF, 0]
168 auto Flip
= [](const KnownBits
&Val
) {
169 unsigned SignBitPosition
= Val
.getBitWidth() - 1;
170 APInt Zero
= Val
.One
;
171 APInt One
= Val
.Zero
;
172 Zero
.setBitVal(SignBitPosition
, Val
.Zero
[SignBitPosition
]);
173 One
.setBitVal(SignBitPosition
, Val
.One
[SignBitPosition
]);
174 return KnownBits(Zero
, One
);
176 return Flip(umax(Flip(LHS
), Flip(RHS
)));
179 static unsigned getMaxShiftAmount(const APInt
&MaxValue
, unsigned BitWidth
) {
180 if (isPowerOf2_32(BitWidth
))
181 return MaxValue
.extractBitsAsZExtValue(Log2_32(BitWidth
), 0);
182 // This is only an approximate upper bound.
183 return MaxValue
.getLimitedValue(BitWidth
- 1);
186 KnownBits
KnownBits::shl(const KnownBits
&LHS
, const KnownBits
&RHS
, bool NUW
,
187 bool NSW
, bool ShAmtNonZero
) {
188 unsigned BitWidth
= LHS
.getBitWidth();
189 auto ShiftByConst
= [&](const KnownBits
&LHS
, unsigned ShiftAmt
) {
191 bool ShiftedOutZero
, ShiftedOutOne
;
192 Known
.Zero
= LHS
.Zero
.ushl_ov(ShiftAmt
, ShiftedOutZero
);
193 Known
.Zero
.setLowBits(ShiftAmt
);
194 Known
.One
= LHS
.One
.ushl_ov(ShiftAmt
, ShiftedOutOne
);
196 // All cases returning poison have been handled by MaxShiftAmount already.
198 if (NUW
&& ShiftAmt
!= 0)
199 // NUW means we can assume anything shifted out was a zero.
200 ShiftedOutZero
= true;
203 Known
.makeNonNegative();
204 else if (ShiftedOutOne
)
205 Known
.makeNegative();
210 // Fast path for a common case when LHS is completely unknown.
211 KnownBits
Known(BitWidth
);
212 unsigned MinShiftAmount
= RHS
.getMinValue().getLimitedValue(BitWidth
);
213 if (MinShiftAmount
== 0 && ShAmtNonZero
)
215 if (LHS
.isUnknown()) {
216 Known
.Zero
.setLowBits(MinShiftAmount
);
217 if (NUW
&& NSW
&& MinShiftAmount
!= 0)
218 Known
.makeNonNegative();
222 // Determine maximum shift amount, taking NUW/NSW flags into account.
223 APInt MaxValue
= RHS
.getMaxValue();
224 unsigned MaxShiftAmount
= getMaxShiftAmount(MaxValue
, BitWidth
);
226 MaxShiftAmount
= std::min(MaxShiftAmount
, LHS
.countMaxLeadingZeros() - 1);
228 MaxShiftAmount
= std::min(MaxShiftAmount
, LHS
.countMaxLeadingZeros());
230 MaxShiftAmount
= std::min(
232 std::max(LHS
.countMaxLeadingZeros(), LHS
.countMaxLeadingOnes()) - 1);
234 // Fast path for common case where the shift amount is unknown.
235 if (MinShiftAmount
== 0 && MaxShiftAmount
== BitWidth
- 1 &&
236 isPowerOf2_32(BitWidth
)) {
237 Known
.Zero
.setLowBits(LHS
.countMinTrailingZeros());
239 Known
.One
.setSignBit();
241 if (LHS
.isNonNegative())
242 Known
.makeNonNegative();
243 if (LHS
.isNegative())
244 Known
.makeNegative();
249 // Find the common bits from all possible shifts.
250 unsigned ShiftAmtZeroMask
= RHS
.Zero
.zextOrTrunc(32).getZExtValue();
251 unsigned ShiftAmtOneMask
= RHS
.One
.zextOrTrunc(32).getZExtValue();
252 Known
.Zero
.setAllBits();
253 Known
.One
.setAllBits();
254 for (unsigned ShiftAmt
= MinShiftAmount
; ShiftAmt
<= MaxShiftAmount
;
256 // Skip if the shift amount is impossible.
257 if ((ShiftAmtZeroMask
& ShiftAmt
) != 0 ||
258 (ShiftAmtOneMask
| ShiftAmt
) != ShiftAmt
)
260 Known
= Known
.intersectWith(ShiftByConst(LHS
, ShiftAmt
));
261 if (Known
.isUnknown())
265 // All shift amounts may result in poison.
266 if (Known
.hasConflict())
271 KnownBits
KnownBits::lshr(const KnownBits
&LHS
, const KnownBits
&RHS
,
273 unsigned BitWidth
= LHS
.getBitWidth();
274 auto ShiftByConst
= [&](const KnownBits
&LHS
, unsigned ShiftAmt
) {
275 KnownBits Known
= LHS
;
276 Known
.Zero
.lshrInPlace(ShiftAmt
);
277 Known
.One
.lshrInPlace(ShiftAmt
);
278 // High bits are known zero.
279 Known
.Zero
.setHighBits(ShiftAmt
);
283 // Fast path for a common case when LHS is completely unknown.
284 KnownBits
Known(BitWidth
);
285 unsigned MinShiftAmount
= RHS
.getMinValue().getLimitedValue(BitWidth
);
286 if (MinShiftAmount
== 0 && ShAmtNonZero
)
288 if (LHS
.isUnknown()) {
289 Known
.Zero
.setHighBits(MinShiftAmount
);
293 // Find the common bits from all possible shifts.
294 APInt MaxValue
= RHS
.getMaxValue();
295 unsigned MaxShiftAmount
= getMaxShiftAmount(MaxValue
, BitWidth
);
296 unsigned ShiftAmtZeroMask
= RHS
.Zero
.zextOrTrunc(32).getZExtValue();
297 unsigned ShiftAmtOneMask
= RHS
.One
.zextOrTrunc(32).getZExtValue();
298 Known
.Zero
.setAllBits();
299 Known
.One
.setAllBits();
300 for (unsigned ShiftAmt
= MinShiftAmount
; ShiftAmt
<= MaxShiftAmount
;
302 // Skip if the shift amount is impossible.
303 if ((ShiftAmtZeroMask
& ShiftAmt
) != 0 ||
304 (ShiftAmtOneMask
| ShiftAmt
) != ShiftAmt
)
306 Known
= Known
.intersectWith(ShiftByConst(LHS
, ShiftAmt
));
307 if (Known
.isUnknown())
311 // All shift amounts may result in poison.
312 if (Known
.hasConflict())
317 KnownBits
KnownBits::ashr(const KnownBits
&LHS
, const KnownBits
&RHS
,
319 unsigned BitWidth
= LHS
.getBitWidth();
320 auto ShiftByConst
= [&](const KnownBits
&LHS
, unsigned ShiftAmt
) {
321 KnownBits Known
= LHS
;
322 Known
.Zero
.ashrInPlace(ShiftAmt
);
323 Known
.One
.ashrInPlace(ShiftAmt
);
327 // Fast path for a common case when LHS is completely unknown.
328 KnownBits
Known(BitWidth
);
329 unsigned MinShiftAmount
= RHS
.getMinValue().getLimitedValue(BitWidth
);
330 if (MinShiftAmount
== 0 && ShAmtNonZero
)
332 if (LHS
.isUnknown()) {
333 if (MinShiftAmount
== BitWidth
) {
334 // Always poison. Return zero because we don't like returning conflict.
341 // Find the common bits from all possible shifts.
342 APInt MaxValue
= RHS
.getMaxValue();
343 unsigned MaxShiftAmount
= getMaxShiftAmount(MaxValue
, BitWidth
);
344 unsigned ShiftAmtZeroMask
= RHS
.Zero
.zextOrTrunc(32).getZExtValue();
345 unsigned ShiftAmtOneMask
= RHS
.One
.zextOrTrunc(32).getZExtValue();
346 Known
.Zero
.setAllBits();
347 Known
.One
.setAllBits();
348 for (unsigned ShiftAmt
= MinShiftAmount
; ShiftAmt
<= MaxShiftAmount
;
350 // Skip if the shift amount is impossible.
351 if ((ShiftAmtZeroMask
& ShiftAmt
) != 0 ||
352 (ShiftAmtOneMask
| ShiftAmt
) != ShiftAmt
)
354 Known
= Known
.intersectWith(ShiftByConst(LHS
, ShiftAmt
));
355 if (Known
.isUnknown())
359 // All shift amounts may result in poison.
360 if (Known
.hasConflict())
365 std::optional
<bool> KnownBits::eq(const KnownBits
&LHS
, const KnownBits
&RHS
) {
366 if (LHS
.isConstant() && RHS
.isConstant())
367 return std::optional
<bool>(LHS
.getConstant() == RHS
.getConstant());
368 if (LHS
.One
.intersects(RHS
.Zero
) || RHS
.One
.intersects(LHS
.Zero
))
369 return std::optional
<bool>(false);
373 std::optional
<bool> KnownBits::ne(const KnownBits
&LHS
, const KnownBits
&RHS
) {
374 if (std::optional
<bool> KnownEQ
= eq(LHS
, RHS
))
375 return std::optional
<bool>(!*KnownEQ
);
379 std::optional
<bool> KnownBits::ugt(const KnownBits
&LHS
, const KnownBits
&RHS
) {
380 // LHS >u RHS -> false if umax(LHS) <= umax(RHS)
381 if (LHS
.getMaxValue().ule(RHS
.getMinValue()))
382 return std::optional
<bool>(false);
383 // LHS >u RHS -> true if umin(LHS) > umax(RHS)
384 if (LHS
.getMinValue().ugt(RHS
.getMaxValue()))
385 return std::optional
<bool>(true);
389 std::optional
<bool> KnownBits::uge(const KnownBits
&LHS
, const KnownBits
&RHS
) {
390 if (std::optional
<bool> IsUGT
= ugt(RHS
, LHS
))
391 return std::optional
<bool>(!*IsUGT
);
395 std::optional
<bool> KnownBits::ult(const KnownBits
&LHS
, const KnownBits
&RHS
) {
396 return ugt(RHS
, LHS
);
399 std::optional
<bool> KnownBits::ule(const KnownBits
&LHS
, const KnownBits
&RHS
) {
400 return uge(RHS
, LHS
);
403 std::optional
<bool> KnownBits::sgt(const KnownBits
&LHS
, const KnownBits
&RHS
) {
404 // LHS >s RHS -> false if smax(LHS) <= smax(RHS)
405 if (LHS
.getSignedMaxValue().sle(RHS
.getSignedMinValue()))
406 return std::optional
<bool>(false);
407 // LHS >s RHS -> true if smin(LHS) > smax(RHS)
408 if (LHS
.getSignedMinValue().sgt(RHS
.getSignedMaxValue()))
409 return std::optional
<bool>(true);
413 std::optional
<bool> KnownBits::sge(const KnownBits
&LHS
, const KnownBits
&RHS
) {
414 if (std::optional
<bool> KnownSGT
= sgt(RHS
, LHS
))
415 return std::optional
<bool>(!*KnownSGT
);
419 std::optional
<bool> KnownBits::slt(const KnownBits
&LHS
, const KnownBits
&RHS
) {
420 return sgt(RHS
, LHS
);
423 std::optional
<bool> KnownBits::sle(const KnownBits
&LHS
, const KnownBits
&RHS
) {
424 return sge(RHS
, LHS
);
427 KnownBits
KnownBits::abs(bool IntMinIsPoison
) const {
428 // If the source's MSB is zero then we know the rest of the bits already.
432 // Absolute value preserves trailing zero count.
433 KnownBits
KnownAbs(getBitWidth());
435 // If the input is negative, then abs(x) == -x.
437 KnownBits Tmp
= *this;
438 // Special case for IntMinIsPoison. We know the sign bit is set and we know
439 // all the rest of the bits except one to be zero. Since we have
440 // IntMinIsPoison, that final bit MUST be a one, as otherwise the input is
442 if (IntMinIsPoison
&& (Zero
.popcount() + 2) == getBitWidth())
443 Tmp
.One
.setBit(countMinTrailingZeros());
445 KnownAbs
= computeForAddSub(
446 /*Add*/ false, IntMinIsPoison
,
447 KnownBits::makeConstant(APInt(getBitWidth(), 0)), Tmp
);
449 // One more special case for IntMinIsPoison. If we don't know any ones other
450 // than the signbit, we know for certain that all the unknowns can't be
451 // zero. So if we know high zero bits, but have unknown low bits, we know
452 // for certain those high-zero bits will end up as one. This is because,
453 // the low bits can't be all zeros, so the +1 in (~x + 1) cannot carry up
454 // to the high bits. If we know a known INT_MIN input skip this. The result
455 // is poison anyways.
456 if (IntMinIsPoison
&& Tmp
.countMinPopulation() == 1 &&
457 Tmp
.countMaxPopulation() != 1) {
458 Tmp
.One
.clearSignBit();
459 Tmp
.Zero
.setSignBit();
460 KnownAbs
.One
.setBits(getBitWidth() - Tmp
.countMinLeadingZeros(),
465 unsigned MaxTZ
= countMaxTrailingZeros();
466 unsigned MinTZ
= countMinTrailingZeros();
468 KnownAbs
.Zero
.setLowBits(MinTZ
);
469 // If we know the lowest set 1, then preserve it.
470 if (MaxTZ
== MinTZ
&& MaxTZ
< getBitWidth())
471 KnownAbs
.One
.setBit(MaxTZ
);
473 // We only know that the absolute values's MSB will be zero if INT_MIN is
474 // poison, or there is a set bit that isn't the sign bit (otherwise it could
476 if (IntMinIsPoison
|| (!One
.isZero() && !One
.isMinSignedValue())) {
477 KnownAbs
.One
.clearSignBit();
478 KnownAbs
.Zero
.setSignBit();
482 assert(!KnownAbs
.hasConflict() && "Bad Output");
486 static KnownBits
computeForSatAddSub(bool Add
, bool Signed
,
487 const KnownBits
&LHS
,
488 const KnownBits
&RHS
) {
489 assert(!LHS
.hasConflict() && !RHS
.hasConflict() && "Bad inputs");
490 // We don't see NSW even for sadd/ssub as we want to check if the result has
492 KnownBits Res
= KnownBits::computeForAddSub(Add
, /*NSW*/ false, LHS
, RHS
);
493 unsigned BitWidth
= Res
.getBitWidth();
494 auto SignBitKnown
= [&](const KnownBits
&K
) {
495 return K
.Zero
[BitWidth
- 1] || K
.One
[BitWidth
- 1];
497 std::optional
<bool> Overflow
;
500 // If we can actually detect overflow do so. Otherwise leave Overflow as
501 // nullopt (we assume it may have happened).
502 if (SignBitKnown(LHS
) && SignBitKnown(RHS
) && SignBitKnown(Res
)) {
505 Overflow
= (LHS
.isNonNegative() == RHS
.isNonNegative() &&
506 Res
.isNonNegative() != LHS
.isNonNegative());
509 Overflow
= (LHS
.isNonNegative() != RHS
.isNonNegative() &&
510 Res
.isNonNegative() != LHS
.isNonNegative());
516 (void)LHS
.getMaxValue().uadd_ov(RHS
.getMaxValue(), Of
);
520 (void)LHS
.getMinValue().uadd_ov(RHS
.getMinValue(), Of
);
527 (void)LHS
.getMinValue().usub_ov(RHS
.getMaxValue(), Of
);
531 (void)LHS
.getMaxValue().usub_ov(RHS
.getMinValue(), Of
);
539 if (LHS
.isNonNegative() && RHS
.isNonNegative()) {
541 Res
.One
.clearSignBit();
542 Res
.Zero
.setSignBit();
544 if (LHS
.isNegative() && RHS
.isNegative()) {
546 Res
.One
.setSignBit();
547 Res
.Zero
.clearSignBit();
550 if (LHS
.isNegative() && RHS
.isNonNegative()) {
552 Res
.One
.setSignBit();
553 Res
.Zero
.clearSignBit();
554 } else if (LHS
.isNonNegative() && RHS
.isNegative()) {
556 Res
.One
.clearSignBit();
557 Res
.Zero
.setSignBit();
561 // Add: Leading ones of either operand are preserved.
562 // Sub: Leading zeros of LHS and leading ones of RHS are preserved
563 // as leading zeros in the result.
564 unsigned LeadingKnown
;
567 std::max(LHS
.countMinLeadingOnes(), RHS
.countMinLeadingOnes());
570 std::max(LHS
.countMinLeadingZeros(), RHS
.countMinLeadingOnes());
572 // We select between the operation result and all-ones/zero
573 // respectively, so we can preserve known ones/zeros.
574 APInt Mask
= APInt::getHighBitsSet(BitWidth
, LeadingKnown
);
585 // We know whether or not we overflowed.
588 assert(!Res
.hasConflict() && "Bad Output");
595 // sadd.sat / ssub.sat
596 assert(SignBitKnown(LHS
) &&
597 "We somehow know overflow without knowing input sign");
598 C
= LHS
.isNegative() ? APInt::getSignedMinValue(BitWidth
)
599 : APInt::getSignedMaxValue(BitWidth
);
602 C
= APInt::getMaxValue(BitWidth
);
605 C
= APInt::getMinValue(BitWidth
);
610 assert(!Res
.hasConflict() && "Bad Output");
614 // We don't know if we overflowed.
617 // We can keep our information about the sign bits.
618 Res
.Zero
.clearLowBits(BitWidth
- 1);
619 Res
.One
.clearLowBits(BitWidth
- 1);
622 // We need to clear all the known zeros as we can only use the leading ones.
623 Res
.Zero
.clearAllBits();
626 // We need to clear all the known ones as we can only use the leading zero.
627 Res
.One
.clearAllBits();
630 assert(!Res
.hasConflict() && "Bad Output");
634 KnownBits
KnownBits::sadd_sat(const KnownBits
&LHS
, const KnownBits
&RHS
) {
635 return computeForSatAddSub(/*Add*/ true, /*Signed*/ true, LHS
, RHS
);
637 KnownBits
KnownBits::ssub_sat(const KnownBits
&LHS
, const KnownBits
&RHS
) {
638 return computeForSatAddSub(/*Add*/ false, /*Signed*/ true, LHS
, RHS
);
640 KnownBits
KnownBits::uadd_sat(const KnownBits
&LHS
, const KnownBits
&RHS
) {
641 return computeForSatAddSub(/*Add*/ true, /*Signed*/ false, LHS
, RHS
);
643 KnownBits
KnownBits::usub_sat(const KnownBits
&LHS
, const KnownBits
&RHS
) {
644 return computeForSatAddSub(/*Add*/ false, /*Signed*/ false, LHS
, RHS
);
647 KnownBits
KnownBits::mul(const KnownBits
&LHS
, const KnownBits
&RHS
,
648 bool NoUndefSelfMultiply
) {
649 unsigned BitWidth
= LHS
.getBitWidth();
650 assert(BitWidth
== RHS
.getBitWidth() && !LHS
.hasConflict() &&
651 !RHS
.hasConflict() && "Operand mismatch");
652 assert((!NoUndefSelfMultiply
|| LHS
== RHS
) &&
653 "Self multiplication knownbits mismatch");
655 // Compute the high known-0 bits by multiplying the unsigned max of each side.
656 // Conservatively, M active bits * N active bits results in M + N bits in the
657 // result. But if we know a value is a power-of-2 for example, then this
658 // computes one more leading zero.
659 // TODO: This could be generalized to number of sign bits (negative numbers).
660 APInt UMaxLHS
= LHS
.getMaxValue();
661 APInt UMaxRHS
= RHS
.getMaxValue();
663 // For leading zeros in the result to be valid, the unsigned max product must
664 // fit in the bitwidth (it must not overflow).
666 APInt UMaxResult
= UMaxLHS
.umul_ov(UMaxRHS
, HasOverflow
);
667 unsigned LeadZ
= HasOverflow
? 0 : UMaxResult
.countl_zero();
669 // The result of the bottom bits of an integer multiply can be
670 // inferred by looking at the bottom bits of both operands and
671 // multiplying them together.
672 // We can infer at least the minimum number of known trailing bits
673 // of both operands. Depending on number of trailing zeros, we can
674 // infer more bits, because (a*b) <=> ((a/m) * (b/n)) * (m*n) assuming
675 // a and b are divisible by m and n respectively.
676 // We then calculate how many of those bits are inferrable and set
677 // the output. For example, the i8 mul:
680 // We know the bottom 3 bits are zero since the first can be divided by
681 // 4 and the second by 2, thus having ((12/4) * (14/2)) * (2*4).
682 // Applying the multiplication to the trimmed arguments gets:
692 // Which allows us to infer the 2 LSBs. Since we're multiplying the result
693 // by 8, the bottom 3 bits will be 0, so we can infer a total of 5 bits.
694 // The proof for this can be described as:
695 // Pre: (C1 >= 0) && (C1 < (1 << C5)) && (C2 >= 0) && (C2 < (1 << C6)) &&
696 // (C7 == (1 << (umin(countTrailingZeros(C1), C5) +
697 // umin(countTrailingZeros(C2), C6) +
698 // umin(C5 - umin(countTrailingZeros(C1), C5),
699 // C6 - umin(countTrailingZeros(C2), C6)))) - 1)
700 // %aa = shl i8 %a, C5
701 // %bb = shl i8 %b, C6
702 // %aaa = or i8 %aa, C1
703 // %bbb = or i8 %bb, C2
704 // %mul = mul i8 %aaa, %bbb
705 // %mask = and i8 %mul, C7
707 // %mask = i8 ((C1*C2)&C7)
708 // Where C5, C6 describe the known bits of %a, %b
709 // C1, C2 describe the known bottom bits of %a, %b.
710 // C7 describes the mask of the known bits of the result.
711 const APInt
&Bottom0
= LHS
.One
;
712 const APInt
&Bottom1
= RHS
.One
;
714 // How many times we'd be able to divide each argument by 2 (shr by 1).
715 // This gives us the number of trailing zeros on the multiplication result.
716 unsigned TrailBitsKnown0
= (LHS
.Zero
| LHS
.One
).countr_one();
717 unsigned TrailBitsKnown1
= (RHS
.Zero
| RHS
.One
).countr_one();
718 unsigned TrailZero0
= LHS
.countMinTrailingZeros();
719 unsigned TrailZero1
= RHS
.countMinTrailingZeros();
720 unsigned TrailZ
= TrailZero0
+ TrailZero1
;
722 // Figure out the fewest known-bits operand.
723 unsigned SmallestOperand
=
724 std::min(TrailBitsKnown0
- TrailZero0
, TrailBitsKnown1
- TrailZero1
);
725 unsigned ResultBitsKnown
= std::min(SmallestOperand
+ TrailZ
, BitWidth
);
728 Bottom0
.getLoBits(TrailBitsKnown0
) * Bottom1
.getLoBits(TrailBitsKnown1
);
730 KnownBits
Res(BitWidth
);
731 Res
.Zero
.setHighBits(LeadZ
);
732 Res
.Zero
|= (~BottomKnown
).getLoBits(ResultBitsKnown
);
733 Res
.One
= BottomKnown
.getLoBits(ResultBitsKnown
);
735 // If we're self-multiplying then bit[1] is guaranteed to be zero.
736 if (NoUndefSelfMultiply
&& BitWidth
> 1) {
737 assert(Res
.One
[1] == 0 &&
738 "Self-multiplication failed Quadratic Reciprocity!");
745 KnownBits
KnownBits::mulhs(const KnownBits
&LHS
, const KnownBits
&RHS
) {
746 unsigned BitWidth
= LHS
.getBitWidth();
747 assert(BitWidth
== RHS
.getBitWidth() && !LHS
.hasConflict() &&
748 !RHS
.hasConflict() && "Operand mismatch");
749 KnownBits WideLHS
= LHS
.sext(2 * BitWidth
);
750 KnownBits WideRHS
= RHS
.sext(2 * BitWidth
);
751 return mul(WideLHS
, WideRHS
).extractBits(BitWidth
, BitWidth
);
754 KnownBits
KnownBits::mulhu(const KnownBits
&LHS
, const KnownBits
&RHS
) {
755 unsigned BitWidth
= LHS
.getBitWidth();
756 assert(BitWidth
== RHS
.getBitWidth() && !LHS
.hasConflict() &&
757 !RHS
.hasConflict() && "Operand mismatch");
758 KnownBits WideLHS
= LHS
.zext(2 * BitWidth
);
759 KnownBits WideRHS
= RHS
.zext(2 * BitWidth
);
760 return mul(WideLHS
, WideRHS
).extractBits(BitWidth
, BitWidth
);
763 static KnownBits
divComputeLowBit(KnownBits Known
, const KnownBits
&LHS
,
764 const KnownBits
&RHS
, bool Exact
) {
769 // If LHS is Odd, the result is Odd no matter what.
771 // Odd / Even -> Impossible (because its exact division)
776 (int)LHS
.countMinTrailingZeros() - (int)RHS
.countMaxTrailingZeros();
778 (int)LHS
.countMaxTrailingZeros() - (int)RHS
.countMinTrailingZeros();
780 // Result has at least MinTZ trailing zeros.
781 Known
.Zero
.setLowBits(MinTZ
);
782 if (MinTZ
== MaxTZ
) {
783 // Result has exactly MinTZ trailing zeros.
784 Known
.One
.setBit(MinTZ
);
786 } else if (MaxTZ
< 0) {
791 // In the KnownBits exhaustive tests, we have poison inputs for exact values
792 // a LOT. If we have a conflict, just return all zeros.
793 if (Known
.hasConflict())
799 KnownBits
KnownBits::sdiv(const KnownBits
&LHS
, const KnownBits
&RHS
,
801 // Equivalent of `udiv`. We must have caught this before it was folded.
802 if (LHS
.isNonNegative() && RHS
.isNonNegative())
803 return udiv(LHS
, RHS
, Exact
);
805 unsigned BitWidth
= LHS
.getBitWidth();
806 assert(!LHS
.hasConflict() && !RHS
.hasConflict() && "Bad inputs");
807 KnownBits
Known(BitWidth
);
809 if (LHS
.isZero() || RHS
.isZero()) {
810 // Result is either known Zero or UB. Return Zero either way.
811 // Checking this earlier saves us a lot of special cases later on.
816 std::optional
<APInt
> Res
;
817 if (LHS
.isNegative() && RHS
.isNegative()) {
818 // Result non-negative.
819 APInt Denom
= RHS
.getSignedMaxValue();
820 APInt Num
= LHS
.getSignedMinValue();
821 // INT_MIN/-1 would be a poison result (impossible). Estimate the division
822 // as signed max (we will only set sign bit in the result).
823 Res
= (Num
.isMinSignedValue() && Denom
.isAllOnes())
824 ? APInt::getSignedMaxValue(BitWidth
)
826 } else if (LHS
.isNegative() && RHS
.isNonNegative()) {
827 // Result is negative if Exact OR -LHS u>= RHS.
828 if (Exact
|| (-LHS
.getSignedMaxValue()).uge(RHS
.getSignedMaxValue())) {
829 APInt Denom
= RHS
.getSignedMinValue();
830 APInt Num
= LHS
.getSignedMinValue();
831 Res
= Denom
.isZero() ? Num
: Num
.sdiv(Denom
);
833 } else if (LHS
.isStrictlyPositive() && RHS
.isNegative()) {
834 // Result is negative if Exact OR LHS u>= -RHS.
835 if (Exact
|| LHS
.getSignedMinValue().uge(-RHS
.getSignedMinValue())) {
836 APInt Denom
= RHS
.getSignedMaxValue();
837 APInt Num
= LHS
.getSignedMaxValue();
838 Res
= Num
.sdiv(Denom
);
843 if (Res
->isNonNegative()) {
844 unsigned LeadZ
= Res
->countLeadingZeros();
845 Known
.Zero
.setHighBits(LeadZ
);
847 unsigned LeadO
= Res
->countLeadingOnes();
848 Known
.One
.setHighBits(LeadO
);
852 Known
= divComputeLowBit(Known
, LHS
, RHS
, Exact
);
854 assert(!Known
.hasConflict() && "Bad Output");
858 KnownBits
KnownBits::udiv(const KnownBits
&LHS
, const KnownBits
&RHS
,
860 unsigned BitWidth
= LHS
.getBitWidth();
861 assert(!LHS
.hasConflict() && !RHS
.hasConflict());
862 KnownBits
Known(BitWidth
);
864 if (LHS
.isZero() || RHS
.isZero()) {
865 // Result is either known Zero or UB. Return Zero either way.
866 // Checking this earlier saves us a lot of special cases later on.
871 // We can figure out the minimum number of upper zero bits by doing
872 // MaxNumerator / MinDenominator. If the Numerator gets smaller or Denominator
873 // gets larger, the number of upper zero bits increases.
874 APInt MinDenom
= RHS
.getMinValue();
875 APInt MaxNum
= LHS
.getMaxValue();
876 APInt MaxRes
= MinDenom
.isZero() ? MaxNum
: MaxNum
.udiv(MinDenom
);
878 unsigned LeadZ
= MaxRes
.countLeadingZeros();
880 Known
.Zero
.setHighBits(LeadZ
);
881 Known
= divComputeLowBit(Known
, LHS
, RHS
, Exact
);
883 assert(!Known
.hasConflict() && "Bad Output");
887 KnownBits
KnownBits::remGetLowBits(const KnownBits
&LHS
, const KnownBits
&RHS
) {
888 unsigned BitWidth
= LHS
.getBitWidth();
889 if (!RHS
.isZero() && RHS
.Zero
[0]) {
890 // rem X, Y where Y[0:N] is zero will preserve X[0:N] in the result.
891 unsigned RHSZeros
= RHS
.countMinTrailingZeros();
892 APInt Mask
= APInt::getLowBitsSet(BitWidth
, RHSZeros
);
893 APInt OnesMask
= LHS
.One
& Mask
;
894 APInt ZerosMask
= LHS
.Zero
& Mask
;
895 return KnownBits(ZerosMask
, OnesMask
);
897 return KnownBits(BitWidth
);
900 KnownBits
KnownBits::urem(const KnownBits
&LHS
, const KnownBits
&RHS
) {
901 assert(!LHS
.hasConflict() && !RHS
.hasConflict());
903 KnownBits Known
= remGetLowBits(LHS
, RHS
);
904 if (RHS
.isConstant() && RHS
.getConstant().isPowerOf2()) {
905 // NB: Low bits set in `remGetLowBits`.
906 APInt HighBits
= ~(RHS
.getConstant() - 1);
907 Known
.Zero
|= HighBits
;
911 // Since the result is less than or equal to either operand, any leading
912 // zero bits in either operand must also exist in the result.
914 std::max(LHS
.countMinLeadingZeros(), RHS
.countMinLeadingZeros());
915 Known
.Zero
.setHighBits(Leaders
);
919 KnownBits
KnownBits::srem(const KnownBits
&LHS
, const KnownBits
&RHS
) {
920 assert(!LHS
.hasConflict() && !RHS
.hasConflict());
922 KnownBits Known
= remGetLowBits(LHS
, RHS
);
923 if (RHS
.isConstant() && RHS
.getConstant().isPowerOf2()) {
924 // NB: Low bits are set in `remGetLowBits`.
925 APInt LowBits
= RHS
.getConstant() - 1;
926 // If the first operand is non-negative or has all low bits zero, then
927 // the upper bits are all zero.
928 if (LHS
.isNonNegative() || LowBits
.isSubsetOf(LHS
.Zero
))
929 Known
.Zero
|= ~LowBits
;
931 // If the first operand is negative and not all low bits are zero, then
932 // the upper bits are all one.
933 if (LHS
.isNegative() && LowBits
.intersects(LHS
.One
))
934 Known
.One
|= ~LowBits
;
938 // The sign bit is the LHS's sign bit, except when the result of the
939 // remainder is zero. The magnitude of the result should be less than or
940 // equal to the magnitude of the LHS. Therefore any leading zeros that exist
941 // in the left hand side must also exist in the result.
942 Known
.Zero
.setHighBits(LHS
.countMinLeadingZeros());
946 KnownBits
&KnownBits::operator&=(const KnownBits
&RHS
) {
947 // Result bit is 0 if either operand bit is 0.
949 // Result bit is 1 if both operand bits are 1.
954 KnownBits
&KnownBits::operator|=(const KnownBits
&RHS
) {
955 // Result bit is 0 if both operand bits are 0.
957 // Result bit is 1 if either operand bit is 1.
962 KnownBits
&KnownBits::operator^=(const KnownBits
&RHS
) {
963 // Result bit is 0 if both operand bits are 0 or both are 1.
964 APInt Z
= (Zero
& RHS
.Zero
) | (One
& RHS
.One
);
965 // Result bit is 1 if one operand bit is 0 and the other is 1.
966 One
= (Zero
& RHS
.One
) | (One
& RHS
.Zero
);
971 KnownBits
KnownBits::blsi() const {
972 unsigned BitWidth
= getBitWidth();
973 KnownBits
Known(Zero
, APInt(BitWidth
, 0));
974 unsigned Max
= countMaxTrailingZeros();
975 Known
.Zero
.setBitsFrom(std::min(Max
+ 1, BitWidth
));
976 unsigned Min
= countMinTrailingZeros();
977 if (Max
== Min
&& Max
< BitWidth
)
978 Known
.One
.setBit(Max
);
982 KnownBits
KnownBits::blsmsk() const {
983 unsigned BitWidth
= getBitWidth();
984 KnownBits
Known(BitWidth
);
985 unsigned Max
= countMaxTrailingZeros();
986 Known
.Zero
.setBitsFrom(std::min(Max
+ 1, BitWidth
));
987 unsigned Min
= countMinTrailingZeros();
988 Known
.One
.setLowBits(std::min(Min
+ 1, BitWidth
));
992 void KnownBits::print(raw_ostream
&OS
) const {
993 unsigned BitWidth
= getBitWidth();
994 for (unsigned I
= 0; I
< BitWidth
; ++I
) {
995 unsigned N
= BitWidth
- I
- 1;
996 if (Zero
[N
] && One
[N
])
1006 void KnownBits::dump() const {