[mlir][linalg] Add support for masked vectorization of `tensor.insert_slice` (1/N...
[llvm-project.git] / llvm / lib / Analysis / DemandedBits.cpp
blobb538e16f258595e13af72251578e74dde2fd1709
1 //===- DemandedBits.cpp - Determine demanded bits -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass implements a demanded bits analysis. A demanded bit is one that
10 // contributes to a result; bits that are not demanded can be either zero or
11 // one without affecting control or data flow. For example in this sequence:
13 // %1 = add i32 %x, %y
14 // %2 = trunc i32 %1 to i16
16 // Only the lowest 16 bits of %1 are demanded; the rest are removed by the
17 // trunc.
19 //===----------------------------------------------------------------------===//
21 #include "llvm/Analysis/DemandedBits.h"
22 #include "llvm/ADT/APInt.h"
23 #include "llvm/ADT/SetVector.h"
24 #include "llvm/Analysis/AssumptionCache.h"
25 #include "llvm/Analysis/ValueTracking.h"
26 #include "llvm/IR/DataLayout.h"
27 #include "llvm/IR/Dominators.h"
28 #include "llvm/IR/InstIterator.h"
29 #include "llvm/IR/Instruction.h"
30 #include "llvm/IR/IntrinsicInst.h"
31 #include "llvm/IR/Operator.h"
32 #include "llvm/IR/PassManager.h"
33 #include "llvm/IR/PatternMatch.h"
34 #include "llvm/IR/Type.h"
35 #include "llvm/IR/Use.h"
36 #include "llvm/Support/Casting.h"
37 #include "llvm/Support/Debug.h"
38 #include "llvm/Support/KnownBits.h"
39 #include "llvm/Support/raw_ostream.h"
40 #include <algorithm>
41 #include <cstdint>
43 using namespace llvm;
44 using namespace llvm::PatternMatch;
46 #define DEBUG_TYPE "demanded-bits"
48 static bool isAlwaysLive(Instruction *I) {
49 return I->isTerminator() || isa<DbgInfoIntrinsic>(I) || I->isEHPad() ||
50 I->mayHaveSideEffects();
53 void DemandedBits::determineLiveOperandBits(
54 const Instruction *UserI, const Value *Val, unsigned OperandNo,
55 const APInt &AOut, APInt &AB, KnownBits &Known, KnownBits &Known2,
56 bool &KnownBitsComputed) {
57 unsigned BitWidth = AB.getBitWidth();
59 // We're called once per operand, but for some instructions, we need to
60 // compute known bits of both operands in order to determine the live bits of
61 // either (when both operands are instructions themselves). We don't,
62 // however, want to do this twice, so we cache the result in APInts that live
63 // in the caller. For the two-relevant-operands case, both operand values are
64 // provided here.
65 auto ComputeKnownBits =
66 [&](unsigned BitWidth, const Value *V1, const Value *V2) {
67 if (KnownBitsComputed)
68 return;
69 KnownBitsComputed = true;
71 const DataLayout &DL = UserI->getDataLayout();
72 Known = KnownBits(BitWidth);
73 computeKnownBits(V1, Known, DL, 0, &AC, UserI, &DT);
75 if (V2) {
76 Known2 = KnownBits(BitWidth);
77 computeKnownBits(V2, Known2, DL, 0, &AC, UserI, &DT);
81 switch (UserI->getOpcode()) {
82 default: break;
83 case Instruction::Call:
84 case Instruction::Invoke:
85 if (const auto *II = dyn_cast<IntrinsicInst>(UserI)) {
86 switch (II->getIntrinsicID()) {
87 default: break;
88 case Intrinsic::bswap:
89 // The alive bits of the input are the swapped alive bits of
90 // the output.
91 AB = AOut.byteSwap();
92 break;
93 case Intrinsic::bitreverse:
94 // The alive bits of the input are the reversed alive bits of
95 // the output.
96 AB = AOut.reverseBits();
97 break;
98 case Intrinsic::ctlz:
99 if (OperandNo == 0) {
100 // We need some output bits, so we need all bits of the
101 // input to the left of, and including, the leftmost bit
102 // known to be one.
103 ComputeKnownBits(BitWidth, Val, nullptr);
104 AB = APInt::getHighBitsSet(BitWidth,
105 std::min(BitWidth, Known.countMaxLeadingZeros()+1));
107 break;
108 case Intrinsic::cttz:
109 if (OperandNo == 0) {
110 // We need some output bits, so we need all bits of the
111 // input to the right of, and including, the rightmost bit
112 // known to be one.
113 ComputeKnownBits(BitWidth, Val, nullptr);
114 AB = APInt::getLowBitsSet(BitWidth,
115 std::min(BitWidth, Known.countMaxTrailingZeros()+1));
117 break;
118 case Intrinsic::fshl:
119 case Intrinsic::fshr: {
120 const APInt *SA;
121 if (OperandNo == 2) {
122 // Shift amount is modulo the bitwidth. For powers of two we have
123 // SA % BW == SA & (BW - 1).
124 if (isPowerOf2_32(BitWidth))
125 AB = BitWidth - 1;
126 } else if (match(II->getOperand(2), m_APInt(SA))) {
127 // Normalize to funnel shift left. APInt shifts of BitWidth are well-
128 // defined, so no need to special-case zero shifts here.
129 uint64_t ShiftAmt = SA->urem(BitWidth);
130 if (II->getIntrinsicID() == Intrinsic::fshr)
131 ShiftAmt = BitWidth - ShiftAmt;
133 if (OperandNo == 0)
134 AB = AOut.lshr(ShiftAmt);
135 else if (OperandNo == 1)
136 AB = AOut.shl(BitWidth - ShiftAmt);
138 break;
140 case Intrinsic::umax:
141 case Intrinsic::umin:
142 case Intrinsic::smax:
143 case Intrinsic::smin:
144 // If low bits of result are not demanded, they are also not demanded
145 // for the min/max operands.
146 AB = APInt::getBitsSetFrom(BitWidth, AOut.countr_zero());
147 break;
150 break;
151 case Instruction::Add:
152 if (AOut.isMask()) {
153 AB = AOut;
154 } else {
155 ComputeKnownBits(BitWidth, UserI->getOperand(0), UserI->getOperand(1));
156 AB = determineLiveOperandBitsAdd(OperandNo, AOut, Known, Known2);
158 break;
159 case Instruction::Sub:
160 if (AOut.isMask()) {
161 AB = AOut;
162 } else {
163 ComputeKnownBits(BitWidth, UserI->getOperand(0), UserI->getOperand(1));
164 AB = determineLiveOperandBitsSub(OperandNo, AOut, Known, Known2);
166 break;
167 case Instruction::Mul:
168 // Find the highest live output bit. We don't need any more input
169 // bits than that (adds, and thus subtracts, ripple only to the
170 // left).
171 AB = APInt::getLowBitsSet(BitWidth, AOut.getActiveBits());
172 break;
173 case Instruction::Shl:
174 if (OperandNo == 0) {
175 const APInt *ShiftAmtC;
176 if (match(UserI->getOperand(1), m_APInt(ShiftAmtC))) {
177 uint64_t ShiftAmt = ShiftAmtC->getLimitedValue(BitWidth - 1);
178 AB = AOut.lshr(ShiftAmt);
180 // If the shift is nuw/nsw, then the high bits are not dead
181 // (because we've promised that they *must* be zero).
182 const auto *S = cast<ShlOperator>(UserI);
183 if (S->hasNoSignedWrap())
184 AB |= APInt::getHighBitsSet(BitWidth, ShiftAmt+1);
185 else if (S->hasNoUnsignedWrap())
186 AB |= APInt::getHighBitsSet(BitWidth, ShiftAmt);
189 break;
190 case Instruction::LShr:
191 if (OperandNo == 0) {
192 const APInt *ShiftAmtC;
193 if (match(UserI->getOperand(1), m_APInt(ShiftAmtC))) {
194 uint64_t ShiftAmt = ShiftAmtC->getLimitedValue(BitWidth - 1);
195 AB = AOut.shl(ShiftAmt);
197 // If the shift is exact, then the low bits are not dead
198 // (they must be zero).
199 if (cast<LShrOperator>(UserI)->isExact())
200 AB |= APInt::getLowBitsSet(BitWidth, ShiftAmt);
203 break;
204 case Instruction::AShr:
205 if (OperandNo == 0) {
206 const APInt *ShiftAmtC;
207 if (match(UserI->getOperand(1), m_APInt(ShiftAmtC))) {
208 uint64_t ShiftAmt = ShiftAmtC->getLimitedValue(BitWidth - 1);
209 AB = AOut.shl(ShiftAmt);
210 // Because the high input bit is replicated into the
211 // high-order bits of the result, if we need any of those
212 // bits, then we must keep the highest input bit.
213 if ((AOut & APInt::getHighBitsSet(BitWidth, ShiftAmt))
214 .getBoolValue())
215 AB.setSignBit();
217 // If the shift is exact, then the low bits are not dead
218 // (they must be zero).
219 if (cast<AShrOperator>(UserI)->isExact())
220 AB |= APInt::getLowBitsSet(BitWidth, ShiftAmt);
223 break;
224 case Instruction::And:
225 AB = AOut;
227 // For bits that are known zero, the corresponding bits in the
228 // other operand are dead (unless they're both zero, in which
229 // case they can't both be dead, so just mark the LHS bits as
230 // dead).
231 ComputeKnownBits(BitWidth, UserI->getOperand(0), UserI->getOperand(1));
232 if (OperandNo == 0)
233 AB &= ~Known2.Zero;
234 else
235 AB &= ~(Known.Zero & ~Known2.Zero);
236 break;
237 case Instruction::Or:
238 AB = AOut;
240 // For bits that are known one, the corresponding bits in the
241 // other operand are dead (unless they're both one, in which
242 // case they can't both be dead, so just mark the LHS bits as
243 // dead).
244 ComputeKnownBits(BitWidth, UserI->getOperand(0), UserI->getOperand(1));
245 if (OperandNo == 0)
246 AB &= ~Known2.One;
247 else
248 AB &= ~(Known.One & ~Known2.One);
249 break;
250 case Instruction::Xor:
251 case Instruction::PHI:
252 AB = AOut;
253 break;
254 case Instruction::Trunc:
255 AB = AOut.zext(BitWidth);
256 break;
257 case Instruction::ZExt:
258 AB = AOut.trunc(BitWidth);
259 break;
260 case Instruction::SExt:
261 AB = AOut.trunc(BitWidth);
262 // Because the high input bit is replicated into the
263 // high-order bits of the result, if we need any of those
264 // bits, then we must keep the highest input bit.
265 if ((AOut & APInt::getHighBitsSet(AOut.getBitWidth(),
266 AOut.getBitWidth() - BitWidth))
267 .getBoolValue())
268 AB.setSignBit();
269 break;
270 case Instruction::Select:
271 if (OperandNo != 0)
272 AB = AOut;
273 break;
274 case Instruction::ExtractElement:
275 if (OperandNo == 0)
276 AB = AOut;
277 break;
278 case Instruction::InsertElement:
279 case Instruction::ShuffleVector:
280 if (OperandNo == 0 || OperandNo == 1)
281 AB = AOut;
282 break;
286 void DemandedBits::performAnalysis() {
287 if (Analyzed)
288 // Analysis already completed for this function.
289 return;
290 Analyzed = true;
292 Visited.clear();
293 AliveBits.clear();
294 DeadUses.clear();
296 SmallSetVector<Instruction*, 16> Worklist;
298 // Collect the set of "root" instructions that are known live.
299 for (Instruction &I : instructions(F)) {
300 if (!isAlwaysLive(&I))
301 continue;
303 LLVM_DEBUG(dbgs() << "DemandedBits: Root: " << I << "\n");
304 // For integer-valued instructions, set up an initial empty set of alive
305 // bits and add the instruction to the work list. For other instructions
306 // add their operands to the work list (for integer values operands, mark
307 // all bits as live).
308 Type *T = I.getType();
309 if (T->isIntOrIntVectorTy()) {
310 if (AliveBits.try_emplace(&I, T->getScalarSizeInBits(), 0).second)
311 Worklist.insert(&I);
313 continue;
316 // Non-integer-typed instructions...
317 for (Use &OI : I.operands()) {
318 if (auto *J = dyn_cast<Instruction>(OI)) {
319 Type *T = J->getType();
320 if (T->isIntOrIntVectorTy())
321 AliveBits[J] = APInt::getAllOnes(T->getScalarSizeInBits());
322 else
323 Visited.insert(J);
324 Worklist.insert(J);
327 // To save memory, we don't add I to the Visited set here. Instead, we
328 // check isAlwaysLive on every instruction when searching for dead
329 // instructions later (we need to check isAlwaysLive for the
330 // integer-typed instructions anyway).
333 // Propagate liveness backwards to operands.
334 while (!Worklist.empty()) {
335 Instruction *UserI = Worklist.pop_back_val();
337 LLVM_DEBUG(dbgs() << "DemandedBits: Visiting: " << *UserI);
338 APInt AOut;
339 bool InputIsKnownDead = false;
340 if (UserI->getType()->isIntOrIntVectorTy()) {
341 AOut = AliveBits[UserI];
342 LLVM_DEBUG(dbgs() << " Alive Out: 0x"
343 << Twine::utohexstr(AOut.getLimitedValue()));
345 // If all bits of the output are dead, then all bits of the input
346 // are also dead.
347 InputIsKnownDead = !AOut && !isAlwaysLive(UserI);
349 LLVM_DEBUG(dbgs() << "\n");
351 KnownBits Known, Known2;
352 bool KnownBitsComputed = false;
353 // Compute the set of alive bits for each operand. These are anded into the
354 // existing set, if any, and if that changes the set of alive bits, the
355 // operand is added to the work-list.
356 for (Use &OI : UserI->operands()) {
357 // We also want to detect dead uses of arguments, but will only store
358 // demanded bits for instructions.
359 auto *I = dyn_cast<Instruction>(OI);
360 if (!I && !isa<Argument>(OI))
361 continue;
363 Type *T = OI->getType();
364 if (T->isIntOrIntVectorTy()) {
365 unsigned BitWidth = T->getScalarSizeInBits();
366 APInt AB = APInt::getAllOnes(BitWidth);
367 if (InputIsKnownDead) {
368 AB = APInt(BitWidth, 0);
369 } else {
370 // Bits of each operand that are used to compute alive bits of the
371 // output are alive, all others are dead.
372 determineLiveOperandBits(UserI, OI, OI.getOperandNo(), AOut, AB,
373 Known, Known2, KnownBitsComputed);
375 // Keep track of uses which have no demanded bits.
376 if (AB.isZero())
377 DeadUses.insert(&OI);
378 else
379 DeadUses.erase(&OI);
382 if (I) {
383 // If we've added to the set of alive bits (or the operand has not
384 // been previously visited), then re-queue the operand to be visited
385 // again.
386 auto Res = AliveBits.try_emplace(I);
387 if (Res.second || (AB |= Res.first->second) != Res.first->second) {
388 Res.first->second = std::move(AB);
389 Worklist.insert(I);
392 } else if (I && Visited.insert(I).second) {
393 Worklist.insert(I);
399 APInt DemandedBits::getDemandedBits(Instruction *I) {
400 performAnalysis();
402 auto Found = AliveBits.find(I);
403 if (Found != AliveBits.end())
404 return Found->second;
406 const DataLayout &DL = I->getDataLayout();
407 return APInt::getAllOnes(DL.getTypeSizeInBits(I->getType()->getScalarType()));
410 APInt DemandedBits::getDemandedBits(Use *U) {
411 Type *T = (*U)->getType();
412 auto *UserI = cast<Instruction>(U->getUser());
413 const DataLayout &DL = UserI->getDataLayout();
414 unsigned BitWidth = DL.getTypeSizeInBits(T->getScalarType());
416 // We only track integer uses, everything else produces a mask with all bits
417 // set
418 if (!T->isIntOrIntVectorTy())
419 return APInt::getAllOnes(BitWidth);
421 if (isUseDead(U))
422 return APInt(BitWidth, 0);
424 performAnalysis();
426 APInt AOut = getDemandedBits(UserI);
427 APInt AB = APInt::getAllOnes(BitWidth);
428 KnownBits Known, Known2;
429 bool KnownBitsComputed = false;
431 determineLiveOperandBits(UserI, *U, U->getOperandNo(), AOut, AB, Known,
432 Known2, KnownBitsComputed);
434 return AB;
437 bool DemandedBits::isInstructionDead(Instruction *I) {
438 performAnalysis();
440 return !Visited.count(I) && !AliveBits.contains(I) && !isAlwaysLive(I);
443 bool DemandedBits::isUseDead(Use *U) {
444 // We only track integer uses, everything else is assumed live.
445 if (!(*U)->getType()->isIntOrIntVectorTy())
446 return false;
448 // Uses by always-live instructions are never dead.
449 auto *UserI = cast<Instruction>(U->getUser());
450 if (isAlwaysLive(UserI))
451 return false;
453 performAnalysis();
454 if (DeadUses.count(U))
455 return true;
457 // If no output bits are demanded, no input bits are demanded and the use
458 // is dead. These uses might not be explicitly present in the DeadUses map.
459 if (UserI->getType()->isIntOrIntVectorTy()) {
460 auto Found = AliveBits.find(UserI);
461 if (Found != AliveBits.end() && Found->second.isZero())
462 return true;
465 return false;
468 void DemandedBits::print(raw_ostream &OS) {
469 auto PrintDB = [&](const Instruction *I, const APInt &A, Value *V = nullptr) {
470 OS << "DemandedBits: 0x" << Twine::utohexstr(A.getLimitedValue())
471 << " for ";
472 if (V) {
473 V->printAsOperand(OS, false);
474 OS << " in ";
476 OS << *I << '\n';
479 OS << "Printing analysis 'Demanded Bits Analysis' for function '" << F.getName() << "':\n";
480 performAnalysis();
481 for (auto &KV : AliveBits) {
482 Instruction *I = KV.first;
483 PrintDB(I, KV.second);
485 for (Use &OI : I->operands()) {
486 PrintDB(I, getDemandedBits(&OI), OI);
491 static APInt determineLiveOperandBitsAddCarry(unsigned OperandNo,
492 const APInt &AOut,
493 const KnownBits &LHS,
494 const KnownBits &RHS,
495 bool CarryZero, bool CarryOne) {
496 assert(!(CarryZero && CarryOne) &&
497 "Carry can't be zero and one at the same time");
499 // The following check should be done by the caller, as it also indicates
500 // that LHS and RHS don't need to be computed.
502 // if (AOut.isMask())
503 // return AOut;
505 // Boundary bits' carry out is unaffected by their carry in.
506 APInt Bound = (LHS.Zero & RHS.Zero) | (LHS.One & RHS.One);
508 // First, the alive carry bits are determined from the alive output bits:
509 // Let demand ripple to the right but only up to any set bit in Bound.
510 // AOut = -1----
511 // Bound = ----1-
512 // ACarry&~AOut = --111-
513 APInt RBound = Bound.reverseBits();
514 APInt RAOut = AOut.reverseBits();
515 APInt RProp = RAOut + (RAOut | ~RBound);
516 APInt RACarry = RProp ^ ~RBound;
517 APInt ACarry = RACarry.reverseBits();
519 // Then, the alive input bits are determined from the alive carry bits:
520 APInt NeededToMaintainCarryZero;
521 APInt NeededToMaintainCarryOne;
522 if (OperandNo == 0) {
523 NeededToMaintainCarryZero = LHS.Zero | ~RHS.Zero;
524 NeededToMaintainCarryOne = LHS.One | ~RHS.One;
525 } else {
526 NeededToMaintainCarryZero = RHS.Zero | ~LHS.Zero;
527 NeededToMaintainCarryOne = RHS.One | ~LHS.One;
530 // As in computeForAddCarry
531 APInt PossibleSumZero = ~LHS.Zero + ~RHS.Zero + !CarryZero;
532 APInt PossibleSumOne = LHS.One + RHS.One + CarryOne;
534 // The below is simplified from
536 // APInt CarryKnownZero = ~(PossibleSumZero ^ LHS.Zero ^ RHS.Zero);
537 // APInt CarryKnownOne = PossibleSumOne ^ LHS.One ^ RHS.One;
538 // APInt CarryUnknown = ~(CarryKnownZero | CarryKnownOne);
540 // APInt NeededToMaintainCarry =
541 // (CarryKnownZero & NeededToMaintainCarryZero) |
542 // (CarryKnownOne & NeededToMaintainCarryOne) |
543 // CarryUnknown;
545 APInt NeededToMaintainCarry = (~PossibleSumZero | NeededToMaintainCarryZero) &
546 (PossibleSumOne | NeededToMaintainCarryOne);
548 APInt AB = AOut | (ACarry & NeededToMaintainCarry);
549 return AB;
552 APInt DemandedBits::determineLiveOperandBitsAdd(unsigned OperandNo,
553 const APInt &AOut,
554 const KnownBits &LHS,
555 const KnownBits &RHS) {
556 return determineLiveOperandBitsAddCarry(OperandNo, AOut, LHS, RHS, true,
557 false);
560 APInt DemandedBits::determineLiveOperandBitsSub(unsigned OperandNo,
561 const APInt &AOut,
562 const KnownBits &LHS,
563 const KnownBits &RHS) {
564 KnownBits NRHS;
565 NRHS.Zero = RHS.One;
566 NRHS.One = RHS.Zero;
567 return determineLiveOperandBitsAddCarry(OperandNo, AOut, LHS, NRHS, false,
568 true);
571 AnalysisKey DemandedBitsAnalysis::Key;
573 DemandedBits DemandedBitsAnalysis::run(Function &F,
574 FunctionAnalysisManager &AM) {
575 auto &AC = AM.getResult<AssumptionAnalysis>(F);
576 auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
577 return DemandedBits(F, AC, DT);
580 PreservedAnalyses DemandedBitsPrinterPass::run(Function &F,
581 FunctionAnalysisManager &AM) {
582 AM.getResult<DemandedBitsAnalysis>(F).print(OS);
583 return PreservedAnalyses::all();