1 //===----------------------------------------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "ReduceOperandsToArgs.h"
11 #include "llvm/ADT/Sequence.h"
12 #include "llvm/IR/InstIterator.h"
13 #include "llvm/IR/InstrTypes.h"
14 #include "llvm/IR/Instructions.h"
15 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
16 #include "llvm/Transforms/Utils/Cloning.h"
20 static bool canReplaceFunction(Function
*F
) {
21 return all_of(F
->uses(), [](Use
&Op
) {
22 if (auto *CI
= dyn_cast
<CallBase
>(Op
.getUser()))
23 return &CI
->getCalledOperandUse() == &Op
;
28 static bool canReduceUse(Use
&Op
) {
29 Value
*Val
= Op
.get();
30 Type
*Ty
= Val
->getType();
32 // Only replace operands that can be passed-by-value.
33 if (!Ty
->isFirstClassType())
36 // Don't pass labels/metadata as arguments.
37 if (Ty
->isLabelTy() || Ty
->isMetadataTy())
40 // No need to replace values that are already arguments.
41 if (isa
<Argument
>(Val
))
44 // Do not replace literals.
45 if (isa
<ConstantData
>(Val
))
48 // Do not convert direct function calls to indirect calls.
49 if (auto *CI
= dyn_cast
<CallBase
>(Op
.getUser()))
50 if (&CI
->getCalledOperandUse() == &Op
)
56 /// Goes over OldF calls and replaces them with a call to NewF.
57 static void replaceFunctionCalls(Function
*OldF
, Function
*NewF
) {
58 SmallVector
<CallBase
*> Callers
;
59 for (Use
&U
: OldF
->uses()) {
60 auto *CI
= cast
<CallBase
>(U
.getUser());
61 assert(&U
== &CI
->getCalledOperandUse());
62 assert(CI
->getCalledFunction() == OldF
);
63 Callers
.push_back(CI
);
66 // Call arguments for NewF.
67 SmallVector
<Value
*> Args(NewF
->arg_size(), nullptr);
69 // Fill up the additional parameters with undef values.
70 for (auto ArgIdx
: llvm::seq
<size_t>(OldF
->arg_size(), NewF
->arg_size())) {
71 Type
*NewArgTy
= NewF
->getArg(ArgIdx
)->getType();
72 Args
[ArgIdx
] = UndefValue::get(NewArgTy
);
75 for (CallBase
*CI
: Callers
) {
76 // Preserve the original function arguments.
77 for (auto Z
: zip_first(CI
->args(), Args
))
78 std::get
<1>(Z
) = std::get
<0>(Z
);
80 // Also preserve operand bundles.
81 SmallVector
<OperandBundleDef
> OperandBundles
;
82 CI
->getOperandBundlesAsDefs(OperandBundles
);
84 // Create the new function call.
86 if (auto *II
= dyn_cast
<InvokeInst
>(CI
)) {
87 NewCI
= InvokeInst::Create(NewF
, cast
<InvokeInst
>(II
)->getNormalDest(),
88 cast
<InvokeInst
>(II
)->getUnwindDest(), Args
,
89 OperandBundles
, CI
->getName());
91 assert(isa
<CallInst
>(CI
));
92 NewCI
= CallInst::Create(NewF
, Args
, OperandBundles
, CI
->getName());
94 NewCI
->setCallingConv(NewF
->getCallingConv());
96 // Do the replacement for this use.
98 CI
->replaceAllUsesWith(NewCI
);
99 ReplaceInstWithInst(CI
, NewCI
);
103 /// Add a new function argument to @p F for each use in @OpsToReplace, and
104 /// replace those operand values with the new function argument.
105 static void substituteOperandWithArgument(Function
*OldF
,
106 ArrayRef
<Use
*> OpsToReplace
) {
107 if (OpsToReplace
.empty())
110 SetVector
<Value
*> UniqueValues
;
111 for (Use
*Op
: OpsToReplace
)
112 UniqueValues
.insert(Op
->get());
114 // Determine the new function's signature.
115 SmallVector
<Type
*> NewArgTypes
;
116 llvm::append_range(NewArgTypes
, OldF
->getFunctionType()->params());
117 size_t ArgOffset
= NewArgTypes
.size();
118 for (Value
*V
: UniqueValues
)
119 NewArgTypes
.push_back(V
->getType());
121 FunctionType::get(OldF
->getFunctionType()->getReturnType(), NewArgTypes
,
122 OldF
->getFunctionType()->isVarArg());
124 // Create the new function...
126 Function::Create(FTy
, OldF
->getLinkage(), OldF
->getAddressSpace(),
127 OldF
->getName(), OldF
->getParent());
129 // In order to preserve function order, we move NewF behind OldF
130 NewF
->removeFromParent();
131 OldF
->getParent()->getFunctionList().insertAfter(OldF
->getIterator(), NewF
);
133 // Preserve the parameters of OldF.
134 ValueToValueMapTy VMap
;
135 for (auto Z
: zip_first(OldF
->args(), NewF
->args())) {
136 Argument
&OldArg
= std::get
<0>(Z
);
137 Argument
&NewArg
= std::get
<1>(Z
);
139 NewArg
.setName(OldArg
.getName()); // Copy the name over...
140 VMap
[&OldArg
] = &NewArg
; // Add mapping to VMap
143 // Adjust the new parameters.
144 ValueToValueMapTy OldValMap
;
145 for (auto Z
: zip_first(UniqueValues
, drop_begin(NewF
->args(), ArgOffset
))) {
146 Value
*OldVal
= std::get
<0>(Z
);
147 Argument
&NewArg
= std::get
<1>(Z
);
149 NewArg
.setName(OldVal
->getName());
150 OldValMap
[OldVal
] = &NewArg
;
153 SmallVector
<ReturnInst
*, 8> Returns
; // Ignore returns cloned.
154 CloneFunctionInto(NewF
, OldF
, VMap
, CloneFunctionChangeType::LocalChangesOnly
,
155 Returns
, "", /*CodeInfo=*/nullptr);
157 // Replace the actual operands.
158 for (Use
*Op
: OpsToReplace
) {
159 Value
*NewArg
= OldValMap
.lookup(Op
->get());
160 auto *NewUser
= cast
<Instruction
>(VMap
.lookup(Op
->getUser()));
161 NewUser
->setOperand(Op
->getOperandNo(), NewArg
);
164 // Replace all OldF uses with NewF.
165 replaceFunctionCalls(OldF
, NewF
);
167 // Rename NewF to OldF's name.
168 std::string FName
= OldF
->getName().str();
169 OldF
->replaceAllUsesWith(ConstantExpr::getBitCast(NewF
, OldF
->getType()));
170 OldF
->eraseFromParent();
171 NewF
->setName(FName
);
174 static void reduceOperandsToArgs(Oracle
&O
, Module
&Program
) {
175 SmallVector
<Use
*> OperandsToReduce
;
176 for (Function
&F
: make_early_inc_range(Program
.functions())) {
177 if (!canReplaceFunction(&F
))
179 OperandsToReduce
.clear();
180 for (Instruction
&I
: instructions(&F
)) {
181 for (Use
&Op
: I
.operands()) {
182 if (!canReduceUse(Op
))
187 OperandsToReduce
.push_back(&Op
);
191 substituteOperandWithArgument(&F
, OperandsToReduce
);
195 void llvm::reduceOperandsToArgsDeltaPass(TestRunner
&Test
) {
196 outs() << "*** Converting operands to function arguments ...\n";
197 return runDeltaPass(Test
, reduceOperandsToArgs
);