1 //===--- ExpandMemCmp.cpp - Expand memcmp() to load/stores ----------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This pass tries to expand memcmp() calls into optimally-sized loads and
10 // compares for the target.
12 //===----------------------------------------------------------------------===//
14 #include "llvm/CodeGen/ExpandMemCmp.h"
15 #include "llvm/ADT/Statistic.h"
16 #include "llvm/Analysis/ConstantFolding.h"
17 #include "llvm/Analysis/DomTreeUpdater.h"
18 #include "llvm/Analysis/LazyBlockFrequencyInfo.h"
19 #include "llvm/Analysis/ProfileSummaryInfo.h"
20 #include "llvm/Analysis/TargetLibraryInfo.h"
21 #include "llvm/Analysis/TargetTransformInfo.h"
22 #include "llvm/Analysis/ValueTracking.h"
23 #include "llvm/CodeGen/TargetPassConfig.h"
24 #include "llvm/CodeGen/TargetSubtargetInfo.h"
25 #include "llvm/IR/Dominators.h"
26 #include "llvm/IR/IRBuilder.h"
27 #include "llvm/IR/PatternMatch.h"
28 #include "llvm/InitializePasses.h"
29 #include "llvm/Target/TargetMachine.h"
30 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
31 #include "llvm/Transforms/Utils/Local.h"
32 #include "llvm/Transforms/Utils/SizeOpts.h"
36 using namespace llvm::PatternMatch
;
42 #define DEBUG_TYPE "expand-memcmp"
44 STATISTIC(NumMemCmpCalls
, "Number of memcmp calls");
45 STATISTIC(NumMemCmpNotConstant
, "Number of memcmp calls without constant size");
46 STATISTIC(NumMemCmpGreaterThanMax
,
47 "Number of memcmp calls with size greater than max size");
48 STATISTIC(NumMemCmpInlined
, "Number of inlined memcmp calls");
50 static cl::opt
<unsigned> MemCmpEqZeroNumLoadsPerBlock(
51 "memcmp-num-loads-per-block", cl::Hidden
, cl::init(1),
52 cl::desc("The number of loads per basic block for inline expansion of "
53 "memcmp that is only being compared against zero."));
55 static cl::opt
<unsigned> MaxLoadsPerMemcmp(
56 "max-loads-per-memcmp", cl::Hidden
,
57 cl::desc("Set maximum number of loads used in expanded memcmp"));
59 static cl::opt
<unsigned> MaxLoadsPerMemcmpOptSize(
60 "max-loads-per-memcmp-opt-size", cl::Hidden
,
61 cl::desc("Set maximum number of loads used in expanded memcmp for -Os/Oz"));
66 // This class provides helper functions to expand a memcmp library call into an
68 class MemCmpExpansion
{
70 BasicBlock
*BB
= nullptr;
71 PHINode
*PhiSrc1
= nullptr;
72 PHINode
*PhiSrc2
= nullptr;
74 ResultBlock() = default;
77 CallInst
*const CI
= nullptr;
80 unsigned MaxLoadSize
= 0;
81 uint64_t NumLoadsNonOneByte
= 0;
82 const uint64_t NumLoadsPerBlockForZeroCmp
;
83 std::vector
<BasicBlock
*> LoadCmpBlocks
;
84 BasicBlock
*EndBlock
= nullptr;
85 PHINode
*PhiRes
= nullptr;
86 const bool IsUsedForZeroCmp
;
88 DomTreeUpdater
*DTU
= nullptr;
90 // Represents the decomposition in blocks of the expansion. For example,
91 // comparing 33 bytes on X86+sse can be done with 2x16-byte loads and
92 // 1x1-byte load, which would be represented as [{16, 0}, {16, 16}, {1, 32}.
94 LoadEntry(unsigned LoadSize
, uint64_t Offset
)
95 : LoadSize(LoadSize
), Offset(Offset
) {
98 // The size of the load for this block, in bytes.
100 // The offset of this load from the base pointer, in bytes.
103 using LoadEntryVector
= SmallVector
<LoadEntry
, 8>;
104 LoadEntryVector LoadSequence
;
106 void createLoadCmpBlocks();
107 void createResultBlock();
108 void setupResultBlockPHINodes();
109 void setupEndBlockPHINodes();
110 Value
*getCompareLoadPairs(unsigned BlockIndex
, unsigned &LoadIndex
);
111 void emitLoadCompareBlock(unsigned BlockIndex
);
112 void emitLoadCompareBlockMultipleLoads(unsigned BlockIndex
,
113 unsigned &LoadIndex
);
114 void emitLoadCompareByteBlock(unsigned BlockIndex
, unsigned OffsetBytes
);
115 void emitMemCmpResultBlock();
116 Value
*getMemCmpExpansionZeroCase();
117 Value
*getMemCmpEqZeroOneBlock();
118 Value
*getMemCmpOneBlock();
120 Value
*Lhs
= nullptr;
121 Value
*Rhs
= nullptr;
123 LoadPair
getLoadPair(Type
*LoadSizeType
, Type
*BSwapSizeType
,
124 Type
*CmpSizeType
, unsigned OffsetBytes
);
126 static LoadEntryVector
127 computeGreedyLoadSequence(uint64_t Size
, llvm::ArrayRef
<unsigned> LoadSizes
,
128 unsigned MaxNumLoads
, unsigned &NumLoadsNonOneByte
);
129 static LoadEntryVector
130 computeOverlappingLoadSequence(uint64_t Size
, unsigned MaxLoadSize
,
131 unsigned MaxNumLoads
,
132 unsigned &NumLoadsNonOneByte
);
134 static void optimiseLoadSequence(
135 LoadEntryVector
&LoadSequence
,
136 const TargetTransformInfo::MemCmpExpansionOptions
&Options
,
137 bool IsUsedForZeroCmp
);
140 MemCmpExpansion(CallInst
*CI
, uint64_t Size
,
141 const TargetTransformInfo::MemCmpExpansionOptions
&Options
,
142 const bool IsUsedForZeroCmp
, const DataLayout
&TheDataLayout
,
143 DomTreeUpdater
*DTU
);
145 unsigned getNumBlocks();
146 uint64_t getNumLoads() const { return LoadSequence
.size(); }
148 Value
*getMemCmpExpansion();
151 MemCmpExpansion::LoadEntryVector
MemCmpExpansion::computeGreedyLoadSequence(
152 uint64_t Size
, llvm::ArrayRef
<unsigned> LoadSizes
,
153 const unsigned MaxNumLoads
, unsigned &NumLoadsNonOneByte
) {
154 NumLoadsNonOneByte
= 0;
155 LoadEntryVector LoadSequence
;
157 while (Size
&& !LoadSizes
.empty()) {
158 const unsigned LoadSize
= LoadSizes
.front();
159 const uint64_t NumLoadsForThisSize
= Size
/ LoadSize
;
160 if (LoadSequence
.size() + NumLoadsForThisSize
> MaxNumLoads
) {
161 // Do not expand if the total number of loads is larger than what the
162 // target allows. Note that it's important that we exit before completing
163 // the expansion to avoid using a ton of memory to store the expansion for
167 if (NumLoadsForThisSize
> 0) {
168 for (uint64_t I
= 0; I
< NumLoadsForThisSize
; ++I
) {
169 LoadSequence
.push_back({LoadSize
, Offset
});
173 ++NumLoadsNonOneByte
;
174 Size
= Size
% LoadSize
;
176 LoadSizes
= LoadSizes
.drop_front();
181 MemCmpExpansion::LoadEntryVector
182 MemCmpExpansion::computeOverlappingLoadSequence(uint64_t Size
,
183 const unsigned MaxLoadSize
,
184 const unsigned MaxNumLoads
,
185 unsigned &NumLoadsNonOneByte
) {
186 // These are already handled by the greedy approach.
187 if (Size
< 2 || MaxLoadSize
< 2)
190 // We try to do as many non-overlapping loads as possible starting from the
192 const uint64_t NumNonOverlappingLoads
= Size
/ MaxLoadSize
;
193 assert(NumNonOverlappingLoads
&& "there must be at least one load");
194 // There remain 0 to (MaxLoadSize - 1) bytes to load, this will be done with
195 // an overlapping load.
196 Size
= Size
- NumNonOverlappingLoads
* MaxLoadSize
;
197 // Bail if we do not need an overloapping store, this is already handled by
198 // the greedy approach.
201 // Bail if the number of loads (non-overlapping + potential overlapping one)
202 // is larger than the max allowed.
203 if ((NumNonOverlappingLoads
+ 1) > MaxNumLoads
)
206 // Add non-overlapping loads.
207 LoadEntryVector LoadSequence
;
209 for (uint64_t I
= 0; I
< NumNonOverlappingLoads
; ++I
) {
210 LoadSequence
.push_back({MaxLoadSize
, Offset
});
211 Offset
+= MaxLoadSize
;
214 // Add the last overlapping load.
215 assert(Size
> 0 && Size
< MaxLoadSize
&& "broken invariant");
216 LoadSequence
.push_back({MaxLoadSize
, Offset
- (MaxLoadSize
- Size
)});
217 NumLoadsNonOneByte
= 1;
221 void MemCmpExpansion::optimiseLoadSequence(
222 LoadEntryVector
&LoadSequence
,
223 const TargetTransformInfo::MemCmpExpansionOptions
&Options
,
224 bool IsUsedForZeroCmp
) {
225 // This part of code attempts to optimize the LoadSequence by merging allowed
226 // subsequences into single loads of allowed sizes from
227 // `MemCmpExpansionOptions::AllowedTailExpansions`. If it is for zero
228 // comparison or if no allowed tail expansions are specified, we exit early.
229 if (IsUsedForZeroCmp
|| Options
.AllowedTailExpansions
.empty())
232 while (LoadSequence
.size() >= 2) {
233 auto Last
= LoadSequence
[LoadSequence
.size() - 1];
234 auto PreLast
= LoadSequence
[LoadSequence
.size() - 2];
236 // Exit the loop if the two sequences are not contiguous
237 if (PreLast
.Offset
+ PreLast
.LoadSize
!= Last
.Offset
)
240 auto LoadSize
= Last
.LoadSize
+ PreLast
.LoadSize
;
241 if (find(Options
.AllowedTailExpansions
, LoadSize
) ==
242 Options
.AllowedTailExpansions
.end())
245 // Remove the last two sequences and replace with the combined sequence
246 LoadSequence
.pop_back();
247 LoadSequence
.pop_back();
248 LoadSequence
.emplace_back(PreLast
.Offset
, LoadSize
);
252 // Initialize the basic block structure required for expansion of memcmp call
253 // with given maximum load size and memcmp size parameter.
254 // This structure includes:
255 // 1. A list of load compare blocks - LoadCmpBlocks.
256 // 2. An EndBlock, split from original instruction point, which is the block to
258 // 3. ResultBlock, block to branch to for early exit when a
259 // LoadCmpBlock finds a difference.
260 MemCmpExpansion::MemCmpExpansion(
261 CallInst
*const CI
, uint64_t Size
,
262 const TargetTransformInfo::MemCmpExpansionOptions
&Options
,
263 const bool IsUsedForZeroCmp
, const DataLayout
&TheDataLayout
,
265 : CI(CI
), Size(Size
), NumLoadsPerBlockForZeroCmp(Options
.NumLoadsPerBlock
),
266 IsUsedForZeroCmp(IsUsedForZeroCmp
), DL(TheDataLayout
), DTU(DTU
),
268 assert(Size
> 0 && "zero blocks");
269 // Scale the max size down if the target can load more bytes than we need.
270 llvm::ArrayRef
<unsigned> LoadSizes(Options
.LoadSizes
);
271 while (!LoadSizes
.empty() && LoadSizes
.front() > Size
) {
272 LoadSizes
= LoadSizes
.drop_front();
274 assert(!LoadSizes
.empty() && "cannot load Size bytes");
275 MaxLoadSize
= LoadSizes
.front();
276 // Compute the decomposition.
277 unsigned GreedyNumLoadsNonOneByte
= 0;
278 LoadSequence
= computeGreedyLoadSequence(Size
, LoadSizes
, Options
.MaxNumLoads
,
279 GreedyNumLoadsNonOneByte
);
280 NumLoadsNonOneByte
= GreedyNumLoadsNonOneByte
;
281 assert(LoadSequence
.size() <= Options
.MaxNumLoads
&& "broken invariant");
282 // If we allow overlapping loads and the load sequence is not already optimal,
283 // use overlapping loads.
284 if (Options
.AllowOverlappingLoads
&&
285 (LoadSequence
.empty() || LoadSequence
.size() > 2)) {
286 unsigned OverlappingNumLoadsNonOneByte
= 0;
287 auto OverlappingLoads
= computeOverlappingLoadSequence(
288 Size
, MaxLoadSize
, Options
.MaxNumLoads
, OverlappingNumLoadsNonOneByte
);
289 if (!OverlappingLoads
.empty() &&
290 (LoadSequence
.empty() ||
291 OverlappingLoads
.size() < LoadSequence
.size())) {
292 LoadSequence
= OverlappingLoads
;
293 NumLoadsNonOneByte
= OverlappingNumLoadsNonOneByte
;
296 assert(LoadSequence
.size() <= Options
.MaxNumLoads
&& "broken invariant");
297 optimiseLoadSequence(LoadSequence
, Options
, IsUsedForZeroCmp
);
300 unsigned MemCmpExpansion::getNumBlocks() {
301 if (IsUsedForZeroCmp
)
302 return getNumLoads() / NumLoadsPerBlockForZeroCmp
+
303 (getNumLoads() % NumLoadsPerBlockForZeroCmp
!= 0 ? 1 : 0);
304 return getNumLoads();
307 void MemCmpExpansion::createLoadCmpBlocks() {
308 for (unsigned i
= 0; i
< getNumBlocks(); i
++) {
309 BasicBlock
*BB
= BasicBlock::Create(CI
->getContext(), "loadbb",
310 EndBlock
->getParent(), EndBlock
);
311 LoadCmpBlocks
.push_back(BB
);
315 void MemCmpExpansion::createResultBlock() {
316 ResBlock
.BB
= BasicBlock::Create(CI
->getContext(), "res_block",
317 EndBlock
->getParent(), EndBlock
);
320 MemCmpExpansion::LoadPair
MemCmpExpansion::getLoadPair(Type
*LoadSizeType
,
323 unsigned OffsetBytes
) {
324 // Get the memory source at offset `OffsetBytes`.
325 Value
*LhsSource
= CI
->getArgOperand(0);
326 Value
*RhsSource
= CI
->getArgOperand(1);
327 Align LhsAlign
= LhsSource
->getPointerAlignment(DL
);
328 Align RhsAlign
= RhsSource
->getPointerAlignment(DL
);
329 if (OffsetBytes
> 0) {
330 auto *ByteType
= Type::getInt8Ty(CI
->getContext());
331 LhsSource
= Builder
.CreateConstGEP1_64(ByteType
, LhsSource
, OffsetBytes
);
332 RhsSource
= Builder
.CreateConstGEP1_64(ByteType
, RhsSource
, OffsetBytes
);
333 LhsAlign
= commonAlignment(LhsAlign
, OffsetBytes
);
334 RhsAlign
= commonAlignment(RhsAlign
, OffsetBytes
);
337 // Create a constant or a load from the source.
338 Value
*Lhs
= nullptr;
339 if (auto *C
= dyn_cast
<Constant
>(LhsSource
))
340 Lhs
= ConstantFoldLoadFromConstPtr(C
, LoadSizeType
, DL
);
342 Lhs
= Builder
.CreateAlignedLoad(LoadSizeType
, LhsSource
, LhsAlign
);
344 Value
*Rhs
= nullptr;
345 if (auto *C
= dyn_cast
<Constant
>(RhsSource
))
346 Rhs
= ConstantFoldLoadFromConstPtr(C
, LoadSizeType
, DL
);
348 Rhs
= Builder
.CreateAlignedLoad(LoadSizeType
, RhsSource
, RhsAlign
);
350 // Zero extend if Byte Swap intrinsic has different type
351 if (BSwapSizeType
&& LoadSizeType
!= BSwapSizeType
) {
352 Lhs
= Builder
.CreateZExt(Lhs
, BSwapSizeType
);
353 Rhs
= Builder
.CreateZExt(Rhs
, BSwapSizeType
);
356 // Swap bytes if required.
358 Function
*Bswap
= Intrinsic::getDeclaration(
359 CI
->getModule(), Intrinsic::bswap
, BSwapSizeType
);
360 Lhs
= Builder
.CreateCall(Bswap
, Lhs
);
361 Rhs
= Builder
.CreateCall(Bswap
, Rhs
);
364 // Zero extend if required.
365 if (CmpSizeType
!= nullptr && CmpSizeType
!= Lhs
->getType()) {
366 Lhs
= Builder
.CreateZExt(Lhs
, CmpSizeType
);
367 Rhs
= Builder
.CreateZExt(Rhs
, CmpSizeType
);
372 // This function creates the IR instructions for loading and comparing 1 byte.
373 // It loads 1 byte from each source of the memcmp parameters with the given
374 // GEPIndex. It then subtracts the two loaded values and adds this result to the
375 // final phi node for selecting the memcmp result.
376 void MemCmpExpansion::emitLoadCompareByteBlock(unsigned BlockIndex
,
377 unsigned OffsetBytes
) {
378 BasicBlock
*BB
= LoadCmpBlocks
[BlockIndex
];
379 Builder
.SetInsertPoint(BB
);
380 const LoadPair Loads
=
381 getLoadPair(Type::getInt8Ty(CI
->getContext()), nullptr,
382 Type::getInt32Ty(CI
->getContext()), OffsetBytes
);
383 Value
*Diff
= Builder
.CreateSub(Loads
.Lhs
, Loads
.Rhs
);
385 PhiRes
->addIncoming(Diff
, BB
);
387 if (BlockIndex
< (LoadCmpBlocks
.size() - 1)) {
388 // Early exit branch if difference found to EndBlock. Otherwise, continue to
389 // next LoadCmpBlock,
390 Value
*Cmp
= Builder
.CreateICmp(ICmpInst::ICMP_NE
, Diff
,
391 ConstantInt::get(Diff
->getType(), 0));
393 BranchInst::Create(EndBlock
, LoadCmpBlocks
[BlockIndex
+ 1], Cmp
);
394 Builder
.Insert(CmpBr
);
397 {{DominatorTree::Insert
, BB
, EndBlock
},
398 {DominatorTree::Insert
, BB
, LoadCmpBlocks
[BlockIndex
+ 1]}});
400 // The last block has an unconditional branch to EndBlock.
401 BranchInst
*CmpBr
= BranchInst::Create(EndBlock
);
402 Builder
.Insert(CmpBr
);
404 DTU
->applyUpdates({{DominatorTree::Insert
, BB
, EndBlock
}});
408 /// Generate an equality comparison for one or more pairs of loaded values.
409 /// This is used in the case where the memcmp() call is compared equal or not
411 Value
*MemCmpExpansion::getCompareLoadPairs(unsigned BlockIndex
,
412 unsigned &LoadIndex
) {
413 assert(LoadIndex
< getNumLoads() &&
414 "getCompareLoadPairs() called with no remaining loads");
415 std::vector
<Value
*> XorList
, OrList
;
416 Value
*Diff
= nullptr;
418 const unsigned NumLoads
=
419 std::min(getNumLoads() - LoadIndex
, NumLoadsPerBlockForZeroCmp
);
421 // For a single-block expansion, start inserting before the memcmp call.
422 if (LoadCmpBlocks
.empty())
423 Builder
.SetInsertPoint(CI
);
425 Builder
.SetInsertPoint(LoadCmpBlocks
[BlockIndex
]);
427 Value
*Cmp
= nullptr;
428 // If we have multiple loads per block, we need to generate a composite
429 // comparison using xor+or. The type for the combinations is the largest load
431 IntegerType
*const MaxLoadType
=
432 NumLoads
== 1 ? nullptr
433 : IntegerType::get(CI
->getContext(), MaxLoadSize
* 8);
435 for (unsigned i
= 0; i
< NumLoads
; ++i
, ++LoadIndex
) {
436 const LoadEntry
&CurLoadEntry
= LoadSequence
[LoadIndex
];
437 const LoadPair Loads
= getLoadPair(
438 IntegerType::get(CI
->getContext(), CurLoadEntry
.LoadSize
* 8), nullptr,
439 MaxLoadType
, CurLoadEntry
.Offset
);
442 // If we have multiple loads per block, we need to generate a composite
443 // comparison using xor+or.
444 Diff
= Builder
.CreateXor(Loads
.Lhs
, Loads
.Rhs
);
445 Diff
= Builder
.CreateZExt(Diff
, MaxLoadType
);
446 XorList
.push_back(Diff
);
448 // If there's only one load per block, we just compare the loaded values.
449 Cmp
= Builder
.CreateICmpNE(Loads
.Lhs
, Loads
.Rhs
);
453 auto pairWiseOr
= [&](std::vector
<Value
*> &InList
) -> std::vector
<Value
*> {
454 std::vector
<Value
*> OutList
;
455 for (unsigned i
= 0; i
< InList
.size() - 1; i
= i
+ 2) {
456 Value
*Or
= Builder
.CreateOr(InList
[i
], InList
[i
+ 1]);
457 OutList
.push_back(Or
);
459 if (InList
.size() % 2 != 0)
460 OutList
.push_back(InList
.back());
465 // Pairwise OR the XOR results.
466 OrList
= pairWiseOr(XorList
);
468 // Pairwise OR the OR results until one result left.
469 while (OrList
.size() != 1) {
470 OrList
= pairWiseOr(OrList
);
473 assert(Diff
&& "Failed to find comparison diff");
474 Cmp
= Builder
.CreateICmpNE(OrList
[0], ConstantInt::get(Diff
->getType(), 0));
480 void MemCmpExpansion::emitLoadCompareBlockMultipleLoads(unsigned BlockIndex
,
481 unsigned &LoadIndex
) {
482 Value
*Cmp
= getCompareLoadPairs(BlockIndex
, LoadIndex
);
484 BasicBlock
*NextBB
= (BlockIndex
== (LoadCmpBlocks
.size() - 1))
486 : LoadCmpBlocks
[BlockIndex
+ 1];
487 // Early exit branch if difference found to ResultBlock. Otherwise,
488 // continue to next LoadCmpBlock or EndBlock.
489 BasicBlock
*BB
= Builder
.GetInsertBlock();
490 BranchInst
*CmpBr
= BranchInst::Create(ResBlock
.BB
, NextBB
, Cmp
);
491 Builder
.Insert(CmpBr
);
493 DTU
->applyUpdates({{DominatorTree::Insert
, BB
, ResBlock
.BB
},
494 {DominatorTree::Insert
, BB
, NextBB
}});
496 // Add a phi edge for the last LoadCmpBlock to Endblock with a value of 0
497 // since early exit to ResultBlock was not taken (no difference was found in
498 // any of the bytes).
499 if (BlockIndex
== LoadCmpBlocks
.size() - 1) {
500 Value
*Zero
= ConstantInt::get(Type::getInt32Ty(CI
->getContext()), 0);
501 PhiRes
->addIncoming(Zero
, LoadCmpBlocks
[BlockIndex
]);
505 // This function creates the IR intructions for loading and comparing using the
506 // given LoadSize. It loads the number of bytes specified by LoadSize from each
507 // source of the memcmp parameters. It then does a subtract to see if there was
508 // a difference in the loaded values. If a difference is found, it branches
509 // with an early exit to the ResultBlock for calculating which source was
510 // larger. Otherwise, it falls through to the either the next LoadCmpBlock or
511 // the EndBlock if this is the last LoadCmpBlock. Loading 1 byte is handled with
512 // a special case through emitLoadCompareByteBlock. The special handling can
513 // simply subtract the loaded values and add it to the result phi node.
514 void MemCmpExpansion::emitLoadCompareBlock(unsigned BlockIndex
) {
515 // There is one load per block in this case, BlockIndex == LoadIndex.
516 const LoadEntry
&CurLoadEntry
= LoadSequence
[BlockIndex
];
518 if (CurLoadEntry
.LoadSize
== 1) {
519 MemCmpExpansion::emitLoadCompareByteBlock(BlockIndex
, CurLoadEntry
.Offset
);
524 IntegerType::get(CI
->getContext(), CurLoadEntry
.LoadSize
* 8);
525 Type
*BSwapSizeType
=
527 ? IntegerType::get(CI
->getContext(),
528 PowerOf2Ceil(CurLoadEntry
.LoadSize
* 8))
530 Type
*MaxLoadType
= IntegerType::get(
532 std::max(MaxLoadSize
, (unsigned)PowerOf2Ceil(CurLoadEntry
.LoadSize
)) * 8);
533 assert(CurLoadEntry
.LoadSize
<= MaxLoadSize
&& "Unexpected load type");
535 Builder
.SetInsertPoint(LoadCmpBlocks
[BlockIndex
]);
537 const LoadPair Loads
= getLoadPair(LoadSizeType
, BSwapSizeType
, MaxLoadType
,
538 CurLoadEntry
.Offset
);
540 // Add the loaded values to the phi nodes for calculating memcmp result only
541 // if result is not used in a zero equality.
542 if (!IsUsedForZeroCmp
) {
543 ResBlock
.PhiSrc1
->addIncoming(Loads
.Lhs
, LoadCmpBlocks
[BlockIndex
]);
544 ResBlock
.PhiSrc2
->addIncoming(Loads
.Rhs
, LoadCmpBlocks
[BlockIndex
]);
547 Value
*Cmp
= Builder
.CreateICmp(ICmpInst::ICMP_EQ
, Loads
.Lhs
, Loads
.Rhs
);
548 BasicBlock
*NextBB
= (BlockIndex
== (LoadCmpBlocks
.size() - 1))
550 : LoadCmpBlocks
[BlockIndex
+ 1];
551 // Early exit branch if difference found to ResultBlock. Otherwise, continue
552 // to next LoadCmpBlock or EndBlock.
553 BasicBlock
*BB
= Builder
.GetInsertBlock();
554 BranchInst
*CmpBr
= BranchInst::Create(NextBB
, ResBlock
.BB
, Cmp
);
555 Builder
.Insert(CmpBr
);
557 DTU
->applyUpdates({{DominatorTree::Insert
, BB
, NextBB
},
558 {DominatorTree::Insert
, BB
, ResBlock
.BB
}});
560 // Add a phi edge for the last LoadCmpBlock to Endblock with a value of 0
561 // since early exit to ResultBlock was not taken (no difference was found in
562 // any of the bytes).
563 if (BlockIndex
== LoadCmpBlocks
.size() - 1) {
564 Value
*Zero
= ConstantInt::get(Type::getInt32Ty(CI
->getContext()), 0);
565 PhiRes
->addIncoming(Zero
, LoadCmpBlocks
[BlockIndex
]);
569 // This function populates the ResultBlock with a sequence to calculate the
570 // memcmp result. It compares the two loaded source values and returns -1 if
571 // src1 < src2 and 1 if src1 > src2.
572 void MemCmpExpansion::emitMemCmpResultBlock() {
573 // Special case: if memcmp result is used in a zero equality, result does not
574 // need to be calculated and can simply return 1.
575 if (IsUsedForZeroCmp
) {
576 BasicBlock::iterator InsertPt
= ResBlock
.BB
->getFirstInsertionPt();
577 Builder
.SetInsertPoint(ResBlock
.BB
, InsertPt
);
578 Value
*Res
= ConstantInt::get(Type::getInt32Ty(CI
->getContext()), 1);
579 PhiRes
->addIncoming(Res
, ResBlock
.BB
);
580 BranchInst
*NewBr
= BranchInst::Create(EndBlock
);
581 Builder
.Insert(NewBr
);
583 DTU
->applyUpdates({{DominatorTree::Insert
, ResBlock
.BB
, EndBlock
}});
586 BasicBlock::iterator InsertPt
= ResBlock
.BB
->getFirstInsertionPt();
587 Builder
.SetInsertPoint(ResBlock
.BB
, InsertPt
);
589 Value
*Cmp
= Builder
.CreateICmp(ICmpInst::ICMP_ULT
, ResBlock
.PhiSrc1
,
593 Builder
.CreateSelect(Cmp
, ConstantInt::get(Builder
.getInt32Ty(), -1),
594 ConstantInt::get(Builder
.getInt32Ty(), 1));
596 PhiRes
->addIncoming(Res
, ResBlock
.BB
);
597 BranchInst
*NewBr
= BranchInst::Create(EndBlock
);
598 Builder
.Insert(NewBr
);
600 DTU
->applyUpdates({{DominatorTree::Insert
, ResBlock
.BB
, EndBlock
}});
603 void MemCmpExpansion::setupResultBlockPHINodes() {
604 Type
*MaxLoadType
= IntegerType::get(CI
->getContext(), MaxLoadSize
* 8);
605 Builder
.SetInsertPoint(ResBlock
.BB
);
606 // Note: this assumes one load per block.
608 Builder
.CreatePHI(MaxLoadType
, NumLoadsNonOneByte
, "phi.src1");
610 Builder
.CreatePHI(MaxLoadType
, NumLoadsNonOneByte
, "phi.src2");
613 void MemCmpExpansion::setupEndBlockPHINodes() {
614 Builder
.SetInsertPoint(EndBlock
, EndBlock
->begin());
615 PhiRes
= Builder
.CreatePHI(Type::getInt32Ty(CI
->getContext()), 2, "phi.res");
618 Value
*MemCmpExpansion::getMemCmpExpansionZeroCase() {
619 unsigned LoadIndex
= 0;
620 // This loop populates each of the LoadCmpBlocks with the IR sequence to
621 // handle multiple loads per block.
622 for (unsigned I
= 0; I
< getNumBlocks(); ++I
) {
623 emitLoadCompareBlockMultipleLoads(I
, LoadIndex
);
626 emitMemCmpResultBlock();
630 /// A memcmp expansion that compares equality with 0 and only has one block of
631 /// load and compare can bypass the compare, branch, and phi IR that is required
632 /// in the general case.
633 Value
*MemCmpExpansion::getMemCmpEqZeroOneBlock() {
634 unsigned LoadIndex
= 0;
635 Value
*Cmp
= getCompareLoadPairs(0, LoadIndex
);
636 assert(LoadIndex
== getNumLoads() && "some entries were not consumed");
637 return Builder
.CreateZExt(Cmp
, Type::getInt32Ty(CI
->getContext()));
640 /// A memcmp expansion that only has one block of load and compare can bypass
641 /// the compare, branch, and phi IR that is required in the general case.
642 /// This function also analyses users of memcmp, and if there is only one user
643 /// from which we can conclude that only 2 out of 3 memcmp outcomes really
644 /// matter, then it generates more efficient code with only one comparison.
645 Value
*MemCmpExpansion::getMemCmpOneBlock() {
646 bool NeedsBSwap
= DL
.isLittleEndian() && Size
!= 1;
647 Type
*LoadSizeType
= IntegerType::get(CI
->getContext(), Size
* 8);
648 Type
*BSwapSizeType
=
649 NeedsBSwap
? IntegerType::get(CI
->getContext(), PowerOf2Ceil(Size
* 8))
652 IntegerType::get(CI
->getContext(),
653 std::max(MaxLoadSize
, (unsigned)PowerOf2Ceil(Size
)) * 8);
655 // The i8 and i16 cases don't need compares. We zext the loaded values and
656 // subtract them to get the suitable negative, zero, or positive i32 result.
657 if (Size
== 1 || Size
== 2) {
658 const LoadPair Loads
= getLoadPair(LoadSizeType
, BSwapSizeType
,
659 Builder
.getInt32Ty(), /*Offset*/ 0);
660 return Builder
.CreateSub(Loads
.Lhs
, Loads
.Rhs
);
663 const LoadPair Loads
= getLoadPair(LoadSizeType
, BSwapSizeType
, MaxLoadType
,
666 // If a user of memcmp cares only about two outcomes, for example:
667 // bool result = memcmp(a, b, NBYTES) > 0;
668 // We can generate more optimal code with a smaller number of operations
669 if (CI
->hasOneUser()) {
670 auto *UI
= cast
<Instruction
>(*CI
->user_begin());
671 ICmpInst::Predicate Pred
= ICmpInst::Predicate::BAD_ICMP_PREDICATE
;
673 bool NeedsZExt
= false;
674 // This is a special case because instead of checking if the result is less
676 // bool result = memcmp(a, b, NBYTES) < 0;
677 // Compiler is clever enough to generate the following code:
678 // bool result = memcmp(a, b, NBYTES) >> 31;
679 if (match(UI
, m_LShr(m_Value(), m_ConstantInt(Shift
))) &&
680 Shift
== (CI
->getType()->getIntegerBitWidth() - 1)) {
681 Pred
= ICmpInst::ICMP_SLT
;
684 // In case of a successful match this call will set `Pred` variable
685 match(UI
, m_ICmp(Pred
, m_Specific(CI
), m_Zero()));
687 // Generate new code and remove the original memcmp call and the user
688 if (ICmpInst::isSigned(Pred
)) {
689 Value
*Cmp
= Builder
.CreateICmp(CmpInst::getUnsignedPredicate(Pred
),
690 Loads
.Lhs
, Loads
.Rhs
);
691 auto *Result
= NeedsZExt
? Builder
.CreateZExt(Cmp
, UI
->getType()) : Cmp
;
692 UI
->replaceAllUsesWith(Result
);
693 UI
->eraseFromParent();
694 CI
->eraseFromParent();
699 // The result of memcmp is negative, zero, or positive, so produce that by
700 // subtracting 2 extended compare bits: sub (ugt, ult).
701 // If a target prefers to use selects to get -1/0/1, they should be able
702 // to transform this later. The inverse transform (going from selects to math)
703 // may not be possible in the DAG because the selects got converted into
704 // branches before we got there.
705 Value
*CmpUGT
= Builder
.CreateICmpUGT(Loads
.Lhs
, Loads
.Rhs
);
706 Value
*CmpULT
= Builder
.CreateICmpULT(Loads
.Lhs
, Loads
.Rhs
);
707 Value
*ZextUGT
= Builder
.CreateZExt(CmpUGT
, Builder
.getInt32Ty());
708 Value
*ZextULT
= Builder
.CreateZExt(CmpULT
, Builder
.getInt32Ty());
709 return Builder
.CreateSub(ZextUGT
, ZextULT
);
712 // This function expands the memcmp call into an inline expansion and returns
713 // the memcmp result. Returns nullptr if the memcmp is already replaced.
714 Value
*MemCmpExpansion::getMemCmpExpansion() {
715 // Create the basic block framework for a multi-block expansion.
716 if (getNumBlocks() != 1) {
717 BasicBlock
*StartBlock
= CI
->getParent();
718 EndBlock
= SplitBlock(StartBlock
, CI
, DTU
, /*LI=*/nullptr,
719 /*MSSAU=*/nullptr, "endblock");
720 setupEndBlockPHINodes();
723 // If return value of memcmp is not used in a zero equality, we need to
724 // calculate which source was larger. The calculation requires the
725 // two loaded source values of each load compare block.
726 // These will be saved in the phi nodes created by setupResultBlockPHINodes.
727 if (!IsUsedForZeroCmp
) setupResultBlockPHINodes();
729 // Create the number of required load compare basic blocks.
730 createLoadCmpBlocks();
732 // Update the terminator added by SplitBlock to branch to the first
734 StartBlock
->getTerminator()->setSuccessor(0, LoadCmpBlocks
[0]);
736 DTU
->applyUpdates({{DominatorTree::Insert
, StartBlock
, LoadCmpBlocks
[0]},
737 {DominatorTree::Delete
, StartBlock
, EndBlock
}});
740 Builder
.SetCurrentDebugLocation(CI
->getDebugLoc());
742 if (IsUsedForZeroCmp
)
743 return getNumBlocks() == 1 ? getMemCmpEqZeroOneBlock()
744 : getMemCmpExpansionZeroCase();
746 if (getNumBlocks() == 1)
747 return getMemCmpOneBlock();
749 for (unsigned I
= 0; I
< getNumBlocks(); ++I
) {
750 emitLoadCompareBlock(I
);
753 emitMemCmpResultBlock();
757 // This function checks to see if an expansion of memcmp can be generated.
758 // It checks for constant compare size that is less than the max inline size.
759 // If an expansion cannot occur, returns false to leave as a library call.
760 // Otherwise, the library call is replaced with a new IR instruction sequence.
761 /// We want to transform:
762 /// %call = call signext i32 @memcmp(i8* %0, i8* %1, i64 15)
765 /// %0 = bitcast i32* %buffer2 to i8*
766 /// %1 = bitcast i32* %buffer1 to i8*
767 /// %2 = bitcast i8* %1 to i64*
768 /// %3 = bitcast i8* %0 to i64*
769 /// %4 = load i64, i64* %2
770 /// %5 = load i64, i64* %3
771 /// %6 = call i64 @llvm.bswap.i64(i64 %4)
772 /// %7 = call i64 @llvm.bswap.i64(i64 %5)
773 /// %8 = sub i64 %6, %7
774 /// %9 = icmp ne i64 %8, 0
775 /// br i1 %9, label %res_block, label %loadbb1
776 /// res_block: ; preds = %loadbb2,
777 /// %loadbb1, %loadbb
778 /// %phi.src1 = phi i64 [ %6, %loadbb ], [ %22, %loadbb1 ], [ %36, %loadbb2 ]
779 /// %phi.src2 = phi i64 [ %7, %loadbb ], [ %23, %loadbb1 ], [ %37, %loadbb2 ]
780 /// %10 = icmp ult i64 %phi.src1, %phi.src2
781 /// %11 = select i1 %10, i32 -1, i32 1
782 /// br label %endblock
783 /// loadbb1: ; preds = %loadbb
784 /// %12 = bitcast i32* %buffer2 to i8*
785 /// %13 = bitcast i32* %buffer1 to i8*
786 /// %14 = bitcast i8* %13 to i32*
787 /// %15 = bitcast i8* %12 to i32*
788 /// %16 = getelementptr i32, i32* %14, i32 2
789 /// %17 = getelementptr i32, i32* %15, i32 2
790 /// %18 = load i32, i32* %16
791 /// %19 = load i32, i32* %17
792 /// %20 = call i32 @llvm.bswap.i32(i32 %18)
793 /// %21 = call i32 @llvm.bswap.i32(i32 %19)
794 /// %22 = zext i32 %20 to i64
795 /// %23 = zext i32 %21 to i64
796 /// %24 = sub i64 %22, %23
797 /// %25 = icmp ne i64 %24, 0
798 /// br i1 %25, label %res_block, label %loadbb2
799 /// loadbb2: ; preds = %loadbb1
800 /// %26 = bitcast i32* %buffer2 to i8*
801 /// %27 = bitcast i32* %buffer1 to i8*
802 /// %28 = bitcast i8* %27 to i16*
803 /// %29 = bitcast i8* %26 to i16*
804 /// %30 = getelementptr i16, i16* %28, i16 6
805 /// %31 = getelementptr i16, i16* %29, i16 6
806 /// %32 = load i16, i16* %30
807 /// %33 = load i16, i16* %31
808 /// %34 = call i16 @llvm.bswap.i16(i16 %32)
809 /// %35 = call i16 @llvm.bswap.i16(i16 %33)
810 /// %36 = zext i16 %34 to i64
811 /// %37 = zext i16 %35 to i64
812 /// %38 = sub i64 %36, %37
813 /// %39 = icmp ne i64 %38, 0
814 /// br i1 %39, label %res_block, label %loadbb3
815 /// loadbb3: ; preds = %loadbb2
816 /// %40 = bitcast i32* %buffer2 to i8*
817 /// %41 = bitcast i32* %buffer1 to i8*
818 /// %42 = getelementptr i8, i8* %41, i8 14
819 /// %43 = getelementptr i8, i8* %40, i8 14
820 /// %44 = load i8, i8* %42
821 /// %45 = load i8, i8* %43
822 /// %46 = zext i8 %44 to i32
823 /// %47 = zext i8 %45 to i32
824 /// %48 = sub i32 %46, %47
825 /// br label %endblock
826 /// endblock: ; preds = %res_block,
828 /// %phi.res = phi i32 [ %48, %loadbb3 ], [ %11, %res_block ]
830 static bool expandMemCmp(CallInst
*CI
, const TargetTransformInfo
*TTI
,
831 const TargetLowering
*TLI
, const DataLayout
*DL
,
832 ProfileSummaryInfo
*PSI
, BlockFrequencyInfo
*BFI
,
833 DomTreeUpdater
*DTU
, const bool IsBCmp
) {
836 // Early exit from expansion if -Oz.
837 if (CI
->getFunction()->hasMinSize())
840 // Early exit from expansion if size is not a constant.
841 ConstantInt
*SizeCast
= dyn_cast
<ConstantInt
>(CI
->getArgOperand(2));
843 NumMemCmpNotConstant
++;
846 const uint64_t SizeVal
= SizeCast
->getZExtValue();
851 // TTI call to check if target would like to expand memcmp. Also, get the
852 // available load sizes.
853 const bool IsUsedForZeroCmp
=
854 IsBCmp
|| isOnlyUsedInZeroEqualityComparison(CI
);
855 bool OptForSize
= CI
->getFunction()->hasOptSize() ||
856 llvm::shouldOptimizeForSize(CI
->getParent(), PSI
, BFI
);
857 auto Options
= TTI
->enableMemCmpExpansion(OptForSize
,
859 if (!Options
) return false;
861 if (MemCmpEqZeroNumLoadsPerBlock
.getNumOccurrences())
862 Options
.NumLoadsPerBlock
= MemCmpEqZeroNumLoadsPerBlock
;
865 MaxLoadsPerMemcmpOptSize
.getNumOccurrences())
866 Options
.MaxNumLoads
= MaxLoadsPerMemcmpOptSize
;
868 if (!OptForSize
&& MaxLoadsPerMemcmp
.getNumOccurrences())
869 Options
.MaxNumLoads
= MaxLoadsPerMemcmp
;
871 MemCmpExpansion
Expansion(CI
, SizeVal
, Options
, IsUsedForZeroCmp
, *DL
, DTU
);
873 // Don't expand if this will require more loads than desired by the target.
874 if (Expansion
.getNumLoads() == 0) {
875 NumMemCmpGreaterThanMax
++;
881 if (Value
*Res
= Expansion
.getMemCmpExpansion()) {
882 // Replace call with result of expansion and erase call.
883 CI
->replaceAllUsesWith(Res
);
884 CI
->eraseFromParent();
890 // Returns true if a change was made.
891 static bool runOnBlock(BasicBlock
&BB
, const TargetLibraryInfo
*TLI
,
892 const TargetTransformInfo
*TTI
, const TargetLowering
*TL
,
893 const DataLayout
&DL
, ProfileSummaryInfo
*PSI
,
894 BlockFrequencyInfo
*BFI
, DomTreeUpdater
*DTU
);
896 static PreservedAnalyses
runImpl(Function
&F
, const TargetLibraryInfo
*TLI
,
897 const TargetTransformInfo
*TTI
,
898 const TargetLowering
*TL
,
899 ProfileSummaryInfo
*PSI
,
900 BlockFrequencyInfo
*BFI
, DominatorTree
*DT
);
902 class ExpandMemCmpLegacyPass
: public FunctionPass
{
906 ExpandMemCmpLegacyPass() : FunctionPass(ID
) {
907 initializeExpandMemCmpLegacyPassPass(*PassRegistry::getPassRegistry());
910 bool runOnFunction(Function
&F
) override
{
911 if (skipFunction(F
)) return false;
913 auto *TPC
= getAnalysisIfAvailable
<TargetPassConfig
>();
917 const TargetLowering
* TL
=
918 TPC
->getTM
<TargetMachine
>().getSubtargetImpl(F
)->getTargetLowering();
920 const TargetLibraryInfo
*TLI
=
921 &getAnalysis
<TargetLibraryInfoWrapperPass
>().getTLI(F
);
922 const TargetTransformInfo
*TTI
=
923 &getAnalysis
<TargetTransformInfoWrapperPass
>().getTTI(F
);
924 auto *PSI
= &getAnalysis
<ProfileSummaryInfoWrapperPass
>().getPSI();
925 auto *BFI
= (PSI
&& PSI
->hasProfileSummary()) ?
926 &getAnalysis
<LazyBlockFrequencyInfoPass
>().getBFI() :
928 DominatorTree
*DT
= nullptr;
929 if (auto *DTWP
= getAnalysisIfAvailable
<DominatorTreeWrapperPass
>())
930 DT
= &DTWP
->getDomTree();
931 auto PA
= runImpl(F
, TLI
, TTI
, TL
, PSI
, BFI
, DT
);
932 return !PA
.areAllPreserved();
936 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
937 AU
.addRequired
<TargetLibraryInfoWrapperPass
>();
938 AU
.addRequired
<TargetTransformInfoWrapperPass
>();
939 AU
.addRequired
<ProfileSummaryInfoWrapperPass
>();
940 AU
.addPreserved
<DominatorTreeWrapperPass
>();
941 LazyBlockFrequencyInfoPass::getLazyBFIAnalysisUsage(AU
);
942 FunctionPass::getAnalysisUsage(AU
);
946 bool runOnBlock(BasicBlock
&BB
, const TargetLibraryInfo
*TLI
,
947 const TargetTransformInfo
*TTI
, const TargetLowering
*TL
,
948 const DataLayout
&DL
, ProfileSummaryInfo
*PSI
,
949 BlockFrequencyInfo
*BFI
, DomTreeUpdater
*DTU
) {
950 for (Instruction
&I
: BB
) {
951 CallInst
*CI
= dyn_cast
<CallInst
>(&I
);
956 if (TLI
->getLibFunc(*CI
, Func
) &&
957 (Func
== LibFunc_memcmp
|| Func
== LibFunc_bcmp
) &&
958 expandMemCmp(CI
, TTI
, TL
, &DL
, PSI
, BFI
, DTU
, Func
== LibFunc_bcmp
)) {
965 PreservedAnalyses
runImpl(Function
&F
, const TargetLibraryInfo
*TLI
,
966 const TargetTransformInfo
*TTI
,
967 const TargetLowering
*TL
, ProfileSummaryInfo
*PSI
,
968 BlockFrequencyInfo
*BFI
, DominatorTree
*DT
) {
969 std::optional
<DomTreeUpdater
> DTU
;
971 DTU
.emplace(DT
, DomTreeUpdater::UpdateStrategy::Lazy
);
973 const DataLayout
& DL
= F
.getParent()->getDataLayout();
974 bool MadeChanges
= false;
975 for (auto BBIt
= F
.begin(); BBIt
!= F
.end();) {
976 if (runOnBlock(*BBIt
, TLI
, TTI
, TL
, DL
, PSI
, BFI
, DTU
? &*DTU
: nullptr)) {
978 // If changes were made, restart the function from the beginning, since
979 // the structure of the function was changed.
986 for (BasicBlock
&BB
: F
)
987 SimplifyInstructionsInBlock(&BB
);
989 return PreservedAnalyses::all();
990 PreservedAnalyses PA
;
991 PA
.preserve
<DominatorTreeAnalysis
>();
997 PreservedAnalyses
ExpandMemCmpPass::run(Function
&F
,
998 FunctionAnalysisManager
&FAM
) {
999 const auto *TL
= TM
->getSubtargetImpl(F
)->getTargetLowering();
1000 const auto &TLI
= FAM
.getResult
<TargetLibraryAnalysis
>(F
);
1001 const auto &TTI
= FAM
.getResult
<TargetIRAnalysis
>(F
);
1002 auto *PSI
= FAM
.getResult
<ModuleAnalysisManagerFunctionProxy
>(F
)
1003 .getCachedResult
<ProfileSummaryAnalysis
>(*F
.getParent());
1004 BlockFrequencyInfo
*BFI
= (PSI
&& PSI
->hasProfileSummary())
1005 ? &FAM
.getResult
<BlockFrequencyAnalysis
>(F
)
1007 auto *DT
= FAM
.getCachedResult
<DominatorTreeAnalysis
>(F
);
1009 return runImpl(F
, &TLI
, &TTI
, TL
, PSI
, BFI
, DT
);
1012 char ExpandMemCmpLegacyPass::ID
= 0;
1013 INITIALIZE_PASS_BEGIN(ExpandMemCmpLegacyPass
, DEBUG_TYPE
,
1014 "Expand memcmp() to load/stores", false, false)
1015 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass
)
1016 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass
)
1017 INITIALIZE_PASS_DEPENDENCY(LazyBlockFrequencyInfoPass
)
1018 INITIALIZE_PASS_DEPENDENCY(ProfileSummaryInfoWrapperPass
)
1019 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass
)
1020 INITIALIZE_PASS_END(ExpandMemCmpLegacyPass
, DEBUG_TYPE
,
1021 "Expand memcmp() to load/stores", false, false)
1023 FunctionPass
*llvm::createExpandMemCmpLegacyPass() {
1024 return new ExpandMemCmpLegacyPass();