[RISCV] Fix mgather -> riscv.masked.strided.load combine not extending indices (...
[llvm-project.git] / llvm / lib / Support / KnownBits.cpp
blob770e4051ca3ffacc620f2576542f339530538efd
1 //===-- KnownBits.cpp - Stores known zeros/ones ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains a class for representing known zeros and ones used by
10 // computeKnownBits.
12 //===----------------------------------------------------------------------===//
14 #include "llvm/Support/KnownBits.h"
15 #include "llvm/Support/Debug.h"
16 #include "llvm/Support/raw_ostream.h"
17 #include <cassert>
19 using namespace llvm;
21 static KnownBits computeForAddCarry(
22 const KnownBits &LHS, const KnownBits &RHS,
23 bool CarryZero, bool CarryOne) {
24 assert(!(CarryZero && CarryOne) &&
25 "Carry can't be zero and one at the same time");
27 APInt PossibleSumZero = LHS.getMaxValue() + RHS.getMaxValue() + !CarryZero;
28 APInt PossibleSumOne = LHS.getMinValue() + RHS.getMinValue() + CarryOne;
30 // Compute known bits of the carry.
31 APInt CarryKnownZero = ~(PossibleSumZero ^ LHS.Zero ^ RHS.Zero);
32 APInt CarryKnownOne = PossibleSumOne ^ LHS.One ^ RHS.One;
34 // Compute set of known bits (where all three relevant bits are known).
35 APInt LHSKnownUnion = LHS.Zero | LHS.One;
36 APInt RHSKnownUnion = RHS.Zero | RHS.One;
37 APInt CarryKnownUnion = std::move(CarryKnownZero) | CarryKnownOne;
38 APInt Known = std::move(LHSKnownUnion) & RHSKnownUnion & CarryKnownUnion;
40 assert((PossibleSumZero & Known) == (PossibleSumOne & Known) &&
41 "known bits of sum differ");
43 // Compute known bits of the result.
44 KnownBits KnownOut;
45 KnownOut.Zero = ~std::move(PossibleSumZero) & Known;
46 KnownOut.One = std::move(PossibleSumOne) & Known;
47 return KnownOut;
50 KnownBits KnownBits::computeForAddCarry(
51 const KnownBits &LHS, const KnownBits &RHS, const KnownBits &Carry) {
52 assert(Carry.getBitWidth() == 1 && "Carry must be 1-bit");
53 return ::computeForAddCarry(
54 LHS, RHS, Carry.Zero.getBoolValue(), Carry.One.getBoolValue());
57 KnownBits KnownBits::computeForAddSub(bool Add, bool NSW,
58 const KnownBits &LHS, KnownBits RHS) {
59 KnownBits KnownOut;
60 if (Add) {
61 // Sum = LHS + RHS + 0
62 KnownOut = ::computeForAddCarry(
63 LHS, RHS, /*CarryZero*/true, /*CarryOne*/false);
64 } else {
65 // Sum = LHS + ~RHS + 1
66 std::swap(RHS.Zero, RHS.One);
67 KnownOut = ::computeForAddCarry(
68 LHS, RHS, /*CarryZero*/false, /*CarryOne*/true);
71 // Are we still trying to solve for the sign bit?
72 if (!KnownOut.isNegative() && !KnownOut.isNonNegative()) {
73 if (NSW) {
74 // Adding two non-negative numbers, or subtracting a negative number from
75 // a non-negative one, can't wrap into negative.
76 if (LHS.isNonNegative() && RHS.isNonNegative())
77 KnownOut.makeNonNegative();
78 // Adding two negative numbers, or subtracting a non-negative number from
79 // a negative one, can't wrap into non-negative.
80 else if (LHS.isNegative() && RHS.isNegative())
81 KnownOut.makeNegative();
85 return KnownOut;
88 KnownBits KnownBits::computeForSubBorrow(const KnownBits &LHS, KnownBits RHS,
89 const KnownBits &Borrow) {
90 assert(Borrow.getBitWidth() == 1 && "Borrow must be 1-bit");
92 // LHS - RHS = LHS + ~RHS + 1
93 // Carry 1 - Borrow in ::computeForAddCarry
94 std::swap(RHS.Zero, RHS.One);
95 return ::computeForAddCarry(LHS, RHS,
96 /*CarryZero=*/Borrow.One.getBoolValue(),
97 /*CarryOne=*/Borrow.Zero.getBoolValue());
100 KnownBits KnownBits::sextInReg(unsigned SrcBitWidth) const {
101 unsigned BitWidth = getBitWidth();
102 assert(0 < SrcBitWidth && SrcBitWidth <= BitWidth &&
103 "Illegal sext-in-register");
105 if (SrcBitWidth == BitWidth)
106 return *this;
108 unsigned ExtBits = BitWidth - SrcBitWidth;
109 KnownBits Result;
110 Result.One = One << ExtBits;
111 Result.Zero = Zero << ExtBits;
112 Result.One.ashrInPlace(ExtBits);
113 Result.Zero.ashrInPlace(ExtBits);
114 return Result;
117 KnownBits KnownBits::makeGE(const APInt &Val) const {
118 // Count the number of leading bit positions where our underlying value is
119 // known to be less than or equal to Val.
120 unsigned N = (Zero | Val).countl_one();
122 // For each of those bit positions, if Val has a 1 in that bit then our
123 // underlying value must also have a 1.
124 APInt MaskedVal(Val);
125 MaskedVal.clearLowBits(getBitWidth() - N);
126 return KnownBits(Zero, One | MaskedVal);
129 KnownBits KnownBits::umax(const KnownBits &LHS, const KnownBits &RHS) {
130 // If we can prove that LHS >= RHS then use LHS as the result. Likewise for
131 // RHS. Ideally our caller would already have spotted these cases and
132 // optimized away the umax operation, but we handle them here for
133 // completeness.
134 if (LHS.getMinValue().uge(RHS.getMaxValue()))
135 return LHS;
136 if (RHS.getMinValue().uge(LHS.getMaxValue()))
137 return RHS;
139 // If the result of the umax is LHS then it must be greater than or equal to
140 // the minimum possible value of RHS. Likewise for RHS. Any known bits that
141 // are common to these two values are also known in the result.
142 KnownBits L = LHS.makeGE(RHS.getMinValue());
143 KnownBits R = RHS.makeGE(LHS.getMinValue());
144 return L.intersectWith(R);
147 KnownBits KnownBits::umin(const KnownBits &LHS, const KnownBits &RHS) {
148 // Flip the range of values: [0, 0xFFFFFFFF] <-> [0xFFFFFFFF, 0]
149 auto Flip = [](const KnownBits &Val) { return KnownBits(Val.One, Val.Zero); };
150 return Flip(umax(Flip(LHS), Flip(RHS)));
153 KnownBits KnownBits::smax(const KnownBits &LHS, const KnownBits &RHS) {
154 // Flip the range of values: [-0x80000000, 0x7FFFFFFF] <-> [0, 0xFFFFFFFF]
155 auto Flip = [](const KnownBits &Val) {
156 unsigned SignBitPosition = Val.getBitWidth() - 1;
157 APInt Zero = Val.Zero;
158 APInt One = Val.One;
159 Zero.setBitVal(SignBitPosition, Val.One[SignBitPosition]);
160 One.setBitVal(SignBitPosition, Val.Zero[SignBitPosition]);
161 return KnownBits(Zero, One);
163 return Flip(umax(Flip(LHS), Flip(RHS)));
166 KnownBits KnownBits::smin(const KnownBits &LHS, const KnownBits &RHS) {
167 // Flip the range of values: [-0x80000000, 0x7FFFFFFF] <-> [0xFFFFFFFF, 0]
168 auto Flip = [](const KnownBits &Val) {
169 unsigned SignBitPosition = Val.getBitWidth() - 1;
170 APInt Zero = Val.One;
171 APInt One = Val.Zero;
172 Zero.setBitVal(SignBitPosition, Val.Zero[SignBitPosition]);
173 One.setBitVal(SignBitPosition, Val.One[SignBitPosition]);
174 return KnownBits(Zero, One);
176 return Flip(umax(Flip(LHS), Flip(RHS)));
179 static unsigned getMaxShiftAmount(const APInt &MaxValue, unsigned BitWidth) {
180 if (isPowerOf2_32(BitWidth))
181 return MaxValue.extractBitsAsZExtValue(Log2_32(BitWidth), 0);
182 // This is only an approximate upper bound.
183 return MaxValue.getLimitedValue(BitWidth - 1);
186 KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW,
187 bool NSW, bool ShAmtNonZero) {
188 unsigned BitWidth = LHS.getBitWidth();
189 auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) {
190 KnownBits Known;
191 bool ShiftedOutZero, ShiftedOutOne;
192 Known.Zero = LHS.Zero.ushl_ov(ShiftAmt, ShiftedOutZero);
193 Known.Zero.setLowBits(ShiftAmt);
194 Known.One = LHS.One.ushl_ov(ShiftAmt, ShiftedOutOne);
196 // All cases returning poison have been handled by MaxShiftAmount already.
197 if (NSW) {
198 if (NUW && ShiftAmt != 0)
199 // NUW means we can assume anything shifted out was a zero.
200 ShiftedOutZero = true;
202 if (ShiftedOutZero)
203 Known.makeNonNegative();
204 else if (ShiftedOutOne)
205 Known.makeNegative();
207 return Known;
210 // Fast path for a common case when LHS is completely unknown.
211 KnownBits Known(BitWidth);
212 unsigned MinShiftAmount = RHS.getMinValue().getLimitedValue(BitWidth);
213 if (MinShiftAmount == 0 && ShAmtNonZero)
214 MinShiftAmount = 1;
215 if (LHS.isUnknown()) {
216 Known.Zero.setLowBits(MinShiftAmount);
217 if (NUW && NSW && MinShiftAmount != 0)
218 Known.makeNonNegative();
219 return Known;
222 // Determine maximum shift amount, taking NUW/NSW flags into account.
223 APInt MaxValue = RHS.getMaxValue();
224 unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth);
225 if (NUW && NSW)
226 MaxShiftAmount = std::min(MaxShiftAmount, LHS.countMaxLeadingZeros() - 1);
227 if (NUW)
228 MaxShiftAmount = std::min(MaxShiftAmount, LHS.countMaxLeadingZeros());
229 if (NSW)
230 MaxShiftAmount = std::min(
231 MaxShiftAmount,
232 std::max(LHS.countMaxLeadingZeros(), LHS.countMaxLeadingOnes()) - 1);
234 // Fast path for common case where the shift amount is unknown.
235 if (MinShiftAmount == 0 && MaxShiftAmount == BitWidth - 1 &&
236 isPowerOf2_32(BitWidth)) {
237 Known.Zero.setLowBits(LHS.countMinTrailingZeros());
238 if (LHS.isAllOnes())
239 Known.One.setSignBit();
240 if (NSW) {
241 if (LHS.isNonNegative())
242 Known.makeNonNegative();
243 if (LHS.isNegative())
244 Known.makeNegative();
246 return Known;
249 // Find the common bits from all possible shifts.
250 unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(32).getZExtValue();
251 unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(32).getZExtValue();
252 Known.Zero.setAllBits();
253 Known.One.setAllBits();
254 for (unsigned ShiftAmt = MinShiftAmount; ShiftAmt <= MaxShiftAmount;
255 ++ShiftAmt) {
256 // Skip if the shift amount is impossible.
257 if ((ShiftAmtZeroMask & ShiftAmt) != 0 ||
258 (ShiftAmtOneMask | ShiftAmt) != ShiftAmt)
259 continue;
260 Known = Known.intersectWith(ShiftByConst(LHS, ShiftAmt));
261 if (Known.isUnknown())
262 break;
265 // All shift amounts may result in poison.
266 if (Known.hasConflict())
267 Known.setAllZero();
268 return Known;
271 KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS,
272 bool ShAmtNonZero) {
273 unsigned BitWidth = LHS.getBitWidth();
274 auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) {
275 KnownBits Known = LHS;
276 Known.Zero.lshrInPlace(ShiftAmt);
277 Known.One.lshrInPlace(ShiftAmt);
278 // High bits are known zero.
279 Known.Zero.setHighBits(ShiftAmt);
280 return Known;
283 // Fast path for a common case when LHS is completely unknown.
284 KnownBits Known(BitWidth);
285 unsigned MinShiftAmount = RHS.getMinValue().getLimitedValue(BitWidth);
286 if (MinShiftAmount == 0 && ShAmtNonZero)
287 MinShiftAmount = 1;
288 if (LHS.isUnknown()) {
289 Known.Zero.setHighBits(MinShiftAmount);
290 return Known;
293 // Find the common bits from all possible shifts.
294 APInt MaxValue = RHS.getMaxValue();
295 unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth);
296 unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(32).getZExtValue();
297 unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(32).getZExtValue();
298 Known.Zero.setAllBits();
299 Known.One.setAllBits();
300 for (unsigned ShiftAmt = MinShiftAmount; ShiftAmt <= MaxShiftAmount;
301 ++ShiftAmt) {
302 // Skip if the shift amount is impossible.
303 if ((ShiftAmtZeroMask & ShiftAmt) != 0 ||
304 (ShiftAmtOneMask | ShiftAmt) != ShiftAmt)
305 continue;
306 Known = Known.intersectWith(ShiftByConst(LHS, ShiftAmt));
307 if (Known.isUnknown())
308 break;
311 // All shift amounts may result in poison.
312 if (Known.hasConflict())
313 Known.setAllZero();
314 return Known;
317 KnownBits KnownBits::ashr(const KnownBits &LHS, const KnownBits &RHS,
318 bool ShAmtNonZero) {
319 unsigned BitWidth = LHS.getBitWidth();
320 auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) {
321 KnownBits Known = LHS;
322 Known.Zero.ashrInPlace(ShiftAmt);
323 Known.One.ashrInPlace(ShiftAmt);
324 return Known;
327 // Fast path for a common case when LHS is completely unknown.
328 KnownBits Known(BitWidth);
329 unsigned MinShiftAmount = RHS.getMinValue().getLimitedValue(BitWidth);
330 if (MinShiftAmount == 0 && ShAmtNonZero)
331 MinShiftAmount = 1;
332 if (LHS.isUnknown()) {
333 if (MinShiftAmount == BitWidth) {
334 // Always poison. Return zero because we don't like returning conflict.
335 Known.setAllZero();
336 return Known;
338 return Known;
341 // Find the common bits from all possible shifts.
342 APInt MaxValue = RHS.getMaxValue();
343 unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth);
344 unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(32).getZExtValue();
345 unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(32).getZExtValue();
346 Known.Zero.setAllBits();
347 Known.One.setAllBits();
348 for (unsigned ShiftAmt = MinShiftAmount; ShiftAmt <= MaxShiftAmount;
349 ++ShiftAmt) {
350 // Skip if the shift amount is impossible.
351 if ((ShiftAmtZeroMask & ShiftAmt) != 0 ||
352 (ShiftAmtOneMask | ShiftAmt) != ShiftAmt)
353 continue;
354 Known = Known.intersectWith(ShiftByConst(LHS, ShiftAmt));
355 if (Known.isUnknown())
356 break;
359 // All shift amounts may result in poison.
360 if (Known.hasConflict())
361 Known.setAllZero();
362 return Known;
365 std::optional<bool> KnownBits::eq(const KnownBits &LHS, const KnownBits &RHS) {
366 if (LHS.isConstant() && RHS.isConstant())
367 return std::optional<bool>(LHS.getConstant() == RHS.getConstant());
368 if (LHS.One.intersects(RHS.Zero) || RHS.One.intersects(LHS.Zero))
369 return std::optional<bool>(false);
370 return std::nullopt;
373 std::optional<bool> KnownBits::ne(const KnownBits &LHS, const KnownBits &RHS) {
374 if (std::optional<bool> KnownEQ = eq(LHS, RHS))
375 return std::optional<bool>(!*KnownEQ);
376 return std::nullopt;
379 std::optional<bool> KnownBits::ugt(const KnownBits &LHS, const KnownBits &RHS) {
380 // LHS >u RHS -> false if umax(LHS) <= umax(RHS)
381 if (LHS.getMaxValue().ule(RHS.getMinValue()))
382 return std::optional<bool>(false);
383 // LHS >u RHS -> true if umin(LHS) > umax(RHS)
384 if (LHS.getMinValue().ugt(RHS.getMaxValue()))
385 return std::optional<bool>(true);
386 return std::nullopt;
389 std::optional<bool> KnownBits::uge(const KnownBits &LHS, const KnownBits &RHS) {
390 if (std::optional<bool> IsUGT = ugt(RHS, LHS))
391 return std::optional<bool>(!*IsUGT);
392 return std::nullopt;
395 std::optional<bool> KnownBits::ult(const KnownBits &LHS, const KnownBits &RHS) {
396 return ugt(RHS, LHS);
399 std::optional<bool> KnownBits::ule(const KnownBits &LHS, const KnownBits &RHS) {
400 return uge(RHS, LHS);
403 std::optional<bool> KnownBits::sgt(const KnownBits &LHS, const KnownBits &RHS) {
404 // LHS >s RHS -> false if smax(LHS) <= smax(RHS)
405 if (LHS.getSignedMaxValue().sle(RHS.getSignedMinValue()))
406 return std::optional<bool>(false);
407 // LHS >s RHS -> true if smin(LHS) > smax(RHS)
408 if (LHS.getSignedMinValue().sgt(RHS.getSignedMaxValue()))
409 return std::optional<bool>(true);
410 return std::nullopt;
413 std::optional<bool> KnownBits::sge(const KnownBits &LHS, const KnownBits &RHS) {
414 if (std::optional<bool> KnownSGT = sgt(RHS, LHS))
415 return std::optional<bool>(!*KnownSGT);
416 return std::nullopt;
419 std::optional<bool> KnownBits::slt(const KnownBits &LHS, const KnownBits &RHS) {
420 return sgt(RHS, LHS);
423 std::optional<bool> KnownBits::sle(const KnownBits &LHS, const KnownBits &RHS) {
424 return sge(RHS, LHS);
427 KnownBits KnownBits::abs(bool IntMinIsPoison) const {
428 // If the source's MSB is zero then we know the rest of the bits already.
429 if (isNonNegative())
430 return *this;
432 // Absolute value preserves trailing zero count.
433 KnownBits KnownAbs(getBitWidth());
435 // If the input is negative, then abs(x) == -x.
436 if (isNegative()) {
437 KnownBits Tmp = *this;
438 // Special case for IntMinIsPoison. We know the sign bit is set and we know
439 // all the rest of the bits except one to be zero. Since we have
440 // IntMinIsPoison, that final bit MUST be a one, as otherwise the input is
441 // INT_MIN.
442 if (IntMinIsPoison && (Zero.popcount() + 2) == getBitWidth())
443 Tmp.One.setBit(countMinTrailingZeros());
445 KnownAbs = computeForAddSub(
446 /*Add*/ false, IntMinIsPoison,
447 KnownBits::makeConstant(APInt(getBitWidth(), 0)), Tmp);
449 // One more special case for IntMinIsPoison. If we don't know any ones other
450 // than the signbit, we know for certain that all the unknowns can't be
451 // zero. So if we know high zero bits, but have unknown low bits, we know
452 // for certain those high-zero bits will end up as one. This is because,
453 // the low bits can't be all zeros, so the +1 in (~x + 1) cannot carry up
454 // to the high bits. If we know a known INT_MIN input skip this. The result
455 // is poison anyways.
456 if (IntMinIsPoison && Tmp.countMinPopulation() == 1 &&
457 Tmp.countMaxPopulation() != 1) {
458 Tmp.One.clearSignBit();
459 Tmp.Zero.setSignBit();
460 KnownAbs.One.setBits(getBitWidth() - Tmp.countMinLeadingZeros(),
461 getBitWidth() - 1);
464 } else {
465 unsigned MaxTZ = countMaxTrailingZeros();
466 unsigned MinTZ = countMinTrailingZeros();
468 KnownAbs.Zero.setLowBits(MinTZ);
469 // If we know the lowest set 1, then preserve it.
470 if (MaxTZ == MinTZ && MaxTZ < getBitWidth())
471 KnownAbs.One.setBit(MaxTZ);
473 // We only know that the absolute values's MSB will be zero if INT_MIN is
474 // poison, or there is a set bit that isn't the sign bit (otherwise it could
475 // be INT_MIN).
476 if (IntMinIsPoison || (!One.isZero() && !One.isMinSignedValue())) {
477 KnownAbs.One.clearSignBit();
478 KnownAbs.Zero.setSignBit();
482 assert(!KnownAbs.hasConflict() && "Bad Output");
483 return KnownAbs;
486 static KnownBits computeForSatAddSub(bool Add, bool Signed,
487 const KnownBits &LHS,
488 const KnownBits &RHS) {
489 assert(!LHS.hasConflict() && !RHS.hasConflict() && "Bad inputs");
490 // We don't see NSW even for sadd/ssub as we want to check if the result has
491 // signed overflow.
492 KnownBits Res = KnownBits::computeForAddSub(Add, /*NSW*/ false, LHS, RHS);
493 unsigned BitWidth = Res.getBitWidth();
494 auto SignBitKnown = [&](const KnownBits &K) {
495 return K.Zero[BitWidth - 1] || K.One[BitWidth - 1];
497 std::optional<bool> Overflow;
499 if (Signed) {
500 // If we can actually detect overflow do so. Otherwise leave Overflow as
501 // nullopt (we assume it may have happened).
502 if (SignBitKnown(LHS) && SignBitKnown(RHS) && SignBitKnown(Res)) {
503 if (Add) {
504 // sadd.sat
505 Overflow = (LHS.isNonNegative() == RHS.isNonNegative() &&
506 Res.isNonNegative() != LHS.isNonNegative());
507 } else {
508 // ssub.sat
509 Overflow = (LHS.isNonNegative() != RHS.isNonNegative() &&
510 Res.isNonNegative() != LHS.isNonNegative());
513 } else if (Add) {
514 // uadd.sat
515 bool Of;
516 (void)LHS.getMaxValue().uadd_ov(RHS.getMaxValue(), Of);
517 if (!Of) {
518 Overflow = false;
519 } else {
520 (void)LHS.getMinValue().uadd_ov(RHS.getMinValue(), Of);
521 if (Of)
522 Overflow = true;
524 } else {
525 // usub.sat
526 bool Of;
527 (void)LHS.getMinValue().usub_ov(RHS.getMaxValue(), Of);
528 if (!Of) {
529 Overflow = false;
530 } else {
531 (void)LHS.getMaxValue().usub_ov(RHS.getMinValue(), Of);
532 if (Of)
533 Overflow = true;
537 if (Signed) {
538 if (Add) {
539 if (LHS.isNonNegative() && RHS.isNonNegative()) {
540 // Pos + Pos -> Pos
541 Res.One.clearSignBit();
542 Res.Zero.setSignBit();
544 if (LHS.isNegative() && RHS.isNegative()) {
545 // Neg + Neg -> Neg
546 Res.One.setSignBit();
547 Res.Zero.clearSignBit();
549 } else {
550 if (LHS.isNegative() && RHS.isNonNegative()) {
551 // Neg - Pos -> Neg
552 Res.One.setSignBit();
553 Res.Zero.clearSignBit();
554 } else if (LHS.isNonNegative() && RHS.isNegative()) {
555 // Pos - Neg -> Pos
556 Res.One.clearSignBit();
557 Res.Zero.setSignBit();
560 } else {
561 // Add: Leading ones of either operand are preserved.
562 // Sub: Leading zeros of LHS and leading ones of RHS are preserved
563 // as leading zeros in the result.
564 unsigned LeadingKnown;
565 if (Add)
566 LeadingKnown =
567 std::max(LHS.countMinLeadingOnes(), RHS.countMinLeadingOnes());
568 else
569 LeadingKnown =
570 std::max(LHS.countMinLeadingZeros(), RHS.countMinLeadingOnes());
572 // We select between the operation result and all-ones/zero
573 // respectively, so we can preserve known ones/zeros.
574 APInt Mask = APInt::getHighBitsSet(BitWidth, LeadingKnown);
575 if (Add) {
576 Res.One |= Mask;
577 Res.Zero &= ~Mask;
578 } else {
579 Res.Zero |= Mask;
580 Res.One &= ~Mask;
584 if (Overflow) {
585 // We know whether or not we overflowed.
586 if (!(*Overflow)) {
587 // No overflow.
588 assert(!Res.hasConflict() && "Bad Output");
589 return Res;
592 // We overflowed
593 APInt C;
594 if (Signed) {
595 // sadd.sat / ssub.sat
596 assert(SignBitKnown(LHS) &&
597 "We somehow know overflow without knowing input sign");
598 C = LHS.isNegative() ? APInt::getSignedMinValue(BitWidth)
599 : APInt::getSignedMaxValue(BitWidth);
600 } else if (Add) {
601 // uadd.sat
602 C = APInt::getMaxValue(BitWidth);
603 } else {
604 // uadd.sat
605 C = APInt::getMinValue(BitWidth);
608 Res.One = C;
609 Res.Zero = ~C;
610 assert(!Res.hasConflict() && "Bad Output");
611 return Res;
614 // We don't know if we overflowed.
615 if (Signed) {
616 // sadd.sat/ssub.sat
617 // We can keep our information about the sign bits.
618 Res.Zero.clearLowBits(BitWidth - 1);
619 Res.One.clearLowBits(BitWidth - 1);
620 } else if (Add) {
621 // uadd.sat
622 // We need to clear all the known zeros as we can only use the leading ones.
623 Res.Zero.clearAllBits();
624 } else {
625 // usub.sat
626 // We need to clear all the known ones as we can only use the leading zero.
627 Res.One.clearAllBits();
630 assert(!Res.hasConflict() && "Bad Output");
631 return Res;
634 KnownBits KnownBits::sadd_sat(const KnownBits &LHS, const KnownBits &RHS) {
635 return computeForSatAddSub(/*Add*/ true, /*Signed*/ true, LHS, RHS);
637 KnownBits KnownBits::ssub_sat(const KnownBits &LHS, const KnownBits &RHS) {
638 return computeForSatAddSub(/*Add*/ false, /*Signed*/ true, LHS, RHS);
640 KnownBits KnownBits::uadd_sat(const KnownBits &LHS, const KnownBits &RHS) {
641 return computeForSatAddSub(/*Add*/ true, /*Signed*/ false, LHS, RHS);
643 KnownBits KnownBits::usub_sat(const KnownBits &LHS, const KnownBits &RHS) {
644 return computeForSatAddSub(/*Add*/ false, /*Signed*/ false, LHS, RHS);
647 KnownBits KnownBits::mul(const KnownBits &LHS, const KnownBits &RHS,
648 bool NoUndefSelfMultiply) {
649 unsigned BitWidth = LHS.getBitWidth();
650 assert(BitWidth == RHS.getBitWidth() && !LHS.hasConflict() &&
651 !RHS.hasConflict() && "Operand mismatch");
652 assert((!NoUndefSelfMultiply || LHS == RHS) &&
653 "Self multiplication knownbits mismatch");
655 // Compute the high known-0 bits by multiplying the unsigned max of each side.
656 // Conservatively, M active bits * N active bits results in M + N bits in the
657 // result. But if we know a value is a power-of-2 for example, then this
658 // computes one more leading zero.
659 // TODO: This could be generalized to number of sign bits (negative numbers).
660 APInt UMaxLHS = LHS.getMaxValue();
661 APInt UMaxRHS = RHS.getMaxValue();
663 // For leading zeros in the result to be valid, the unsigned max product must
664 // fit in the bitwidth (it must not overflow).
665 bool HasOverflow;
666 APInt UMaxResult = UMaxLHS.umul_ov(UMaxRHS, HasOverflow);
667 unsigned LeadZ = HasOverflow ? 0 : UMaxResult.countl_zero();
669 // The result of the bottom bits of an integer multiply can be
670 // inferred by looking at the bottom bits of both operands and
671 // multiplying them together.
672 // We can infer at least the minimum number of known trailing bits
673 // of both operands. Depending on number of trailing zeros, we can
674 // infer more bits, because (a*b) <=> ((a/m) * (b/n)) * (m*n) assuming
675 // a and b are divisible by m and n respectively.
676 // We then calculate how many of those bits are inferrable and set
677 // the output. For example, the i8 mul:
678 // a = XXXX1100 (12)
679 // b = XXXX1110 (14)
680 // We know the bottom 3 bits are zero since the first can be divided by
681 // 4 and the second by 2, thus having ((12/4) * (14/2)) * (2*4).
682 // Applying the multiplication to the trimmed arguments gets:
683 // XX11 (3)
684 // X111 (7)
685 // -------
686 // XX11
687 // XX11
688 // XX11
689 // XX11
690 // -------
691 // XXXXX01
692 // Which allows us to infer the 2 LSBs. Since we're multiplying the result
693 // by 8, the bottom 3 bits will be 0, so we can infer a total of 5 bits.
694 // The proof for this can be described as:
695 // Pre: (C1 >= 0) && (C1 < (1 << C5)) && (C2 >= 0) && (C2 < (1 << C6)) &&
696 // (C7 == (1 << (umin(countTrailingZeros(C1), C5) +
697 // umin(countTrailingZeros(C2), C6) +
698 // umin(C5 - umin(countTrailingZeros(C1), C5),
699 // C6 - umin(countTrailingZeros(C2), C6)))) - 1)
700 // %aa = shl i8 %a, C5
701 // %bb = shl i8 %b, C6
702 // %aaa = or i8 %aa, C1
703 // %bbb = or i8 %bb, C2
704 // %mul = mul i8 %aaa, %bbb
705 // %mask = and i8 %mul, C7
706 // =>
707 // %mask = i8 ((C1*C2)&C7)
708 // Where C5, C6 describe the known bits of %a, %b
709 // C1, C2 describe the known bottom bits of %a, %b.
710 // C7 describes the mask of the known bits of the result.
711 const APInt &Bottom0 = LHS.One;
712 const APInt &Bottom1 = RHS.One;
714 // How many times we'd be able to divide each argument by 2 (shr by 1).
715 // This gives us the number of trailing zeros on the multiplication result.
716 unsigned TrailBitsKnown0 = (LHS.Zero | LHS.One).countr_one();
717 unsigned TrailBitsKnown1 = (RHS.Zero | RHS.One).countr_one();
718 unsigned TrailZero0 = LHS.countMinTrailingZeros();
719 unsigned TrailZero1 = RHS.countMinTrailingZeros();
720 unsigned TrailZ = TrailZero0 + TrailZero1;
722 // Figure out the fewest known-bits operand.
723 unsigned SmallestOperand =
724 std::min(TrailBitsKnown0 - TrailZero0, TrailBitsKnown1 - TrailZero1);
725 unsigned ResultBitsKnown = std::min(SmallestOperand + TrailZ, BitWidth);
727 APInt BottomKnown =
728 Bottom0.getLoBits(TrailBitsKnown0) * Bottom1.getLoBits(TrailBitsKnown1);
730 KnownBits Res(BitWidth);
731 Res.Zero.setHighBits(LeadZ);
732 Res.Zero |= (~BottomKnown).getLoBits(ResultBitsKnown);
733 Res.One = BottomKnown.getLoBits(ResultBitsKnown);
735 // If we're self-multiplying then bit[1] is guaranteed to be zero.
736 if (NoUndefSelfMultiply && BitWidth > 1) {
737 assert(Res.One[1] == 0 &&
738 "Self-multiplication failed Quadratic Reciprocity!");
739 Res.Zero.setBit(1);
742 return Res;
745 KnownBits KnownBits::mulhs(const KnownBits &LHS, const KnownBits &RHS) {
746 unsigned BitWidth = LHS.getBitWidth();
747 assert(BitWidth == RHS.getBitWidth() && !LHS.hasConflict() &&
748 !RHS.hasConflict() && "Operand mismatch");
749 KnownBits WideLHS = LHS.sext(2 * BitWidth);
750 KnownBits WideRHS = RHS.sext(2 * BitWidth);
751 return mul(WideLHS, WideRHS).extractBits(BitWidth, BitWidth);
754 KnownBits KnownBits::mulhu(const KnownBits &LHS, const KnownBits &RHS) {
755 unsigned BitWidth = LHS.getBitWidth();
756 assert(BitWidth == RHS.getBitWidth() && !LHS.hasConflict() &&
757 !RHS.hasConflict() && "Operand mismatch");
758 KnownBits WideLHS = LHS.zext(2 * BitWidth);
759 KnownBits WideRHS = RHS.zext(2 * BitWidth);
760 return mul(WideLHS, WideRHS).extractBits(BitWidth, BitWidth);
763 static KnownBits divComputeLowBit(KnownBits Known, const KnownBits &LHS,
764 const KnownBits &RHS, bool Exact) {
766 if (!Exact)
767 return Known;
769 // If LHS is Odd, the result is Odd no matter what.
770 // Odd / Odd -> Odd
771 // Odd / Even -> Impossible (because its exact division)
772 if (LHS.One[0])
773 Known.One.setBit(0);
775 int MinTZ =
776 (int)LHS.countMinTrailingZeros() - (int)RHS.countMaxTrailingZeros();
777 int MaxTZ =
778 (int)LHS.countMaxTrailingZeros() - (int)RHS.countMinTrailingZeros();
779 if (MinTZ >= 0) {
780 // Result has at least MinTZ trailing zeros.
781 Known.Zero.setLowBits(MinTZ);
782 if (MinTZ == MaxTZ) {
783 // Result has exactly MinTZ trailing zeros.
784 Known.One.setBit(MinTZ);
786 } else if (MaxTZ < 0) {
787 // Poison Result
788 Known.setAllZero();
791 // In the KnownBits exhaustive tests, we have poison inputs for exact values
792 // a LOT. If we have a conflict, just return all zeros.
793 if (Known.hasConflict())
794 Known.setAllZero();
796 return Known;
799 KnownBits KnownBits::sdiv(const KnownBits &LHS, const KnownBits &RHS,
800 bool Exact) {
801 // Equivalent of `udiv`. We must have caught this before it was folded.
802 if (LHS.isNonNegative() && RHS.isNonNegative())
803 return udiv(LHS, RHS, Exact);
805 unsigned BitWidth = LHS.getBitWidth();
806 assert(!LHS.hasConflict() && !RHS.hasConflict() && "Bad inputs");
807 KnownBits Known(BitWidth);
809 if (LHS.isZero() || RHS.isZero()) {
810 // Result is either known Zero or UB. Return Zero either way.
811 // Checking this earlier saves us a lot of special cases later on.
812 Known.setAllZero();
813 return Known;
816 std::optional<APInt> Res;
817 if (LHS.isNegative() && RHS.isNegative()) {
818 // Result non-negative.
819 APInt Denom = RHS.getSignedMaxValue();
820 APInt Num = LHS.getSignedMinValue();
821 // INT_MIN/-1 would be a poison result (impossible). Estimate the division
822 // as signed max (we will only set sign bit in the result).
823 Res = (Num.isMinSignedValue() && Denom.isAllOnes())
824 ? APInt::getSignedMaxValue(BitWidth)
825 : Num.sdiv(Denom);
826 } else if (LHS.isNegative() && RHS.isNonNegative()) {
827 // Result is negative if Exact OR -LHS u>= RHS.
828 if (Exact || (-LHS.getSignedMaxValue()).uge(RHS.getSignedMaxValue())) {
829 APInt Denom = RHS.getSignedMinValue();
830 APInt Num = LHS.getSignedMinValue();
831 Res = Denom.isZero() ? Num : Num.sdiv(Denom);
833 } else if (LHS.isStrictlyPositive() && RHS.isNegative()) {
834 // Result is negative if Exact OR LHS u>= -RHS.
835 if (Exact || LHS.getSignedMinValue().uge(-RHS.getSignedMinValue())) {
836 APInt Denom = RHS.getSignedMaxValue();
837 APInt Num = LHS.getSignedMaxValue();
838 Res = Num.sdiv(Denom);
842 if (Res) {
843 if (Res->isNonNegative()) {
844 unsigned LeadZ = Res->countLeadingZeros();
845 Known.Zero.setHighBits(LeadZ);
846 } else {
847 unsigned LeadO = Res->countLeadingOnes();
848 Known.One.setHighBits(LeadO);
852 Known = divComputeLowBit(Known, LHS, RHS, Exact);
854 assert(!Known.hasConflict() && "Bad Output");
855 return Known;
858 KnownBits KnownBits::udiv(const KnownBits &LHS, const KnownBits &RHS,
859 bool Exact) {
860 unsigned BitWidth = LHS.getBitWidth();
861 assert(!LHS.hasConflict() && !RHS.hasConflict());
862 KnownBits Known(BitWidth);
864 if (LHS.isZero() || RHS.isZero()) {
865 // Result is either known Zero or UB. Return Zero either way.
866 // Checking this earlier saves us a lot of special cases later on.
867 Known.setAllZero();
868 return Known;
871 // We can figure out the minimum number of upper zero bits by doing
872 // MaxNumerator / MinDenominator. If the Numerator gets smaller or Denominator
873 // gets larger, the number of upper zero bits increases.
874 APInt MinDenom = RHS.getMinValue();
875 APInt MaxNum = LHS.getMaxValue();
876 APInt MaxRes = MinDenom.isZero() ? MaxNum : MaxNum.udiv(MinDenom);
878 unsigned LeadZ = MaxRes.countLeadingZeros();
880 Known.Zero.setHighBits(LeadZ);
881 Known = divComputeLowBit(Known, LHS, RHS, Exact);
883 assert(!Known.hasConflict() && "Bad Output");
884 return Known;
887 KnownBits KnownBits::remGetLowBits(const KnownBits &LHS, const KnownBits &RHS) {
888 unsigned BitWidth = LHS.getBitWidth();
889 if (!RHS.isZero() && RHS.Zero[0]) {
890 // rem X, Y where Y[0:N] is zero will preserve X[0:N] in the result.
891 unsigned RHSZeros = RHS.countMinTrailingZeros();
892 APInt Mask = APInt::getLowBitsSet(BitWidth, RHSZeros);
893 APInt OnesMask = LHS.One & Mask;
894 APInt ZerosMask = LHS.Zero & Mask;
895 return KnownBits(ZerosMask, OnesMask);
897 return KnownBits(BitWidth);
900 KnownBits KnownBits::urem(const KnownBits &LHS, const KnownBits &RHS) {
901 assert(!LHS.hasConflict() && !RHS.hasConflict());
903 KnownBits Known = remGetLowBits(LHS, RHS);
904 if (RHS.isConstant() && RHS.getConstant().isPowerOf2()) {
905 // NB: Low bits set in `remGetLowBits`.
906 APInt HighBits = ~(RHS.getConstant() - 1);
907 Known.Zero |= HighBits;
908 return Known;
911 // Since the result is less than or equal to either operand, any leading
912 // zero bits in either operand must also exist in the result.
913 uint32_t Leaders =
914 std::max(LHS.countMinLeadingZeros(), RHS.countMinLeadingZeros());
915 Known.Zero.setHighBits(Leaders);
916 return Known;
919 KnownBits KnownBits::srem(const KnownBits &LHS, const KnownBits &RHS) {
920 assert(!LHS.hasConflict() && !RHS.hasConflict());
922 KnownBits Known = remGetLowBits(LHS, RHS);
923 if (RHS.isConstant() && RHS.getConstant().isPowerOf2()) {
924 // NB: Low bits are set in `remGetLowBits`.
925 APInt LowBits = RHS.getConstant() - 1;
926 // If the first operand is non-negative or has all low bits zero, then
927 // the upper bits are all zero.
928 if (LHS.isNonNegative() || LowBits.isSubsetOf(LHS.Zero))
929 Known.Zero |= ~LowBits;
931 // If the first operand is negative and not all low bits are zero, then
932 // the upper bits are all one.
933 if (LHS.isNegative() && LowBits.intersects(LHS.One))
934 Known.One |= ~LowBits;
935 return Known;
938 // The sign bit is the LHS's sign bit, except when the result of the
939 // remainder is zero. The magnitude of the result should be less than or
940 // equal to the magnitude of the LHS. Therefore any leading zeros that exist
941 // in the left hand side must also exist in the result.
942 Known.Zero.setHighBits(LHS.countMinLeadingZeros());
943 return Known;
946 KnownBits &KnownBits::operator&=(const KnownBits &RHS) {
947 // Result bit is 0 if either operand bit is 0.
948 Zero |= RHS.Zero;
949 // Result bit is 1 if both operand bits are 1.
950 One &= RHS.One;
951 return *this;
954 KnownBits &KnownBits::operator|=(const KnownBits &RHS) {
955 // Result bit is 0 if both operand bits are 0.
956 Zero &= RHS.Zero;
957 // Result bit is 1 if either operand bit is 1.
958 One |= RHS.One;
959 return *this;
962 KnownBits &KnownBits::operator^=(const KnownBits &RHS) {
963 // Result bit is 0 if both operand bits are 0 or both are 1.
964 APInt Z = (Zero & RHS.Zero) | (One & RHS.One);
965 // Result bit is 1 if one operand bit is 0 and the other is 1.
966 One = (Zero & RHS.One) | (One & RHS.Zero);
967 Zero = std::move(Z);
968 return *this;
971 KnownBits KnownBits::blsi() const {
972 unsigned BitWidth = getBitWidth();
973 KnownBits Known(Zero, APInt(BitWidth, 0));
974 unsigned Max = countMaxTrailingZeros();
975 Known.Zero.setBitsFrom(std::min(Max + 1, BitWidth));
976 unsigned Min = countMinTrailingZeros();
977 if (Max == Min && Max < BitWidth)
978 Known.One.setBit(Max);
979 return Known;
982 KnownBits KnownBits::blsmsk() const {
983 unsigned BitWidth = getBitWidth();
984 KnownBits Known(BitWidth);
985 unsigned Max = countMaxTrailingZeros();
986 Known.Zero.setBitsFrom(std::min(Max + 1, BitWidth));
987 unsigned Min = countMinTrailingZeros();
988 Known.One.setLowBits(std::min(Min + 1, BitWidth));
989 return Known;
992 void KnownBits::print(raw_ostream &OS) const {
993 unsigned BitWidth = getBitWidth();
994 for (unsigned I = 0; I < BitWidth; ++I) {
995 unsigned N = BitWidth - I - 1;
996 if (Zero[N] && One[N])
997 OS << "!";
998 else if (Zero[N])
999 OS << "0";
1000 else if (One[N])
1001 OS << "1";
1002 else
1003 OS << "?";
1006 void KnownBits::dump() const {
1007 print(dbgs());
1008 dbgs() << "\n";