[InstCombine] Signed saturation patterns
[llvm-complete.git] / include / llvm / CodeGen / GlobalISel / LegalizationArtifactCombiner.h
blob7f960e727846e1f344f47cb941d54106c4c9c619
1 //===-- llvm/CodeGen/GlobalISel/LegalizationArtifactCombiner.h -----*- C++ -*-//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // This file contains some helper functions which try to cleanup artifacts
9 // such as G_TRUNCs/G_[ZSA]EXTENDS that were created during legalization to make
10 // the types match. This file also contains some combines of merges that happens
11 // at the end of the legalization.
12 //===----------------------------------------------------------------------===//
14 #include "llvm/CodeGen/GlobalISel/Legalizer.h"
15 #include "llvm/CodeGen/GlobalISel/LegalizerInfo.h"
16 #include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
17 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
18 #include "llvm/CodeGen/GlobalISel/Utils.h"
19 #include "llvm/CodeGen/MachineRegisterInfo.h"
20 #include "llvm/Support/Debug.h"
22 #define DEBUG_TYPE "legalizer"
23 using namespace llvm::MIPatternMatch;
25 namespace llvm {
26 class LegalizationArtifactCombiner {
27 MachineIRBuilder &Builder;
28 MachineRegisterInfo &MRI;
29 const LegalizerInfo &LI;
31 static bool isArtifactCast(unsigned Opc) {
32 switch (Opc) {
33 case TargetOpcode::G_TRUNC:
34 case TargetOpcode::G_SEXT:
35 case TargetOpcode::G_ZEXT:
36 case TargetOpcode::G_ANYEXT:
37 return true;
38 default:
39 return false;
43 public:
44 LegalizationArtifactCombiner(MachineIRBuilder &B, MachineRegisterInfo &MRI,
45 const LegalizerInfo &LI)
46 : Builder(B), MRI(MRI), LI(LI) {}
48 bool tryCombineAnyExt(MachineInstr &MI,
49 SmallVectorImpl<MachineInstr *> &DeadInsts) {
50 assert(MI.getOpcode() == TargetOpcode::G_ANYEXT);
52 Builder.setInstr(MI);
53 Register DstReg = MI.getOperand(0).getReg();
54 Register SrcReg = lookThroughCopyInstrs(MI.getOperand(1).getReg());
56 // aext(trunc x) - > aext/copy/trunc x
57 Register TruncSrc;
58 if (mi_match(SrcReg, MRI, m_GTrunc(m_Reg(TruncSrc)))) {
59 LLVM_DEBUG(dbgs() << ".. Combine MI: " << MI;);
60 Builder.buildAnyExtOrTrunc(DstReg, TruncSrc);
61 markInstAndDefDead(MI, *MRI.getVRegDef(SrcReg), DeadInsts);
62 return true;
65 // aext([asz]ext x) -> [asz]ext x
66 Register ExtSrc;
67 MachineInstr *ExtMI;
68 if (mi_match(SrcReg, MRI,
69 m_all_of(m_MInstr(ExtMI), m_any_of(m_GAnyExt(m_Reg(ExtSrc)),
70 m_GSExt(m_Reg(ExtSrc)),
71 m_GZExt(m_Reg(ExtSrc)))))) {
72 Builder.buildInstr(ExtMI->getOpcode(), {DstReg}, {ExtSrc});
73 markInstAndDefDead(MI, *ExtMI, DeadInsts);
74 return true;
77 // Try to fold aext(g_constant) when the larger constant type is legal.
78 // Can't use MIPattern because we don't have a specific constant in mind.
79 auto *SrcMI = MRI.getVRegDef(SrcReg);
80 if (SrcMI->getOpcode() == TargetOpcode::G_CONSTANT) {
81 const LLT &DstTy = MRI.getType(DstReg);
82 if (isInstLegal({TargetOpcode::G_CONSTANT, {DstTy}})) {
83 auto &CstVal = SrcMI->getOperand(1);
84 Builder.buildConstant(
85 DstReg, CstVal.getCImm()->getValue().sext(DstTy.getSizeInBits()));
86 markInstAndDefDead(MI, *SrcMI, DeadInsts);
87 return true;
90 return tryFoldImplicitDef(MI, DeadInsts);
93 bool tryCombineZExt(MachineInstr &MI,
94 SmallVectorImpl<MachineInstr *> &DeadInsts) {
95 assert(MI.getOpcode() == TargetOpcode::G_ZEXT);
97 Builder.setInstr(MI);
98 Register DstReg = MI.getOperand(0).getReg();
99 Register SrcReg = lookThroughCopyInstrs(MI.getOperand(1).getReg());
101 // zext(trunc x) - > and (aext/copy/trunc x), mask
102 Register TruncSrc;
103 if (mi_match(SrcReg, MRI, m_GTrunc(m_Reg(TruncSrc)))) {
104 LLT DstTy = MRI.getType(DstReg);
105 if (isInstUnsupported({TargetOpcode::G_AND, {DstTy}}) ||
106 isConstantUnsupported(DstTy))
107 return false;
108 LLVM_DEBUG(dbgs() << ".. Combine MI: " << MI;);
109 LLT SrcTy = MRI.getType(SrcReg);
110 APInt Mask = APInt::getAllOnesValue(SrcTy.getScalarSizeInBits());
111 auto MIBMask = Builder.buildConstant(DstTy, Mask.getZExtValue());
112 Builder.buildAnd(DstReg, Builder.buildAnyExtOrTrunc(DstTy, TruncSrc),
113 MIBMask);
114 markInstAndDefDead(MI, *MRI.getVRegDef(SrcReg), DeadInsts);
115 return true;
118 // Try to fold zext(g_constant) when the larger constant type is legal.
119 // Can't use MIPattern because we don't have a specific constant in mind.
120 auto *SrcMI = MRI.getVRegDef(SrcReg);
121 if (SrcMI->getOpcode() == TargetOpcode::G_CONSTANT) {
122 const LLT &DstTy = MRI.getType(DstReg);
123 if (isInstLegal({TargetOpcode::G_CONSTANT, {DstTy}})) {
124 auto &CstVal = SrcMI->getOperand(1);
125 Builder.buildConstant(
126 DstReg, CstVal.getCImm()->getValue().zext(DstTy.getSizeInBits()));
127 markInstAndDefDead(MI, *SrcMI, DeadInsts);
128 return true;
131 return tryFoldImplicitDef(MI, DeadInsts);
134 bool tryCombineSExt(MachineInstr &MI,
135 SmallVectorImpl<MachineInstr *> &DeadInsts) {
136 assert(MI.getOpcode() == TargetOpcode::G_SEXT);
138 Builder.setInstr(MI);
139 Register DstReg = MI.getOperand(0).getReg();
140 Register SrcReg = lookThroughCopyInstrs(MI.getOperand(1).getReg());
142 // sext(trunc x) - > (sext_inreg (aext/copy/trunc x), c)
143 Register TruncSrc;
144 if (mi_match(SrcReg, MRI, m_GTrunc(m_Reg(TruncSrc)))) {
145 LLT DstTy = MRI.getType(DstReg);
146 if (isInstUnsupported({TargetOpcode::G_SEXT_INREG, {DstTy}}))
147 return false;
148 LLVM_DEBUG(dbgs() << ".. Combine MI: " << MI;);
149 LLT SrcTy = MRI.getType(SrcReg);
150 uint64_t SizeInBits = SrcTy.getScalarSizeInBits();
151 Builder.buildInstr(
152 TargetOpcode::G_SEXT_INREG, {DstReg},
153 {Builder.buildAnyExtOrTrunc(DstTy, TruncSrc), SizeInBits});
154 markInstAndDefDead(MI, *MRI.getVRegDef(SrcReg), DeadInsts);
155 return true;
157 return tryFoldImplicitDef(MI, DeadInsts);
160 /// Try to fold G_[ASZ]EXT (G_IMPLICIT_DEF).
161 bool tryFoldImplicitDef(MachineInstr &MI,
162 SmallVectorImpl<MachineInstr *> &DeadInsts) {
163 unsigned Opcode = MI.getOpcode();
164 assert(Opcode == TargetOpcode::G_ANYEXT || Opcode == TargetOpcode::G_ZEXT ||
165 Opcode == TargetOpcode::G_SEXT);
167 if (MachineInstr *DefMI = getOpcodeDef(TargetOpcode::G_IMPLICIT_DEF,
168 MI.getOperand(1).getReg(), MRI)) {
169 Builder.setInstr(MI);
170 Register DstReg = MI.getOperand(0).getReg();
171 LLT DstTy = MRI.getType(DstReg);
173 if (Opcode == TargetOpcode::G_ANYEXT) {
174 // G_ANYEXT (G_IMPLICIT_DEF) -> G_IMPLICIT_DEF
175 if (isInstUnsupported({TargetOpcode::G_IMPLICIT_DEF, {DstTy}}))
176 return false;
177 LLVM_DEBUG(dbgs() << ".. Combine G_ANYEXT(G_IMPLICIT_DEF): " << MI;);
178 Builder.buildInstr(TargetOpcode::G_IMPLICIT_DEF, {DstReg}, {});
179 } else {
180 // G_[SZ]EXT (G_IMPLICIT_DEF) -> G_CONSTANT 0 because the top
181 // bits will be 0 for G_ZEXT and 0/1 for the G_SEXT.
182 if (isConstantUnsupported(DstTy))
183 return false;
184 LLVM_DEBUG(dbgs() << ".. Combine G_[SZ]EXT(G_IMPLICIT_DEF): " << MI;);
185 Builder.buildConstant(DstReg, 0);
188 markInstAndDefDead(MI, *DefMI, DeadInsts);
189 return true;
191 return false;
194 static unsigned canFoldMergeOpcode(unsigned MergeOp, unsigned ConvertOp,
195 LLT OpTy, LLT DestTy) {
196 if (OpTy.isVector() && DestTy.isVector())
197 return MergeOp == TargetOpcode::G_CONCAT_VECTORS;
199 if (OpTy.isVector() && !DestTy.isVector()) {
200 if (MergeOp == TargetOpcode::G_BUILD_VECTOR)
201 return true;
203 if (MergeOp == TargetOpcode::G_CONCAT_VECTORS) {
204 if (ConvertOp == 0)
205 return true;
207 const unsigned OpEltSize = OpTy.getElementType().getSizeInBits();
209 // Don't handle scalarization with a cast that isn't in the same
210 // direction as the vector cast. This could be handled, but it would
211 // require more intermediate unmerges.
212 if (ConvertOp == TargetOpcode::G_TRUNC)
213 return DestTy.getSizeInBits() <= OpEltSize;
214 return DestTy.getSizeInBits() >= OpEltSize;
217 return false;
220 return MergeOp == TargetOpcode::G_MERGE_VALUES;
223 bool tryCombineMerges(MachineInstr &MI,
224 SmallVectorImpl<MachineInstr *> &DeadInsts) {
225 assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES);
227 unsigned NumDefs = MI.getNumOperands() - 1;
228 MachineInstr *SrcDef =
229 getDefIgnoringCopies(MI.getOperand(NumDefs).getReg(), MRI);
230 if (!SrcDef)
231 return false;
233 LLT OpTy = MRI.getType(MI.getOperand(NumDefs).getReg());
234 LLT DestTy = MRI.getType(MI.getOperand(0).getReg());
235 MachineInstr *MergeI = SrcDef;
236 unsigned ConvertOp = 0;
238 // Handle intermediate conversions
239 unsigned SrcOp = SrcDef->getOpcode();
240 if (isArtifactCast(SrcOp)) {
241 ConvertOp = SrcOp;
242 MergeI = getDefIgnoringCopies(SrcDef->getOperand(1).getReg(), MRI);
245 if (!MergeI || !canFoldMergeOpcode(MergeI->getOpcode(),
246 ConvertOp, OpTy, DestTy))
247 return false;
249 const unsigned NumMergeRegs = MergeI->getNumOperands() - 1;
251 if (NumMergeRegs < NumDefs) {
252 if (NumDefs % NumMergeRegs != 0)
253 return false;
255 Builder.setInstr(MI);
256 // Transform to UNMERGEs, for example
257 // %1 = G_MERGE_VALUES %4, %5
258 // %9, %10, %11, %12 = G_UNMERGE_VALUES %1
259 // to
260 // %9, %10 = G_UNMERGE_VALUES %4
261 // %11, %12 = G_UNMERGE_VALUES %5
263 const unsigned NewNumDefs = NumDefs / NumMergeRegs;
264 for (unsigned Idx = 0; Idx < NumMergeRegs; ++Idx) {
265 SmallVector<Register, 2> DstRegs;
266 for (unsigned j = 0, DefIdx = Idx * NewNumDefs; j < NewNumDefs;
267 ++j, ++DefIdx)
268 DstRegs.push_back(MI.getOperand(DefIdx).getReg());
270 if (ConvertOp) {
271 SmallVector<Register, 2> TmpRegs;
272 // This is a vector that is being scalarized and casted. Extract to
273 // the element type, and do the conversion on the scalars.
274 LLT MergeEltTy
275 = MRI.getType(MergeI->getOperand(0).getReg()).getElementType();
276 for (unsigned j = 0; j < NumMergeRegs; ++j)
277 TmpRegs.push_back(MRI.createGenericVirtualRegister(MergeEltTy));
279 Builder.buildUnmerge(TmpRegs, MergeI->getOperand(Idx + 1).getReg());
281 for (unsigned j = 0; j < NumMergeRegs; ++j)
282 Builder.buildInstr(ConvertOp, {DstRegs[j]}, {TmpRegs[j]});
283 } else {
284 Builder.buildUnmerge(DstRegs, MergeI->getOperand(Idx + 1).getReg());
288 } else if (NumMergeRegs > NumDefs) {
289 if (ConvertOp != 0 || NumMergeRegs % NumDefs != 0)
290 return false;
292 Builder.setInstr(MI);
293 // Transform to MERGEs
294 // %6 = G_MERGE_VALUES %17, %18, %19, %20
295 // %7, %8 = G_UNMERGE_VALUES %6
296 // to
297 // %7 = G_MERGE_VALUES %17, %18
298 // %8 = G_MERGE_VALUES %19, %20
300 const unsigned NumRegs = NumMergeRegs / NumDefs;
301 for (unsigned DefIdx = 0; DefIdx < NumDefs; ++DefIdx) {
302 SmallVector<Register, 2> Regs;
303 for (unsigned j = 0, Idx = NumRegs * DefIdx + 1; j < NumRegs;
304 ++j, ++Idx)
305 Regs.push_back(MergeI->getOperand(Idx).getReg());
307 Builder.buildMerge(MI.getOperand(DefIdx).getReg(), Regs);
310 } else {
311 LLT MergeSrcTy = MRI.getType(MergeI->getOperand(1).getReg());
312 if (ConvertOp) {
313 Builder.setInstr(MI);
315 for (unsigned Idx = 0; Idx < NumDefs; ++Idx) {
316 Register MergeSrc = MergeI->getOperand(Idx + 1).getReg();
317 Builder.buildInstr(ConvertOp, {MI.getOperand(Idx).getReg()},
318 {MergeSrc});
321 markInstAndDefDead(MI, *MergeI, DeadInsts);
322 return true;
324 // FIXME: is a COPY appropriate if the types mismatch? We know both
325 // registers are allocatable by now.
326 if (DestTy != MergeSrcTy)
327 return false;
329 for (unsigned Idx = 0; Idx < NumDefs; ++Idx)
330 MRI.replaceRegWith(MI.getOperand(Idx).getReg(),
331 MergeI->getOperand(Idx + 1).getReg());
334 markInstAndDefDead(MI, *MergeI, DeadInsts);
335 return true;
338 static bool isMergeLikeOpcode(unsigned Opc) {
339 switch (Opc) {
340 case TargetOpcode::G_MERGE_VALUES:
341 case TargetOpcode::G_BUILD_VECTOR:
342 case TargetOpcode::G_CONCAT_VECTORS:
343 return true;
344 default:
345 return false;
349 bool tryCombineExtract(MachineInstr &MI,
350 SmallVectorImpl<MachineInstr *> &DeadInsts) {
351 assert(MI.getOpcode() == TargetOpcode::G_EXTRACT);
353 // Try to use the source registers from a G_MERGE_VALUES
355 // %2 = G_MERGE_VALUES %0, %1
356 // %3 = G_EXTRACT %2, N
357 // =>
359 // for N < %2.getSizeInBits() / 2
360 // %3 = G_EXTRACT %0, N
362 // for N >= %2.getSizeInBits() / 2
363 // %3 = G_EXTRACT %1, (N - %0.getSizeInBits()
365 unsigned Src = lookThroughCopyInstrs(MI.getOperand(1).getReg());
366 MachineInstr *MergeI = MRI.getVRegDef(Src);
367 if (!MergeI || !isMergeLikeOpcode(MergeI->getOpcode()))
368 return false;
370 LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
371 LLT SrcTy = MRI.getType(Src);
373 // TODO: Do we need to check if the resulting extract is supported?
374 unsigned ExtractDstSize = DstTy.getSizeInBits();
375 unsigned Offset = MI.getOperand(2).getImm();
376 unsigned NumMergeSrcs = MergeI->getNumOperands() - 1;
377 unsigned MergeSrcSize = SrcTy.getSizeInBits() / NumMergeSrcs;
378 unsigned MergeSrcIdx = Offset / MergeSrcSize;
380 // Compute the offset of the last bit the extract needs.
381 unsigned EndMergeSrcIdx = (Offset + ExtractDstSize - 1) / MergeSrcSize;
383 // Can't handle the case where the extract spans multiple inputs.
384 if (MergeSrcIdx != EndMergeSrcIdx)
385 return false;
387 // TODO: We could modify MI in place in most cases.
388 Builder.setInstr(MI);
389 Builder.buildExtract(
390 MI.getOperand(0).getReg(),
391 MergeI->getOperand(MergeSrcIdx + 1).getReg(),
392 Offset - MergeSrcIdx * MergeSrcSize);
393 markInstAndDefDead(MI, *MergeI, DeadInsts);
394 return true;
397 /// Try to combine away MI.
398 /// Returns true if it combined away the MI.
399 /// Adds instructions that are dead as a result of the combine
400 /// into DeadInsts, which can include MI.
401 bool tryCombineInstruction(MachineInstr &MI,
402 SmallVectorImpl<MachineInstr *> &DeadInsts,
403 GISelObserverWrapper &WrapperObserver) {
404 // This might be a recursive call, and we might have DeadInsts already
405 // populated. To avoid bad things happening later with multiple vreg defs
406 // etc, process the dead instructions now if any.
407 if (!DeadInsts.empty())
408 deleteMarkedDeadInsts(DeadInsts, WrapperObserver);
409 switch (MI.getOpcode()) {
410 default:
411 return false;
412 case TargetOpcode::G_ANYEXT:
413 return tryCombineAnyExt(MI, DeadInsts);
414 case TargetOpcode::G_ZEXT:
415 return tryCombineZExt(MI, DeadInsts);
416 case TargetOpcode::G_SEXT:
417 return tryCombineSExt(MI, DeadInsts);
418 case TargetOpcode::G_UNMERGE_VALUES:
419 return tryCombineMerges(MI, DeadInsts);
420 case TargetOpcode::G_EXTRACT:
421 return tryCombineExtract(MI, DeadInsts);
422 case TargetOpcode::G_TRUNC: {
423 bool Changed = false;
424 for (auto &Use : MRI.use_instructions(MI.getOperand(0).getReg()))
425 Changed |= tryCombineInstruction(Use, DeadInsts, WrapperObserver);
426 return Changed;
431 private:
433 static unsigned getArtifactSrcReg(const MachineInstr &MI) {
434 switch (MI.getOpcode()) {
435 case TargetOpcode::COPY:
436 case TargetOpcode::G_TRUNC:
437 case TargetOpcode::G_ZEXT:
438 case TargetOpcode::G_ANYEXT:
439 case TargetOpcode::G_SEXT:
440 case TargetOpcode::G_UNMERGE_VALUES:
441 return MI.getOperand(MI.getNumOperands() - 1).getReg();
442 case TargetOpcode::G_EXTRACT:
443 return MI.getOperand(1).getReg();
444 default:
445 llvm_unreachable("Not a legalization artifact happen");
449 /// Mark MI as dead. If a def of one of MI's operands, DefMI, would also be
450 /// dead due to MI being killed, then mark DefMI as dead too.
451 /// Some of the combines (extends(trunc)), try to walk through redundant
452 /// copies in between the extends and the truncs, and this attempts to collect
453 /// the in between copies if they're dead.
454 void markInstAndDefDead(MachineInstr &MI, MachineInstr &DefMI,
455 SmallVectorImpl<MachineInstr *> &DeadInsts) {
456 DeadInsts.push_back(&MI);
458 // Collect all the copy instructions that are made dead, due to deleting
459 // this instruction. Collect all of them until the Trunc(DefMI).
460 // Eg,
461 // %1(s1) = G_TRUNC %0(s32)
462 // %2(s1) = COPY %1(s1)
463 // %3(s1) = COPY %2(s1)
464 // %4(s32) = G_ANYEXT %3(s1)
465 // In this case, we would have replaced %4 with a copy of %0,
466 // and as a result, %3, %2, %1 are dead.
467 MachineInstr *PrevMI = &MI;
468 while (PrevMI != &DefMI) {
469 unsigned PrevRegSrc = getArtifactSrcReg(*PrevMI);
471 MachineInstr *TmpDef = MRI.getVRegDef(PrevRegSrc);
472 if (MRI.hasOneUse(PrevRegSrc)) {
473 if (TmpDef != &DefMI) {
474 assert((TmpDef->getOpcode() == TargetOpcode::COPY ||
475 isArtifactCast(TmpDef->getOpcode())) &&
476 "Expecting copy or artifact cast here");
478 DeadInsts.push_back(TmpDef);
480 } else
481 break;
482 PrevMI = TmpDef;
484 if (PrevMI == &DefMI && MRI.hasOneUse(DefMI.getOperand(0).getReg()))
485 DeadInsts.push_back(&DefMI);
488 /// Erase the dead instructions in the list and call the observer hooks.
489 /// Normally the Legalizer will deal with erasing instructions that have been
490 /// marked dead. However, for the trunc(ext(x)) cases we can end up trying to
491 /// process instructions which have been marked dead, but otherwise break the
492 /// MIR by introducing multiple vreg defs. For those cases, allow the combines
493 /// to explicitly delete the instructions before we run into trouble.
494 void deleteMarkedDeadInsts(SmallVectorImpl<MachineInstr *> &DeadInsts,
495 GISelObserverWrapper &WrapperObserver) {
496 for (auto *DeadMI : DeadInsts) {
497 LLVM_DEBUG(dbgs() << *DeadMI << "Is dead, eagerly deleting\n");
498 WrapperObserver.erasingInstr(*DeadMI);
499 DeadMI->eraseFromParentAndMarkDBGValuesForRemoval();
501 DeadInsts.clear();
504 /// Checks if the target legalizer info has specified anything about the
505 /// instruction, or if unsupported.
506 bool isInstUnsupported(const LegalityQuery &Query) const {
507 using namespace LegalizeActions;
508 auto Step = LI.getAction(Query);
509 return Step.Action == Unsupported || Step.Action == NotFound;
512 bool isInstLegal(const LegalityQuery &Query) const {
513 return LI.getAction(Query).Action == LegalizeActions::Legal;
516 bool isConstantUnsupported(LLT Ty) const {
517 if (!Ty.isVector())
518 return isInstUnsupported({TargetOpcode::G_CONSTANT, {Ty}});
520 LLT EltTy = Ty.getElementType();
521 return isInstUnsupported({TargetOpcode::G_CONSTANT, {EltTy}}) ||
522 isInstUnsupported({TargetOpcode::G_BUILD_VECTOR, {Ty, EltTy}});
525 /// Looks through copy instructions and returns the actual
526 /// source register.
527 unsigned lookThroughCopyInstrs(Register Reg) {
528 Register TmpReg;
529 while (mi_match(Reg, MRI, m_Copy(m_Reg(TmpReg)))) {
530 if (MRI.getType(TmpReg).isValid())
531 Reg = TmpReg;
532 else
533 break;
535 return Reg;
539 } // namespace llvm