1 //===- NVPTXLowerAggrCopies.cpp - ------------------------------*- C++ -*--===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
10 // Lower aggregate copies, memset, memcpy, memmov intrinsics into loops when
11 // the size is large or is not a compile-time constant.
13 //===----------------------------------------------------------------------===//
15 #include "NVPTXLowerAggrCopies.h"
16 #include "llvm/Analysis/TargetTransformInfo.h"
17 #include "llvm/CodeGen/StackProtector.h"
18 #include "llvm/IR/Constants.h"
19 #include "llvm/IR/DataLayout.h"
20 #include "llvm/IR/Function.h"
21 #include "llvm/IR/IRBuilder.h"
22 #include "llvm/IR/Instructions.h"
23 #include "llvm/IR/IntrinsicInst.h"
24 #include "llvm/IR/Intrinsics.h"
25 #include "llvm/IR/LLVMContext.h"
26 #include "llvm/IR/Module.h"
27 #include "llvm/Support/Debug.h"
28 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
29 #include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
31 #define DEBUG_TYPE "nvptx"
37 // actual analysis class, which is a functionpass
38 struct NVPTXLowerAggrCopies
: public FunctionPass
{
41 NVPTXLowerAggrCopies() : FunctionPass(ID
) {}
43 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
44 AU
.addPreserved
<StackProtector
>();
45 AU
.addRequired
<TargetTransformInfoWrapperPass
>();
48 bool runOnFunction(Function
&F
) override
;
50 static const unsigned MaxAggrCopySize
= 128;
52 StringRef
getPassName() const override
{
53 return "Lower aggregate copies/intrinsics into loops";
57 char NVPTXLowerAggrCopies::ID
= 0;
59 bool NVPTXLowerAggrCopies::runOnFunction(Function
&F
) {
60 SmallVector
<LoadInst
*, 4> AggrLoads
;
61 SmallVector
<MemIntrinsic
*, 4> MemCalls
;
63 const DataLayout
&DL
= F
.getParent()->getDataLayout();
64 LLVMContext
&Context
= F
.getParent()->getContext();
65 const TargetTransformInfo
&TTI
=
66 getAnalysis
<TargetTransformInfoWrapperPass
>().getTTI(F
);
68 // Collect all aggregate loads and mem* calls.
69 for (Function::iterator BI
= F
.begin(), BE
= F
.end(); BI
!= BE
; ++BI
) {
70 for (BasicBlock::iterator II
= BI
->begin(), IE
= BI
->end(); II
!= IE
;
72 if (LoadInst
*LI
= dyn_cast
<LoadInst
>(II
)) {
76 if (DL
.getTypeStoreSize(LI
->getType()) < MaxAggrCopySize
)
79 if (StoreInst
*SI
= dyn_cast
<StoreInst
>(LI
->user_back())) {
80 if (SI
->getOperand(0) != LI
)
82 AggrLoads
.push_back(LI
);
84 } else if (MemIntrinsic
*IntrCall
= dyn_cast
<MemIntrinsic
>(II
)) {
85 // Convert intrinsic calls with variable size or with constant size
86 // larger than the MaxAggrCopySize threshold.
87 if (ConstantInt
*LenCI
= dyn_cast
<ConstantInt
>(IntrCall
->getLength())) {
88 if (LenCI
->getZExtValue() >= MaxAggrCopySize
) {
89 MemCalls
.push_back(IntrCall
);
92 MemCalls
.push_back(IntrCall
);
98 if (AggrLoads
.size() == 0 && MemCalls
.size() == 0) {
103 // Do the transformation of an aggr load/copy/set to a loop
105 for (LoadInst
*LI
: AggrLoads
) {
106 StoreInst
*SI
= dyn_cast
<StoreInst
>(*LI
->user_begin());
107 Value
*SrcAddr
= LI
->getOperand(0);
108 Value
*DstAddr
= SI
->getOperand(1);
109 unsigned NumLoads
= DL
.getTypeStoreSize(LI
->getType());
110 ConstantInt
*CopyLen
=
111 ConstantInt::get(Type::getInt32Ty(Context
), NumLoads
);
113 createMemCpyLoopKnownSize(/* ConvertedInst */ SI
,
114 /* SrcAddr */ SrcAddr
, /* DstAddr */ DstAddr
,
115 /* CopyLen */ CopyLen
,
116 /* SrcAlign */ LI
->getAlignment(),
117 /* DestAlign */ SI
->getAlignment(),
118 /* SrcIsVolatile */ LI
->isVolatile(),
119 /* DstIsVolatile */ SI
->isVolatile(), TTI
);
121 SI
->eraseFromParent();
122 LI
->eraseFromParent();
125 // Transform mem* intrinsic calls.
126 for (MemIntrinsic
*MemCall
: MemCalls
) {
127 if (MemCpyInst
*Memcpy
= dyn_cast
<MemCpyInst
>(MemCall
)) {
128 expandMemCpyAsLoop(Memcpy
, TTI
);
129 } else if (MemMoveInst
*Memmove
= dyn_cast
<MemMoveInst
>(MemCall
)) {
130 expandMemMoveAsLoop(Memmove
);
131 } else if (MemSetInst
*Memset
= dyn_cast
<MemSetInst
>(MemCall
)) {
132 expandMemSetAsLoop(Memset
);
134 MemCall
->eraseFromParent();
143 void initializeNVPTXLowerAggrCopiesPass(PassRegistry
&);
146 INITIALIZE_PASS(NVPTXLowerAggrCopies
, "nvptx-lower-aggr-copies",
147 "Lower aggregate copies, and llvm.mem* intrinsics into loops",
150 FunctionPass
*llvm::createLowerAggrCopies() {
151 return new NVPTXLowerAggrCopies();