[NFC][analyzer][docs] Crosslink MallocChecker's ownership attributes (#121939)
[llvm-project.git] / llvm / lib / Transforms / AggressiveInstCombine / AggressiveInstCombine.cpp
blobfe7b3b1676e08435a76aac9759094e90fe3f347a
1 //===- AggressiveInstCombine.cpp ------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the aggressive expression pattern combiner classes.
10 // Currently, it handles expression patterns for:
11 // * Truncate instruction
13 //===----------------------------------------------------------------------===//
15 #include "llvm/Transforms/AggressiveInstCombine/AggressiveInstCombine.h"
16 #include "AggressiveInstCombineInternal.h"
17 #include "llvm/ADT/Statistic.h"
18 #include "llvm/Analysis/AliasAnalysis.h"
19 #include "llvm/Analysis/AssumptionCache.h"
20 #include "llvm/Analysis/BasicAliasAnalysis.h"
21 #include "llvm/Analysis/ConstantFolding.h"
22 #include "llvm/Analysis/DomTreeUpdater.h"
23 #include "llvm/Analysis/GlobalsModRef.h"
24 #include "llvm/Analysis/TargetLibraryInfo.h"
25 #include "llvm/Analysis/TargetTransformInfo.h"
26 #include "llvm/Analysis/ValueTracking.h"
27 #include "llvm/IR/DataLayout.h"
28 #include "llvm/IR/Dominators.h"
29 #include "llvm/IR/Function.h"
30 #include "llvm/IR/IRBuilder.h"
31 #include "llvm/IR/PatternMatch.h"
32 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
33 #include "llvm/Transforms/Utils/BuildLibCalls.h"
34 #include "llvm/Transforms/Utils/Local.h"
36 using namespace llvm;
37 using namespace PatternMatch;
39 #define DEBUG_TYPE "aggressive-instcombine"
41 STATISTIC(NumAnyOrAllBitsSet, "Number of any/all-bits-set patterns folded");
42 STATISTIC(NumGuardedRotates,
43 "Number of guarded rotates transformed into funnel shifts");
44 STATISTIC(NumGuardedFunnelShifts,
45 "Number of guarded funnel shifts transformed into funnel shifts");
46 STATISTIC(NumPopCountRecognized, "Number of popcount idioms recognized");
48 static cl::opt<unsigned> MaxInstrsToScan(
49 "aggressive-instcombine-max-scan-instrs", cl::init(64), cl::Hidden,
50 cl::desc("Max number of instructions to scan for aggressive instcombine."));
52 static cl::opt<unsigned> StrNCmpInlineThreshold(
53 "strncmp-inline-threshold", cl::init(3), cl::Hidden,
54 cl::desc("The maximum length of a constant string for a builtin string cmp "
55 "call eligible for inlining. The default value is 3."));
57 static cl::opt<unsigned>
58 MemChrInlineThreshold("memchr-inline-threshold", cl::init(3), cl::Hidden,
59 cl::desc("The maximum length of a constant string to "
60 "inline a memchr call."));
62 /// Match a pattern for a bitwise funnel/rotate operation that partially guards
63 /// against undefined behavior by branching around the funnel-shift/rotation
64 /// when the shift amount is 0.
65 static bool foldGuardedFunnelShift(Instruction &I, const DominatorTree &DT) {
66 if (I.getOpcode() != Instruction::PHI || I.getNumOperands() != 2)
67 return false;
69 // As with the one-use checks below, this is not strictly necessary, but we
70 // are being cautious to avoid potential perf regressions on targets that
71 // do not actually have a funnel/rotate instruction (where the funnel shift
72 // would be expanded back into math/shift/logic ops).
73 if (!isPowerOf2_32(I.getType()->getScalarSizeInBits()))
74 return false;
76 // Match V to funnel shift left/right and capture the source operands and
77 // shift amount.
78 auto matchFunnelShift = [](Value *V, Value *&ShVal0, Value *&ShVal1,
79 Value *&ShAmt) {
80 unsigned Width = V->getType()->getScalarSizeInBits();
82 // fshl(ShVal0, ShVal1, ShAmt)
83 // == (ShVal0 << ShAmt) | (ShVal1 >> (Width -ShAmt))
84 if (match(V, m_OneUse(m_c_Or(
85 m_Shl(m_Value(ShVal0), m_Value(ShAmt)),
86 m_LShr(m_Value(ShVal1),
87 m_Sub(m_SpecificInt(Width), m_Deferred(ShAmt))))))) {
88 return Intrinsic::fshl;
91 // fshr(ShVal0, ShVal1, ShAmt)
92 // == (ShVal0 >> ShAmt) | (ShVal1 << (Width - ShAmt))
93 if (match(V,
94 m_OneUse(m_c_Or(m_Shl(m_Value(ShVal0), m_Sub(m_SpecificInt(Width),
95 m_Value(ShAmt))),
96 m_LShr(m_Value(ShVal1), m_Deferred(ShAmt)))))) {
97 return Intrinsic::fshr;
100 return Intrinsic::not_intrinsic;
103 // One phi operand must be a funnel/rotate operation, and the other phi
104 // operand must be the source value of that funnel/rotate operation:
105 // phi [ rotate(RotSrc, ShAmt), FunnelBB ], [ RotSrc, GuardBB ]
106 // phi [ fshl(ShVal0, ShVal1, ShAmt), FunnelBB ], [ ShVal0, GuardBB ]
107 // phi [ fshr(ShVal0, ShVal1, ShAmt), FunnelBB ], [ ShVal1, GuardBB ]
108 PHINode &Phi = cast<PHINode>(I);
109 unsigned FunnelOp = 0, GuardOp = 1;
110 Value *P0 = Phi.getOperand(0), *P1 = Phi.getOperand(1);
111 Value *ShVal0, *ShVal1, *ShAmt;
112 Intrinsic::ID IID = matchFunnelShift(P0, ShVal0, ShVal1, ShAmt);
113 if (IID == Intrinsic::not_intrinsic ||
114 (IID == Intrinsic::fshl && ShVal0 != P1) ||
115 (IID == Intrinsic::fshr && ShVal1 != P1)) {
116 IID = matchFunnelShift(P1, ShVal0, ShVal1, ShAmt);
117 if (IID == Intrinsic::not_intrinsic ||
118 (IID == Intrinsic::fshl && ShVal0 != P0) ||
119 (IID == Intrinsic::fshr && ShVal1 != P0))
120 return false;
121 assert((IID == Intrinsic::fshl || IID == Intrinsic::fshr) &&
122 "Pattern must match funnel shift left or right");
123 std::swap(FunnelOp, GuardOp);
126 // The incoming block with our source operand must be the "guard" block.
127 // That must contain a cmp+branch to avoid the funnel/rotate when the shift
128 // amount is equal to 0. The other incoming block is the block with the
129 // funnel/rotate.
130 BasicBlock *GuardBB = Phi.getIncomingBlock(GuardOp);
131 BasicBlock *FunnelBB = Phi.getIncomingBlock(FunnelOp);
132 Instruction *TermI = GuardBB->getTerminator();
134 // Ensure that the shift values dominate each block.
135 if (!DT.dominates(ShVal0, TermI) || !DT.dominates(ShVal1, TermI))
136 return false;
138 BasicBlock *PhiBB = Phi.getParent();
139 if (!match(TermI, m_Br(m_SpecificICmp(CmpInst::ICMP_EQ, m_Specific(ShAmt),
140 m_ZeroInt()),
141 m_SpecificBB(PhiBB), m_SpecificBB(FunnelBB))))
142 return false;
144 IRBuilder<> Builder(PhiBB, PhiBB->getFirstInsertionPt());
146 if (ShVal0 == ShVal1)
147 ++NumGuardedRotates;
148 else
149 ++NumGuardedFunnelShifts;
151 // If this is not a rotate then the select was blocking poison from the
152 // 'shift-by-zero' non-TVal, but a funnel shift won't - so freeze it.
153 bool IsFshl = IID == Intrinsic::fshl;
154 if (ShVal0 != ShVal1) {
155 if (IsFshl && !llvm::isGuaranteedNotToBePoison(ShVal1))
156 ShVal1 = Builder.CreateFreeze(ShVal1);
157 else if (!IsFshl && !llvm::isGuaranteedNotToBePoison(ShVal0))
158 ShVal0 = Builder.CreateFreeze(ShVal0);
161 // We matched a variation of this IR pattern:
162 // GuardBB:
163 // %cmp = icmp eq i32 %ShAmt, 0
164 // br i1 %cmp, label %PhiBB, label %FunnelBB
165 // FunnelBB:
166 // %sub = sub i32 32, %ShAmt
167 // %shr = lshr i32 %ShVal1, %sub
168 // %shl = shl i32 %ShVal0, %ShAmt
169 // %fsh = or i32 %shr, %shl
170 // br label %PhiBB
171 // PhiBB:
172 // %cond = phi i32 [ %fsh, %FunnelBB ], [ %ShVal0, %GuardBB ]
173 // -->
174 // llvm.fshl.i32(i32 %ShVal0, i32 %ShVal1, i32 %ShAmt)
175 Phi.replaceAllUsesWith(
176 Builder.CreateIntrinsic(IID, Phi.getType(), {ShVal0, ShVal1, ShAmt}));
177 return true;
180 /// This is used by foldAnyOrAllBitsSet() to capture a source value (Root) and
181 /// the bit indexes (Mask) needed by a masked compare. If we're matching a chain
182 /// of 'and' ops, then we also need to capture the fact that we saw an
183 /// "and X, 1", so that's an extra return value for that case.
184 namespace {
185 struct MaskOps {
186 Value *Root = nullptr;
187 APInt Mask;
188 bool MatchAndChain;
189 bool FoundAnd1 = false;
191 MaskOps(unsigned BitWidth, bool MatchAnds)
192 : Mask(APInt::getZero(BitWidth)), MatchAndChain(MatchAnds) {}
194 } // namespace
196 /// This is a recursive helper for foldAnyOrAllBitsSet() that walks through a
197 /// chain of 'and' or 'or' instructions looking for shift ops of a common source
198 /// value. Examples:
199 /// or (or (or X, (X >> 3)), (X >> 5)), (X >> 8)
200 /// returns { X, 0x129 }
201 /// and (and (X >> 1), 1), (X >> 4)
202 /// returns { X, 0x12 }
203 static bool matchAndOrChain(Value *V, MaskOps &MOps) {
204 Value *Op0, *Op1;
205 if (MOps.MatchAndChain) {
206 // Recurse through a chain of 'and' operands. This requires an extra check
207 // vs. the 'or' matcher: we must find an "and X, 1" instruction somewhere
208 // in the chain to know that all of the high bits are cleared.
209 if (match(V, m_And(m_Value(Op0), m_One()))) {
210 MOps.FoundAnd1 = true;
211 return matchAndOrChain(Op0, MOps);
213 if (match(V, m_And(m_Value(Op0), m_Value(Op1))))
214 return matchAndOrChain(Op0, MOps) && matchAndOrChain(Op1, MOps);
215 } else {
216 // Recurse through a chain of 'or' operands.
217 if (match(V, m_Or(m_Value(Op0), m_Value(Op1))))
218 return matchAndOrChain(Op0, MOps) && matchAndOrChain(Op1, MOps);
221 // We need a shift-right or a bare value representing a compare of bit 0 of
222 // the original source operand.
223 Value *Candidate;
224 const APInt *BitIndex = nullptr;
225 if (!match(V, m_LShr(m_Value(Candidate), m_APInt(BitIndex))))
226 Candidate = V;
228 // Initialize result source operand.
229 if (!MOps.Root)
230 MOps.Root = Candidate;
232 // The shift constant is out-of-range? This code hasn't been simplified.
233 if (BitIndex && BitIndex->uge(MOps.Mask.getBitWidth()))
234 return false;
236 // Fill in the mask bit derived from the shift constant.
237 MOps.Mask.setBit(BitIndex ? BitIndex->getZExtValue() : 0);
238 return MOps.Root == Candidate;
241 /// Match patterns that correspond to "any-bits-set" and "all-bits-set".
242 /// These will include a chain of 'or' or 'and'-shifted bits from a
243 /// common source value:
244 /// and (or (lshr X, C), ...), 1 --> (X & CMask) != 0
245 /// and (and (lshr X, C), ...), 1 --> (X & CMask) == CMask
246 /// Note: "any-bits-clear" and "all-bits-clear" are variations of these patterns
247 /// that differ only with a final 'not' of the result. We expect that final
248 /// 'not' to be folded with the compare that we create here (invert predicate).
249 static bool foldAnyOrAllBitsSet(Instruction &I) {
250 // The 'any-bits-set' ('or' chain) pattern is simpler to match because the
251 // final "and X, 1" instruction must be the final op in the sequence.
252 bool MatchAllBitsSet;
253 if (match(&I, m_c_And(m_OneUse(m_And(m_Value(), m_Value())), m_Value())))
254 MatchAllBitsSet = true;
255 else if (match(&I, m_And(m_OneUse(m_Or(m_Value(), m_Value())), m_One())))
256 MatchAllBitsSet = false;
257 else
258 return false;
260 MaskOps MOps(I.getType()->getScalarSizeInBits(), MatchAllBitsSet);
261 if (MatchAllBitsSet) {
262 if (!matchAndOrChain(cast<BinaryOperator>(&I), MOps) || !MOps.FoundAnd1)
263 return false;
264 } else {
265 if (!matchAndOrChain(cast<BinaryOperator>(&I)->getOperand(0), MOps))
266 return false;
269 // The pattern was found. Create a masked compare that replaces all of the
270 // shift and logic ops.
271 IRBuilder<> Builder(&I);
272 Constant *Mask = ConstantInt::get(I.getType(), MOps.Mask);
273 Value *And = Builder.CreateAnd(MOps.Root, Mask);
274 Value *Cmp = MatchAllBitsSet ? Builder.CreateICmpEQ(And, Mask)
275 : Builder.CreateIsNotNull(And);
276 Value *Zext = Builder.CreateZExt(Cmp, I.getType());
277 I.replaceAllUsesWith(Zext);
278 ++NumAnyOrAllBitsSet;
279 return true;
282 // Try to recognize below function as popcount intrinsic.
283 // This is the "best" algorithm from
284 // http://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetParallel
285 // Also used in TargetLowering::expandCTPOP().
287 // int popcount(unsigned int i) {
288 // i = i - ((i >> 1) & 0x55555555);
289 // i = (i & 0x33333333) + ((i >> 2) & 0x33333333);
290 // i = ((i + (i >> 4)) & 0x0F0F0F0F);
291 // return (i * 0x01010101) >> 24;
292 // }
293 static bool tryToRecognizePopCount(Instruction &I) {
294 if (I.getOpcode() != Instruction::LShr)
295 return false;
297 Type *Ty = I.getType();
298 if (!Ty->isIntOrIntVectorTy())
299 return false;
301 unsigned Len = Ty->getScalarSizeInBits();
302 // FIXME: fix Len == 8 and other irregular type lengths.
303 if (!(Len <= 128 && Len > 8 && Len % 8 == 0))
304 return false;
306 APInt Mask55 = APInt::getSplat(Len, APInt(8, 0x55));
307 APInt Mask33 = APInt::getSplat(Len, APInt(8, 0x33));
308 APInt Mask0F = APInt::getSplat(Len, APInt(8, 0x0F));
309 APInt Mask01 = APInt::getSplat(Len, APInt(8, 0x01));
310 APInt MaskShift = APInt(Len, Len - 8);
312 Value *Op0 = I.getOperand(0);
313 Value *Op1 = I.getOperand(1);
314 Value *MulOp0;
315 // Matching "(i * 0x01010101...) >> 24".
316 if ((match(Op0, m_Mul(m_Value(MulOp0), m_SpecificInt(Mask01)))) &&
317 match(Op1, m_SpecificInt(MaskShift))) {
318 Value *ShiftOp0;
319 // Matching "((i + (i >> 4)) & 0x0F0F0F0F...)".
320 if (match(MulOp0, m_And(m_c_Add(m_LShr(m_Value(ShiftOp0), m_SpecificInt(4)),
321 m_Deferred(ShiftOp0)),
322 m_SpecificInt(Mask0F)))) {
323 Value *AndOp0;
324 // Matching "(i & 0x33333333...) + ((i >> 2) & 0x33333333...)".
325 if (match(ShiftOp0,
326 m_c_Add(m_And(m_Value(AndOp0), m_SpecificInt(Mask33)),
327 m_And(m_LShr(m_Deferred(AndOp0), m_SpecificInt(2)),
328 m_SpecificInt(Mask33))))) {
329 Value *Root, *SubOp1;
330 // Matching "i - ((i >> 1) & 0x55555555...)".
331 if (match(AndOp0, m_Sub(m_Value(Root), m_Value(SubOp1))) &&
332 match(SubOp1, m_And(m_LShr(m_Specific(Root), m_SpecificInt(1)),
333 m_SpecificInt(Mask55)))) {
334 LLVM_DEBUG(dbgs() << "Recognized popcount intrinsic\n");
335 IRBuilder<> Builder(&I);
336 I.replaceAllUsesWith(
337 Builder.CreateIntrinsic(Intrinsic::ctpop, I.getType(), {Root}));
338 ++NumPopCountRecognized;
339 return true;
345 return false;
348 /// Fold smin(smax(fptosi(x), C1), C2) to llvm.fptosi.sat(x), providing C1 and
349 /// C2 saturate the value of the fp conversion. The transform is not reversable
350 /// as the fptosi.sat is more defined than the input - all values produce a
351 /// valid value for the fptosi.sat, where as some produce poison for original
352 /// that were out of range of the integer conversion. The reversed pattern may
353 /// use fmax and fmin instead. As we cannot directly reverse the transform, and
354 /// it is not always profitable, we make it conditional on the cost being
355 /// reported as lower by TTI.
356 static bool tryToFPToSat(Instruction &I, TargetTransformInfo &TTI) {
357 // Look for min(max(fptosi, converting to fptosi_sat.
358 Value *In;
359 const APInt *MinC, *MaxC;
360 if (!match(&I, m_SMax(m_OneUse(m_SMin(m_OneUse(m_FPToSI(m_Value(In))),
361 m_APInt(MinC))),
362 m_APInt(MaxC))) &&
363 !match(&I, m_SMin(m_OneUse(m_SMax(m_OneUse(m_FPToSI(m_Value(In))),
364 m_APInt(MaxC))),
365 m_APInt(MinC))))
366 return false;
368 // Check that the constants clamp a saturate.
369 if (!(*MinC + 1).isPowerOf2() || -*MaxC != *MinC + 1)
370 return false;
372 Type *IntTy = I.getType();
373 Type *FpTy = In->getType();
374 Type *SatTy =
375 IntegerType::get(IntTy->getContext(), (*MinC + 1).exactLogBase2() + 1);
376 if (auto *VecTy = dyn_cast<VectorType>(IntTy))
377 SatTy = VectorType::get(SatTy, VecTy->getElementCount());
379 // Get the cost of the intrinsic, and check that against the cost of
380 // fptosi+smin+smax
381 InstructionCost SatCost = TTI.getIntrinsicInstrCost(
382 IntrinsicCostAttributes(Intrinsic::fptosi_sat, SatTy, {In}, {FpTy}),
383 TTI::TCK_RecipThroughput);
384 SatCost += TTI.getCastInstrCost(Instruction::SExt, IntTy, SatTy,
385 TTI::CastContextHint::None,
386 TTI::TCK_RecipThroughput);
388 InstructionCost MinMaxCost = TTI.getCastInstrCost(
389 Instruction::FPToSI, IntTy, FpTy, TTI::CastContextHint::None,
390 TTI::TCK_RecipThroughput);
391 MinMaxCost += TTI.getIntrinsicInstrCost(
392 IntrinsicCostAttributes(Intrinsic::smin, IntTy, {IntTy}),
393 TTI::TCK_RecipThroughput);
394 MinMaxCost += TTI.getIntrinsicInstrCost(
395 IntrinsicCostAttributes(Intrinsic::smax, IntTy, {IntTy}),
396 TTI::TCK_RecipThroughput);
398 if (SatCost >= MinMaxCost)
399 return false;
401 IRBuilder<> Builder(&I);
402 Value *Sat =
403 Builder.CreateIntrinsic(Intrinsic::fptosi_sat, {SatTy, FpTy}, In);
404 I.replaceAllUsesWith(Builder.CreateSExt(Sat, IntTy));
405 return true;
408 /// Try to replace a mathlib call to sqrt with the LLVM intrinsic. This avoids
409 /// pessimistic codegen that has to account for setting errno and can enable
410 /// vectorization.
411 static bool foldSqrt(CallInst *Call, LibFunc Func, TargetTransformInfo &TTI,
412 TargetLibraryInfo &TLI, AssumptionCache &AC,
413 DominatorTree &DT) {
414 // If (1) this is a sqrt libcall, (2) we can assume that NAN is not created
415 // (because NNAN or the operand arg must not be less than -0.0) and (2) we
416 // would not end up lowering to a libcall anyway (which could change the value
417 // of errno), then:
418 // (1) errno won't be set.
419 // (2) it is safe to convert this to an intrinsic call.
420 Type *Ty = Call->getType();
421 Value *Arg = Call->getArgOperand(0);
422 if (TTI.haveFastSqrt(Ty) &&
423 (Call->hasNoNaNs() ||
424 cannotBeOrderedLessThanZero(
425 Arg, 0,
426 SimplifyQuery(Call->getDataLayout(), &TLI, &DT, &AC, Call)))) {
427 IRBuilder<> Builder(Call);
428 Value *NewSqrt =
429 Builder.CreateIntrinsic(Intrinsic::sqrt, Ty, Arg, Call, "sqrt");
430 Call->replaceAllUsesWith(NewSqrt);
432 // Explicitly erase the old call because a call with side effects is not
433 // trivially dead.
434 Call->eraseFromParent();
435 return true;
438 return false;
441 // Check if this array of constants represents a cttz table.
442 // Iterate over the elements from \p Table by trying to find/match all
443 // the numbers from 0 to \p InputBits that should represent cttz results.
444 static bool isCTTZTable(const ConstantDataArray &Table, uint64_t Mul,
445 uint64_t Shift, uint64_t InputBits) {
446 unsigned Length = Table.getNumElements();
447 if (Length < InputBits || Length > InputBits * 2)
448 return false;
450 APInt Mask = APInt::getBitsSetFrom(InputBits, Shift);
451 unsigned Matched = 0;
453 for (unsigned i = 0; i < Length; i++) {
454 uint64_t Element = Table.getElementAsInteger(i);
455 if (Element >= InputBits)
456 continue;
458 // Check if \p Element matches a concrete answer. It could fail for some
459 // elements that are never accessed, so we keep iterating over each element
460 // from the table. The number of matched elements should be equal to the
461 // number of potential right answers which is \p InputBits actually.
462 if ((((Mul << Element) & Mask.getZExtValue()) >> Shift) == i)
463 Matched++;
466 return Matched == InputBits;
469 // Try to recognize table-based ctz implementation.
470 // E.g., an example in C (for more cases please see the llvm/tests):
471 // int f(unsigned x) {
472 // static const char table[32] =
473 // {0, 1, 28, 2, 29, 14, 24, 3, 30,
474 // 22, 20, 15, 25, 17, 4, 8, 31, 27,
475 // 13, 23, 21, 19, 16, 7, 26, 12, 18, 6, 11, 5, 10, 9};
476 // return table[((unsigned)((x & -x) * 0x077CB531U)) >> 27];
477 // }
478 // this can be lowered to `cttz` instruction.
479 // There is also a special case when the element is 0.
481 // Here are some examples or LLVM IR for a 64-bit target:
483 // CASE 1:
484 // %sub = sub i32 0, %x
485 // %and = and i32 %sub, %x
486 // %mul = mul i32 %and, 125613361
487 // %shr = lshr i32 %mul, 27
488 // %idxprom = zext i32 %shr to i64
489 // %arrayidx = getelementptr inbounds [32 x i8], [32 x i8]* @ctz1.table, i64 0,
490 // i64 %idxprom
491 // %0 = load i8, i8* %arrayidx, align 1, !tbaa !8
493 // CASE 2:
494 // %sub = sub i32 0, %x
495 // %and = and i32 %sub, %x
496 // %mul = mul i32 %and, 72416175
497 // %shr = lshr i32 %mul, 26
498 // %idxprom = zext i32 %shr to i64
499 // %arrayidx = getelementptr inbounds [64 x i16], [64 x i16]* @ctz2.table,
500 // i64 0, i64 %idxprom
501 // %0 = load i16, i16* %arrayidx, align 2, !tbaa !8
503 // CASE 3:
504 // %sub = sub i32 0, %x
505 // %and = and i32 %sub, %x
506 // %mul = mul i32 %and, 81224991
507 // %shr = lshr i32 %mul, 27
508 // %idxprom = zext i32 %shr to i64
509 // %arrayidx = getelementptr inbounds [32 x i32], [32 x i32]* @ctz3.table,
510 // i64 0, i64 %idxprom
511 // %0 = load i32, i32* %arrayidx, align 4, !tbaa !8
513 // CASE 4:
514 // %sub = sub i64 0, %x
515 // %and = and i64 %sub, %x
516 // %mul = mul i64 %and, 283881067100198605
517 // %shr = lshr i64 %mul, 58
518 // %arrayidx = getelementptr inbounds [64 x i8], [64 x i8]* @table, i64 0,
519 // i64 %shr
520 // %0 = load i8, i8* %arrayidx, align 1, !tbaa !8
522 // All this can be lowered to @llvm.cttz.i32/64 intrinsic.
523 static bool tryToRecognizeTableBasedCttz(Instruction &I) {
524 LoadInst *LI = dyn_cast<LoadInst>(&I);
525 if (!LI)
526 return false;
528 Type *AccessType = LI->getType();
529 if (!AccessType->isIntegerTy())
530 return false;
532 GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(LI->getPointerOperand());
533 if (!GEP || !GEP->isInBounds() || GEP->getNumIndices() != 2)
534 return false;
536 if (!GEP->getSourceElementType()->isArrayTy())
537 return false;
539 uint64_t ArraySize = GEP->getSourceElementType()->getArrayNumElements();
540 if (ArraySize != 32 && ArraySize != 64)
541 return false;
543 GlobalVariable *GVTable = dyn_cast<GlobalVariable>(GEP->getPointerOperand());
544 if (!GVTable || !GVTable->hasInitializer() || !GVTable->isConstant())
545 return false;
547 ConstantDataArray *ConstData =
548 dyn_cast<ConstantDataArray>(GVTable->getInitializer());
549 if (!ConstData)
550 return false;
552 if (!match(GEP->idx_begin()->get(), m_ZeroInt()))
553 return false;
555 Value *Idx2 = std::next(GEP->idx_begin())->get();
556 Value *X1;
557 uint64_t MulConst, ShiftConst;
558 // FIXME: 64-bit targets have `i64` type for the GEP index, so this match will
559 // probably fail for other (e.g. 32-bit) targets.
560 if (!match(Idx2, m_ZExtOrSelf(
561 m_LShr(m_Mul(m_c_And(m_Neg(m_Value(X1)), m_Deferred(X1)),
562 m_ConstantInt(MulConst)),
563 m_ConstantInt(ShiftConst)))))
564 return false;
566 unsigned InputBits = X1->getType()->getScalarSizeInBits();
567 if (InputBits != 32 && InputBits != 64)
568 return false;
570 // Shift should extract top 5..7 bits.
571 if (InputBits - Log2_32(InputBits) != ShiftConst &&
572 InputBits - Log2_32(InputBits) - 1 != ShiftConst)
573 return false;
575 if (!isCTTZTable(*ConstData, MulConst, ShiftConst, InputBits))
576 return false;
578 auto ZeroTableElem = ConstData->getElementAsInteger(0);
579 bool DefinedForZero = ZeroTableElem == InputBits;
581 IRBuilder<> B(LI);
582 ConstantInt *BoolConst = B.getInt1(!DefinedForZero);
583 Type *XType = X1->getType();
584 auto Cttz = B.CreateIntrinsic(Intrinsic::cttz, {XType}, {X1, BoolConst});
585 Value *ZExtOrTrunc = nullptr;
587 if (DefinedForZero) {
588 ZExtOrTrunc = B.CreateZExtOrTrunc(Cttz, AccessType);
589 } else {
590 // If the value in elem 0 isn't the same as InputBits, we still want to
591 // produce the value from the table.
592 auto Cmp = B.CreateICmpEQ(X1, ConstantInt::get(XType, 0));
593 auto Select =
594 B.CreateSelect(Cmp, ConstantInt::get(XType, ZeroTableElem), Cttz);
596 // NOTE: If the table[0] is 0, but the cttz(0) is defined by the Target
597 // it should be handled as: `cttz(x) & (typeSize - 1)`.
599 ZExtOrTrunc = B.CreateZExtOrTrunc(Select, AccessType);
602 LI->replaceAllUsesWith(ZExtOrTrunc);
604 return true;
607 /// This is used by foldLoadsRecursive() to capture a Root Load node which is
608 /// of type or(load, load) and recursively build the wide load. Also capture the
609 /// shift amount, zero extend type and loadSize.
610 struct LoadOps {
611 LoadInst *Root = nullptr;
612 LoadInst *RootInsert = nullptr;
613 bool FoundRoot = false;
614 uint64_t LoadSize = 0;
615 const APInt *Shift = nullptr;
616 Type *ZextType;
617 AAMDNodes AATags;
620 // Identify and Merge consecutive loads recursively which is of the form
621 // (ZExt(L1) << shift1) | (ZExt(L2) << shift2) -> ZExt(L3) << shift1
622 // (ZExt(L1) << shift1) | ZExt(L2) -> ZExt(L3)
623 static bool foldLoadsRecursive(Value *V, LoadOps &LOps, const DataLayout &DL,
624 AliasAnalysis &AA) {
625 const APInt *ShAmt2 = nullptr;
626 Value *X;
627 Instruction *L1, *L2;
629 // Go to the last node with loads.
630 if (match(V, m_OneUse(m_c_Or(
631 m_Value(X),
632 m_OneUse(m_Shl(m_OneUse(m_ZExt(m_OneUse(m_Instruction(L2)))),
633 m_APInt(ShAmt2)))))) ||
634 match(V, m_OneUse(m_Or(m_Value(X),
635 m_OneUse(m_ZExt(m_OneUse(m_Instruction(L2)))))))) {
636 if (!foldLoadsRecursive(X, LOps, DL, AA) && LOps.FoundRoot)
637 // Avoid Partial chain merge.
638 return false;
639 } else
640 return false;
642 // Check if the pattern has loads
643 LoadInst *LI1 = LOps.Root;
644 const APInt *ShAmt1 = LOps.Shift;
645 if (LOps.FoundRoot == false &&
646 (match(X, m_OneUse(m_ZExt(m_Instruction(L1)))) ||
647 match(X, m_OneUse(m_Shl(m_OneUse(m_ZExt(m_OneUse(m_Instruction(L1)))),
648 m_APInt(ShAmt1)))))) {
649 LI1 = dyn_cast<LoadInst>(L1);
651 LoadInst *LI2 = dyn_cast<LoadInst>(L2);
653 // Check if loads are same, atomic, volatile and having same address space.
654 if (LI1 == LI2 || !LI1 || !LI2 || !LI1->isSimple() || !LI2->isSimple() ||
655 LI1->getPointerAddressSpace() != LI2->getPointerAddressSpace())
656 return false;
658 // Check if Loads come from same BB.
659 if (LI1->getParent() != LI2->getParent())
660 return false;
662 // Find the data layout
663 bool IsBigEndian = DL.isBigEndian();
665 // Check if loads are consecutive and same size.
666 Value *Load1Ptr = LI1->getPointerOperand();
667 APInt Offset1(DL.getIndexTypeSizeInBits(Load1Ptr->getType()), 0);
668 Load1Ptr =
669 Load1Ptr->stripAndAccumulateConstantOffsets(DL, Offset1,
670 /* AllowNonInbounds */ true);
672 Value *Load2Ptr = LI2->getPointerOperand();
673 APInt Offset2(DL.getIndexTypeSizeInBits(Load2Ptr->getType()), 0);
674 Load2Ptr =
675 Load2Ptr->stripAndAccumulateConstantOffsets(DL, Offset2,
676 /* AllowNonInbounds */ true);
678 // Verify if both loads have same base pointers and load sizes are same.
679 uint64_t LoadSize1 = LI1->getType()->getPrimitiveSizeInBits();
680 uint64_t LoadSize2 = LI2->getType()->getPrimitiveSizeInBits();
681 if (Load1Ptr != Load2Ptr || LoadSize1 != LoadSize2)
682 return false;
684 // Support Loadsizes greater or equal to 8bits and only power of 2.
685 if (LoadSize1 < 8 || !isPowerOf2_64(LoadSize1))
686 return false;
688 // Alias Analysis to check for stores b/w the loads.
689 LoadInst *Start = LOps.FoundRoot ? LOps.RootInsert : LI1, *End = LI2;
690 MemoryLocation Loc;
691 if (!Start->comesBefore(End)) {
692 std::swap(Start, End);
693 Loc = MemoryLocation::get(End);
694 if (LOps.FoundRoot)
695 Loc = Loc.getWithNewSize(LOps.LoadSize);
696 } else
697 Loc = MemoryLocation::get(End);
698 unsigned NumScanned = 0;
699 for (Instruction &Inst :
700 make_range(Start->getIterator(), End->getIterator())) {
701 if (Inst.mayWriteToMemory() && isModSet(AA.getModRefInfo(&Inst, Loc)))
702 return false;
704 // Ignore debug info so that's not counted against MaxInstrsToScan.
705 // Otherwise debug info could affect codegen.
706 if (!isa<DbgInfoIntrinsic>(Inst) && ++NumScanned > MaxInstrsToScan)
707 return false;
710 // Make sure Load with lower Offset is at LI1
711 bool Reverse = false;
712 if (Offset2.slt(Offset1)) {
713 std::swap(LI1, LI2);
714 std::swap(ShAmt1, ShAmt2);
715 std::swap(Offset1, Offset2);
716 std::swap(Load1Ptr, Load2Ptr);
717 std::swap(LoadSize1, LoadSize2);
718 Reverse = true;
721 // Big endian swap the shifts
722 if (IsBigEndian)
723 std::swap(ShAmt1, ShAmt2);
725 // Find Shifts values.
726 uint64_t Shift1 = 0, Shift2 = 0;
727 if (ShAmt1)
728 Shift1 = ShAmt1->getZExtValue();
729 if (ShAmt2)
730 Shift2 = ShAmt2->getZExtValue();
732 // First load is always LI1. This is where we put the new load.
733 // Use the merged load size available from LI1 for forward loads.
734 if (LOps.FoundRoot) {
735 if (!Reverse)
736 LoadSize1 = LOps.LoadSize;
737 else
738 LoadSize2 = LOps.LoadSize;
741 // Verify if shift amount and load index aligns and verifies that loads
742 // are consecutive.
743 uint64_t ShiftDiff = IsBigEndian ? LoadSize2 : LoadSize1;
744 uint64_t PrevSize =
745 DL.getTypeStoreSize(IntegerType::get(LI1->getContext(), LoadSize1));
746 if ((Shift2 - Shift1) != ShiftDiff || (Offset2 - Offset1) != PrevSize)
747 return false;
749 // Update LOps
750 AAMDNodes AATags1 = LOps.AATags;
751 AAMDNodes AATags2 = LI2->getAAMetadata();
752 if (LOps.FoundRoot == false) {
753 LOps.FoundRoot = true;
754 AATags1 = LI1->getAAMetadata();
756 LOps.LoadSize = LoadSize1 + LoadSize2;
757 LOps.RootInsert = Start;
759 // Concatenate the AATags of the Merged Loads.
760 LOps.AATags = AATags1.concat(AATags2);
762 LOps.Root = LI1;
763 LOps.Shift = ShAmt1;
764 LOps.ZextType = X->getType();
765 return true;
768 // For a given BB instruction, evaluate all loads in the chain that form a
769 // pattern which suggests that the loads can be combined. The one and only use
770 // of the loads is to form a wider load.
771 static bool foldConsecutiveLoads(Instruction &I, const DataLayout &DL,
772 TargetTransformInfo &TTI, AliasAnalysis &AA,
773 const DominatorTree &DT) {
774 // Only consider load chains of scalar values.
775 if (isa<VectorType>(I.getType()))
776 return false;
778 LoadOps LOps;
779 if (!foldLoadsRecursive(&I, LOps, DL, AA) || !LOps.FoundRoot)
780 return false;
782 IRBuilder<> Builder(&I);
783 LoadInst *NewLoad = nullptr, *LI1 = LOps.Root;
785 IntegerType *WiderType = IntegerType::get(I.getContext(), LOps.LoadSize);
786 // TTI based checks if we want to proceed with wider load
787 bool Allowed = TTI.isTypeLegal(WiderType);
788 if (!Allowed)
789 return false;
791 unsigned AS = LI1->getPointerAddressSpace();
792 unsigned Fast = 0;
793 Allowed = TTI.allowsMisalignedMemoryAccesses(I.getContext(), LOps.LoadSize,
794 AS, LI1->getAlign(), &Fast);
795 if (!Allowed || !Fast)
796 return false;
798 // Get the Index and Ptr for the new GEP.
799 Value *Load1Ptr = LI1->getPointerOperand();
800 Builder.SetInsertPoint(LOps.RootInsert);
801 if (!DT.dominates(Load1Ptr, LOps.RootInsert)) {
802 APInt Offset1(DL.getIndexTypeSizeInBits(Load1Ptr->getType()), 0);
803 Load1Ptr = Load1Ptr->stripAndAccumulateConstantOffsets(
804 DL, Offset1, /* AllowNonInbounds */ true);
805 Load1Ptr = Builder.CreatePtrAdd(Load1Ptr, Builder.getInt(Offset1));
807 // Generate wider load.
808 NewLoad = Builder.CreateAlignedLoad(WiderType, Load1Ptr, LI1->getAlign(),
809 LI1->isVolatile(), "");
810 NewLoad->takeName(LI1);
811 // Set the New Load AATags Metadata.
812 if (LOps.AATags)
813 NewLoad->setAAMetadata(LOps.AATags);
815 Value *NewOp = NewLoad;
816 // Check if zero extend needed.
817 if (LOps.ZextType)
818 NewOp = Builder.CreateZExt(NewOp, LOps.ZextType);
820 // Check if shift needed. We need to shift with the amount of load1
821 // shift if not zero.
822 if (LOps.Shift)
823 NewOp = Builder.CreateShl(NewOp, ConstantInt::get(I.getContext(), *LOps.Shift));
824 I.replaceAllUsesWith(NewOp);
826 return true;
829 // Calculate GEP Stride and accumulated const ModOffset. Return Stride and
830 // ModOffset
831 static std::pair<APInt, APInt>
832 getStrideAndModOffsetOfGEP(Value *PtrOp, const DataLayout &DL) {
833 unsigned BW = DL.getIndexTypeSizeInBits(PtrOp->getType());
834 std::optional<APInt> Stride;
835 APInt ModOffset(BW, 0);
836 // Return a minimum gep stride, greatest common divisor of consective gep
837 // index scales(c.f. Bézout's identity).
838 while (auto *GEP = dyn_cast<GEPOperator>(PtrOp)) {
839 SmallMapVector<Value *, APInt, 4> VarOffsets;
840 if (!GEP->collectOffset(DL, BW, VarOffsets, ModOffset))
841 break;
843 for (auto [V, Scale] : VarOffsets) {
844 // Only keep a power of two factor for non-inbounds
845 if (!GEP->isInBounds())
846 Scale = APInt::getOneBitSet(Scale.getBitWidth(), Scale.countr_zero());
848 if (!Stride)
849 Stride = Scale;
850 else
851 Stride = APIntOps::GreatestCommonDivisor(*Stride, Scale);
854 PtrOp = GEP->getPointerOperand();
857 // Check whether pointer arrives back at Global Variable via at least one GEP.
858 // Even if it doesn't, we can check by alignment.
859 if (!isa<GlobalVariable>(PtrOp) || !Stride)
860 return {APInt(BW, 1), APInt(BW, 0)};
862 // In consideration of signed GEP indices, non-negligible offset become
863 // remainder of division by minimum GEP stride.
864 ModOffset = ModOffset.srem(*Stride);
865 if (ModOffset.isNegative())
866 ModOffset += *Stride;
868 return {*Stride, ModOffset};
871 /// If C is a constant patterned array and all valid loaded results for given
872 /// alignment are same to a constant, return that constant.
873 static bool foldPatternedLoads(Instruction &I, const DataLayout &DL) {
874 auto *LI = dyn_cast<LoadInst>(&I);
875 if (!LI || LI->isVolatile())
876 return false;
878 // We can only fold the load if it is from a constant global with definitive
879 // initializer. Skip expensive logic if this is not the case.
880 auto *PtrOp = LI->getPointerOperand();
881 auto *GV = dyn_cast<GlobalVariable>(getUnderlyingObject(PtrOp));
882 if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer())
883 return false;
885 // Bail for large initializers in excess of 4K to avoid too many scans.
886 Constant *C = GV->getInitializer();
887 uint64_t GVSize = DL.getTypeAllocSize(C->getType());
888 if (!GVSize || 4096 < GVSize)
889 return false;
891 Type *LoadTy = LI->getType();
892 unsigned BW = DL.getIndexTypeSizeInBits(PtrOp->getType());
893 auto [Stride, ConstOffset] = getStrideAndModOffsetOfGEP(PtrOp, DL);
895 // Any possible offset could be multiple of GEP stride. And any valid
896 // offset is multiple of load alignment, so checking only multiples of bigger
897 // one is sufficient to say results' equality.
898 if (auto LA = LI->getAlign();
899 LA <= GV->getAlign().valueOrOne() && Stride.getZExtValue() < LA.value()) {
900 ConstOffset = APInt(BW, 0);
901 Stride = APInt(BW, LA.value());
904 Constant *Ca = ConstantFoldLoadFromConst(C, LoadTy, ConstOffset, DL);
905 if (!Ca)
906 return false;
908 unsigned E = GVSize - DL.getTypeStoreSize(LoadTy);
909 for (; ConstOffset.getZExtValue() <= E; ConstOffset += Stride)
910 if (Ca != ConstantFoldLoadFromConst(C, LoadTy, ConstOffset, DL))
911 return false;
913 I.replaceAllUsesWith(Ca);
915 return true;
918 namespace {
919 class StrNCmpInliner {
920 public:
921 StrNCmpInliner(CallInst *CI, LibFunc Func, DomTreeUpdater *DTU,
922 const DataLayout &DL)
923 : CI(CI), Func(Func), DTU(DTU), DL(DL) {}
925 bool optimizeStrNCmp();
927 private:
928 void inlineCompare(Value *LHS, StringRef RHS, uint64_t N, bool Swapped);
930 CallInst *CI;
931 LibFunc Func;
932 DomTreeUpdater *DTU;
933 const DataLayout &DL;
936 } // namespace
938 /// First we normalize calls to strncmp/strcmp to the form of
939 /// compare(s1, s2, N), which means comparing first N bytes of s1 and s2
940 /// (without considering '\0').
942 /// Examples:
944 /// \code
945 /// strncmp(s, "a", 3) -> compare(s, "a", 2)
946 /// strncmp(s, "abc", 3) -> compare(s, "abc", 3)
947 /// strncmp(s, "a\0b", 3) -> compare(s, "a\0b", 2)
948 /// strcmp(s, "a") -> compare(s, "a", 2)
950 /// char s2[] = {'a'}
951 /// strncmp(s, s2, 3) -> compare(s, s2, 3)
953 /// char s2[] = {'a', 'b', 'c', 'd'}
954 /// strncmp(s, s2, 3) -> compare(s, s2, 3)
955 /// \endcode
957 /// We only handle cases where N and exactly one of s1 and s2 are constant.
958 /// Cases that s1 and s2 are both constant are already handled by the
959 /// instcombine pass.
961 /// We do not handle cases where N > StrNCmpInlineThreshold.
963 /// We also do not handles cases where N < 2, which are already
964 /// handled by the instcombine pass.
966 bool StrNCmpInliner::optimizeStrNCmp() {
967 if (StrNCmpInlineThreshold < 2)
968 return false;
970 if (!isOnlyUsedInZeroComparison(CI))
971 return false;
973 Value *Str1P = CI->getArgOperand(0);
974 Value *Str2P = CI->getArgOperand(1);
975 // Should be handled elsewhere.
976 if (Str1P == Str2P)
977 return false;
979 StringRef Str1, Str2;
980 bool HasStr1 = getConstantStringInfo(Str1P, Str1, /*TrimAtNul=*/false);
981 bool HasStr2 = getConstantStringInfo(Str2P, Str2, /*TrimAtNul=*/false);
982 if (HasStr1 == HasStr2)
983 return false;
985 // Note that '\0' and characters after it are not trimmed.
986 StringRef Str = HasStr1 ? Str1 : Str2;
987 Value *StrP = HasStr1 ? Str2P : Str1P;
989 size_t Idx = Str.find('\0');
990 uint64_t N = Idx == StringRef::npos ? UINT64_MAX : Idx + 1;
991 if (Func == LibFunc_strncmp) {
992 if (auto *ConstInt = dyn_cast<ConstantInt>(CI->getArgOperand(2)))
993 N = std::min(N, ConstInt->getZExtValue());
994 else
995 return false;
997 // Now N means how many bytes we need to compare at most.
998 if (N > Str.size() || N < 2 || N > StrNCmpInlineThreshold)
999 return false;
1001 // Cases where StrP has two or more dereferenceable bytes might be better
1002 // optimized elsewhere.
1003 bool CanBeNull = false, CanBeFreed = false;
1004 if (StrP->getPointerDereferenceableBytes(DL, CanBeNull, CanBeFreed) > 1)
1005 return false;
1006 inlineCompare(StrP, Str, N, HasStr1);
1007 return true;
1010 /// Convert
1012 /// \code
1013 /// ret = compare(s1, s2, N)
1014 /// \endcode
1016 /// into
1018 /// \code
1019 /// ret = (int)s1[0] - (int)s2[0]
1020 /// if (ret != 0)
1021 /// goto NE
1022 /// ...
1023 /// ret = (int)s1[N-2] - (int)s2[N-2]
1024 /// if (ret != 0)
1025 /// goto NE
1026 /// ret = (int)s1[N-1] - (int)s2[N-1]
1027 /// NE:
1028 /// \endcode
1030 /// CFG before and after the transformation:
1032 /// (before)
1033 /// BBCI
1035 /// (after)
1036 /// BBCI -> BBSubs[0] (sub,icmp) --NE-> BBNE -> BBTail
1037 /// | ^
1038 /// E |
1039 /// | |
1040 /// BBSubs[1] (sub,icmp) --NE-----+
1041 /// ... |
1042 /// BBSubs[N-1] (sub) ---------+
1044 void StrNCmpInliner::inlineCompare(Value *LHS, StringRef RHS, uint64_t N,
1045 bool Swapped) {
1046 auto &Ctx = CI->getContext();
1047 IRBuilder<> B(Ctx);
1048 // We want these instructions to be recognized as inlined instructions for the
1049 // compare call, but we don't have a source location for the definition of
1050 // that function, since we're generating that code now. Because the generated
1051 // code is a viable point for a memory access error, we make the pragmatic
1052 // choice here to directly use CI's location so that we have useful
1053 // attribution for the generated code.
1054 B.SetCurrentDebugLocation(CI->getDebugLoc());
1056 BasicBlock *BBCI = CI->getParent();
1057 BasicBlock *BBTail =
1058 SplitBlock(BBCI, CI, DTU, nullptr, nullptr, BBCI->getName() + ".tail");
1060 SmallVector<BasicBlock *> BBSubs;
1061 for (uint64_t I = 0; I < N; ++I)
1062 BBSubs.push_back(
1063 BasicBlock::Create(Ctx, "sub_" + Twine(I), BBCI->getParent(), BBTail));
1064 BasicBlock *BBNE = BasicBlock::Create(Ctx, "ne", BBCI->getParent(), BBTail);
1066 cast<BranchInst>(BBCI->getTerminator())->setSuccessor(0, BBSubs[0]);
1068 B.SetInsertPoint(BBNE);
1069 PHINode *Phi = B.CreatePHI(CI->getType(), N);
1070 B.CreateBr(BBTail);
1072 Value *Base = LHS;
1073 for (uint64_t i = 0; i < N; ++i) {
1074 B.SetInsertPoint(BBSubs[i]);
1075 Value *VL =
1076 B.CreateZExt(B.CreateLoad(B.getInt8Ty(),
1077 B.CreateInBoundsPtrAdd(Base, B.getInt64(i))),
1078 CI->getType());
1079 Value *VR =
1080 ConstantInt::get(CI->getType(), static_cast<unsigned char>(RHS[i]));
1081 Value *Sub = Swapped ? B.CreateSub(VR, VL) : B.CreateSub(VL, VR);
1082 if (i < N - 1)
1083 B.CreateCondBr(B.CreateICmpNE(Sub, ConstantInt::get(CI->getType(), 0)),
1084 BBNE, BBSubs[i + 1]);
1085 else
1086 B.CreateBr(BBNE);
1088 Phi->addIncoming(Sub, BBSubs[i]);
1091 CI->replaceAllUsesWith(Phi);
1092 CI->eraseFromParent();
1094 if (DTU) {
1095 SmallVector<DominatorTree::UpdateType, 8> Updates;
1096 Updates.push_back({DominatorTree::Insert, BBCI, BBSubs[0]});
1097 for (uint64_t i = 0; i < N; ++i) {
1098 if (i < N - 1)
1099 Updates.push_back({DominatorTree::Insert, BBSubs[i], BBSubs[i + 1]});
1100 Updates.push_back({DominatorTree::Insert, BBSubs[i], BBNE});
1102 Updates.push_back({DominatorTree::Insert, BBNE, BBTail});
1103 Updates.push_back({DominatorTree::Delete, BBCI, BBTail});
1104 DTU->applyUpdates(Updates);
1108 /// Convert memchr with a small constant string into a switch
1109 static bool foldMemChr(CallInst *Call, DomTreeUpdater *DTU,
1110 const DataLayout &DL) {
1111 if (isa<Constant>(Call->getArgOperand(1)))
1112 return false;
1114 StringRef Str;
1115 Value *Base = Call->getArgOperand(0);
1116 if (!getConstantStringInfo(Base, Str, /*TrimAtNul=*/false))
1117 return false;
1119 uint64_t N = Str.size();
1120 if (auto *ConstInt = dyn_cast<ConstantInt>(Call->getArgOperand(2))) {
1121 uint64_t Val = ConstInt->getZExtValue();
1122 // Ignore the case that n is larger than the size of string.
1123 if (Val > N)
1124 return false;
1125 N = Val;
1126 } else
1127 return false;
1129 if (N > MemChrInlineThreshold)
1130 return false;
1132 BasicBlock *BB = Call->getParent();
1133 BasicBlock *BBNext = SplitBlock(BB, Call, DTU);
1134 IRBuilder<> IRB(BB);
1135 IntegerType *ByteTy = IRB.getInt8Ty();
1136 BB->getTerminator()->eraseFromParent();
1137 SwitchInst *SI = IRB.CreateSwitch(
1138 IRB.CreateTrunc(Call->getArgOperand(1), ByteTy), BBNext, N);
1139 Type *IndexTy = DL.getIndexType(Call->getType());
1140 SmallVector<DominatorTree::UpdateType, 8> Updates;
1142 BasicBlock *BBSuccess = BasicBlock::Create(
1143 Call->getContext(), "memchr.success", BB->getParent(), BBNext);
1144 IRB.SetInsertPoint(BBSuccess);
1145 PHINode *IndexPHI = IRB.CreatePHI(IndexTy, N, "memchr.idx");
1146 Value *FirstOccursLocation = IRB.CreateInBoundsPtrAdd(Base, IndexPHI);
1147 IRB.CreateBr(BBNext);
1148 if (DTU)
1149 Updates.push_back({DominatorTree::Insert, BBSuccess, BBNext});
1151 SmallPtrSet<ConstantInt *, 4> Cases;
1152 for (uint64_t I = 0; I < N; ++I) {
1153 ConstantInt *CaseVal = ConstantInt::get(ByteTy, Str[I]);
1154 if (!Cases.insert(CaseVal).second)
1155 continue;
1157 BasicBlock *BBCase = BasicBlock::Create(Call->getContext(), "memchr.case",
1158 BB->getParent(), BBSuccess);
1159 SI->addCase(CaseVal, BBCase);
1160 IRB.SetInsertPoint(BBCase);
1161 IndexPHI->addIncoming(ConstantInt::get(IndexTy, I), BBCase);
1162 IRB.CreateBr(BBSuccess);
1163 if (DTU) {
1164 Updates.push_back({DominatorTree::Insert, BB, BBCase});
1165 Updates.push_back({DominatorTree::Insert, BBCase, BBSuccess});
1169 PHINode *PHI =
1170 PHINode::Create(Call->getType(), 2, Call->getName(), BBNext->begin());
1171 PHI->addIncoming(Constant::getNullValue(Call->getType()), BB);
1172 PHI->addIncoming(FirstOccursLocation, BBSuccess);
1174 Call->replaceAllUsesWith(PHI);
1175 Call->eraseFromParent();
1177 if (DTU)
1178 DTU->applyUpdates(Updates);
1180 return true;
1183 static bool foldLibCalls(Instruction &I, TargetTransformInfo &TTI,
1184 TargetLibraryInfo &TLI, AssumptionCache &AC,
1185 DominatorTree &DT, const DataLayout &DL,
1186 bool &MadeCFGChange) {
1188 auto *CI = dyn_cast<CallInst>(&I);
1189 if (!CI || CI->isNoBuiltin())
1190 return false;
1192 Function *CalledFunc = CI->getCalledFunction();
1193 if (!CalledFunc)
1194 return false;
1196 LibFunc LF;
1197 if (!TLI.getLibFunc(*CalledFunc, LF) ||
1198 !isLibFuncEmittable(CI->getModule(), &TLI, LF))
1199 return false;
1201 DomTreeUpdater DTU(&DT, DomTreeUpdater::UpdateStrategy::Lazy);
1203 switch (LF) {
1204 case LibFunc_sqrt:
1205 case LibFunc_sqrtf:
1206 case LibFunc_sqrtl:
1207 return foldSqrt(CI, LF, TTI, TLI, AC, DT);
1208 case LibFunc_strcmp:
1209 case LibFunc_strncmp:
1210 if (StrNCmpInliner(CI, LF, &DTU, DL).optimizeStrNCmp()) {
1211 MadeCFGChange = true;
1212 return true;
1214 break;
1215 case LibFunc_memchr:
1216 if (foldMemChr(CI, &DTU, DL)) {
1217 MadeCFGChange = true;
1218 return true;
1220 break;
1221 default:;
1223 return false;
1226 /// This is the entry point for folds that could be implemented in regular
1227 /// InstCombine, but they are separated because they are not expected to
1228 /// occur frequently and/or have more than a constant-length pattern match.
1229 static bool foldUnusualPatterns(Function &F, DominatorTree &DT,
1230 TargetTransformInfo &TTI,
1231 TargetLibraryInfo &TLI, AliasAnalysis &AA,
1232 AssumptionCache &AC, bool &MadeCFGChange) {
1233 bool MadeChange = false;
1234 for (BasicBlock &BB : F) {
1235 // Ignore unreachable basic blocks.
1236 if (!DT.isReachableFromEntry(&BB))
1237 continue;
1239 const DataLayout &DL = F.getDataLayout();
1241 // Walk the block backwards for efficiency. We're matching a chain of
1242 // use->defs, so we're more likely to succeed by starting from the bottom.
1243 // Also, we want to avoid matching partial patterns.
1244 // TODO: It would be more efficient if we removed dead instructions
1245 // iteratively in this loop rather than waiting until the end.
1246 for (Instruction &I : make_early_inc_range(llvm::reverse(BB))) {
1247 MadeChange |= foldAnyOrAllBitsSet(I);
1248 MadeChange |= foldGuardedFunnelShift(I, DT);
1249 MadeChange |= tryToRecognizePopCount(I);
1250 MadeChange |= tryToFPToSat(I, TTI);
1251 MadeChange |= tryToRecognizeTableBasedCttz(I);
1252 MadeChange |= foldConsecutiveLoads(I, DL, TTI, AA, DT);
1253 MadeChange |= foldPatternedLoads(I, DL);
1254 // NOTE: This function introduces erasing of the instruction `I`, so it
1255 // needs to be called at the end of this sequence, otherwise we may make
1256 // bugs.
1257 MadeChange |= foldLibCalls(I, TTI, TLI, AC, DT, DL, MadeCFGChange);
1261 // We're done with transforms, so remove dead instructions.
1262 if (MadeChange)
1263 for (BasicBlock &BB : F)
1264 SimplifyInstructionsInBlock(&BB);
1266 return MadeChange;
1269 /// This is the entry point for all transforms. Pass manager differences are
1270 /// handled in the callers of this function.
1271 static bool runImpl(Function &F, AssumptionCache &AC, TargetTransformInfo &TTI,
1272 TargetLibraryInfo &TLI, DominatorTree &DT,
1273 AliasAnalysis &AA, bool &MadeCFGChange) {
1274 bool MadeChange = false;
1275 const DataLayout &DL = F.getDataLayout();
1276 TruncInstCombine TIC(AC, TLI, DL, DT);
1277 MadeChange |= TIC.run(F);
1278 MadeChange |= foldUnusualPatterns(F, DT, TTI, TLI, AA, AC, MadeCFGChange);
1279 return MadeChange;
1282 PreservedAnalyses AggressiveInstCombinePass::run(Function &F,
1283 FunctionAnalysisManager &AM) {
1284 auto &AC = AM.getResult<AssumptionAnalysis>(F);
1285 auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
1286 auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
1287 auto &TTI = AM.getResult<TargetIRAnalysis>(F);
1288 auto &AA = AM.getResult<AAManager>(F);
1289 bool MadeCFGChange = false;
1290 if (!runImpl(F, AC, TTI, TLI, DT, AA, MadeCFGChange)) {
1291 // No changes, all analyses are preserved.
1292 return PreservedAnalyses::all();
1294 // Mark all the analyses that instcombine updates as preserved.
1295 PreservedAnalyses PA;
1296 if (MadeCFGChange)
1297 PA.preserve<DominatorTreeAnalysis>();
1298 else
1299 PA.preserveSet<CFGAnalyses>();
1300 return PA;