1 //===- TruncInstCombine.cpp -----------------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // TruncInstCombine - looks for expression dags post-dominated by TruncInst and
10 // for each eligible dag, it will create a reduced bit-width expression, replace
11 // the old expression with this new one and remove the old expression.
12 // Eligible expression dag is such that:
13 // 1. Contains only supported instructions.
14 // 2. Supported leaves: ZExtInst, SExtInst, TruncInst and Constant value.
15 // 3. Can be evaluated into type with reduced legal bit-width.
16 // 4. All instructions in the dag must not have users outside the dag.
17 // The only exception is for {ZExt, SExt}Inst with operand type equal to
18 // the new reduced type evaluated in (3).
20 // The motivation for this optimization is that evaluating and expression using
21 // smaller bit-width is preferable, especially for vectorization where we can
22 // fit more values in one vectorized instruction. In addition, this optimization
23 // may decrease the number of cast instructions, but will not increase it.
25 //===----------------------------------------------------------------------===//
27 #include "AggressiveInstCombineInternal.h"
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ADT/Statistic.h"
30 #include "llvm/Analysis/ConstantFolding.h"
31 #include "llvm/Analysis/TargetLibraryInfo.h"
32 #include "llvm/Analysis/ValueTracking.h"
33 #include "llvm/IR/DataLayout.h"
34 #include "llvm/IR/Dominators.h"
35 #include "llvm/IR/IRBuilder.h"
36 #include "llvm/IR/Instruction.h"
37 #include "llvm/Support/KnownBits.h"
41 #define DEBUG_TYPE "aggressive-instcombine"
45 "Number of truncations eliminated by reducing bit width of expression DAG");
46 STATISTIC(NumInstrsReduced
,
47 "Number of instructions whose bit width was reduced");
49 /// Given an instruction and a container, it fills all the relevant operands of
50 /// that instruction, with respect to the Trunc expression dag optimizaton.
51 static void getRelevantOperands(Instruction
*I
, SmallVectorImpl
<Value
*> &Ops
) {
52 unsigned Opc
= I
->getOpcode();
54 case Instruction::Trunc
:
55 case Instruction::ZExt
:
56 case Instruction::SExt
:
57 // These CastInst are considered leaves of the evaluated expression, thus,
58 // their operands are not relevent.
60 case Instruction::Add
:
61 case Instruction::Sub
:
62 case Instruction::Mul
:
63 case Instruction::And
:
65 case Instruction::Xor
:
66 case Instruction::Shl
:
67 case Instruction::LShr
:
68 Ops
.push_back(I
->getOperand(0));
69 Ops
.push_back(I
->getOperand(1));
71 case Instruction::Select
:
72 Ops
.push_back(I
->getOperand(1));
73 Ops
.push_back(I
->getOperand(2));
76 llvm_unreachable("Unreachable!");
80 bool TruncInstCombine::buildTruncExpressionDag() {
81 SmallVector
<Value
*, 8> Worklist
;
82 SmallVector
<Instruction
*, 8> Stack
;
83 // Clear old expression dag.
86 Worklist
.push_back(CurrentTruncInst
->getOperand(0));
88 while (!Worklist
.empty()) {
89 Value
*Curr
= Worklist
.back();
91 if (isa
<Constant
>(Curr
)) {
96 auto *I
= dyn_cast
<Instruction
>(Curr
);
100 if (!Stack
.empty() && Stack
.back() == I
) {
101 // Already handled all instruction operands, can remove it from both the
102 // Worklist and the Stack, and add it to the instruction info map.
105 // Insert I to the Info map.
106 InstInfoMap
.insert(std::make_pair(I
, Info()));
110 if (InstInfoMap
.count(I
)) {
115 // Add the instruction to the stack before start handling its operands.
118 unsigned Opc
= I
->getOpcode();
120 case Instruction::Trunc
:
121 case Instruction::ZExt
:
122 case Instruction::SExt
:
123 // trunc(trunc(x)) -> trunc(x)
124 // trunc(ext(x)) -> ext(x) if the source type is smaller than the new dest
125 // trunc(ext(x)) -> trunc(x) if the source type is larger than the new
128 case Instruction::Add
:
129 case Instruction::Sub
:
130 case Instruction::Mul
:
131 case Instruction::And
:
132 case Instruction::Or
:
133 case Instruction::Xor
:
134 case Instruction::Shl
:
135 case Instruction::LShr
:
136 case Instruction::Select
: {
137 SmallVector
<Value
*, 2> Operands
;
138 getRelevantOperands(I
, Operands
);
139 append_range(Worklist
, Operands
);
143 // TODO: Can handle more cases here:
144 // 1. shufflevector, extractelement, insertelement
147 // 4. phi node(and loop handling)
155 unsigned TruncInstCombine::getMinBitWidth() {
156 SmallVector
<Value
*, 8> Worklist
;
157 SmallVector
<Instruction
*, 8> Stack
;
159 Value
*Src
= CurrentTruncInst
->getOperand(0);
160 Type
*DstTy
= CurrentTruncInst
->getType();
161 unsigned TruncBitWidth
= DstTy
->getScalarSizeInBits();
162 unsigned OrigBitWidth
=
163 CurrentTruncInst
->getOperand(0)->getType()->getScalarSizeInBits();
165 if (isa
<Constant
>(Src
))
166 return TruncBitWidth
;
168 Worklist
.push_back(Src
);
169 InstInfoMap
[cast
<Instruction
>(Src
)].ValidBitWidth
= TruncBitWidth
;
171 while (!Worklist
.empty()) {
172 Value
*Curr
= Worklist
.back();
174 if (isa
<Constant
>(Curr
)) {
179 // Otherwise, it must be an instruction.
180 auto *I
= cast
<Instruction
>(Curr
);
182 auto &Info
= InstInfoMap
[I
];
184 SmallVector
<Value
*, 2> Operands
;
185 getRelevantOperands(I
, Operands
);
187 if (!Stack
.empty() && Stack
.back() == I
) {
188 // Already handled all instruction operands, can remove it from both, the
189 // Worklist and the Stack, and update MinBitWidth.
192 for (auto *Operand
: Operands
)
193 if (auto *IOp
= dyn_cast
<Instruction
>(Operand
))
195 std::max(Info
.MinBitWidth
, InstInfoMap
[IOp
].MinBitWidth
);
199 // Add the instruction to the stack before start handling its operands.
201 unsigned ValidBitWidth
= Info
.ValidBitWidth
;
203 // Update minimum bit-width before handling its operands. This is required
204 // when the instruction is part of a loop.
205 Info
.MinBitWidth
= std::max(Info
.MinBitWidth
, Info
.ValidBitWidth
);
207 for (auto *Operand
: Operands
)
208 if (auto *IOp
= dyn_cast
<Instruction
>(Operand
)) {
209 // If we already calculated the minimum bit-width for this valid
210 // bit-width, or for a smaller valid bit-width, then just keep the
211 // answer we already calculated.
212 unsigned IOpBitwidth
= InstInfoMap
.lookup(IOp
).ValidBitWidth
;
213 if (IOpBitwidth
>= ValidBitWidth
)
215 InstInfoMap
[IOp
].ValidBitWidth
= ValidBitWidth
;
216 Worklist
.push_back(IOp
);
219 unsigned MinBitWidth
= InstInfoMap
.lookup(cast
<Instruction
>(Src
)).MinBitWidth
;
220 assert(MinBitWidth
>= TruncBitWidth
);
222 if (MinBitWidth
> TruncBitWidth
) {
223 // In this case reducing expression with vector type might generate a new
224 // vector type, which is not preferable as it might result in generating
226 if (DstTy
->isVectorTy())
228 // Use the smallest integer type in the range [MinBitWidth, OrigBitWidth).
229 Type
*Ty
= DL
.getSmallestLegalIntType(DstTy
->getContext(), MinBitWidth
);
230 // Update minimum bit-width with the new destination type bit-width if
231 // succeeded to find such, otherwise, with original bit-width.
232 MinBitWidth
= Ty
? Ty
->getScalarSizeInBits() : OrigBitWidth
;
233 } else { // MinBitWidth == TruncBitWidth
234 // In this case the expression can be evaluated with the trunc instruction
235 // destination type, and trunc instruction can be omitted. However, we
236 // should not perform the evaluation if the original type is a legal scalar
237 // type and the target type is illegal.
238 bool FromLegal
= MinBitWidth
== 1 || DL
.isLegalInteger(OrigBitWidth
);
239 bool ToLegal
= MinBitWidth
== 1 || DL
.isLegalInteger(MinBitWidth
);
240 if (!DstTy
->isVectorTy() && FromLegal
&& !ToLegal
)
246 Type
*TruncInstCombine::getBestTruncatedType() {
247 if (!buildTruncExpressionDag())
250 // We don't want to duplicate instructions, which isn't profitable. Thus, we
251 // can't shrink something that has multiple users, unless all users are
252 // post-dominated by the trunc instruction, i.e., were visited during the
253 // expression evaluation.
254 unsigned DesiredBitWidth
= 0;
255 for (auto Itr
: InstInfoMap
) {
256 Instruction
*I
= Itr
.first
;
259 bool IsExtInst
= (isa
<ZExtInst
>(I
) || isa
<SExtInst
>(I
));
260 for (auto *U
: I
->users())
261 if (auto *UI
= dyn_cast
<Instruction
>(U
))
262 if (UI
!= CurrentTruncInst
&& !InstInfoMap
.count(UI
)) {
265 // If this is an extension from the dest type, we can eliminate it,
266 // even if it has multiple users. Thus, update the DesiredBitWidth and
267 // validate all extension instructions agrees on same DesiredBitWidth.
268 unsigned ExtInstBitWidth
=
269 I
->getOperand(0)->getType()->getScalarSizeInBits();
270 if (DesiredBitWidth
&& DesiredBitWidth
!= ExtInstBitWidth
)
272 DesiredBitWidth
= ExtInstBitWidth
;
276 unsigned OrigBitWidth
=
277 CurrentTruncInst
->getOperand(0)->getType()->getScalarSizeInBits();
279 // Initialize MinBitWidth for shift instructions with the minimum number
280 // that is greater than shift amount (i.e. shift amount + 1). For `lshr`
281 // adjust MinBitWidth so that all potentially truncated bits of
282 // the value-to-be-shifted are zeros.
283 // Also normalize MinBitWidth not to be greater than source bitwidth.
284 for (auto &Itr
: InstInfoMap
) {
285 Instruction
*I
= Itr
.first
;
286 if (I
->getOpcode() == Instruction::Shl
||
287 I
->getOpcode() == Instruction::LShr
) {
288 KnownBits KnownRHS
= computeKnownBits(I
->getOperand(1), DL
);
289 unsigned MinBitWidth
= KnownRHS
.getMaxValue()
290 .uadd_sat(APInt(OrigBitWidth
, 1))
291 .getLimitedValue(OrigBitWidth
);
292 if (MinBitWidth
== OrigBitWidth
)
294 if (I
->getOpcode() == Instruction::LShr
) {
295 KnownBits KnownLHS
= computeKnownBits(I
->getOperand(0), DL
);
297 std::max(MinBitWidth
, KnownLHS
.getMaxValue().getActiveBits());
298 if (MinBitWidth
>= OrigBitWidth
)
301 Itr
.second
.MinBitWidth
= MinBitWidth
;
305 // Calculate minimum allowed bit-width allowed for shrinking the currently
306 // visited truncate's operand.
307 unsigned MinBitWidth
= getMinBitWidth();
309 // Check that we can shrink to smaller bit-width than original one and that
310 // it is similar to the DesiredBitWidth is such exists.
311 if (MinBitWidth
>= OrigBitWidth
||
312 (DesiredBitWidth
&& DesiredBitWidth
!= MinBitWidth
))
315 return IntegerType::get(CurrentTruncInst
->getContext(), MinBitWidth
);
318 /// Given a reduced scalar type \p Ty and a \p V value, return a reduced type
319 /// for \p V, according to its type, if it vector type, return the vector
320 /// version of \p Ty, otherwise return \p Ty.
321 static Type
*getReducedType(Value
*V
, Type
*Ty
) {
322 assert(Ty
&& !Ty
->isVectorTy() && "Expect Scalar Type");
323 if (auto *VTy
= dyn_cast
<VectorType
>(V
->getType()))
324 return VectorType::get(Ty
, VTy
->getElementCount());
328 Value
*TruncInstCombine::getReducedOperand(Value
*V
, Type
*SclTy
) {
329 Type
*Ty
= getReducedType(V
, SclTy
);
330 if (auto *C
= dyn_cast
<Constant
>(V
)) {
331 C
= ConstantExpr::getIntegerCast(C
, Ty
, false);
332 // If we got a constantexpr back, try to simplify it with DL info.
333 return ConstantFoldConstant(C
, DL
, &TLI
);
336 auto *I
= cast
<Instruction
>(V
);
337 Info Entry
= InstInfoMap
.lookup(I
);
338 assert(Entry
.NewValue
);
339 return Entry
.NewValue
;
342 void TruncInstCombine::ReduceExpressionDag(Type
*SclTy
) {
343 NumInstrsReduced
+= InstInfoMap
.size();
344 for (auto &Itr
: InstInfoMap
) { // Forward
345 Instruction
*I
= Itr
.first
;
346 TruncInstCombine::Info
&NodeInfo
= Itr
.second
;
348 assert(!NodeInfo
.NewValue
&& "Instruction has been evaluated");
350 IRBuilder
<> Builder(I
);
351 Value
*Res
= nullptr;
352 unsigned Opc
= I
->getOpcode();
354 case Instruction::Trunc
:
355 case Instruction::ZExt
:
356 case Instruction::SExt
: {
357 Type
*Ty
= getReducedType(I
, SclTy
);
358 // If the source type of the cast is the type we're trying for then we can
359 // just return the source. There's no need to insert it because it is not
361 if (I
->getOperand(0)->getType() == Ty
) {
362 assert(!isa
<TruncInst
>(I
) && "Cannot reach here with TruncInst");
363 NodeInfo
.NewValue
= I
->getOperand(0);
366 // Otherwise, must be the same type of cast, so just reinsert a new one.
367 // This also handles the case of zext(trunc(x)) -> zext(x).
368 Res
= Builder
.CreateIntCast(I
->getOperand(0), Ty
,
369 Opc
== Instruction::SExt
);
371 // Update Worklist entries with new value if needed.
372 // There are three possible changes to the Worklist:
373 // 1. Update Old-TruncInst -> New-TruncInst.
374 // 2. Remove Old-TruncInst (if New node is not TruncInst).
375 // 3. Add New-TruncInst (if Old node was not TruncInst).
376 auto *Entry
= find(Worklist
, I
);
377 if (Entry
!= Worklist
.end()) {
378 if (auto *NewCI
= dyn_cast
<TruncInst
>(Res
))
381 Worklist
.erase(Entry
);
382 } else if (auto *NewCI
= dyn_cast
<TruncInst
>(Res
))
383 Worklist
.push_back(NewCI
);
386 case Instruction::Add
:
387 case Instruction::Sub
:
388 case Instruction::Mul
:
389 case Instruction::And
:
390 case Instruction::Or
:
391 case Instruction::Xor
:
392 case Instruction::Shl
:
393 case Instruction::LShr
: {
394 Value
*LHS
= getReducedOperand(I
->getOperand(0), SclTy
);
395 Value
*RHS
= getReducedOperand(I
->getOperand(1), SclTy
);
396 Res
= Builder
.CreateBinOp((Instruction::BinaryOps
)Opc
, LHS
, RHS
);
397 // Preserve `exact` flag since truncation doesn't change exactness
398 if (Opc
== Instruction::LShr
)
399 if (auto *ResI
= dyn_cast
<Instruction
>(Res
))
400 ResI
->setIsExact(I
->isExact());
403 case Instruction::Select
: {
404 Value
*Op0
= I
->getOperand(0);
405 Value
*LHS
= getReducedOperand(I
->getOperand(1), SclTy
);
406 Value
*RHS
= getReducedOperand(I
->getOperand(2), SclTy
);
407 Res
= Builder
.CreateSelect(Op0
, LHS
, RHS
);
411 llvm_unreachable("Unhandled instruction");
414 NodeInfo
.NewValue
= Res
;
415 if (auto *ResI
= dyn_cast
<Instruction
>(Res
))
419 Value
*Res
= getReducedOperand(CurrentTruncInst
->getOperand(0), SclTy
);
420 Type
*DstTy
= CurrentTruncInst
->getType();
421 if (Res
->getType() != DstTy
) {
422 IRBuilder
<> Builder(CurrentTruncInst
);
423 Res
= Builder
.CreateIntCast(Res
, DstTy
, false);
424 if (auto *ResI
= dyn_cast
<Instruction
>(Res
))
425 ResI
->takeName(CurrentTruncInst
);
427 CurrentTruncInst
->replaceAllUsesWith(Res
);
429 // Erase old expression dag, which was replaced by the reduced expression dag.
430 // We iterate backward, which means we visit the instruction before we visit
431 // any of its operands, this way, when we get to the operand, we already
432 // removed the instructions (from the expression dag) that uses it.
433 CurrentTruncInst
->eraseFromParent();
434 for (auto I
= InstInfoMap
.rbegin(), E
= InstInfoMap
.rend(); I
!= E
; ++I
) {
435 // We still need to check that the instruction has no users before we erase
436 // it, because {SExt, ZExt}Inst Instruction might have other users that was
437 // not reduced, in such case, we need to keep that instruction.
438 if (I
->first
->use_empty())
439 I
->first
->eraseFromParent();
443 bool TruncInstCombine::run(Function
&F
) {
444 bool MadeIRChange
= false;
446 // Collect all TruncInst in the function into the Worklist for evaluating.
448 // Ignore unreachable basic block.
449 if (!DT
.isReachableFromEntry(&BB
))
452 if (auto *CI
= dyn_cast
<TruncInst
>(&I
))
453 Worklist
.push_back(CI
);
456 // Process all TruncInst in the Worklist, for each instruction:
457 // 1. Check if it dominates an eligible expression dag to be reduced.
458 // 2. Create a reduced expression dag and replace the old one with it.
459 while (!Worklist
.empty()) {
460 CurrentTruncInst
= Worklist
.pop_back_val();
462 if (Type
*NewDstSclTy
= getBestTruncatedType()) {
464 dbgs() << "ICE: TruncInstCombine reducing type of expression dag "
466 << CurrentTruncInst
<< '\n');
467 ReduceExpressionDag(NewDstSclTy
);