1 //===- TruncInstCombine.cpp -----------------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // TruncInstCombine - looks for expression graphs post-dominated by TruncInst
10 // and for each eligible graph, it will create a reduced bit-width expression,
11 // replace the old expression with this new one and remove the old expression.
12 // Eligible expression graph is such that:
13 // 1. Contains only supported instructions.
14 // 2. Supported leaves: ZExtInst, SExtInst, TruncInst and Constant value.
15 // 3. Can be evaluated into type with reduced legal bit-width.
16 // 4. All instructions in the graph must not have users outside the graph.
17 // The only exception is for {ZExt, SExt}Inst with operand type equal to
18 // the new reduced type evaluated in (3).
20 // The motivation for this optimization is that evaluating and expression using
21 // smaller bit-width is preferable, especially for vectorization where we can
22 // fit more values in one vectorized instruction. In addition, this optimization
23 // may decrease the number of cast instructions, but will not increase it.
25 //===----------------------------------------------------------------------===//
27 #include "AggressiveInstCombineInternal.h"
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ADT/Statistic.h"
30 #include "llvm/Analysis/ConstantFolding.h"
31 #include "llvm/IR/DataLayout.h"
32 #include "llvm/IR/Dominators.h"
33 #include "llvm/IR/IRBuilder.h"
34 #include "llvm/IR/Instruction.h"
35 #include "llvm/Support/KnownBits.h"
39 #define DEBUG_TYPE "aggressive-instcombine"
41 STATISTIC(NumExprsReduced
, "Number of truncations eliminated by reducing bit "
42 "width of expression graph");
43 STATISTIC(NumInstrsReduced
,
44 "Number of instructions whose bit width was reduced");
46 /// Given an instruction and a container, it fills all the relevant operands of
47 /// that instruction, with respect to the Trunc expression graph optimizaton.
48 static void getRelevantOperands(Instruction
*I
, SmallVectorImpl
<Value
*> &Ops
) {
49 unsigned Opc
= I
->getOpcode();
51 case Instruction::Trunc
:
52 case Instruction::ZExt
:
53 case Instruction::SExt
:
54 // These CastInst are considered leaves of the evaluated expression, thus,
55 // their operands are not relevent.
57 case Instruction::Add
:
58 case Instruction::Sub
:
59 case Instruction::Mul
:
60 case Instruction::And
:
62 case Instruction::Xor
:
63 case Instruction::Shl
:
64 case Instruction::LShr
:
65 case Instruction::AShr
:
66 case Instruction::UDiv
:
67 case Instruction::URem
:
68 case Instruction::InsertElement
:
69 Ops
.push_back(I
->getOperand(0));
70 Ops
.push_back(I
->getOperand(1));
72 case Instruction::ExtractElement
:
73 Ops
.push_back(I
->getOperand(0));
75 case Instruction::Select
:
76 Ops
.push_back(I
->getOperand(1));
77 Ops
.push_back(I
->getOperand(2));
79 case Instruction::PHI
:
80 for (Value
*V
: cast
<PHINode
>(I
)->incoming_values())
84 llvm_unreachable("Unreachable!");
88 bool TruncInstCombine::buildTruncExpressionGraph() {
89 SmallVector
<Value
*, 8> Worklist
;
90 SmallVector
<Instruction
*, 8> Stack
;
91 // Clear old instructions info.
94 Worklist
.push_back(CurrentTruncInst
->getOperand(0));
96 while (!Worklist
.empty()) {
97 Value
*Curr
= Worklist
.back();
99 if (isa
<Constant
>(Curr
)) {
104 auto *I
= dyn_cast
<Instruction
>(Curr
);
108 if (!Stack
.empty() && Stack
.back() == I
) {
109 // Already handled all instruction operands, can remove it from both the
110 // Worklist and the Stack, and add it to the instruction info map.
113 // Insert I to the Info map.
114 InstInfoMap
.insert(std::make_pair(I
, Info()));
118 if (InstInfoMap
.count(I
)) {
123 // Add the instruction to the stack before start handling its operands.
126 unsigned Opc
= I
->getOpcode();
128 case Instruction::Trunc
:
129 case Instruction::ZExt
:
130 case Instruction::SExt
:
131 // trunc(trunc(x)) -> trunc(x)
132 // trunc(ext(x)) -> ext(x) if the source type is smaller than the new dest
133 // trunc(ext(x)) -> trunc(x) if the source type is larger than the new
136 case Instruction::Add
:
137 case Instruction::Sub
:
138 case Instruction::Mul
:
139 case Instruction::And
:
140 case Instruction::Or
:
141 case Instruction::Xor
:
142 case Instruction::Shl
:
143 case Instruction::LShr
:
144 case Instruction::AShr
:
145 case Instruction::UDiv
:
146 case Instruction::URem
:
147 case Instruction::InsertElement
:
148 case Instruction::ExtractElement
:
149 case Instruction::Select
: {
150 SmallVector
<Value
*, 2> Operands
;
151 getRelevantOperands(I
, Operands
);
152 append_range(Worklist
, Operands
);
155 case Instruction::PHI
: {
156 SmallVector
<Value
*, 2> Operands
;
157 getRelevantOperands(I
, Operands
);
158 // Add only operands not in Stack to prevent cycle
159 for (auto *Op
: Operands
)
160 if (!llvm::is_contained(Stack
, Op
))
161 Worklist
.push_back(Op
);
165 // TODO: Can handle more cases here:
175 unsigned TruncInstCombine::getMinBitWidth() {
176 SmallVector
<Value
*, 8> Worklist
;
177 SmallVector
<Instruction
*, 8> Stack
;
179 Value
*Src
= CurrentTruncInst
->getOperand(0);
180 Type
*DstTy
= CurrentTruncInst
->getType();
181 unsigned TruncBitWidth
= DstTy
->getScalarSizeInBits();
182 unsigned OrigBitWidth
=
183 CurrentTruncInst
->getOperand(0)->getType()->getScalarSizeInBits();
185 if (isa
<Constant
>(Src
))
186 return TruncBitWidth
;
188 Worklist
.push_back(Src
);
189 InstInfoMap
[cast
<Instruction
>(Src
)].ValidBitWidth
= TruncBitWidth
;
191 while (!Worklist
.empty()) {
192 Value
*Curr
= Worklist
.back();
194 if (isa
<Constant
>(Curr
)) {
199 // Otherwise, it must be an instruction.
200 auto *I
= cast
<Instruction
>(Curr
);
202 auto &Info
= InstInfoMap
[I
];
204 SmallVector
<Value
*, 2> Operands
;
205 getRelevantOperands(I
, Operands
);
207 if (!Stack
.empty() && Stack
.back() == I
) {
208 // Already handled all instruction operands, can remove it from both, the
209 // Worklist and the Stack, and update MinBitWidth.
212 for (auto *Operand
: Operands
)
213 if (auto *IOp
= dyn_cast
<Instruction
>(Operand
))
215 std::max(Info
.MinBitWidth
, InstInfoMap
[IOp
].MinBitWidth
);
219 // Add the instruction to the stack before start handling its operands.
221 unsigned ValidBitWidth
= Info
.ValidBitWidth
;
223 // Update minimum bit-width before handling its operands. This is required
224 // when the instruction is part of a loop.
225 Info
.MinBitWidth
= std::max(Info
.MinBitWidth
, Info
.ValidBitWidth
);
227 for (auto *Operand
: Operands
)
228 if (auto *IOp
= dyn_cast
<Instruction
>(Operand
)) {
229 // If we already calculated the minimum bit-width for this valid
230 // bit-width, or for a smaller valid bit-width, then just keep the
231 // answer we already calculated.
232 unsigned IOpBitwidth
= InstInfoMap
.lookup(IOp
).ValidBitWidth
;
233 if (IOpBitwidth
>= ValidBitWidth
)
235 InstInfoMap
[IOp
].ValidBitWidth
= ValidBitWidth
;
236 Worklist
.push_back(IOp
);
239 unsigned MinBitWidth
= InstInfoMap
.lookup(cast
<Instruction
>(Src
)).MinBitWidth
;
240 assert(MinBitWidth
>= TruncBitWidth
);
242 if (MinBitWidth
> TruncBitWidth
) {
243 // In this case reducing expression with vector type might generate a new
244 // vector type, which is not preferable as it might result in generating
246 if (DstTy
->isVectorTy())
248 // Use the smallest integer type in the range [MinBitWidth, OrigBitWidth).
249 Type
*Ty
= DL
.getSmallestLegalIntType(DstTy
->getContext(), MinBitWidth
);
250 // Update minimum bit-width with the new destination type bit-width if
251 // succeeded to find such, otherwise, with original bit-width.
252 MinBitWidth
= Ty
? Ty
->getScalarSizeInBits() : OrigBitWidth
;
253 } else { // MinBitWidth == TruncBitWidth
254 // In this case the expression can be evaluated with the trunc instruction
255 // destination type, and trunc instruction can be omitted. However, we
256 // should not perform the evaluation if the original type is a legal scalar
257 // type and the target type is illegal.
258 bool FromLegal
= MinBitWidth
== 1 || DL
.isLegalInteger(OrigBitWidth
);
259 bool ToLegal
= MinBitWidth
== 1 || DL
.isLegalInteger(MinBitWidth
);
260 if (!DstTy
->isVectorTy() && FromLegal
&& !ToLegal
)
266 Type
*TruncInstCombine::getBestTruncatedType() {
267 if (!buildTruncExpressionGraph())
270 // We don't want to duplicate instructions, which isn't profitable. Thus, we
271 // can't shrink something that has multiple users, unless all users are
272 // post-dominated by the trunc instruction, i.e., were visited during the
273 // expression evaluation.
274 unsigned DesiredBitWidth
= 0;
275 for (auto Itr
: InstInfoMap
) {
276 Instruction
*I
= Itr
.first
;
279 bool IsExtInst
= (isa
<ZExtInst
>(I
) || isa
<SExtInst
>(I
));
280 for (auto *U
: I
->users())
281 if (auto *UI
= dyn_cast
<Instruction
>(U
))
282 if (UI
!= CurrentTruncInst
&& !InstInfoMap
.count(UI
)) {
285 // If this is an extension from the dest type, we can eliminate it,
286 // even if it has multiple users. Thus, update the DesiredBitWidth and
287 // validate all extension instructions agrees on same DesiredBitWidth.
288 unsigned ExtInstBitWidth
=
289 I
->getOperand(0)->getType()->getScalarSizeInBits();
290 if (DesiredBitWidth
&& DesiredBitWidth
!= ExtInstBitWidth
)
292 DesiredBitWidth
= ExtInstBitWidth
;
296 unsigned OrigBitWidth
=
297 CurrentTruncInst
->getOperand(0)->getType()->getScalarSizeInBits();
299 // Initialize MinBitWidth for shift instructions with the minimum number
300 // that is greater than shift amount (i.e. shift amount + 1).
301 // For `lshr` adjust MinBitWidth so that all potentially truncated
302 // bits of the value-to-be-shifted are zeros.
303 // For `ashr` adjust MinBitWidth so that all potentially truncated
304 // bits of the value-to-be-shifted are sign bits (all zeros or ones)
305 // and even one (first) untruncated bit is sign bit.
306 // Exit early if MinBitWidth is not less than original bitwidth.
307 for (auto &Itr
: InstInfoMap
) {
308 Instruction
*I
= Itr
.first
;
310 KnownBits KnownRHS
= computeKnownBits(I
->getOperand(1));
311 unsigned MinBitWidth
= KnownRHS
.getMaxValue()
312 .uadd_sat(APInt(OrigBitWidth
, 1))
313 .getLimitedValue(OrigBitWidth
);
314 if (MinBitWidth
== OrigBitWidth
)
316 if (I
->getOpcode() == Instruction::LShr
) {
317 KnownBits KnownLHS
= computeKnownBits(I
->getOperand(0));
319 std::max(MinBitWidth
, KnownLHS
.getMaxValue().getActiveBits());
321 if (I
->getOpcode() == Instruction::AShr
) {
322 unsigned NumSignBits
= ComputeNumSignBits(I
->getOperand(0));
323 MinBitWidth
= std::max(MinBitWidth
, OrigBitWidth
- NumSignBits
+ 1);
325 if (MinBitWidth
>= OrigBitWidth
)
327 Itr
.second
.MinBitWidth
= MinBitWidth
;
329 if (I
->getOpcode() == Instruction::UDiv
||
330 I
->getOpcode() == Instruction::URem
) {
331 unsigned MinBitWidth
= 0;
332 for (const auto &Op
: I
->operands()) {
333 KnownBits Known
= computeKnownBits(Op
);
335 std::max(Known
.getMaxValue().getActiveBits(), MinBitWidth
);
336 if (MinBitWidth
>= OrigBitWidth
)
339 Itr
.second
.MinBitWidth
= MinBitWidth
;
343 // Calculate minimum allowed bit-width allowed for shrinking the currently
344 // visited truncate's operand.
345 unsigned MinBitWidth
= getMinBitWidth();
347 // Check that we can shrink to smaller bit-width than original one and that
348 // it is similar to the DesiredBitWidth is such exists.
349 if (MinBitWidth
>= OrigBitWidth
||
350 (DesiredBitWidth
&& DesiredBitWidth
!= MinBitWidth
))
353 return IntegerType::get(CurrentTruncInst
->getContext(), MinBitWidth
);
356 /// Given a reduced scalar type \p Ty and a \p V value, return a reduced type
357 /// for \p V, according to its type, if it vector type, return the vector
358 /// version of \p Ty, otherwise return \p Ty.
359 static Type
*getReducedType(Value
*V
, Type
*Ty
) {
360 assert(Ty
&& !Ty
->isVectorTy() && "Expect Scalar Type");
361 if (auto *VTy
= dyn_cast
<VectorType
>(V
->getType()))
362 return VectorType::get(Ty
, VTy
->getElementCount());
366 Value
*TruncInstCombine::getReducedOperand(Value
*V
, Type
*SclTy
) {
367 Type
*Ty
= getReducedType(V
, SclTy
);
368 if (auto *C
= dyn_cast
<Constant
>(V
)) {
369 C
= ConstantExpr::getIntegerCast(C
, Ty
, false);
370 // If we got a constantexpr back, try to simplify it with DL info.
371 return ConstantFoldConstant(C
, DL
, &TLI
);
374 auto *I
= cast
<Instruction
>(V
);
375 Info Entry
= InstInfoMap
.lookup(I
);
376 assert(Entry
.NewValue
);
377 return Entry
.NewValue
;
380 void TruncInstCombine::ReduceExpressionGraph(Type
*SclTy
) {
381 NumInstrsReduced
+= InstInfoMap
.size();
382 // Pairs of old and new phi-nodes
383 SmallVector
<std::pair
<PHINode
*, PHINode
*>, 2> OldNewPHINodes
;
384 for (auto &Itr
: InstInfoMap
) { // Forward
385 Instruction
*I
= Itr
.first
;
386 TruncInstCombine::Info
&NodeInfo
= Itr
.second
;
388 assert(!NodeInfo
.NewValue
&& "Instruction has been evaluated");
390 IRBuilder
<> Builder(I
);
391 Value
*Res
= nullptr;
392 unsigned Opc
= I
->getOpcode();
394 case Instruction::Trunc
:
395 case Instruction::ZExt
:
396 case Instruction::SExt
: {
397 Type
*Ty
= getReducedType(I
, SclTy
);
398 // If the source type of the cast is the type we're trying for then we can
399 // just return the source. There's no need to insert it because it is not
401 if (I
->getOperand(0)->getType() == Ty
) {
402 assert(!isa
<TruncInst
>(I
) && "Cannot reach here with TruncInst");
403 NodeInfo
.NewValue
= I
->getOperand(0);
406 // Otherwise, must be the same type of cast, so just reinsert a new one.
407 // This also handles the case of zext(trunc(x)) -> zext(x).
408 Res
= Builder
.CreateIntCast(I
->getOperand(0), Ty
,
409 Opc
== Instruction::SExt
);
411 // Update Worklist entries with new value if needed.
412 // There are three possible changes to the Worklist:
413 // 1. Update Old-TruncInst -> New-TruncInst.
414 // 2. Remove Old-TruncInst (if New node is not TruncInst).
415 // 3. Add New-TruncInst (if Old node was not TruncInst).
416 auto *Entry
= find(Worklist
, I
);
417 if (Entry
!= Worklist
.end()) {
418 if (auto *NewCI
= dyn_cast
<TruncInst
>(Res
))
421 Worklist
.erase(Entry
);
422 } else if (auto *NewCI
= dyn_cast
<TruncInst
>(Res
))
423 Worklist
.push_back(NewCI
);
426 case Instruction::Add
:
427 case Instruction::Sub
:
428 case Instruction::Mul
:
429 case Instruction::And
:
430 case Instruction::Or
:
431 case Instruction::Xor
:
432 case Instruction::Shl
:
433 case Instruction::LShr
:
434 case Instruction::AShr
:
435 case Instruction::UDiv
:
436 case Instruction::URem
: {
437 Value
*LHS
= getReducedOperand(I
->getOperand(0), SclTy
);
438 Value
*RHS
= getReducedOperand(I
->getOperand(1), SclTy
);
439 Res
= Builder
.CreateBinOp((Instruction::BinaryOps
)Opc
, LHS
, RHS
);
440 // Preserve `exact` flag since truncation doesn't change exactness
441 if (auto *PEO
= dyn_cast
<PossiblyExactOperator
>(I
))
442 if (auto *ResI
= dyn_cast
<Instruction
>(Res
))
443 ResI
->setIsExact(PEO
->isExact());
446 case Instruction::ExtractElement
: {
447 Value
*Vec
= getReducedOperand(I
->getOperand(0), SclTy
);
448 Value
*Idx
= I
->getOperand(1);
449 Res
= Builder
.CreateExtractElement(Vec
, Idx
);
452 case Instruction::InsertElement
: {
453 Value
*Vec
= getReducedOperand(I
->getOperand(0), SclTy
);
454 Value
*NewElt
= getReducedOperand(I
->getOperand(1), SclTy
);
455 Value
*Idx
= I
->getOperand(2);
456 Res
= Builder
.CreateInsertElement(Vec
, NewElt
, Idx
);
459 case Instruction::Select
: {
460 Value
*Op0
= I
->getOperand(0);
461 Value
*LHS
= getReducedOperand(I
->getOperand(1), SclTy
);
462 Value
*RHS
= getReducedOperand(I
->getOperand(2), SclTy
);
463 Res
= Builder
.CreateSelect(Op0
, LHS
, RHS
);
466 case Instruction::PHI
: {
467 Res
= Builder
.CreatePHI(getReducedType(I
, SclTy
), I
->getNumOperands());
468 OldNewPHINodes
.push_back(
469 std::make_pair(cast
<PHINode
>(I
), cast
<PHINode
>(Res
)));
473 llvm_unreachable("Unhandled instruction");
476 NodeInfo
.NewValue
= Res
;
477 if (auto *ResI
= dyn_cast
<Instruction
>(Res
))
481 for (auto &Node
: OldNewPHINodes
) {
482 PHINode
*OldPN
= Node
.first
;
483 PHINode
*NewPN
= Node
.second
;
484 for (auto Incoming
: zip(OldPN
->incoming_values(), OldPN
->blocks()))
485 NewPN
->addIncoming(getReducedOperand(std::get
<0>(Incoming
), SclTy
),
486 std::get
<1>(Incoming
));
489 Value
*Res
= getReducedOperand(CurrentTruncInst
->getOperand(0), SclTy
);
490 Type
*DstTy
= CurrentTruncInst
->getType();
491 if (Res
->getType() != DstTy
) {
492 IRBuilder
<> Builder(CurrentTruncInst
);
493 Res
= Builder
.CreateIntCast(Res
, DstTy
, false);
494 if (auto *ResI
= dyn_cast
<Instruction
>(Res
))
495 ResI
->takeName(CurrentTruncInst
);
497 CurrentTruncInst
->replaceAllUsesWith(Res
);
499 // Erase old expression graph, which was replaced by the reduced expression
501 CurrentTruncInst
->eraseFromParent();
502 // First, erase old phi-nodes and its uses
503 for (auto &Node
: OldNewPHINodes
) {
504 PHINode
*OldPN
= Node
.first
;
505 OldPN
->replaceAllUsesWith(PoisonValue::get(OldPN
->getType()));
506 InstInfoMap
.erase(OldPN
);
507 OldPN
->eraseFromParent();
509 // Now we have expression graph turned into dag.
510 // We iterate backward, which means we visit the instruction before we
511 // visit any of its operands, this way, when we get to the operand, we already
512 // removed the instructions (from the expression dag) that uses it.
513 for (auto &I
: llvm::reverse(InstInfoMap
)) {
514 // We still need to check that the instruction has no users before we erase
515 // it, because {SExt, ZExt}Inst Instruction might have other users that was
516 // not reduced, in such case, we need to keep that instruction.
517 if (I
.first
->use_empty())
518 I
.first
->eraseFromParent();
520 assert((isa
<SExtInst
>(I
.first
) || isa
<ZExtInst
>(I
.first
)) &&
521 "Only {SExt, ZExt}Inst might have unreduced users");
525 bool TruncInstCombine::run(Function
&F
) {
526 bool MadeIRChange
= false;
528 // Collect all TruncInst in the function into the Worklist for evaluating.
530 // Ignore unreachable basic block.
531 if (!DT
.isReachableFromEntry(&BB
))
534 if (auto *CI
= dyn_cast
<TruncInst
>(&I
))
535 Worklist
.push_back(CI
);
538 // Process all TruncInst in the Worklist, for each instruction:
539 // 1. Check if it dominates an eligible expression graph to be reduced.
540 // 2. Create a reduced expression graph and replace the old one with it.
541 while (!Worklist
.empty()) {
542 CurrentTruncInst
= Worklist
.pop_back_val();
544 if (Type
*NewDstSclTy
= getBestTruncatedType()) {
546 dbgs() << "ICE: TruncInstCombine reducing type of expression graph "
548 << CurrentTruncInst
<< '\n');
549 ReduceExpressionGraph(NewDstSclTy
);