1 //===- TruncInstCombine.cpp -----------------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // TruncInstCombine - looks for expression dags post-dominated by TruncInst and
10 // for each eligible dag, it will create a reduced bit-width expression, replace
11 // the old expression with this new one and remove the old expression.
12 // Eligible expression dag is such that:
13 // 1. Contains only supported instructions.
14 // 2. Supported leaves: ZExtInst, SExtInst, TruncInst and Constant value.
15 // 3. Can be evaluated into type with reduced legal bit-width.
16 // 4. All instructions in the dag must not have users outside the dag.
17 // The only exception is for {ZExt, SExt}Inst with operand type equal to
18 // the new reduced type evaluated in (3).
20 // The motivation for this optimization is that evaluating and expression using
21 // smaller bit-width is preferable, especially for vectorization where we can
22 // fit more values in one vectorized instruction. In addition, this optimization
23 // may decrease the number of cast instructions, but will not increase it.
25 //===----------------------------------------------------------------------===//
27 #include "AggressiveInstCombineInternal.h"
28 #include "llvm/ADT/MapVector.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/Analysis/ConstantFolding.h"
31 #include "llvm/Analysis/TargetLibraryInfo.h"
32 #include "llvm/IR/DataLayout.h"
33 #include "llvm/IR/Dominators.h"
34 #include "llvm/IR/IRBuilder.h"
37 #define DEBUG_TYPE "aggressive-instcombine"
39 /// Given an instruction and a container, it fills all the relevant operands of
40 /// that instruction, with respect to the Trunc expression dag optimizaton.
41 static void getRelevantOperands(Instruction
*I
, SmallVectorImpl
<Value
*> &Ops
) {
42 unsigned Opc
= I
->getOpcode();
44 case Instruction::Trunc
:
45 case Instruction::ZExt
:
46 case Instruction::SExt
:
47 // These CastInst are considered leaves of the evaluated expression, thus,
48 // their operands are not relevent.
50 case Instruction::Add
:
51 case Instruction::Sub
:
52 case Instruction::Mul
:
53 case Instruction::And
:
55 case Instruction::Xor
:
56 Ops
.push_back(I
->getOperand(0));
57 Ops
.push_back(I
->getOperand(1));
60 llvm_unreachable("Unreachable!");
64 bool TruncInstCombine::buildTruncExpressionDag() {
65 SmallVector
<Value
*, 8> Worklist
;
66 SmallVector
<Instruction
*, 8> Stack
;
67 // Clear old expression dag.
70 Worklist
.push_back(CurrentTruncInst
->getOperand(0));
72 while (!Worklist
.empty()) {
73 Value
*Curr
= Worklist
.back();
75 if (isa
<Constant
>(Curr
)) {
80 auto *I
= dyn_cast
<Instruction
>(Curr
);
84 if (!Stack
.empty() && Stack
.back() == I
) {
85 // Already handled all instruction operands, can remove it from both the
86 // Worklist and the Stack, and add it to the instruction info map.
89 // Insert I to the Info map.
90 InstInfoMap
.insert(std::make_pair(I
, Info()));
94 if (InstInfoMap
.count(I
)) {
99 // Add the instruction to the stack before start handling its operands.
102 unsigned Opc
= I
->getOpcode();
104 case Instruction::Trunc
:
105 case Instruction::ZExt
:
106 case Instruction::SExt
:
107 // trunc(trunc(x)) -> trunc(x)
108 // trunc(ext(x)) -> ext(x) if the source type is smaller than the new dest
109 // trunc(ext(x)) -> trunc(x) if the source type is larger than the new
112 case Instruction::Add
:
113 case Instruction::Sub
:
114 case Instruction::Mul
:
115 case Instruction::And
:
116 case Instruction::Or
:
117 case Instruction::Xor
: {
118 SmallVector
<Value
*, 2> Operands
;
119 getRelevantOperands(I
, Operands
);
120 for (Value
*Operand
: Operands
)
121 Worklist
.push_back(Operand
);
125 // TODO: Can handle more cases here:
126 // 1. select, shufflevector, extractelement, insertelement
128 // 3. shl, lshr, ashr
129 // 4. phi node(and loop handling)
137 unsigned TruncInstCombine::getMinBitWidth() {
138 SmallVector
<Value
*, 8> Worklist
;
139 SmallVector
<Instruction
*, 8> Stack
;
141 Value
*Src
= CurrentTruncInst
->getOperand(0);
142 Type
*DstTy
= CurrentTruncInst
->getType();
143 unsigned TruncBitWidth
= DstTy
->getScalarSizeInBits();
144 unsigned OrigBitWidth
=
145 CurrentTruncInst
->getOperand(0)->getType()->getScalarSizeInBits();
147 if (isa
<Constant
>(Src
))
148 return TruncBitWidth
;
150 Worklist
.push_back(Src
);
151 InstInfoMap
[cast
<Instruction
>(Src
)].ValidBitWidth
= TruncBitWidth
;
153 while (!Worklist
.empty()) {
154 Value
*Curr
= Worklist
.back();
156 if (isa
<Constant
>(Curr
)) {
161 // Otherwise, it must be an instruction.
162 auto *I
= cast
<Instruction
>(Curr
);
164 auto &Info
= InstInfoMap
[I
];
166 SmallVector
<Value
*, 2> Operands
;
167 getRelevantOperands(I
, Operands
);
169 if (!Stack
.empty() && Stack
.back() == I
) {
170 // Already handled all instruction operands, can remove it from both, the
171 // Worklist and the Stack, and update MinBitWidth.
174 for (auto *Operand
: Operands
)
175 if (auto *IOp
= dyn_cast
<Instruction
>(Operand
))
177 std::max(Info
.MinBitWidth
, InstInfoMap
[IOp
].MinBitWidth
);
181 // Add the instruction to the stack before start handling its operands.
183 unsigned ValidBitWidth
= Info
.ValidBitWidth
;
185 // Update minimum bit-width before handling its operands. This is required
186 // when the instruction is part of a loop.
187 Info
.MinBitWidth
= std::max(Info
.MinBitWidth
, Info
.ValidBitWidth
);
189 for (auto *Operand
: Operands
)
190 if (auto *IOp
= dyn_cast
<Instruction
>(Operand
)) {
191 // If we already calculated the minimum bit-width for this valid
192 // bit-width, or for a smaller valid bit-width, then just keep the
193 // answer we already calculated.
194 unsigned IOpBitwidth
= InstInfoMap
.lookup(IOp
).ValidBitWidth
;
195 if (IOpBitwidth
>= ValidBitWidth
)
197 InstInfoMap
[IOp
].ValidBitWidth
= std::max(ValidBitWidth
, IOpBitwidth
);
198 Worklist
.push_back(IOp
);
201 unsigned MinBitWidth
= InstInfoMap
.lookup(cast
<Instruction
>(Src
)).MinBitWidth
;
202 assert(MinBitWidth
>= TruncBitWidth
);
204 if (MinBitWidth
> TruncBitWidth
) {
205 // In this case reducing expression with vector type might generate a new
206 // vector type, which is not preferable as it might result in generating
208 if (DstTy
->isVectorTy())
210 // Use the smallest integer type in the range [MinBitWidth, OrigBitWidth).
211 Type
*Ty
= DL
.getSmallestLegalIntType(DstTy
->getContext(), MinBitWidth
);
212 // Update minimum bit-width with the new destination type bit-width if
213 // succeeded to find such, otherwise, with original bit-width.
214 MinBitWidth
= Ty
? Ty
->getScalarSizeInBits() : OrigBitWidth
;
215 } else { // MinBitWidth == TruncBitWidth
216 // In this case the expression can be evaluated with the trunc instruction
217 // destination type, and trunc instruction can be omitted. However, we
218 // should not perform the evaluation if the original type is a legal scalar
219 // type and the target type is illegal.
220 bool FromLegal
= MinBitWidth
== 1 || DL
.isLegalInteger(OrigBitWidth
);
221 bool ToLegal
= MinBitWidth
== 1 || DL
.isLegalInteger(MinBitWidth
);
222 if (!DstTy
->isVectorTy() && FromLegal
&& !ToLegal
)
228 Type
*TruncInstCombine::getBestTruncatedType() {
229 if (!buildTruncExpressionDag())
232 // We don't want to duplicate instructions, which isn't profitable. Thus, we
233 // can't shrink something that has multiple users, unless all users are
234 // post-dominated by the trunc instruction, i.e., were visited during the
235 // expression evaluation.
236 unsigned DesiredBitWidth
= 0;
237 for (auto Itr
: InstInfoMap
) {
238 Instruction
*I
= Itr
.first
;
241 bool IsExtInst
= (isa
<ZExtInst
>(I
) || isa
<SExtInst
>(I
));
242 for (auto *U
: I
->users())
243 if (auto *UI
= dyn_cast
<Instruction
>(U
))
244 if (UI
!= CurrentTruncInst
&& !InstInfoMap
.count(UI
)) {
247 // If this is an extension from the dest type, we can eliminate it,
248 // even if it has multiple users. Thus, update the DesiredBitWidth and
249 // validate all extension instructions agrees on same DesiredBitWidth.
250 unsigned ExtInstBitWidth
=
251 I
->getOperand(0)->getType()->getScalarSizeInBits();
252 if (DesiredBitWidth
&& DesiredBitWidth
!= ExtInstBitWidth
)
254 DesiredBitWidth
= ExtInstBitWidth
;
258 unsigned OrigBitWidth
=
259 CurrentTruncInst
->getOperand(0)->getType()->getScalarSizeInBits();
261 // Calculate minimum allowed bit-width allowed for shrinking the currently
262 // visited truncate's operand.
263 unsigned MinBitWidth
= getMinBitWidth();
265 // Check that we can shrink to smaller bit-width than original one and that
266 // it is similar to the DesiredBitWidth is such exists.
267 if (MinBitWidth
>= OrigBitWidth
||
268 (DesiredBitWidth
&& DesiredBitWidth
!= MinBitWidth
))
271 return IntegerType::get(CurrentTruncInst
->getContext(), MinBitWidth
);
274 /// Given a reduced scalar type \p Ty and a \p V value, return a reduced type
275 /// for \p V, according to its type, if it vector type, return the vector
276 /// version of \p Ty, otherwise return \p Ty.
277 static Type
*getReducedType(Value
*V
, Type
*Ty
) {
278 assert(Ty
&& !Ty
->isVectorTy() && "Expect Scalar Type");
279 if (auto *VTy
= dyn_cast
<VectorType
>(V
->getType()))
280 return VectorType::get(Ty
, VTy
->getNumElements());
284 Value
*TruncInstCombine::getReducedOperand(Value
*V
, Type
*SclTy
) {
285 Type
*Ty
= getReducedType(V
, SclTy
);
286 if (auto *C
= dyn_cast
<Constant
>(V
)) {
287 C
= ConstantExpr::getIntegerCast(C
, Ty
, false);
288 // If we got a constantexpr back, try to simplify it with DL info.
289 if (Constant
*FoldedC
= ConstantFoldConstant(C
, DL
, &TLI
))
294 auto *I
= cast
<Instruction
>(V
);
295 Info Entry
= InstInfoMap
.lookup(I
);
296 assert(Entry
.NewValue
);
297 return Entry
.NewValue
;
300 void TruncInstCombine::ReduceExpressionDag(Type
*SclTy
) {
301 for (auto &Itr
: InstInfoMap
) { // Forward
302 Instruction
*I
= Itr
.first
;
303 TruncInstCombine::Info
&NodeInfo
= Itr
.second
;
305 assert(!NodeInfo
.NewValue
&& "Instruction has been evaluated");
307 IRBuilder
<> Builder(I
);
308 Value
*Res
= nullptr;
309 unsigned Opc
= I
->getOpcode();
311 case Instruction::Trunc
:
312 case Instruction::ZExt
:
313 case Instruction::SExt
: {
314 Type
*Ty
= getReducedType(I
, SclTy
);
315 // If the source type of the cast is the type we're trying for then we can
316 // just return the source. There's no need to insert it because it is not
318 if (I
->getOperand(0)->getType() == Ty
) {
319 assert(!isa
<TruncInst
>(I
) && "Cannot reach here with TruncInst");
320 NodeInfo
.NewValue
= I
->getOperand(0);
323 // Otherwise, must be the same type of cast, so just reinsert a new one.
324 // This also handles the case of zext(trunc(x)) -> zext(x).
325 Res
= Builder
.CreateIntCast(I
->getOperand(0), Ty
,
326 Opc
== Instruction::SExt
);
328 // Update Worklist entries with new value if needed.
329 // There are three possible changes to the Worklist:
330 // 1. Update Old-TruncInst -> New-TruncInst.
331 // 2. Remove Old-TruncInst (if New node is not TruncInst).
332 // 3. Add New-TruncInst (if Old node was not TruncInst).
333 auto Entry
= find(Worklist
, I
);
334 if (Entry
!= Worklist
.end()) {
335 if (auto *NewCI
= dyn_cast
<TruncInst
>(Res
))
338 Worklist
.erase(Entry
);
339 } else if (auto *NewCI
= dyn_cast
<TruncInst
>(Res
))
340 Worklist
.push_back(NewCI
);
343 case Instruction::Add
:
344 case Instruction::Sub
:
345 case Instruction::Mul
:
346 case Instruction::And
:
347 case Instruction::Or
:
348 case Instruction::Xor
: {
349 Value
*LHS
= getReducedOperand(I
->getOperand(0), SclTy
);
350 Value
*RHS
= getReducedOperand(I
->getOperand(1), SclTy
);
351 Res
= Builder
.CreateBinOp((Instruction::BinaryOps
)Opc
, LHS
, RHS
);
355 llvm_unreachable("Unhandled instruction");
358 NodeInfo
.NewValue
= Res
;
359 if (auto *ResI
= dyn_cast
<Instruction
>(Res
))
363 Value
*Res
= getReducedOperand(CurrentTruncInst
->getOperand(0), SclTy
);
364 Type
*DstTy
= CurrentTruncInst
->getType();
365 if (Res
->getType() != DstTy
) {
366 IRBuilder
<> Builder(CurrentTruncInst
);
367 Res
= Builder
.CreateIntCast(Res
, DstTy
, false);
368 if (auto *ResI
= dyn_cast
<Instruction
>(Res
))
369 ResI
->takeName(CurrentTruncInst
);
371 CurrentTruncInst
->replaceAllUsesWith(Res
);
373 // Erase old expression dag, which was replaced by the reduced expression dag.
374 // We iterate backward, which means we visit the instruction before we visit
375 // any of its operands, this way, when we get to the operand, we already
376 // removed the instructions (from the expression dag) that uses it.
377 CurrentTruncInst
->eraseFromParent();
378 for (auto I
= InstInfoMap
.rbegin(), E
= InstInfoMap
.rend(); I
!= E
; ++I
) {
379 // We still need to check that the instruction has no users before we erase
380 // it, because {SExt, ZExt}Inst Instruction might have other users that was
381 // not reduced, in such case, we need to keep that instruction.
382 if (I
->first
->use_empty())
383 I
->first
->eraseFromParent();
387 bool TruncInstCombine::run(Function
&F
) {
388 bool MadeIRChange
= false;
390 // Collect all TruncInst in the function into the Worklist for evaluating.
392 // Ignore unreachable basic block.
393 if (!DT
.isReachableFromEntry(&BB
))
396 if (auto *CI
= dyn_cast
<TruncInst
>(&I
))
397 Worklist
.push_back(CI
);
400 // Process all TruncInst in the Worklist, for each instruction:
401 // 1. Check if it dominates an eligible expression dag to be reduced.
402 // 2. Create a reduced expression dag and replace the old one with it.
403 while (!Worklist
.empty()) {
404 CurrentTruncInst
= Worklist
.pop_back_val();
406 if (Type
*NewDstSclTy
= getBestTruncatedType()) {
408 dbgs() << "ICE: TruncInstCombine reducing type of expression dag "
410 << CurrentTruncInst
<< '\n');
411 ReduceExpressionDag(NewDstSclTy
);