[ASan] Make insertion of version mismatch guard configurable
[llvm-core.git] / lib / Transforms / Utils / LowerMemIntrinsics.cpp
blob0cc085dc366c66890a6f3aeb9cfea10ed064a4ee
1 //===- LowerMemIntrinsics.cpp ----------------------------------*- C++ -*--===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
10 #include "llvm/Analysis/TargetTransformInfo.h"
11 #include "llvm/IR/IRBuilder.h"
12 #include "llvm/IR/IntrinsicInst.h"
13 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
15 using namespace llvm;
17 static unsigned getLoopOperandSizeInBytes(Type *Type) {
18 if (VectorType *VTy = dyn_cast<VectorType>(Type)) {
19 return VTy->getBitWidth() / 8;
22 return Type->getPrimitiveSizeInBits() / 8;
25 void llvm::createMemCpyLoopKnownSize(Instruction *InsertBefore, Value *SrcAddr,
26 Value *DstAddr, ConstantInt *CopyLen,
27 unsigned SrcAlign, unsigned DestAlign,
28 bool SrcIsVolatile, bool DstIsVolatile,
29 const TargetTransformInfo &TTI) {
30 // No need to expand zero length copies.
31 if (CopyLen->isZero())
32 return;
34 BasicBlock *PreLoopBB = InsertBefore->getParent();
35 BasicBlock *PostLoopBB = nullptr;
36 Function *ParentFunc = PreLoopBB->getParent();
37 LLVMContext &Ctx = PreLoopBB->getContext();
39 Type *TypeOfCopyLen = CopyLen->getType();
40 Type *LoopOpType =
41 TTI.getMemcpyLoopLoweringType(Ctx, CopyLen, SrcAlign, DestAlign);
43 unsigned LoopOpSize = getLoopOperandSizeInBytes(LoopOpType);
44 uint64_t LoopEndCount = CopyLen->getZExtValue() / LoopOpSize;
46 unsigned SrcAS = cast<PointerType>(SrcAddr->getType())->getAddressSpace();
47 unsigned DstAS = cast<PointerType>(DstAddr->getType())->getAddressSpace();
49 if (LoopEndCount != 0) {
50 // Split
51 PostLoopBB = PreLoopBB->splitBasicBlock(InsertBefore, "memcpy-split");
52 BasicBlock *LoopBB =
53 BasicBlock::Create(Ctx, "load-store-loop", ParentFunc, PostLoopBB);
54 PreLoopBB->getTerminator()->setSuccessor(0, LoopBB);
56 IRBuilder<> PLBuilder(PreLoopBB->getTerminator());
58 // Cast the Src and Dst pointers to pointers to the loop operand type (if
59 // needed).
60 PointerType *SrcOpType = PointerType::get(LoopOpType, SrcAS);
61 PointerType *DstOpType = PointerType::get(LoopOpType, DstAS);
62 if (SrcAddr->getType() != SrcOpType) {
63 SrcAddr = PLBuilder.CreateBitCast(SrcAddr, SrcOpType);
65 if (DstAddr->getType() != DstOpType) {
66 DstAddr = PLBuilder.CreateBitCast(DstAddr, DstOpType);
69 IRBuilder<> LoopBuilder(LoopBB);
70 PHINode *LoopIndex = LoopBuilder.CreatePHI(TypeOfCopyLen, 2, "loop-index");
71 LoopIndex->addIncoming(ConstantInt::get(TypeOfCopyLen, 0U), PreLoopBB);
72 // Loop Body
73 Value *SrcGEP =
74 LoopBuilder.CreateInBoundsGEP(LoopOpType, SrcAddr, LoopIndex);
75 Value *Load = LoopBuilder.CreateLoad(LoopOpType, SrcGEP, SrcIsVolatile);
76 Value *DstGEP =
77 LoopBuilder.CreateInBoundsGEP(LoopOpType, DstAddr, LoopIndex);
78 LoopBuilder.CreateStore(Load, DstGEP, DstIsVolatile);
80 Value *NewIndex =
81 LoopBuilder.CreateAdd(LoopIndex, ConstantInt::get(TypeOfCopyLen, 1U));
82 LoopIndex->addIncoming(NewIndex, LoopBB);
84 // Create the loop branch condition.
85 Constant *LoopEndCI = ConstantInt::get(TypeOfCopyLen, LoopEndCount);
86 LoopBuilder.CreateCondBr(LoopBuilder.CreateICmpULT(NewIndex, LoopEndCI),
87 LoopBB, PostLoopBB);
90 uint64_t BytesCopied = LoopEndCount * LoopOpSize;
91 uint64_t RemainingBytes = CopyLen->getZExtValue() - BytesCopied;
92 if (RemainingBytes) {
93 IRBuilder<> RBuilder(PostLoopBB ? PostLoopBB->getFirstNonPHI()
94 : InsertBefore);
96 // Update the alignment based on the copy size used in the loop body.
97 SrcAlign = std::min(SrcAlign, LoopOpSize);
98 DestAlign = std::min(DestAlign, LoopOpSize);
100 SmallVector<Type *, 5> RemainingOps;
101 TTI.getMemcpyLoopResidualLoweringType(RemainingOps, Ctx, RemainingBytes,
102 SrcAlign, DestAlign);
104 for (auto OpTy : RemainingOps) {
105 // Calaculate the new index
106 unsigned OperandSize = getLoopOperandSizeInBytes(OpTy);
107 uint64_t GepIndex = BytesCopied / OperandSize;
108 assert(GepIndex * OperandSize == BytesCopied &&
109 "Division should have no Remainder!");
110 // Cast source to operand type and load
111 PointerType *SrcPtrType = PointerType::get(OpTy, SrcAS);
112 Value *CastedSrc = SrcAddr->getType() == SrcPtrType
113 ? SrcAddr
114 : RBuilder.CreateBitCast(SrcAddr, SrcPtrType);
115 Value *SrcGEP = RBuilder.CreateInBoundsGEP(
116 OpTy, CastedSrc, ConstantInt::get(TypeOfCopyLen, GepIndex));
117 Value *Load = RBuilder.CreateLoad(OpTy, SrcGEP, SrcIsVolatile);
119 // Cast destination to operand type and store.
120 PointerType *DstPtrType = PointerType::get(OpTy, DstAS);
121 Value *CastedDst = DstAddr->getType() == DstPtrType
122 ? DstAddr
123 : RBuilder.CreateBitCast(DstAddr, DstPtrType);
124 Value *DstGEP = RBuilder.CreateInBoundsGEP(
125 OpTy, CastedDst, ConstantInt::get(TypeOfCopyLen, GepIndex));
126 RBuilder.CreateStore(Load, DstGEP, DstIsVolatile);
128 BytesCopied += OperandSize;
131 assert(BytesCopied == CopyLen->getZExtValue() &&
132 "Bytes copied should match size in the call!");
135 void llvm::createMemCpyLoopUnknownSize(Instruction *InsertBefore,
136 Value *SrcAddr, Value *DstAddr,
137 Value *CopyLen, unsigned SrcAlign,
138 unsigned DestAlign, bool SrcIsVolatile,
139 bool DstIsVolatile,
140 const TargetTransformInfo &TTI) {
141 BasicBlock *PreLoopBB = InsertBefore->getParent();
142 BasicBlock *PostLoopBB =
143 PreLoopBB->splitBasicBlock(InsertBefore, "post-loop-memcpy-expansion");
145 Function *ParentFunc = PreLoopBB->getParent();
146 LLVMContext &Ctx = PreLoopBB->getContext();
148 Type *LoopOpType =
149 TTI.getMemcpyLoopLoweringType(Ctx, CopyLen, SrcAlign, DestAlign);
150 unsigned LoopOpSize = getLoopOperandSizeInBytes(LoopOpType);
152 IRBuilder<> PLBuilder(PreLoopBB->getTerminator());
154 unsigned SrcAS = cast<PointerType>(SrcAddr->getType())->getAddressSpace();
155 unsigned DstAS = cast<PointerType>(DstAddr->getType())->getAddressSpace();
156 PointerType *SrcOpType = PointerType::get(LoopOpType, SrcAS);
157 PointerType *DstOpType = PointerType::get(LoopOpType, DstAS);
158 if (SrcAddr->getType() != SrcOpType) {
159 SrcAddr = PLBuilder.CreateBitCast(SrcAddr, SrcOpType);
161 if (DstAddr->getType() != DstOpType) {
162 DstAddr = PLBuilder.CreateBitCast(DstAddr, DstOpType);
165 // Calculate the loop trip count, and remaining bytes to copy after the loop.
166 Type *CopyLenType = CopyLen->getType();
167 IntegerType *ILengthType = dyn_cast<IntegerType>(CopyLenType);
168 assert(ILengthType &&
169 "expected size argument to memcpy to be an integer type!");
170 Type *Int8Type = Type::getInt8Ty(Ctx);
171 bool LoopOpIsInt8 = LoopOpType == Int8Type;
172 ConstantInt *CILoopOpSize = ConstantInt::get(ILengthType, LoopOpSize);
173 Value *RuntimeLoopCount = LoopOpIsInt8 ?
174 CopyLen :
175 PLBuilder.CreateUDiv(CopyLen, CILoopOpSize);
176 BasicBlock *LoopBB =
177 BasicBlock::Create(Ctx, "loop-memcpy-expansion", ParentFunc, PostLoopBB);
178 IRBuilder<> LoopBuilder(LoopBB);
180 PHINode *LoopIndex = LoopBuilder.CreatePHI(CopyLenType, 2, "loop-index");
181 LoopIndex->addIncoming(ConstantInt::get(CopyLenType, 0U), PreLoopBB);
183 Value *SrcGEP = LoopBuilder.CreateInBoundsGEP(LoopOpType, SrcAddr, LoopIndex);
184 Value *Load = LoopBuilder.CreateLoad(LoopOpType, SrcGEP, SrcIsVolatile);
185 Value *DstGEP = LoopBuilder.CreateInBoundsGEP(LoopOpType, DstAddr, LoopIndex);
186 LoopBuilder.CreateStore(Load, DstGEP, DstIsVolatile);
188 Value *NewIndex =
189 LoopBuilder.CreateAdd(LoopIndex, ConstantInt::get(CopyLenType, 1U));
190 LoopIndex->addIncoming(NewIndex, LoopBB);
192 if (!LoopOpIsInt8) {
193 // Add in the
194 Value *RuntimeResidual = PLBuilder.CreateURem(CopyLen, CILoopOpSize);
195 Value *RuntimeBytesCopied = PLBuilder.CreateSub(CopyLen, RuntimeResidual);
197 // Loop body for the residual copy.
198 BasicBlock *ResLoopBB = BasicBlock::Create(Ctx, "loop-memcpy-residual",
199 PreLoopBB->getParent(),
200 PostLoopBB);
201 // Residual loop header.
202 BasicBlock *ResHeaderBB = BasicBlock::Create(
203 Ctx, "loop-memcpy-residual-header", PreLoopBB->getParent(), nullptr);
205 // Need to update the pre-loop basic block to branch to the correct place.
206 // branch to the main loop if the count is non-zero, branch to the residual
207 // loop if the copy size is smaller then 1 iteration of the main loop but
208 // non-zero and finally branch to after the residual loop if the memcpy
209 // size is zero.
210 ConstantInt *Zero = ConstantInt::get(ILengthType, 0U);
211 PLBuilder.CreateCondBr(PLBuilder.CreateICmpNE(RuntimeLoopCount, Zero),
212 LoopBB, ResHeaderBB);
213 PreLoopBB->getTerminator()->eraseFromParent();
215 LoopBuilder.CreateCondBr(
216 LoopBuilder.CreateICmpULT(NewIndex, RuntimeLoopCount), LoopBB,
217 ResHeaderBB);
219 // Determine if we need to branch to the residual loop or bypass it.
220 IRBuilder<> RHBuilder(ResHeaderBB);
221 RHBuilder.CreateCondBr(RHBuilder.CreateICmpNE(RuntimeResidual, Zero),
222 ResLoopBB, PostLoopBB);
224 // Copy the residual with single byte load/store loop.
225 IRBuilder<> ResBuilder(ResLoopBB);
226 PHINode *ResidualIndex =
227 ResBuilder.CreatePHI(CopyLenType, 2, "residual-loop-index");
228 ResidualIndex->addIncoming(Zero, ResHeaderBB);
230 Value *SrcAsInt8 =
231 ResBuilder.CreateBitCast(SrcAddr, PointerType::get(Int8Type, SrcAS));
232 Value *DstAsInt8 =
233 ResBuilder.CreateBitCast(DstAddr, PointerType::get(Int8Type, DstAS));
234 Value *FullOffset = ResBuilder.CreateAdd(RuntimeBytesCopied, ResidualIndex);
235 Value *SrcGEP =
236 ResBuilder.CreateInBoundsGEP(Int8Type, SrcAsInt8, FullOffset);
237 Value *Load = ResBuilder.CreateLoad(Int8Type, SrcGEP, SrcIsVolatile);
238 Value *DstGEP =
239 ResBuilder.CreateInBoundsGEP(Int8Type, DstAsInt8, FullOffset);
240 ResBuilder.CreateStore(Load, DstGEP, DstIsVolatile);
242 Value *ResNewIndex =
243 ResBuilder.CreateAdd(ResidualIndex, ConstantInt::get(CopyLenType, 1U));
244 ResidualIndex->addIncoming(ResNewIndex, ResLoopBB);
246 // Create the loop branch condition.
247 ResBuilder.CreateCondBr(
248 ResBuilder.CreateICmpULT(ResNewIndex, RuntimeResidual), ResLoopBB,
249 PostLoopBB);
250 } else {
251 // In this case the loop operand type was a byte, and there is no need for a
252 // residual loop to copy the remaining memory after the main loop.
253 // We do however need to patch up the control flow by creating the
254 // terminators for the preloop block and the memcpy loop.
255 ConstantInt *Zero = ConstantInt::get(ILengthType, 0U);
256 PLBuilder.CreateCondBr(PLBuilder.CreateICmpNE(RuntimeLoopCount, Zero),
257 LoopBB, PostLoopBB);
258 PreLoopBB->getTerminator()->eraseFromParent();
259 LoopBuilder.CreateCondBr(
260 LoopBuilder.CreateICmpULT(NewIndex, RuntimeLoopCount), LoopBB,
261 PostLoopBB);
265 // Lower memmove to IR. memmove is required to correctly copy overlapping memory
266 // regions; therefore, it has to check the relative positions of the source and
267 // destination pointers and choose the copy direction accordingly.
269 // The code below is an IR rendition of this C function:
271 // void* memmove(void* dst, const void* src, size_t n) {
272 // unsigned char* d = dst;
273 // const unsigned char* s = src;
274 // if (s < d) {
275 // // copy backwards
276 // while (n--) {
277 // d[n] = s[n];
278 // }
279 // } else {
280 // // copy forward
281 // for (size_t i = 0; i < n; ++i) {
282 // d[i] = s[i];
283 // }
284 // }
285 // return dst;
286 // }
287 static void createMemMoveLoop(Instruction *InsertBefore,
288 Value *SrcAddr, Value *DstAddr, Value *CopyLen,
289 unsigned SrcAlign, unsigned DestAlign,
290 bool SrcIsVolatile, bool DstIsVolatile) {
291 Type *TypeOfCopyLen = CopyLen->getType();
292 BasicBlock *OrigBB = InsertBefore->getParent();
293 Function *F = OrigBB->getParent();
295 Type *EltTy = cast<PointerType>(SrcAddr->getType())->getElementType();
297 // Create the a comparison of src and dst, based on which we jump to either
298 // the forward-copy part of the function (if src >= dst) or the backwards-copy
299 // part (if src < dst).
300 // SplitBlockAndInsertIfThenElse conveniently creates the basic if-then-else
301 // structure. Its block terminators (unconditional branches) are replaced by
302 // the appropriate conditional branches when the loop is built.
303 ICmpInst *PtrCompare = new ICmpInst(InsertBefore, ICmpInst::ICMP_ULT,
304 SrcAddr, DstAddr, "compare_src_dst");
305 Instruction *ThenTerm, *ElseTerm;
306 SplitBlockAndInsertIfThenElse(PtrCompare, InsertBefore, &ThenTerm,
307 &ElseTerm);
309 // Each part of the function consists of two blocks:
310 // copy_backwards: used to skip the loop when n == 0
311 // copy_backwards_loop: the actual backwards loop BB
312 // copy_forward: used to skip the loop when n == 0
313 // copy_forward_loop: the actual forward loop BB
314 BasicBlock *CopyBackwardsBB = ThenTerm->getParent();
315 CopyBackwardsBB->setName("copy_backwards");
316 BasicBlock *CopyForwardBB = ElseTerm->getParent();
317 CopyForwardBB->setName("copy_forward");
318 BasicBlock *ExitBB = InsertBefore->getParent();
319 ExitBB->setName("memmove_done");
321 // Initial comparison of n == 0 that lets us skip the loops altogether. Shared
322 // between both backwards and forward copy clauses.
323 ICmpInst *CompareN =
324 new ICmpInst(OrigBB->getTerminator(), ICmpInst::ICMP_EQ, CopyLen,
325 ConstantInt::get(TypeOfCopyLen, 0), "compare_n_to_0");
327 // Copying backwards.
328 BasicBlock *LoopBB =
329 BasicBlock::Create(F->getContext(), "copy_backwards_loop", F, CopyForwardBB);
330 IRBuilder<> LoopBuilder(LoopBB);
331 PHINode *LoopPhi = LoopBuilder.CreatePHI(TypeOfCopyLen, 0);
332 Value *IndexPtr = LoopBuilder.CreateSub(
333 LoopPhi, ConstantInt::get(TypeOfCopyLen, 1), "index_ptr");
334 Value *Element = LoopBuilder.CreateLoad(
335 EltTy, LoopBuilder.CreateInBoundsGEP(EltTy, SrcAddr, IndexPtr),
336 "element");
337 LoopBuilder.CreateStore(
338 Element, LoopBuilder.CreateInBoundsGEP(EltTy, DstAddr, IndexPtr));
339 LoopBuilder.CreateCondBr(
340 LoopBuilder.CreateICmpEQ(IndexPtr, ConstantInt::get(TypeOfCopyLen, 0)),
341 ExitBB, LoopBB);
342 LoopPhi->addIncoming(IndexPtr, LoopBB);
343 LoopPhi->addIncoming(CopyLen, CopyBackwardsBB);
344 BranchInst::Create(ExitBB, LoopBB, CompareN, ThenTerm);
345 ThenTerm->eraseFromParent();
347 // Copying forward.
348 BasicBlock *FwdLoopBB =
349 BasicBlock::Create(F->getContext(), "copy_forward_loop", F, ExitBB);
350 IRBuilder<> FwdLoopBuilder(FwdLoopBB);
351 PHINode *FwdCopyPhi = FwdLoopBuilder.CreatePHI(TypeOfCopyLen, 0, "index_ptr");
352 Value *FwdElement = FwdLoopBuilder.CreateLoad(
353 EltTy, FwdLoopBuilder.CreateInBoundsGEP(EltTy, SrcAddr, FwdCopyPhi),
354 "element");
355 FwdLoopBuilder.CreateStore(
356 FwdElement, FwdLoopBuilder.CreateInBoundsGEP(EltTy, DstAddr, FwdCopyPhi));
357 Value *FwdIndexPtr = FwdLoopBuilder.CreateAdd(
358 FwdCopyPhi, ConstantInt::get(TypeOfCopyLen, 1), "index_increment");
359 FwdLoopBuilder.CreateCondBr(FwdLoopBuilder.CreateICmpEQ(FwdIndexPtr, CopyLen),
360 ExitBB, FwdLoopBB);
361 FwdCopyPhi->addIncoming(FwdIndexPtr, FwdLoopBB);
362 FwdCopyPhi->addIncoming(ConstantInt::get(TypeOfCopyLen, 0), CopyForwardBB);
364 BranchInst::Create(ExitBB, FwdLoopBB, CompareN, ElseTerm);
365 ElseTerm->eraseFromParent();
368 static void createMemSetLoop(Instruction *InsertBefore,
369 Value *DstAddr, Value *CopyLen, Value *SetValue,
370 unsigned Align, bool IsVolatile) {
371 Type *TypeOfCopyLen = CopyLen->getType();
372 BasicBlock *OrigBB = InsertBefore->getParent();
373 Function *F = OrigBB->getParent();
374 BasicBlock *NewBB =
375 OrigBB->splitBasicBlock(InsertBefore, "split");
376 BasicBlock *LoopBB
377 = BasicBlock::Create(F->getContext(), "loadstoreloop", F, NewBB);
379 IRBuilder<> Builder(OrigBB->getTerminator());
381 // Cast pointer to the type of value getting stored
382 unsigned dstAS = cast<PointerType>(DstAddr->getType())->getAddressSpace();
383 DstAddr = Builder.CreateBitCast(DstAddr,
384 PointerType::get(SetValue->getType(), dstAS));
386 Builder.CreateCondBr(
387 Builder.CreateICmpEQ(ConstantInt::get(TypeOfCopyLen, 0), CopyLen), NewBB,
388 LoopBB);
389 OrigBB->getTerminator()->eraseFromParent();
391 IRBuilder<> LoopBuilder(LoopBB);
392 PHINode *LoopIndex = LoopBuilder.CreatePHI(TypeOfCopyLen, 0);
393 LoopIndex->addIncoming(ConstantInt::get(TypeOfCopyLen, 0), OrigBB);
395 LoopBuilder.CreateStore(
396 SetValue,
397 LoopBuilder.CreateInBoundsGEP(SetValue->getType(), DstAddr, LoopIndex),
398 IsVolatile);
400 Value *NewIndex =
401 LoopBuilder.CreateAdd(LoopIndex, ConstantInt::get(TypeOfCopyLen, 1));
402 LoopIndex->addIncoming(NewIndex, LoopBB);
404 LoopBuilder.CreateCondBr(LoopBuilder.CreateICmpULT(NewIndex, CopyLen), LoopBB,
405 NewBB);
408 void llvm::expandMemCpyAsLoop(MemCpyInst *Memcpy,
409 const TargetTransformInfo &TTI) {
410 if (ConstantInt *CI = dyn_cast<ConstantInt>(Memcpy->getLength())) {
411 createMemCpyLoopKnownSize(/* InsertBefore */ Memcpy,
412 /* SrcAddr */ Memcpy->getRawSource(),
413 /* DstAddr */ Memcpy->getRawDest(),
414 /* CopyLen */ CI,
415 /* SrcAlign */ Memcpy->getSourceAlignment(),
416 /* DestAlign */ Memcpy->getDestAlignment(),
417 /* SrcIsVolatile */ Memcpy->isVolatile(),
418 /* DstIsVolatile */ Memcpy->isVolatile(),
419 /* TargetTransformInfo */ TTI);
420 } else {
421 createMemCpyLoopUnknownSize(/* InsertBefore */ Memcpy,
422 /* SrcAddr */ Memcpy->getRawSource(),
423 /* DstAddr */ Memcpy->getRawDest(),
424 /* CopyLen */ Memcpy->getLength(),
425 /* SrcAlign */ Memcpy->getSourceAlignment(),
426 /* DestAlign */ Memcpy->getDestAlignment(),
427 /* SrcIsVolatile */ Memcpy->isVolatile(),
428 /* DstIsVolatile */ Memcpy->isVolatile(),
429 /* TargetTransfomrInfo */ TTI);
433 void llvm::expandMemMoveAsLoop(MemMoveInst *Memmove) {
434 createMemMoveLoop(/* InsertBefore */ Memmove,
435 /* SrcAddr */ Memmove->getRawSource(),
436 /* DstAddr */ Memmove->getRawDest(),
437 /* CopyLen */ Memmove->getLength(),
438 /* SrcAlign */ Memmove->getSourceAlignment(),
439 /* DestAlign */ Memmove->getDestAlignment(),
440 /* SrcIsVolatile */ Memmove->isVolatile(),
441 /* DstIsVolatile */ Memmove->isVolatile());
444 void llvm::expandMemSetAsLoop(MemSetInst *Memset) {
445 createMemSetLoop(/* InsertBefore */ Memset,
446 /* DstAddr */ Memset->getRawDest(),
447 /* CopyLen */ Memset->getLength(),
448 /* SetValue */ Memset->getValue(),
449 /* Alignment */ Memset->getDestAlignment(),
450 Memset->isVolatile());