1 //===- NVPTXLowerAggrCopies.cpp - ------------------------------*- C++ -*--===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
10 // Lower aggregate copies, memset, memcpy, memmov intrinsics into loops when
11 // the size is large or is not a compile-time constant.
13 //===----------------------------------------------------------------------===//
15 #include "NVPTXLowerAggrCopies.h"
16 #include "llvm/Analysis/TargetTransformInfo.h"
17 #include "llvm/CodeGen/StackProtector.h"
18 #include "llvm/IR/Constants.h"
19 #include "llvm/IR/DataLayout.h"
20 #include "llvm/IR/Function.h"
21 #include "llvm/IR/IRBuilder.h"
22 #include "llvm/IR/Instructions.h"
23 #include "llvm/IR/IntrinsicInst.h"
24 #include "llvm/IR/Intrinsics.h"
25 #include "llvm/IR/LLVMContext.h"
26 #include "llvm/IR/Module.h"
27 #include "llvm/Support/Debug.h"
28 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
29 #include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
31 #define DEBUG_TYPE "nvptx"
37 // actual analysis class, which is a functionpass
38 struct NVPTXLowerAggrCopies
: public FunctionPass
{
41 NVPTXLowerAggrCopies() : FunctionPass(ID
) {}
43 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
44 AU
.addPreserved
<StackProtector
>();
45 AU
.addRequired
<TargetTransformInfoWrapperPass
>();
48 bool runOnFunction(Function
&F
) override
;
50 static const unsigned MaxAggrCopySize
= 128;
52 StringRef
getPassName() const override
{
53 return "Lower aggregate copies/intrinsics into loops";
57 char NVPTXLowerAggrCopies::ID
= 0;
59 bool NVPTXLowerAggrCopies::runOnFunction(Function
&F
) {
60 SmallVector
<LoadInst
*, 4> AggrLoads
;
61 SmallVector
<MemIntrinsic
*, 4> MemCalls
;
63 const DataLayout
&DL
= F
.getParent()->getDataLayout();
64 LLVMContext
&Context
= F
.getParent()->getContext();
65 const TargetTransformInfo
&TTI
=
66 getAnalysis
<TargetTransformInfoWrapperPass
>().getTTI(F
);
68 // Collect all aggregate loads and mem* calls.
69 for (BasicBlock
&BB
: F
) {
70 for (Instruction
&I
: BB
) {
71 if (LoadInst
*LI
= dyn_cast
<LoadInst
>(&I
)) {
75 if (DL
.getTypeStoreSize(LI
->getType()) < MaxAggrCopySize
)
78 if (StoreInst
*SI
= dyn_cast
<StoreInst
>(LI
->user_back())) {
79 if (SI
->getOperand(0) != LI
)
81 AggrLoads
.push_back(LI
);
83 } else if (MemIntrinsic
*IntrCall
= dyn_cast
<MemIntrinsic
>(&I
)) {
84 // Convert intrinsic calls with variable size or with constant size
85 // larger than the MaxAggrCopySize threshold.
86 if (ConstantInt
*LenCI
= dyn_cast
<ConstantInt
>(IntrCall
->getLength())) {
87 if (LenCI
->getZExtValue() >= MaxAggrCopySize
) {
88 MemCalls
.push_back(IntrCall
);
91 MemCalls
.push_back(IntrCall
);
97 if (AggrLoads
.size() == 0 && MemCalls
.size() == 0) {
102 // Do the transformation of an aggr load/copy/set to a loop
104 for (LoadInst
*LI
: AggrLoads
) {
105 auto *SI
= cast
<StoreInst
>(*LI
->user_begin());
106 Value
*SrcAddr
= LI
->getOperand(0);
107 Value
*DstAddr
= SI
->getOperand(1);
108 unsigned NumLoads
= DL
.getTypeStoreSize(LI
->getType());
109 ConstantInt
*CopyLen
=
110 ConstantInt::get(Type::getInt32Ty(Context
), NumLoads
);
112 createMemCpyLoopKnownSize(/* ConvertedInst */ SI
,
113 /* SrcAddr */ SrcAddr
, /* DstAddr */ DstAddr
,
114 /* CopyLen */ CopyLen
,
115 /* SrcAlign */ LI
->getAlign(),
116 /* DestAlign */ SI
->getAlign(),
117 /* SrcIsVolatile */ LI
->isVolatile(),
118 /* DstIsVolatile */ SI
->isVolatile(),
119 /* CanOverlap */ true, TTI
);
121 SI
->eraseFromParent();
122 LI
->eraseFromParent();
125 // Transform mem* intrinsic calls.
126 for (MemIntrinsic
*MemCall
: MemCalls
) {
127 if (MemCpyInst
*Memcpy
= dyn_cast
<MemCpyInst
>(MemCall
)) {
128 expandMemCpyAsLoop(Memcpy
, TTI
);
129 } else if (MemMoveInst
*Memmove
= dyn_cast
<MemMoveInst
>(MemCall
)) {
130 expandMemMoveAsLoop(Memmove
, TTI
);
131 } else if (MemSetInst
*Memset
= dyn_cast
<MemSetInst
>(MemCall
)) {
132 expandMemSetAsLoop(Memset
);
134 MemCall
->eraseFromParent();
143 void initializeNVPTXLowerAggrCopiesPass(PassRegistry
&);
146 INITIALIZE_PASS(NVPTXLowerAggrCopies
, "nvptx-lower-aggr-copies",
147 "Lower aggregate copies, and llvm.mem* intrinsics into loops",
150 FunctionPass
*llvm::createLowerAggrCopies() {
151 return new NVPTXLowerAggrCopies();