1 //===- AggressiveInstCombine.cpp ------------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements the aggressive expression pattern combiner classes.
10 // Currently, it handles expression patterns for:
11 // * Truncate instruction
13 //===----------------------------------------------------------------------===//
15 #include "llvm/Transforms/AggressiveInstCombine/AggressiveInstCombine.h"
16 #include "AggressiveInstCombineInternal.h"
17 #include "llvm/ADT/Statistic.h"
18 #include "llvm/Analysis/AliasAnalysis.h"
19 #include "llvm/Analysis/AssumptionCache.h"
20 #include "llvm/Analysis/BasicAliasAnalysis.h"
21 #include "llvm/Analysis/ConstantFolding.h"
22 #include "llvm/Analysis/GlobalsModRef.h"
23 #include "llvm/Analysis/TargetLibraryInfo.h"
24 #include "llvm/Analysis/TargetTransformInfo.h"
25 #include "llvm/Analysis/ValueTracking.h"
26 #include "llvm/IR/DataLayout.h"
27 #include "llvm/IR/Dominators.h"
28 #include "llvm/IR/Function.h"
29 #include "llvm/IR/IRBuilder.h"
30 #include "llvm/IR/PatternMatch.h"
31 #include "llvm/Transforms/Utils/BuildLibCalls.h"
32 #include "llvm/Transforms/Utils/Local.h"
35 using namespace PatternMatch
;
37 #define DEBUG_TYPE "aggressive-instcombine"
39 STATISTIC(NumAnyOrAllBitsSet
, "Number of any/all-bits-set patterns folded");
40 STATISTIC(NumGuardedRotates
,
41 "Number of guarded rotates transformed into funnel shifts");
42 STATISTIC(NumGuardedFunnelShifts
,
43 "Number of guarded funnel shifts transformed into funnel shifts");
44 STATISTIC(NumPopCountRecognized
, "Number of popcount idioms recognized");
46 static cl::opt
<unsigned> MaxInstrsToScan(
47 "aggressive-instcombine-max-scan-instrs", cl::init(64), cl::Hidden
,
48 cl::desc("Max number of instructions to scan for aggressive instcombine."));
50 /// Match a pattern for a bitwise funnel/rotate operation that partially guards
51 /// against undefined behavior by branching around the funnel-shift/rotation
52 /// when the shift amount is 0.
53 static bool foldGuardedFunnelShift(Instruction
&I
, const DominatorTree
&DT
) {
54 if (I
.getOpcode() != Instruction::PHI
|| I
.getNumOperands() != 2)
57 // As with the one-use checks below, this is not strictly necessary, but we
58 // are being cautious to avoid potential perf regressions on targets that
59 // do not actually have a funnel/rotate instruction (where the funnel shift
60 // would be expanded back into math/shift/logic ops).
61 if (!isPowerOf2_32(I
.getType()->getScalarSizeInBits()))
64 // Match V to funnel shift left/right and capture the source operands and
66 auto matchFunnelShift
= [](Value
*V
, Value
*&ShVal0
, Value
*&ShVal1
,
68 unsigned Width
= V
->getType()->getScalarSizeInBits();
70 // fshl(ShVal0, ShVal1, ShAmt)
71 // == (ShVal0 << ShAmt) | (ShVal1 >> (Width -ShAmt))
72 if (match(V
, m_OneUse(m_c_Or(
73 m_Shl(m_Value(ShVal0
), m_Value(ShAmt
)),
74 m_LShr(m_Value(ShVal1
),
75 m_Sub(m_SpecificInt(Width
), m_Deferred(ShAmt
))))))) {
76 return Intrinsic::fshl
;
79 // fshr(ShVal0, ShVal1, ShAmt)
80 // == (ShVal0 >> ShAmt) | (ShVal1 << (Width - ShAmt))
82 m_OneUse(m_c_Or(m_Shl(m_Value(ShVal0
), m_Sub(m_SpecificInt(Width
),
84 m_LShr(m_Value(ShVal1
), m_Deferred(ShAmt
)))))) {
85 return Intrinsic::fshr
;
88 return Intrinsic::not_intrinsic
;
91 // One phi operand must be a funnel/rotate operation, and the other phi
92 // operand must be the source value of that funnel/rotate operation:
93 // phi [ rotate(RotSrc, ShAmt), FunnelBB ], [ RotSrc, GuardBB ]
94 // phi [ fshl(ShVal0, ShVal1, ShAmt), FunnelBB ], [ ShVal0, GuardBB ]
95 // phi [ fshr(ShVal0, ShVal1, ShAmt), FunnelBB ], [ ShVal1, GuardBB ]
96 PHINode
&Phi
= cast
<PHINode
>(I
);
97 unsigned FunnelOp
= 0, GuardOp
= 1;
98 Value
*P0
= Phi
.getOperand(0), *P1
= Phi
.getOperand(1);
99 Value
*ShVal0
, *ShVal1
, *ShAmt
;
100 Intrinsic::ID IID
= matchFunnelShift(P0
, ShVal0
, ShVal1
, ShAmt
);
101 if (IID
== Intrinsic::not_intrinsic
||
102 (IID
== Intrinsic::fshl
&& ShVal0
!= P1
) ||
103 (IID
== Intrinsic::fshr
&& ShVal1
!= P1
)) {
104 IID
= matchFunnelShift(P1
, ShVal0
, ShVal1
, ShAmt
);
105 if (IID
== Intrinsic::not_intrinsic
||
106 (IID
== Intrinsic::fshl
&& ShVal0
!= P0
) ||
107 (IID
== Intrinsic::fshr
&& ShVal1
!= P0
))
109 assert((IID
== Intrinsic::fshl
|| IID
== Intrinsic::fshr
) &&
110 "Pattern must match funnel shift left or right");
111 std::swap(FunnelOp
, GuardOp
);
114 // The incoming block with our source operand must be the "guard" block.
115 // That must contain a cmp+branch to avoid the funnel/rotate when the shift
116 // amount is equal to 0. The other incoming block is the block with the
118 BasicBlock
*GuardBB
= Phi
.getIncomingBlock(GuardOp
);
119 BasicBlock
*FunnelBB
= Phi
.getIncomingBlock(FunnelOp
);
120 Instruction
*TermI
= GuardBB
->getTerminator();
122 // Ensure that the shift values dominate each block.
123 if (!DT
.dominates(ShVal0
, TermI
) || !DT
.dominates(ShVal1
, TermI
))
126 ICmpInst::Predicate Pred
;
127 BasicBlock
*PhiBB
= Phi
.getParent();
128 if (!match(TermI
, m_Br(m_ICmp(Pred
, m_Specific(ShAmt
), m_ZeroInt()),
129 m_SpecificBB(PhiBB
), m_SpecificBB(FunnelBB
))))
132 if (Pred
!= CmpInst::ICMP_EQ
)
135 IRBuilder
<> Builder(PhiBB
, PhiBB
->getFirstInsertionPt());
137 if (ShVal0
== ShVal1
)
140 ++NumGuardedFunnelShifts
;
142 // If this is not a rotate then the select was blocking poison from the
143 // 'shift-by-zero' non-TVal, but a funnel shift won't - so freeze it.
144 bool IsFshl
= IID
== Intrinsic::fshl
;
145 if (ShVal0
!= ShVal1
) {
146 if (IsFshl
&& !llvm::isGuaranteedNotToBePoison(ShVal1
))
147 ShVal1
= Builder
.CreateFreeze(ShVal1
);
148 else if (!IsFshl
&& !llvm::isGuaranteedNotToBePoison(ShVal0
))
149 ShVal0
= Builder
.CreateFreeze(ShVal0
);
152 // We matched a variation of this IR pattern:
154 // %cmp = icmp eq i32 %ShAmt, 0
155 // br i1 %cmp, label %PhiBB, label %FunnelBB
157 // %sub = sub i32 32, %ShAmt
158 // %shr = lshr i32 %ShVal1, %sub
159 // %shl = shl i32 %ShVal0, %ShAmt
160 // %fsh = or i32 %shr, %shl
163 // %cond = phi i32 [ %fsh, %FunnelBB ], [ %ShVal0, %GuardBB ]
165 // llvm.fshl.i32(i32 %ShVal0, i32 %ShVal1, i32 %ShAmt)
166 Function
*F
= Intrinsic::getDeclaration(Phi
.getModule(), IID
, Phi
.getType());
167 Phi
.replaceAllUsesWith(Builder
.CreateCall(F
, {ShVal0
, ShVal1
, ShAmt
}));
171 /// This is used by foldAnyOrAllBitsSet() to capture a source value (Root) and
172 /// the bit indexes (Mask) needed by a masked compare. If we're matching a chain
173 /// of 'and' ops, then we also need to capture the fact that we saw an
174 /// "and X, 1", so that's an extra return value for that case.
176 Value
*Root
= nullptr;
179 bool FoundAnd1
= false;
181 MaskOps(unsigned BitWidth
, bool MatchAnds
)
182 : Mask(APInt::getZero(BitWidth
)), MatchAndChain(MatchAnds
) {}
185 /// This is a recursive helper for foldAnyOrAllBitsSet() that walks through a
186 /// chain of 'and' or 'or' instructions looking for shift ops of a common source
188 /// or (or (or X, (X >> 3)), (X >> 5)), (X >> 8)
189 /// returns { X, 0x129 }
190 /// and (and (X >> 1), 1), (X >> 4)
191 /// returns { X, 0x12 }
192 static bool matchAndOrChain(Value
*V
, MaskOps
&MOps
) {
194 if (MOps
.MatchAndChain
) {
195 // Recurse through a chain of 'and' operands. This requires an extra check
196 // vs. the 'or' matcher: we must find an "and X, 1" instruction somewhere
197 // in the chain to know that all of the high bits are cleared.
198 if (match(V
, m_And(m_Value(Op0
), m_One()))) {
199 MOps
.FoundAnd1
= true;
200 return matchAndOrChain(Op0
, MOps
);
202 if (match(V
, m_And(m_Value(Op0
), m_Value(Op1
))))
203 return matchAndOrChain(Op0
, MOps
) && matchAndOrChain(Op1
, MOps
);
205 // Recurse through a chain of 'or' operands.
206 if (match(V
, m_Or(m_Value(Op0
), m_Value(Op1
))))
207 return matchAndOrChain(Op0
, MOps
) && matchAndOrChain(Op1
, MOps
);
210 // We need a shift-right or a bare value representing a compare of bit 0 of
211 // the original source operand.
213 const APInt
*BitIndex
= nullptr;
214 if (!match(V
, m_LShr(m_Value(Candidate
), m_APInt(BitIndex
))))
217 // Initialize result source operand.
219 MOps
.Root
= Candidate
;
221 // The shift constant is out-of-range? This code hasn't been simplified.
222 if (BitIndex
&& BitIndex
->uge(MOps
.Mask
.getBitWidth()))
225 // Fill in the mask bit derived from the shift constant.
226 MOps
.Mask
.setBit(BitIndex
? BitIndex
->getZExtValue() : 0);
227 return MOps
.Root
== Candidate
;
230 /// Match patterns that correspond to "any-bits-set" and "all-bits-set".
231 /// These will include a chain of 'or' or 'and'-shifted bits from a
232 /// common source value:
233 /// and (or (lshr X, C), ...), 1 --> (X & CMask) != 0
234 /// and (and (lshr X, C), ...), 1 --> (X & CMask) == CMask
235 /// Note: "any-bits-clear" and "all-bits-clear" are variations of these patterns
236 /// that differ only with a final 'not' of the result. We expect that final
237 /// 'not' to be folded with the compare that we create here (invert predicate).
238 static bool foldAnyOrAllBitsSet(Instruction
&I
) {
239 // The 'any-bits-set' ('or' chain) pattern is simpler to match because the
240 // final "and X, 1" instruction must be the final op in the sequence.
241 bool MatchAllBitsSet
;
242 if (match(&I
, m_c_And(m_OneUse(m_And(m_Value(), m_Value())), m_Value())))
243 MatchAllBitsSet
= true;
244 else if (match(&I
, m_And(m_OneUse(m_Or(m_Value(), m_Value())), m_One())))
245 MatchAllBitsSet
= false;
249 MaskOps
MOps(I
.getType()->getScalarSizeInBits(), MatchAllBitsSet
);
250 if (MatchAllBitsSet
) {
251 if (!matchAndOrChain(cast
<BinaryOperator
>(&I
), MOps
) || !MOps
.FoundAnd1
)
254 if (!matchAndOrChain(cast
<BinaryOperator
>(&I
)->getOperand(0), MOps
))
258 // The pattern was found. Create a masked compare that replaces all of the
259 // shift and logic ops.
260 IRBuilder
<> Builder(&I
);
261 Constant
*Mask
= ConstantInt::get(I
.getType(), MOps
.Mask
);
262 Value
*And
= Builder
.CreateAnd(MOps
.Root
, Mask
);
263 Value
*Cmp
= MatchAllBitsSet
? Builder
.CreateICmpEQ(And
, Mask
)
264 : Builder
.CreateIsNotNull(And
);
265 Value
*Zext
= Builder
.CreateZExt(Cmp
, I
.getType());
266 I
.replaceAllUsesWith(Zext
);
267 ++NumAnyOrAllBitsSet
;
271 // Try to recognize below function as popcount intrinsic.
272 // This is the "best" algorithm from
273 // http://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetParallel
274 // Also used in TargetLowering::expandCTPOP().
276 // int popcount(unsigned int i) {
277 // i = i - ((i >> 1) & 0x55555555);
278 // i = (i & 0x33333333) + ((i >> 2) & 0x33333333);
279 // i = ((i + (i >> 4)) & 0x0F0F0F0F);
280 // return (i * 0x01010101) >> 24;
282 static bool tryToRecognizePopCount(Instruction
&I
) {
283 if (I
.getOpcode() != Instruction::LShr
)
286 Type
*Ty
= I
.getType();
287 if (!Ty
->isIntOrIntVectorTy())
290 unsigned Len
= Ty
->getScalarSizeInBits();
291 // FIXME: fix Len == 8 and other irregular type lengths.
292 if (!(Len
<= 128 && Len
> 8 && Len
% 8 == 0))
295 APInt Mask55
= APInt::getSplat(Len
, APInt(8, 0x55));
296 APInt Mask33
= APInt::getSplat(Len
, APInt(8, 0x33));
297 APInt Mask0F
= APInt::getSplat(Len
, APInt(8, 0x0F));
298 APInt Mask01
= APInt::getSplat(Len
, APInt(8, 0x01));
299 APInt MaskShift
= APInt(Len
, Len
- 8);
301 Value
*Op0
= I
.getOperand(0);
302 Value
*Op1
= I
.getOperand(1);
304 // Matching "(i * 0x01010101...) >> 24".
305 if ((match(Op0
, m_Mul(m_Value(MulOp0
), m_SpecificInt(Mask01
)))) &&
306 match(Op1
, m_SpecificInt(MaskShift
))) {
308 // Matching "((i + (i >> 4)) & 0x0F0F0F0F...)".
309 if (match(MulOp0
, m_And(m_c_Add(m_LShr(m_Value(ShiftOp0
), m_SpecificInt(4)),
310 m_Deferred(ShiftOp0
)),
311 m_SpecificInt(Mask0F
)))) {
313 // Matching "(i & 0x33333333...) + ((i >> 2) & 0x33333333...)".
315 m_c_Add(m_And(m_Value(AndOp0
), m_SpecificInt(Mask33
)),
316 m_And(m_LShr(m_Deferred(AndOp0
), m_SpecificInt(2)),
317 m_SpecificInt(Mask33
))))) {
318 Value
*Root
, *SubOp1
;
319 // Matching "i - ((i >> 1) & 0x55555555...)".
320 if (match(AndOp0
, m_Sub(m_Value(Root
), m_Value(SubOp1
))) &&
321 match(SubOp1
, m_And(m_LShr(m_Specific(Root
), m_SpecificInt(1)),
322 m_SpecificInt(Mask55
)))) {
323 LLVM_DEBUG(dbgs() << "Recognized popcount intrinsic\n");
324 IRBuilder
<> Builder(&I
);
325 Function
*Func
= Intrinsic::getDeclaration(
326 I
.getModule(), Intrinsic::ctpop
, I
.getType());
327 I
.replaceAllUsesWith(Builder
.CreateCall(Func
, {Root
}));
328 ++NumPopCountRecognized
;
338 /// Fold smin(smax(fptosi(x), C1), C2) to llvm.fptosi.sat(x), providing C1 and
339 /// C2 saturate the value of the fp conversion. The transform is not reversable
340 /// as the fptosi.sat is more defined than the input - all values produce a
341 /// valid value for the fptosi.sat, where as some produce poison for original
342 /// that were out of range of the integer conversion. The reversed pattern may
343 /// use fmax and fmin instead. As we cannot directly reverse the transform, and
344 /// it is not always profitable, we make it conditional on the cost being
345 /// reported as lower by TTI.
346 static bool tryToFPToSat(Instruction
&I
, TargetTransformInfo
&TTI
) {
347 // Look for min(max(fptosi, converting to fptosi_sat.
349 const APInt
*MinC
, *MaxC
;
350 if (!match(&I
, m_SMax(m_OneUse(m_SMin(m_OneUse(m_FPToSI(m_Value(In
))),
353 !match(&I
, m_SMin(m_OneUse(m_SMax(m_OneUse(m_FPToSI(m_Value(In
))),
358 // Check that the constants clamp a saturate.
359 if (!(*MinC
+ 1).isPowerOf2() || -*MaxC
!= *MinC
+ 1)
362 Type
*IntTy
= I
.getType();
363 Type
*FpTy
= In
->getType();
365 IntegerType::get(IntTy
->getContext(), (*MinC
+ 1).exactLogBase2() + 1);
366 if (auto *VecTy
= dyn_cast
<VectorType
>(IntTy
))
367 SatTy
= VectorType::get(SatTy
, VecTy
->getElementCount());
369 // Get the cost of the intrinsic, and check that against the cost of
371 InstructionCost SatCost
= TTI
.getIntrinsicInstrCost(
372 IntrinsicCostAttributes(Intrinsic::fptosi_sat
, SatTy
, {In
}, {FpTy
}),
373 TTI::TCK_RecipThroughput
);
374 SatCost
+= TTI
.getCastInstrCost(Instruction::SExt
, IntTy
, SatTy
,
375 TTI::CastContextHint::None
,
376 TTI::TCK_RecipThroughput
);
378 InstructionCost MinMaxCost
= TTI
.getCastInstrCost(
379 Instruction::FPToSI
, IntTy
, FpTy
, TTI::CastContextHint::None
,
380 TTI::TCK_RecipThroughput
);
381 MinMaxCost
+= TTI
.getIntrinsicInstrCost(
382 IntrinsicCostAttributes(Intrinsic::smin
, IntTy
, {IntTy
}),
383 TTI::TCK_RecipThroughput
);
384 MinMaxCost
+= TTI
.getIntrinsicInstrCost(
385 IntrinsicCostAttributes(Intrinsic::smax
, IntTy
, {IntTy
}),
386 TTI::TCK_RecipThroughput
);
388 if (SatCost
>= MinMaxCost
)
391 IRBuilder
<> Builder(&I
);
392 Function
*Fn
= Intrinsic::getDeclaration(I
.getModule(), Intrinsic::fptosi_sat
,
394 Value
*Sat
= Builder
.CreateCall(Fn
, In
);
395 I
.replaceAllUsesWith(Builder
.CreateSExt(Sat
, IntTy
));
399 /// Try to replace a mathlib call to sqrt with the LLVM intrinsic. This avoids
400 /// pessimistic codegen that has to account for setting errno and can enable
402 static bool foldSqrt(Instruction
&I
, TargetTransformInfo
&TTI
,
403 TargetLibraryInfo
&TLI
, AssumptionCache
&AC
,
405 // Match a call to sqrt mathlib function.
406 auto *Call
= dyn_cast
<CallInst
>(&I
);
410 Module
*M
= Call
->getModule();
412 if (!TLI
.getLibFunc(*Call
, Func
) || !isLibFuncEmittable(M
, &TLI
, Func
))
415 if (Func
!= LibFunc_sqrt
&& Func
!= LibFunc_sqrtf
&& Func
!= LibFunc_sqrtl
)
418 // If (1) this is a sqrt libcall, (2) we can assume that NAN is not created
419 // (because NNAN or the operand arg must not be less than -0.0) and (2) we
420 // would not end up lowering to a libcall anyway (which could change the value
422 // (1) errno won't be set.
423 // (2) it is safe to convert this to an intrinsic call.
424 Type
*Ty
= Call
->getType();
425 Value
*Arg
= Call
->getArgOperand(0);
426 if (TTI
.haveFastSqrt(Ty
) &&
427 (Call
->hasNoNaNs() ||
428 cannotBeOrderedLessThanZero(Arg
, M
->getDataLayout(), &TLI
, 0, &AC
, &I
,
430 IRBuilder
<> Builder(&I
);
431 IRBuilderBase::FastMathFlagGuard
Guard(Builder
);
432 Builder
.setFastMathFlags(Call
->getFastMathFlags());
434 Function
*Sqrt
= Intrinsic::getDeclaration(M
, Intrinsic::sqrt
, Ty
);
435 Value
*NewSqrt
= Builder
.CreateCall(Sqrt
, Arg
, "sqrt");
436 I
.replaceAllUsesWith(NewSqrt
);
438 // Explicitly erase the old call because a call with side effects is not
447 // Check if this array of constants represents a cttz table.
448 // Iterate over the elements from \p Table by trying to find/match all
449 // the numbers from 0 to \p InputBits that should represent cttz results.
450 static bool isCTTZTable(const ConstantDataArray
&Table
, uint64_t Mul
,
451 uint64_t Shift
, uint64_t InputBits
) {
452 unsigned Length
= Table
.getNumElements();
453 if (Length
< InputBits
|| Length
> InputBits
* 2)
456 APInt Mask
= APInt::getBitsSetFrom(InputBits
, Shift
);
457 unsigned Matched
= 0;
459 for (unsigned i
= 0; i
< Length
; i
++) {
460 uint64_t Element
= Table
.getElementAsInteger(i
);
461 if (Element
>= InputBits
)
464 // Check if \p Element matches a concrete answer. It could fail for some
465 // elements that are never accessed, so we keep iterating over each element
466 // from the table. The number of matched elements should be equal to the
467 // number of potential right answers which is \p InputBits actually.
468 if ((((Mul
<< Element
) & Mask
.getZExtValue()) >> Shift
) == i
)
472 return Matched
== InputBits
;
475 // Try to recognize table-based ctz implementation.
476 // E.g., an example in C (for more cases please see the llvm/tests):
477 // int f(unsigned x) {
478 // static const char table[32] =
479 // {0, 1, 28, 2, 29, 14, 24, 3, 30,
480 // 22, 20, 15, 25, 17, 4, 8, 31, 27,
481 // 13, 23, 21, 19, 16, 7, 26, 12, 18, 6, 11, 5, 10, 9};
482 // return table[((unsigned)((x & -x) * 0x077CB531U)) >> 27];
484 // this can be lowered to `cttz` instruction.
485 // There is also a special case when the element is 0.
487 // Here are some examples or LLVM IR for a 64-bit target:
490 // %sub = sub i32 0, %x
491 // %and = and i32 %sub, %x
492 // %mul = mul i32 %and, 125613361
493 // %shr = lshr i32 %mul, 27
494 // %idxprom = zext i32 %shr to i64
495 // %arrayidx = getelementptr inbounds [32 x i8], [32 x i8]* @ctz1.table, i64 0,
497 // %0 = load i8, i8* %arrayidx, align 1, !tbaa !8
500 // %sub = sub i32 0, %x
501 // %and = and i32 %sub, %x
502 // %mul = mul i32 %and, 72416175
503 // %shr = lshr i32 %mul, 26
504 // %idxprom = zext i32 %shr to i64
505 // %arrayidx = getelementptr inbounds [64 x i16], [64 x i16]* @ctz2.table,
506 // i64 0, i64 %idxprom
507 // %0 = load i16, i16* %arrayidx, align 2, !tbaa !8
510 // %sub = sub i32 0, %x
511 // %and = and i32 %sub, %x
512 // %mul = mul i32 %and, 81224991
513 // %shr = lshr i32 %mul, 27
514 // %idxprom = zext i32 %shr to i64
515 // %arrayidx = getelementptr inbounds [32 x i32], [32 x i32]* @ctz3.table,
516 // i64 0, i64 %idxprom
517 // %0 = load i32, i32* %arrayidx, align 4, !tbaa !8
520 // %sub = sub i64 0, %x
521 // %and = and i64 %sub, %x
522 // %mul = mul i64 %and, 283881067100198605
523 // %shr = lshr i64 %mul, 58
524 // %arrayidx = getelementptr inbounds [64 x i8], [64 x i8]* @table, i64 0,
526 // %0 = load i8, i8* %arrayidx, align 1, !tbaa !8
528 // All this can be lowered to @llvm.cttz.i32/64 intrinsic.
529 static bool tryToRecognizeTableBasedCttz(Instruction
&I
) {
530 LoadInst
*LI
= dyn_cast
<LoadInst
>(&I
);
534 Type
*AccessType
= LI
->getType();
535 if (!AccessType
->isIntegerTy())
538 GetElementPtrInst
*GEP
= dyn_cast
<GetElementPtrInst
>(LI
->getPointerOperand());
539 if (!GEP
|| !GEP
->isInBounds() || GEP
->getNumIndices() != 2)
542 if (!GEP
->getSourceElementType()->isArrayTy())
545 uint64_t ArraySize
= GEP
->getSourceElementType()->getArrayNumElements();
546 if (ArraySize
!= 32 && ArraySize
!= 64)
549 GlobalVariable
*GVTable
= dyn_cast
<GlobalVariable
>(GEP
->getPointerOperand());
550 if (!GVTable
|| !GVTable
->hasInitializer() || !GVTable
->isConstant())
553 ConstantDataArray
*ConstData
=
554 dyn_cast
<ConstantDataArray
>(GVTable
->getInitializer());
558 if (!match(GEP
->idx_begin()->get(), m_ZeroInt()))
561 Value
*Idx2
= std::next(GEP
->idx_begin())->get();
563 uint64_t MulConst
, ShiftConst
;
564 // FIXME: 64-bit targets have `i64` type for the GEP index, so this match will
565 // probably fail for other (e.g. 32-bit) targets.
566 if (!match(Idx2
, m_ZExtOrSelf(
567 m_LShr(m_Mul(m_c_And(m_Neg(m_Value(X1
)), m_Deferred(X1
)),
568 m_ConstantInt(MulConst
)),
569 m_ConstantInt(ShiftConst
)))))
572 unsigned InputBits
= X1
->getType()->getScalarSizeInBits();
573 if (InputBits
!= 32 && InputBits
!= 64)
576 // Shift should extract top 5..7 bits.
577 if (InputBits
- Log2_32(InputBits
) != ShiftConst
&&
578 InputBits
- Log2_32(InputBits
) - 1 != ShiftConst
)
581 if (!isCTTZTable(*ConstData
, MulConst
, ShiftConst
, InputBits
))
584 auto ZeroTableElem
= ConstData
->getElementAsInteger(0);
585 bool DefinedForZero
= ZeroTableElem
== InputBits
;
588 ConstantInt
*BoolConst
= B
.getInt1(!DefinedForZero
);
589 Type
*XType
= X1
->getType();
590 auto Cttz
= B
.CreateIntrinsic(Intrinsic::cttz
, {XType
}, {X1
, BoolConst
});
591 Value
*ZExtOrTrunc
= nullptr;
593 if (DefinedForZero
) {
594 ZExtOrTrunc
= B
.CreateZExtOrTrunc(Cttz
, AccessType
);
596 // If the value in elem 0 isn't the same as InputBits, we still want to
597 // produce the value from the table.
598 auto Cmp
= B
.CreateICmpEQ(X1
, ConstantInt::get(XType
, 0));
600 B
.CreateSelect(Cmp
, ConstantInt::get(XType
, ZeroTableElem
), Cttz
);
602 // NOTE: If the table[0] is 0, but the cttz(0) is defined by the Target
603 // it should be handled as: `cttz(x) & (typeSize - 1)`.
605 ZExtOrTrunc
= B
.CreateZExtOrTrunc(Select
, AccessType
);
608 LI
->replaceAllUsesWith(ZExtOrTrunc
);
613 /// This is used by foldLoadsRecursive() to capture a Root Load node which is
614 /// of type or(load, load) and recursively build the wide load. Also capture the
615 /// shift amount, zero extend type and loadSize.
617 LoadInst
*Root
= nullptr;
618 LoadInst
*RootInsert
= nullptr;
619 bool FoundRoot
= false;
620 uint64_t LoadSize
= 0;
621 const APInt
*Shift
= nullptr;
626 // Identify and Merge consecutive loads recursively which is of the form
627 // (ZExt(L1) << shift1) | (ZExt(L2) << shift2) -> ZExt(L3) << shift1
628 // (ZExt(L1) << shift1) | ZExt(L2) -> ZExt(L3)
629 static bool foldLoadsRecursive(Value
*V
, LoadOps
&LOps
, const DataLayout
&DL
,
631 const APInt
*ShAmt2
= nullptr;
633 Instruction
*L1
, *L2
;
635 // Go to the last node with loads.
636 if (match(V
, m_OneUse(m_c_Or(
638 m_OneUse(m_Shl(m_OneUse(m_ZExt(m_OneUse(m_Instruction(L2
)))),
639 m_APInt(ShAmt2
)))))) ||
640 match(V
, m_OneUse(m_Or(m_Value(X
),
641 m_OneUse(m_ZExt(m_OneUse(m_Instruction(L2
)))))))) {
642 if (!foldLoadsRecursive(X
, LOps
, DL
, AA
) && LOps
.FoundRoot
)
643 // Avoid Partial chain merge.
648 // Check if the pattern has loads
649 LoadInst
*LI1
= LOps
.Root
;
650 const APInt
*ShAmt1
= LOps
.Shift
;
651 if (LOps
.FoundRoot
== false &&
652 (match(X
, m_OneUse(m_ZExt(m_Instruction(L1
)))) ||
653 match(X
, m_OneUse(m_Shl(m_OneUse(m_ZExt(m_OneUse(m_Instruction(L1
)))),
654 m_APInt(ShAmt1
)))))) {
655 LI1
= dyn_cast
<LoadInst
>(L1
);
657 LoadInst
*LI2
= dyn_cast
<LoadInst
>(L2
);
659 // Check if loads are same, atomic, volatile and having same address space.
660 if (LI1
== LI2
|| !LI1
|| !LI2
|| !LI1
->isSimple() || !LI2
->isSimple() ||
661 LI1
->getPointerAddressSpace() != LI2
->getPointerAddressSpace())
664 // Check if Loads come from same BB.
665 if (LI1
->getParent() != LI2
->getParent())
668 // Find the data layout
669 bool IsBigEndian
= DL
.isBigEndian();
671 // Check if loads are consecutive and same size.
672 Value
*Load1Ptr
= LI1
->getPointerOperand();
673 APInt
Offset1(DL
.getIndexTypeSizeInBits(Load1Ptr
->getType()), 0);
675 Load1Ptr
->stripAndAccumulateConstantOffsets(DL
, Offset1
,
676 /* AllowNonInbounds */ true);
678 Value
*Load2Ptr
= LI2
->getPointerOperand();
679 APInt
Offset2(DL
.getIndexTypeSizeInBits(Load2Ptr
->getType()), 0);
681 Load2Ptr
->stripAndAccumulateConstantOffsets(DL
, Offset2
,
682 /* AllowNonInbounds */ true);
684 // Verify if both loads have same base pointers and load sizes are same.
685 uint64_t LoadSize1
= LI1
->getType()->getPrimitiveSizeInBits();
686 uint64_t LoadSize2
= LI2
->getType()->getPrimitiveSizeInBits();
687 if (Load1Ptr
!= Load2Ptr
|| LoadSize1
!= LoadSize2
)
690 // Support Loadsizes greater or equal to 8bits and only power of 2.
691 if (LoadSize1
< 8 || !isPowerOf2_64(LoadSize1
))
694 // Alias Analysis to check for stores b/w the loads.
695 LoadInst
*Start
= LOps
.FoundRoot
? LOps
.RootInsert
: LI1
, *End
= LI2
;
697 if (!Start
->comesBefore(End
)) {
698 std::swap(Start
, End
);
699 Loc
= MemoryLocation::get(End
);
701 Loc
= Loc
.getWithNewSize(LOps
.LoadSize
);
703 Loc
= MemoryLocation::get(End
);
704 unsigned NumScanned
= 0;
705 for (Instruction
&Inst
:
706 make_range(Start
->getIterator(), End
->getIterator())) {
707 if (Inst
.mayWriteToMemory() && isModSet(AA
.getModRefInfo(&Inst
, Loc
)))
709 if (++NumScanned
> MaxInstrsToScan
)
713 // Make sure Load with lower Offset is at LI1
714 bool Reverse
= false;
715 if (Offset2
.slt(Offset1
)) {
717 std::swap(ShAmt1
, ShAmt2
);
718 std::swap(Offset1
, Offset2
);
719 std::swap(Load1Ptr
, Load2Ptr
);
720 std::swap(LoadSize1
, LoadSize2
);
724 // Big endian swap the shifts
726 std::swap(ShAmt1
, ShAmt2
);
728 // Find Shifts values.
729 uint64_t Shift1
= 0, Shift2
= 0;
731 Shift1
= ShAmt1
->getZExtValue();
733 Shift2
= ShAmt2
->getZExtValue();
735 // First load is always LI1. This is where we put the new load.
736 // Use the merged load size available from LI1 for forward loads.
737 if (LOps
.FoundRoot
) {
739 LoadSize1
= LOps
.LoadSize
;
741 LoadSize2
= LOps
.LoadSize
;
744 // Verify if shift amount and load index aligns and verifies that loads
746 uint64_t ShiftDiff
= IsBigEndian
? LoadSize2
: LoadSize1
;
748 DL
.getTypeStoreSize(IntegerType::get(LI1
->getContext(), LoadSize1
));
749 if ((Shift2
- Shift1
) != ShiftDiff
|| (Offset2
- Offset1
) != PrevSize
)
753 AAMDNodes AATags1
= LOps
.AATags
;
754 AAMDNodes AATags2
= LI2
->getAAMetadata();
755 if (LOps
.FoundRoot
== false) {
756 LOps
.FoundRoot
= true;
757 AATags1
= LI1
->getAAMetadata();
759 LOps
.LoadSize
= LoadSize1
+ LoadSize2
;
760 LOps
.RootInsert
= Start
;
762 // Concatenate the AATags of the Merged Loads.
763 LOps
.AATags
= AATags1
.concat(AATags2
);
767 LOps
.ZextType
= X
->getType();
771 // For a given BB instruction, evaluate all loads in the chain that form a
772 // pattern which suggests that the loads can be combined. The one and only use
773 // of the loads is to form a wider load.
774 static bool foldConsecutiveLoads(Instruction
&I
, const DataLayout
&DL
,
775 TargetTransformInfo
&TTI
, AliasAnalysis
&AA
,
776 const DominatorTree
&DT
) {
777 // Only consider load chains of scalar values.
778 if (isa
<VectorType
>(I
.getType()))
782 if (!foldLoadsRecursive(&I
, LOps
, DL
, AA
) || !LOps
.FoundRoot
)
785 IRBuilder
<> Builder(&I
);
786 LoadInst
*NewLoad
= nullptr, *LI1
= LOps
.Root
;
788 IntegerType
*WiderType
= IntegerType::get(I
.getContext(), LOps
.LoadSize
);
789 // TTI based checks if we want to proceed with wider load
790 bool Allowed
= TTI
.isTypeLegal(WiderType
);
794 unsigned AS
= LI1
->getPointerAddressSpace();
796 Allowed
= TTI
.allowsMisalignedMemoryAccesses(I
.getContext(), LOps
.LoadSize
,
797 AS
, LI1
->getAlign(), &Fast
);
798 if (!Allowed
|| !Fast
)
801 // Get the Index and Ptr for the new GEP.
802 Value
*Load1Ptr
= LI1
->getPointerOperand();
803 Builder
.SetInsertPoint(LOps
.RootInsert
);
804 if (!DT
.dominates(Load1Ptr
, LOps
.RootInsert
)) {
805 APInt
Offset1(DL
.getIndexTypeSizeInBits(Load1Ptr
->getType()), 0);
806 Load1Ptr
= Load1Ptr
->stripAndAccumulateConstantOffsets(
807 DL
, Offset1
, /* AllowNonInbounds */ true);
808 Load1Ptr
= Builder
.CreateGEP(Builder
.getInt8Ty(), Load1Ptr
,
809 Builder
.getInt32(Offset1
.getZExtValue()));
811 // Generate wider load.
812 NewLoad
= Builder
.CreateAlignedLoad(WiderType
, Load1Ptr
, LI1
->getAlign(),
813 LI1
->isVolatile(), "");
814 NewLoad
->takeName(LI1
);
815 // Set the New Load AATags Metadata.
817 NewLoad
->setAAMetadata(LOps
.AATags
);
819 Value
*NewOp
= NewLoad
;
820 // Check if zero extend needed.
822 NewOp
= Builder
.CreateZExt(NewOp
, LOps
.ZextType
);
824 // Check if shift needed. We need to shift with the amount of load1
825 // shift if not zero.
827 NewOp
= Builder
.CreateShl(NewOp
, ConstantInt::get(I
.getContext(), *LOps
.Shift
));
828 I
.replaceAllUsesWith(NewOp
);
833 // Calculate GEP Stride and accumulated const ModOffset. Return Stride and
835 static std::pair
<APInt
, APInt
>
836 getStrideAndModOffsetOfGEP(Value
*PtrOp
, const DataLayout
&DL
) {
837 unsigned BW
= DL
.getIndexTypeSizeInBits(PtrOp
->getType());
838 std::optional
<APInt
> Stride
;
839 APInt
ModOffset(BW
, 0);
840 // Return a minimum gep stride, greatest common divisor of consective gep
841 // index scales(c.f. Bézout's identity).
842 while (auto *GEP
= dyn_cast
<GEPOperator
>(PtrOp
)) {
843 MapVector
<Value
*, APInt
> VarOffsets
;
844 if (!GEP
->collectOffset(DL
, BW
, VarOffsets
, ModOffset
))
847 for (auto [V
, Scale
] : VarOffsets
) {
848 // Only keep a power of two factor for non-inbounds
849 if (!GEP
->isInBounds())
850 Scale
= APInt::getOneBitSet(Scale
.getBitWidth(), Scale
.countr_zero());
855 Stride
= APIntOps::GreatestCommonDivisor(*Stride
, Scale
);
858 PtrOp
= GEP
->getPointerOperand();
861 // Check whether pointer arrives back at Global Variable via at least one GEP.
862 // Even if it doesn't, we can check by alignment.
863 if (!isa
<GlobalVariable
>(PtrOp
) || !Stride
)
864 return {APInt(BW
, 1), APInt(BW
, 0)};
866 // In consideration of signed GEP indices, non-negligible offset become
867 // remainder of division by minimum GEP stride.
868 ModOffset
= ModOffset
.srem(*Stride
);
869 if (ModOffset
.isNegative())
870 ModOffset
+= *Stride
;
872 return {*Stride
, ModOffset
};
875 /// If C is a constant patterned array and all valid loaded results for given
876 /// alignment are same to a constant, return that constant.
877 static bool foldPatternedLoads(Instruction
&I
, const DataLayout
&DL
) {
878 auto *LI
= dyn_cast
<LoadInst
>(&I
);
879 if (!LI
|| LI
->isVolatile())
882 // We can only fold the load if it is from a constant global with definitive
883 // initializer. Skip expensive logic if this is not the case.
884 auto *PtrOp
= LI
->getPointerOperand();
885 auto *GV
= dyn_cast
<GlobalVariable
>(getUnderlyingObject(PtrOp
));
886 if (!GV
|| !GV
->isConstant() || !GV
->hasDefinitiveInitializer())
889 // Bail for large initializers in excess of 4K to avoid too many scans.
890 Constant
*C
= GV
->getInitializer();
891 uint64_t GVSize
= DL
.getTypeAllocSize(C
->getType());
892 if (!GVSize
|| 4096 < GVSize
)
895 Type
*LoadTy
= LI
->getType();
896 unsigned BW
= DL
.getIndexTypeSizeInBits(PtrOp
->getType());
897 auto [Stride
, ConstOffset
] = getStrideAndModOffsetOfGEP(PtrOp
, DL
);
899 // Any possible offset could be multiple of GEP stride. And any valid
900 // offset is multiple of load alignment, so checking only multiples of bigger
901 // one is sufficient to say results' equality.
902 if (auto LA
= LI
->getAlign();
903 LA
<= GV
->getAlign().valueOrOne() && Stride
.getZExtValue() < LA
.value()) {
904 ConstOffset
= APInt(BW
, 0);
905 Stride
= APInt(BW
, LA
.value());
908 Constant
*Ca
= ConstantFoldLoadFromConst(C
, LoadTy
, ConstOffset
, DL
);
912 unsigned E
= GVSize
- DL
.getTypeStoreSize(LoadTy
);
913 for (; ConstOffset
.getZExtValue() <= E
; ConstOffset
+= Stride
)
914 if (Ca
!= ConstantFoldLoadFromConst(C
, LoadTy
, ConstOffset
, DL
))
917 I
.replaceAllUsesWith(Ca
);
922 /// This is the entry point for folds that could be implemented in regular
923 /// InstCombine, but they are separated because they are not expected to
924 /// occur frequently and/or have more than a constant-length pattern match.
925 static bool foldUnusualPatterns(Function
&F
, DominatorTree
&DT
,
926 TargetTransformInfo
&TTI
,
927 TargetLibraryInfo
&TLI
, AliasAnalysis
&AA
,
928 AssumptionCache
&AC
) {
929 bool MadeChange
= false;
930 for (BasicBlock
&BB
: F
) {
931 // Ignore unreachable basic blocks.
932 if (!DT
.isReachableFromEntry(&BB
))
935 const DataLayout
&DL
= F
.getParent()->getDataLayout();
937 // Walk the block backwards for efficiency. We're matching a chain of
938 // use->defs, so we're more likely to succeed by starting from the bottom.
939 // Also, we want to avoid matching partial patterns.
940 // TODO: It would be more efficient if we removed dead instructions
941 // iteratively in this loop rather than waiting until the end.
942 for (Instruction
&I
: make_early_inc_range(llvm::reverse(BB
))) {
943 MadeChange
|= foldAnyOrAllBitsSet(I
);
944 MadeChange
|= foldGuardedFunnelShift(I
, DT
);
945 MadeChange
|= tryToRecognizePopCount(I
);
946 MadeChange
|= tryToFPToSat(I
, TTI
);
947 MadeChange
|= tryToRecognizeTableBasedCttz(I
);
948 MadeChange
|= foldConsecutiveLoads(I
, DL
, TTI
, AA
, DT
);
949 MadeChange
|= foldPatternedLoads(I
, DL
);
950 // NOTE: This function introduces erasing of the instruction `I`, so it
951 // needs to be called at the end of this sequence, otherwise we may make
953 MadeChange
|= foldSqrt(I
, TTI
, TLI
, AC
, DT
);
957 // We're done with transforms, so remove dead instructions.
959 for (BasicBlock
&BB
: F
)
960 SimplifyInstructionsInBlock(&BB
);
965 /// This is the entry point for all transforms. Pass manager differences are
966 /// handled in the callers of this function.
967 static bool runImpl(Function
&F
, AssumptionCache
&AC
, TargetTransformInfo
&TTI
,
968 TargetLibraryInfo
&TLI
, DominatorTree
&DT
,
970 bool MadeChange
= false;
971 const DataLayout
&DL
= F
.getParent()->getDataLayout();
972 TruncInstCombine
TIC(AC
, TLI
, DL
, DT
);
973 MadeChange
|= TIC
.run(F
);
974 MadeChange
|= foldUnusualPatterns(F
, DT
, TTI
, TLI
, AA
, AC
);
978 PreservedAnalyses
AggressiveInstCombinePass::run(Function
&F
,
979 FunctionAnalysisManager
&AM
) {
980 auto &AC
= AM
.getResult
<AssumptionAnalysis
>(F
);
981 auto &TLI
= AM
.getResult
<TargetLibraryAnalysis
>(F
);
982 auto &DT
= AM
.getResult
<DominatorTreeAnalysis
>(F
);
983 auto &TTI
= AM
.getResult
<TargetIRAnalysis
>(F
);
984 auto &AA
= AM
.getResult
<AAManager
>(F
);
985 if (!runImpl(F
, AC
, TTI
, TLI
, DT
, AA
)) {
986 // No changes, all analyses are preserved.
987 return PreservedAnalyses::all();
989 // Mark all the analyses that instcombine updates as preserved.
990 PreservedAnalyses PA
;
991 PA
.preserveSet
<CFGAnalyses
>();