AMDGPU: Mark test as XFAIL in expensive_checks builds
[llvm-project.git] / llvm / lib / CodeGen / GlobalISel / CombinerHelper.cpp
blob4e3aaf5da7198c8e6a48226d649ac59c2f1ccf05
1 //===-- lib/CodeGen/GlobalISel/GICombinerHelper.cpp -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 #include "llvm/CodeGen/GlobalISel/CombinerHelper.h"
9 #include "llvm/ADT/APFloat.h"
10 #include "llvm/ADT/STLExtras.h"
11 #include "llvm/ADT/SetVector.h"
12 #include "llvm/ADT/SmallBitVector.h"
13 #include "llvm/Analysis/CmpInstAnalysis.h"
14 #include "llvm/CodeGen/GlobalISel/GISelChangeObserver.h"
15 #include "llvm/CodeGen/GlobalISel/GISelKnownBits.h"
16 #include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
17 #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
18 #include "llvm/CodeGen/GlobalISel/LegalizerInfo.h"
19 #include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
20 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
21 #include "llvm/CodeGen/GlobalISel/Utils.h"
22 #include "llvm/CodeGen/LowLevelTypeUtils.h"
23 #include "llvm/CodeGen/MachineBasicBlock.h"
24 #include "llvm/CodeGen/MachineDominators.h"
25 #include "llvm/CodeGen/MachineInstr.h"
26 #include "llvm/CodeGen/MachineMemOperand.h"
27 #include "llvm/CodeGen/MachineRegisterInfo.h"
28 #include "llvm/CodeGen/RegisterBankInfo.h"
29 #include "llvm/CodeGen/TargetInstrInfo.h"
30 #include "llvm/CodeGen/TargetLowering.h"
31 #include "llvm/CodeGen/TargetOpcodes.h"
32 #include "llvm/IR/ConstantRange.h"
33 #include "llvm/IR/DataLayout.h"
34 #include "llvm/IR/InstrTypes.h"
35 #include "llvm/Support/Casting.h"
36 #include "llvm/Support/DivisionByConstantInfo.h"
37 #include "llvm/Support/ErrorHandling.h"
38 #include "llvm/Support/MathExtras.h"
39 #include "llvm/Target/TargetMachine.h"
40 #include <cmath>
41 #include <optional>
42 #include <tuple>
44 #define DEBUG_TYPE "gi-combiner"
46 using namespace llvm;
47 using namespace MIPatternMatch;
49 // Option to allow testing of the combiner while no targets know about indexed
50 // addressing.
51 static cl::opt<bool>
52 ForceLegalIndexing("force-legal-indexing", cl::Hidden, cl::init(false),
53 cl::desc("Force all indexed operations to be "
54 "legal for the GlobalISel combiner"));
56 CombinerHelper::CombinerHelper(GISelChangeObserver &Observer,
57 MachineIRBuilder &B, bool IsPreLegalize,
58 GISelKnownBits *KB, MachineDominatorTree *MDT,
59 const LegalizerInfo *LI)
60 : Builder(B), MRI(Builder.getMF().getRegInfo()), Observer(Observer), KB(KB),
61 MDT(MDT), IsPreLegalize(IsPreLegalize), LI(LI),
62 RBI(Builder.getMF().getSubtarget().getRegBankInfo()),
63 TRI(Builder.getMF().getSubtarget().getRegisterInfo()) {
64 (void)this->KB;
67 const TargetLowering &CombinerHelper::getTargetLowering() const {
68 return *Builder.getMF().getSubtarget().getTargetLowering();
71 const MachineFunction &CombinerHelper::getMachineFunction() const {
72 return Builder.getMF();
75 const DataLayout &CombinerHelper::getDataLayout() const {
76 return getMachineFunction().getDataLayout();
79 LLVMContext &CombinerHelper::getContext() const { return Builder.getContext(); }
81 /// \returns The little endian in-memory byte position of byte \p I in a
82 /// \p ByteWidth bytes wide type.
83 ///
84 /// E.g. Given a 4-byte type x, x[0] -> byte 0
85 static unsigned littleEndianByteAt(const unsigned ByteWidth, const unsigned I) {
86 assert(I < ByteWidth && "I must be in [0, ByteWidth)");
87 return I;
90 /// Determines the LogBase2 value for a non-null input value using the
91 /// transform: LogBase2(V) = (EltBits - 1) - ctlz(V).
92 static Register buildLogBase2(Register V, MachineIRBuilder &MIB) {
93 auto &MRI = *MIB.getMRI();
94 LLT Ty = MRI.getType(V);
95 auto Ctlz = MIB.buildCTLZ(Ty, V);
96 auto Base = MIB.buildConstant(Ty, Ty.getScalarSizeInBits() - 1);
97 return MIB.buildSub(Ty, Base, Ctlz).getReg(0);
100 /// \returns The big endian in-memory byte position of byte \p I in a
101 /// \p ByteWidth bytes wide type.
103 /// E.g. Given a 4-byte type x, x[0] -> byte 3
104 static unsigned bigEndianByteAt(const unsigned ByteWidth, const unsigned I) {
105 assert(I < ByteWidth && "I must be in [0, ByteWidth)");
106 return ByteWidth - I - 1;
109 /// Given a map from byte offsets in memory to indices in a load/store,
110 /// determine if that map corresponds to a little or big endian byte pattern.
112 /// \param MemOffset2Idx maps memory offsets to address offsets.
113 /// \param LowestIdx is the lowest index in \p MemOffset2Idx.
115 /// \returns true if the map corresponds to a big endian byte pattern, false if
116 /// it corresponds to a little endian byte pattern, and std::nullopt otherwise.
118 /// E.g. given a 32-bit type x, and x[AddrOffset], the in-memory byte patterns
119 /// are as follows:
121 /// AddrOffset Little endian Big endian
122 /// 0 0 3
123 /// 1 1 2
124 /// 2 2 1
125 /// 3 3 0
126 static std::optional<bool>
127 isBigEndian(const SmallDenseMap<int64_t, int64_t, 8> &MemOffset2Idx,
128 int64_t LowestIdx) {
129 // Need at least two byte positions to decide on endianness.
130 unsigned Width = MemOffset2Idx.size();
131 if (Width < 2)
132 return std::nullopt;
133 bool BigEndian = true, LittleEndian = true;
134 for (unsigned MemOffset = 0; MemOffset < Width; ++ MemOffset) {
135 auto MemOffsetAndIdx = MemOffset2Idx.find(MemOffset);
136 if (MemOffsetAndIdx == MemOffset2Idx.end())
137 return std::nullopt;
138 const int64_t Idx = MemOffsetAndIdx->second - LowestIdx;
139 assert(Idx >= 0 && "Expected non-negative byte offset?");
140 LittleEndian &= Idx == littleEndianByteAt(Width, MemOffset);
141 BigEndian &= Idx == bigEndianByteAt(Width, MemOffset);
142 if (!BigEndian && !LittleEndian)
143 return std::nullopt;
146 assert((BigEndian != LittleEndian) &&
147 "Pattern cannot be both big and little endian!");
148 return BigEndian;
151 bool CombinerHelper::isPreLegalize() const { return IsPreLegalize; }
153 bool CombinerHelper::isLegal(const LegalityQuery &Query) const {
154 assert(LI && "Must have LegalizerInfo to query isLegal!");
155 return LI->getAction(Query).Action == LegalizeActions::Legal;
158 bool CombinerHelper::isLegalOrBeforeLegalizer(
159 const LegalityQuery &Query) const {
160 return isPreLegalize() || isLegal(Query);
163 bool CombinerHelper::isConstantLegalOrBeforeLegalizer(const LLT Ty) const {
164 if (!Ty.isVector())
165 return isLegalOrBeforeLegalizer({TargetOpcode::G_CONSTANT, {Ty}});
166 // Vector constants are represented as a G_BUILD_VECTOR of scalar G_CONSTANTs.
167 if (isPreLegalize())
168 return true;
169 LLT EltTy = Ty.getElementType();
170 return isLegal({TargetOpcode::G_BUILD_VECTOR, {Ty, EltTy}}) &&
171 isLegal({TargetOpcode::G_CONSTANT, {EltTy}});
174 void CombinerHelper::replaceRegWith(MachineRegisterInfo &MRI, Register FromReg,
175 Register ToReg) const {
176 Observer.changingAllUsesOfReg(MRI, FromReg);
178 if (MRI.constrainRegAttrs(ToReg, FromReg))
179 MRI.replaceRegWith(FromReg, ToReg);
180 else
181 Builder.buildCopy(FromReg, ToReg);
183 Observer.finishedChangingAllUsesOfReg();
186 void CombinerHelper::replaceRegOpWith(MachineRegisterInfo &MRI,
187 MachineOperand &FromRegOp,
188 Register ToReg) const {
189 assert(FromRegOp.getParent() && "Expected an operand in an MI");
190 Observer.changingInstr(*FromRegOp.getParent());
192 FromRegOp.setReg(ToReg);
194 Observer.changedInstr(*FromRegOp.getParent());
197 void CombinerHelper::replaceOpcodeWith(MachineInstr &FromMI,
198 unsigned ToOpcode) const {
199 Observer.changingInstr(FromMI);
201 FromMI.setDesc(Builder.getTII().get(ToOpcode));
203 Observer.changedInstr(FromMI);
206 const RegisterBank *CombinerHelper::getRegBank(Register Reg) const {
207 return RBI->getRegBank(Reg, MRI, *TRI);
210 void CombinerHelper::setRegBank(Register Reg,
211 const RegisterBank *RegBank) const {
212 if (RegBank)
213 MRI.setRegBank(Reg, *RegBank);
216 bool CombinerHelper::tryCombineCopy(MachineInstr &MI) const {
217 if (matchCombineCopy(MI)) {
218 applyCombineCopy(MI);
219 return true;
221 return false;
223 bool CombinerHelper::matchCombineCopy(MachineInstr &MI) const {
224 if (MI.getOpcode() != TargetOpcode::COPY)
225 return false;
226 Register DstReg = MI.getOperand(0).getReg();
227 Register SrcReg = MI.getOperand(1).getReg();
228 return canReplaceReg(DstReg, SrcReg, MRI);
230 void CombinerHelper::applyCombineCopy(MachineInstr &MI) const {
231 Register DstReg = MI.getOperand(0).getReg();
232 Register SrcReg = MI.getOperand(1).getReg();
233 replaceRegWith(MRI, DstReg, SrcReg);
234 MI.eraseFromParent();
237 bool CombinerHelper::matchFreezeOfSingleMaybePoisonOperand(
238 MachineInstr &MI, BuildFnTy &MatchInfo) const {
239 // Ported from InstCombinerImpl::pushFreezeToPreventPoisonFromPropagating.
240 Register DstOp = MI.getOperand(0).getReg();
241 Register OrigOp = MI.getOperand(1).getReg();
243 if (!MRI.hasOneNonDBGUse(OrigOp))
244 return false;
246 MachineInstr *OrigDef = MRI.getUniqueVRegDef(OrigOp);
247 // Even if only a single operand of the PHI is not guaranteed non-poison,
248 // moving freeze() backwards across a PHI can cause optimization issues for
249 // other users of that operand.
251 // Moving freeze() from one of the output registers of a G_UNMERGE_VALUES to
252 // the source register is unprofitable because it makes the freeze() more
253 // strict than is necessary (it would affect the whole register instead of
254 // just the subreg being frozen).
255 if (OrigDef->isPHI() || isa<GUnmerge>(OrigDef))
256 return false;
258 if (canCreateUndefOrPoison(OrigOp, MRI,
259 /*ConsiderFlagsAndMetadata=*/false))
260 return false;
262 std::optional<MachineOperand> MaybePoisonOperand;
263 for (MachineOperand &Operand : OrigDef->uses()) {
264 if (!Operand.isReg())
265 return false;
267 if (isGuaranteedNotToBeUndefOrPoison(Operand.getReg(), MRI))
268 continue;
270 if (!MaybePoisonOperand)
271 MaybePoisonOperand = Operand;
272 else {
273 // We have more than one maybe-poison operand. Moving the freeze is
274 // unsafe.
275 return false;
279 // Eliminate freeze if all operands are guaranteed non-poison.
280 if (!MaybePoisonOperand) {
281 MatchInfo = [=](MachineIRBuilder &B) {
282 Observer.changingInstr(*OrigDef);
283 cast<GenericMachineInstr>(OrigDef)->dropPoisonGeneratingFlags();
284 Observer.changedInstr(*OrigDef);
285 B.buildCopy(DstOp, OrigOp);
287 return true;
290 Register MaybePoisonOperandReg = MaybePoisonOperand->getReg();
291 LLT MaybePoisonOperandRegTy = MRI.getType(MaybePoisonOperandReg);
293 MatchInfo = [=](MachineIRBuilder &B) mutable {
294 Observer.changingInstr(*OrigDef);
295 cast<GenericMachineInstr>(OrigDef)->dropPoisonGeneratingFlags();
296 Observer.changedInstr(*OrigDef);
297 B.setInsertPt(*OrigDef->getParent(), OrigDef->getIterator());
298 auto Freeze = B.buildFreeze(MaybePoisonOperandRegTy, MaybePoisonOperandReg);
299 replaceRegOpWith(
300 MRI, *OrigDef->findRegisterUseOperand(MaybePoisonOperandReg, TRI),
301 Freeze.getReg(0));
302 replaceRegWith(MRI, DstOp, OrigOp);
304 return true;
307 bool CombinerHelper::matchCombineConcatVectors(
308 MachineInstr &MI, SmallVector<Register> &Ops) const {
309 assert(MI.getOpcode() == TargetOpcode::G_CONCAT_VECTORS &&
310 "Invalid instruction");
311 bool IsUndef = true;
312 MachineInstr *Undef = nullptr;
314 // Walk over all the operands of concat vectors and check if they are
315 // build_vector themselves or undef.
316 // Then collect their operands in Ops.
317 for (const MachineOperand &MO : MI.uses()) {
318 Register Reg = MO.getReg();
319 MachineInstr *Def = MRI.getVRegDef(Reg);
320 assert(Def && "Operand not defined");
321 if (!MRI.hasOneNonDBGUse(Reg))
322 return false;
323 switch (Def->getOpcode()) {
324 case TargetOpcode::G_BUILD_VECTOR:
325 IsUndef = false;
326 // Remember the operands of the build_vector to fold
327 // them into the yet-to-build flattened concat vectors.
328 for (const MachineOperand &BuildVecMO : Def->uses())
329 Ops.push_back(BuildVecMO.getReg());
330 break;
331 case TargetOpcode::G_IMPLICIT_DEF: {
332 LLT OpType = MRI.getType(Reg);
333 // Keep one undef value for all the undef operands.
334 if (!Undef) {
335 Builder.setInsertPt(*MI.getParent(), MI);
336 Undef = Builder.buildUndef(OpType.getScalarType());
338 assert(MRI.getType(Undef->getOperand(0).getReg()) ==
339 OpType.getScalarType() &&
340 "All undefs should have the same type");
341 // Break the undef vector in as many scalar elements as needed
342 // for the flattening.
343 for (unsigned EltIdx = 0, EltEnd = OpType.getNumElements();
344 EltIdx != EltEnd; ++EltIdx)
345 Ops.push_back(Undef->getOperand(0).getReg());
346 break;
348 default:
349 return false;
353 // Check if the combine is illegal
354 LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
355 if (!isLegalOrBeforeLegalizer(
356 {TargetOpcode::G_BUILD_VECTOR, {DstTy, MRI.getType(Ops[0])}})) {
357 return false;
360 if (IsUndef)
361 Ops.clear();
363 return true;
365 void CombinerHelper::applyCombineConcatVectors(
366 MachineInstr &MI, SmallVector<Register> &Ops) const {
367 // We determined that the concat_vectors can be flatten.
368 // Generate the flattened build_vector.
369 Register DstReg = MI.getOperand(0).getReg();
370 Builder.setInsertPt(*MI.getParent(), MI);
371 Register NewDstReg = MRI.cloneVirtualRegister(DstReg);
373 // Note: IsUndef is sort of redundant. We could have determine it by
374 // checking that at all Ops are undef. Alternatively, we could have
375 // generate a build_vector of undefs and rely on another combine to
376 // clean that up. For now, given we already gather this information
377 // in matchCombineConcatVectors, just save compile time and issue the
378 // right thing.
379 if (Ops.empty())
380 Builder.buildUndef(NewDstReg);
381 else
382 Builder.buildBuildVector(NewDstReg, Ops);
383 replaceRegWith(MRI, DstReg, NewDstReg);
384 MI.eraseFromParent();
387 bool CombinerHelper::matchCombineShuffleConcat(
388 MachineInstr &MI, SmallVector<Register> &Ops) const {
389 ArrayRef<int> Mask = MI.getOperand(3).getShuffleMask();
390 auto ConcatMI1 =
391 dyn_cast<GConcatVectors>(MRI.getVRegDef(MI.getOperand(1).getReg()));
392 auto ConcatMI2 =
393 dyn_cast<GConcatVectors>(MRI.getVRegDef(MI.getOperand(2).getReg()));
394 if (!ConcatMI1 || !ConcatMI2)
395 return false;
397 // Check that the sources of the Concat instructions have the same type
398 if (MRI.getType(ConcatMI1->getSourceReg(0)) !=
399 MRI.getType(ConcatMI2->getSourceReg(0)))
400 return false;
402 LLT ConcatSrcTy = MRI.getType(ConcatMI1->getReg(1));
403 LLT ShuffleSrcTy1 = MRI.getType(MI.getOperand(1).getReg());
404 unsigned ConcatSrcNumElt = ConcatSrcTy.getNumElements();
405 for (unsigned i = 0; i < Mask.size(); i += ConcatSrcNumElt) {
406 // Check if the index takes a whole source register from G_CONCAT_VECTORS
407 // Assumes that all Sources of G_CONCAT_VECTORS are the same type
408 if (Mask[i] == -1) {
409 for (unsigned j = 1; j < ConcatSrcNumElt; j++) {
410 if (i + j >= Mask.size())
411 return false;
412 if (Mask[i + j] != -1)
413 return false;
415 if (!isLegalOrBeforeLegalizer(
416 {TargetOpcode::G_IMPLICIT_DEF, {ConcatSrcTy}}))
417 return false;
418 Ops.push_back(0);
419 } else if (Mask[i] % ConcatSrcNumElt == 0) {
420 for (unsigned j = 1; j < ConcatSrcNumElt; j++) {
421 if (i + j >= Mask.size())
422 return false;
423 if (Mask[i + j] != Mask[i] + static_cast<int>(j))
424 return false;
426 // Retrieve the source register from its respective G_CONCAT_VECTORS
427 // instruction
428 if (Mask[i] < ShuffleSrcTy1.getNumElements()) {
429 Ops.push_back(ConcatMI1->getSourceReg(Mask[i] / ConcatSrcNumElt));
430 } else {
431 Ops.push_back(ConcatMI2->getSourceReg(Mask[i] / ConcatSrcNumElt -
432 ConcatMI1->getNumSources()));
434 } else {
435 return false;
439 if (!isLegalOrBeforeLegalizer(
440 {TargetOpcode::G_CONCAT_VECTORS,
441 {MRI.getType(MI.getOperand(0).getReg()), ConcatSrcTy}}))
442 return false;
444 return !Ops.empty();
447 void CombinerHelper::applyCombineShuffleConcat(
448 MachineInstr &MI, SmallVector<Register> &Ops) const {
449 LLT SrcTy;
450 for (Register &Reg : Ops) {
451 if (Reg != 0)
452 SrcTy = MRI.getType(Reg);
454 assert(SrcTy.isValid() && "Unexpected full undef vector in concat combine");
456 Register UndefReg = 0;
458 for (Register &Reg : Ops) {
459 if (Reg == 0) {
460 if (UndefReg == 0)
461 UndefReg = Builder.buildUndef(SrcTy).getReg(0);
462 Reg = UndefReg;
466 if (Ops.size() > 1)
467 Builder.buildConcatVectors(MI.getOperand(0).getReg(), Ops);
468 else
469 Builder.buildCopy(MI.getOperand(0).getReg(), Ops[0]);
470 MI.eraseFromParent();
473 bool CombinerHelper::tryCombineShuffleVector(MachineInstr &MI) const {
474 SmallVector<Register, 4> Ops;
475 if (matchCombineShuffleVector(MI, Ops)) {
476 applyCombineShuffleVector(MI, Ops);
477 return true;
479 return false;
482 bool CombinerHelper::matchCombineShuffleVector(
483 MachineInstr &MI, SmallVectorImpl<Register> &Ops) const {
484 assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR &&
485 "Invalid instruction kind");
486 LLT DstType = MRI.getType(MI.getOperand(0).getReg());
487 Register Src1 = MI.getOperand(1).getReg();
488 LLT SrcType = MRI.getType(Src1);
489 // As bizarre as it may look, shuffle vector can actually produce
490 // scalar! This is because at the IR level a <1 x ty> shuffle
491 // vector is perfectly valid.
492 unsigned DstNumElts = DstType.isVector() ? DstType.getNumElements() : 1;
493 unsigned SrcNumElts = SrcType.isVector() ? SrcType.getNumElements() : 1;
495 // If the resulting vector is smaller than the size of the source
496 // vectors being concatenated, we won't be able to replace the
497 // shuffle vector into a concat_vectors.
499 // Note: We may still be able to produce a concat_vectors fed by
500 // extract_vector_elt and so on. It is less clear that would
501 // be better though, so don't bother for now.
503 // If the destination is a scalar, the size of the sources doesn't
504 // matter. we will lower the shuffle to a plain copy. This will
505 // work only if the source and destination have the same size. But
506 // that's covered by the next condition.
508 // TODO: If the size between the source and destination don't match
509 // we could still emit an extract vector element in that case.
510 if (DstNumElts < 2 * SrcNumElts && DstNumElts != 1)
511 return false;
513 // Check that the shuffle mask can be broken evenly between the
514 // different sources.
515 if (DstNumElts % SrcNumElts != 0)
516 return false;
518 // Mask length is a multiple of the source vector length.
519 // Check if the shuffle is some kind of concatenation of the input
520 // vectors.
521 unsigned NumConcat = DstNumElts / SrcNumElts;
522 SmallVector<int, 8> ConcatSrcs(NumConcat, -1);
523 ArrayRef<int> Mask = MI.getOperand(3).getShuffleMask();
524 for (unsigned i = 0; i != DstNumElts; ++i) {
525 int Idx = Mask[i];
526 // Undef value.
527 if (Idx < 0)
528 continue;
529 // Ensure the indices in each SrcType sized piece are sequential and that
530 // the same source is used for the whole piece.
531 if ((Idx % SrcNumElts != (i % SrcNumElts)) ||
532 (ConcatSrcs[i / SrcNumElts] >= 0 &&
533 ConcatSrcs[i / SrcNumElts] != (int)(Idx / SrcNumElts)))
534 return false;
535 // Remember which source this index came from.
536 ConcatSrcs[i / SrcNumElts] = Idx / SrcNumElts;
539 // The shuffle is concatenating multiple vectors together.
540 // Collect the different operands for that.
541 Register UndefReg;
542 Register Src2 = MI.getOperand(2).getReg();
543 for (auto Src : ConcatSrcs) {
544 if (Src < 0) {
545 if (!UndefReg) {
546 Builder.setInsertPt(*MI.getParent(), MI);
547 UndefReg = Builder.buildUndef(SrcType).getReg(0);
549 Ops.push_back(UndefReg);
550 } else if (Src == 0)
551 Ops.push_back(Src1);
552 else
553 Ops.push_back(Src2);
555 return true;
558 void CombinerHelper::applyCombineShuffleVector(
559 MachineInstr &MI, const ArrayRef<Register> Ops) const {
560 Register DstReg = MI.getOperand(0).getReg();
561 Builder.setInsertPt(*MI.getParent(), MI);
562 Register NewDstReg = MRI.cloneVirtualRegister(DstReg);
564 if (Ops.size() == 1)
565 Builder.buildCopy(NewDstReg, Ops[0]);
566 else
567 Builder.buildMergeLikeInstr(NewDstReg, Ops);
569 replaceRegWith(MRI, DstReg, NewDstReg);
570 MI.eraseFromParent();
573 bool CombinerHelper::matchShuffleToExtract(MachineInstr &MI) const {
574 assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR &&
575 "Invalid instruction kind");
577 ArrayRef<int> Mask = MI.getOperand(3).getShuffleMask();
578 return Mask.size() == 1;
581 void CombinerHelper::applyShuffleToExtract(MachineInstr &MI) const {
582 Register DstReg = MI.getOperand(0).getReg();
583 Builder.setInsertPt(*MI.getParent(), MI);
585 int I = MI.getOperand(3).getShuffleMask()[0];
586 Register Src1 = MI.getOperand(1).getReg();
587 LLT Src1Ty = MRI.getType(Src1);
588 int Src1NumElts = Src1Ty.isVector() ? Src1Ty.getNumElements() : 1;
589 Register SrcReg;
590 if (I >= Src1NumElts) {
591 SrcReg = MI.getOperand(2).getReg();
592 I -= Src1NumElts;
593 } else if (I >= 0)
594 SrcReg = Src1;
596 if (I < 0)
597 Builder.buildUndef(DstReg);
598 else if (!MRI.getType(SrcReg).isVector())
599 Builder.buildCopy(DstReg, SrcReg);
600 else
601 Builder.buildExtractVectorElementConstant(DstReg, SrcReg, I);
603 MI.eraseFromParent();
606 namespace {
608 /// Select a preference between two uses. CurrentUse is the current preference
609 /// while *ForCandidate is attributes of the candidate under consideration.
610 PreferredTuple ChoosePreferredUse(MachineInstr &LoadMI,
611 PreferredTuple &CurrentUse,
612 const LLT TyForCandidate,
613 unsigned OpcodeForCandidate,
614 MachineInstr *MIForCandidate) {
615 if (!CurrentUse.Ty.isValid()) {
616 if (CurrentUse.ExtendOpcode == OpcodeForCandidate ||
617 CurrentUse.ExtendOpcode == TargetOpcode::G_ANYEXT)
618 return {TyForCandidate, OpcodeForCandidate, MIForCandidate};
619 return CurrentUse;
622 // We permit the extend to hoist through basic blocks but this is only
623 // sensible if the target has extending loads. If you end up lowering back
624 // into a load and extend during the legalizer then the end result is
625 // hoisting the extend up to the load.
627 // Prefer defined extensions to undefined extensions as these are more
628 // likely to reduce the number of instructions.
629 if (OpcodeForCandidate == TargetOpcode::G_ANYEXT &&
630 CurrentUse.ExtendOpcode != TargetOpcode::G_ANYEXT)
631 return CurrentUse;
632 else if (CurrentUse.ExtendOpcode == TargetOpcode::G_ANYEXT &&
633 OpcodeForCandidate != TargetOpcode::G_ANYEXT)
634 return {TyForCandidate, OpcodeForCandidate, MIForCandidate};
636 // Prefer sign extensions to zero extensions as sign-extensions tend to be
637 // more expensive. Don't do this if the load is already a zero-extend load
638 // though, otherwise we'll rewrite a zero-extend load into a sign-extend
639 // later.
640 if (!isa<GZExtLoad>(LoadMI) && CurrentUse.Ty == TyForCandidate) {
641 if (CurrentUse.ExtendOpcode == TargetOpcode::G_SEXT &&
642 OpcodeForCandidate == TargetOpcode::G_ZEXT)
643 return CurrentUse;
644 else if (CurrentUse.ExtendOpcode == TargetOpcode::G_ZEXT &&
645 OpcodeForCandidate == TargetOpcode::G_SEXT)
646 return {TyForCandidate, OpcodeForCandidate, MIForCandidate};
649 // This is potentially target specific. We've chosen the largest type
650 // because G_TRUNC is usually free. One potential catch with this is that
651 // some targets have a reduced number of larger registers than smaller
652 // registers and this choice potentially increases the live-range for the
653 // larger value.
654 if (TyForCandidate.getSizeInBits() > CurrentUse.Ty.getSizeInBits()) {
655 return {TyForCandidate, OpcodeForCandidate, MIForCandidate};
657 return CurrentUse;
660 /// Find a suitable place to insert some instructions and insert them. This
661 /// function accounts for special cases like inserting before a PHI node.
662 /// The current strategy for inserting before PHI's is to duplicate the
663 /// instructions for each predecessor. However, while that's ok for G_TRUNC
664 /// on most targets since it generally requires no code, other targets/cases may
665 /// want to try harder to find a dominating block.
666 static void InsertInsnsWithoutSideEffectsBeforeUse(
667 MachineIRBuilder &Builder, MachineInstr &DefMI, MachineOperand &UseMO,
668 std::function<void(MachineBasicBlock *, MachineBasicBlock::iterator,
669 MachineOperand &UseMO)>
670 Inserter) {
671 MachineInstr &UseMI = *UseMO.getParent();
673 MachineBasicBlock *InsertBB = UseMI.getParent();
675 // If the use is a PHI then we want the predecessor block instead.
676 if (UseMI.isPHI()) {
677 MachineOperand *PredBB = std::next(&UseMO);
678 InsertBB = PredBB->getMBB();
681 // If the block is the same block as the def then we want to insert just after
682 // the def instead of at the start of the block.
683 if (InsertBB == DefMI.getParent()) {
684 MachineBasicBlock::iterator InsertPt = &DefMI;
685 Inserter(InsertBB, std::next(InsertPt), UseMO);
686 return;
689 // Otherwise we want the start of the BB
690 Inserter(InsertBB, InsertBB->getFirstNonPHI(), UseMO);
692 } // end anonymous namespace
694 bool CombinerHelper::tryCombineExtendingLoads(MachineInstr &MI) const {
695 PreferredTuple Preferred;
696 if (matchCombineExtendingLoads(MI, Preferred)) {
697 applyCombineExtendingLoads(MI, Preferred);
698 return true;
700 return false;
703 static unsigned getExtLoadOpcForExtend(unsigned ExtOpc) {
704 unsigned CandidateLoadOpc;
705 switch (ExtOpc) {
706 case TargetOpcode::G_ANYEXT:
707 CandidateLoadOpc = TargetOpcode::G_LOAD;
708 break;
709 case TargetOpcode::G_SEXT:
710 CandidateLoadOpc = TargetOpcode::G_SEXTLOAD;
711 break;
712 case TargetOpcode::G_ZEXT:
713 CandidateLoadOpc = TargetOpcode::G_ZEXTLOAD;
714 break;
715 default:
716 llvm_unreachable("Unexpected extend opc");
718 return CandidateLoadOpc;
721 bool CombinerHelper::matchCombineExtendingLoads(
722 MachineInstr &MI, PreferredTuple &Preferred) const {
723 // We match the loads and follow the uses to the extend instead of matching
724 // the extends and following the def to the load. This is because the load
725 // must remain in the same position for correctness (unless we also add code
726 // to find a safe place to sink it) whereas the extend is freely movable.
727 // It also prevents us from duplicating the load for the volatile case or just
728 // for performance.
729 GAnyLoad *LoadMI = dyn_cast<GAnyLoad>(&MI);
730 if (!LoadMI)
731 return false;
733 Register LoadReg = LoadMI->getDstReg();
735 LLT LoadValueTy = MRI.getType(LoadReg);
736 if (!LoadValueTy.isScalar())
737 return false;
739 // Most architectures are going to legalize <s8 loads into at least a 1 byte
740 // load, and the MMOs can only describe memory accesses in multiples of bytes.
741 // If we try to perform extload combining on those, we can end up with
742 // %a(s8) = extload %ptr (load 1 byte from %ptr)
743 // ... which is an illegal extload instruction.
744 if (LoadValueTy.getSizeInBits() < 8)
745 return false;
747 // For non power-of-2 types, they will very likely be legalized into multiple
748 // loads. Don't bother trying to match them into extending loads.
749 if (!llvm::has_single_bit<uint32_t>(LoadValueTy.getSizeInBits()))
750 return false;
752 // Find the preferred type aside from the any-extends (unless it's the only
753 // one) and non-extending ops. We'll emit an extending load to that type and
754 // and emit a variant of (extend (trunc X)) for the others according to the
755 // relative type sizes. At the same time, pick an extend to use based on the
756 // extend involved in the chosen type.
757 unsigned PreferredOpcode =
758 isa<GLoad>(&MI)
759 ? TargetOpcode::G_ANYEXT
760 : isa<GSExtLoad>(&MI) ? TargetOpcode::G_SEXT : TargetOpcode::G_ZEXT;
761 Preferred = {LLT(), PreferredOpcode, nullptr};
762 for (auto &UseMI : MRI.use_nodbg_instructions(LoadReg)) {
763 if (UseMI.getOpcode() == TargetOpcode::G_SEXT ||
764 UseMI.getOpcode() == TargetOpcode::G_ZEXT ||
765 (UseMI.getOpcode() == TargetOpcode::G_ANYEXT)) {
766 const auto &MMO = LoadMI->getMMO();
767 // Don't do anything for atomics.
768 if (MMO.isAtomic())
769 continue;
770 // Check for legality.
771 if (!isPreLegalize()) {
772 LegalityQuery::MemDesc MMDesc(MMO);
773 unsigned CandidateLoadOpc = getExtLoadOpcForExtend(UseMI.getOpcode());
774 LLT UseTy = MRI.getType(UseMI.getOperand(0).getReg());
775 LLT SrcTy = MRI.getType(LoadMI->getPointerReg());
776 if (LI->getAction({CandidateLoadOpc, {UseTy, SrcTy}, {MMDesc}})
777 .Action != LegalizeActions::Legal)
778 continue;
780 Preferred = ChoosePreferredUse(MI, Preferred,
781 MRI.getType(UseMI.getOperand(0).getReg()),
782 UseMI.getOpcode(), &UseMI);
786 // There were no extends
787 if (!Preferred.MI)
788 return false;
789 // It should be impossible to chose an extend without selecting a different
790 // type since by definition the result of an extend is larger.
791 assert(Preferred.Ty != LoadValueTy && "Extending to same type?");
793 LLVM_DEBUG(dbgs() << "Preferred use is: " << *Preferred.MI);
794 return true;
797 void CombinerHelper::applyCombineExtendingLoads(
798 MachineInstr &MI, PreferredTuple &Preferred) const {
799 // Rewrite the load to the chosen extending load.
800 Register ChosenDstReg = Preferred.MI->getOperand(0).getReg();
802 // Inserter to insert a truncate back to the original type at a given point
803 // with some basic CSE to limit truncate duplication to one per BB.
804 DenseMap<MachineBasicBlock *, MachineInstr *> EmittedInsns;
805 auto InsertTruncAt = [&](MachineBasicBlock *InsertIntoBB,
806 MachineBasicBlock::iterator InsertBefore,
807 MachineOperand &UseMO) {
808 MachineInstr *PreviouslyEmitted = EmittedInsns.lookup(InsertIntoBB);
809 if (PreviouslyEmitted) {
810 Observer.changingInstr(*UseMO.getParent());
811 UseMO.setReg(PreviouslyEmitted->getOperand(0).getReg());
812 Observer.changedInstr(*UseMO.getParent());
813 return;
816 Builder.setInsertPt(*InsertIntoBB, InsertBefore);
817 Register NewDstReg = MRI.cloneVirtualRegister(MI.getOperand(0).getReg());
818 MachineInstr *NewMI = Builder.buildTrunc(NewDstReg, ChosenDstReg);
819 EmittedInsns[InsertIntoBB] = NewMI;
820 replaceRegOpWith(MRI, UseMO, NewDstReg);
823 Observer.changingInstr(MI);
824 unsigned LoadOpc = getExtLoadOpcForExtend(Preferred.ExtendOpcode);
825 MI.setDesc(Builder.getTII().get(LoadOpc));
827 // Rewrite all the uses to fix up the types.
828 auto &LoadValue = MI.getOperand(0);
829 SmallVector<MachineOperand *, 4> Uses;
830 for (auto &UseMO : MRI.use_operands(LoadValue.getReg()))
831 Uses.push_back(&UseMO);
833 for (auto *UseMO : Uses) {
834 MachineInstr *UseMI = UseMO->getParent();
836 // If the extend is compatible with the preferred extend then we should fix
837 // up the type and extend so that it uses the preferred use.
838 if (UseMI->getOpcode() == Preferred.ExtendOpcode ||
839 UseMI->getOpcode() == TargetOpcode::G_ANYEXT) {
840 Register UseDstReg = UseMI->getOperand(0).getReg();
841 MachineOperand &UseSrcMO = UseMI->getOperand(1);
842 const LLT UseDstTy = MRI.getType(UseDstReg);
843 if (UseDstReg != ChosenDstReg) {
844 if (Preferred.Ty == UseDstTy) {
845 // If the use has the same type as the preferred use, then merge
846 // the vregs and erase the extend. For example:
847 // %1:_(s8) = G_LOAD ...
848 // %2:_(s32) = G_SEXT %1(s8)
849 // %3:_(s32) = G_ANYEXT %1(s8)
850 // ... = ... %3(s32)
851 // rewrites to:
852 // %2:_(s32) = G_SEXTLOAD ...
853 // ... = ... %2(s32)
854 replaceRegWith(MRI, UseDstReg, ChosenDstReg);
855 Observer.erasingInstr(*UseMO->getParent());
856 UseMO->getParent()->eraseFromParent();
857 } else if (Preferred.Ty.getSizeInBits() < UseDstTy.getSizeInBits()) {
858 // If the preferred size is smaller, then keep the extend but extend
859 // from the result of the extending load. For example:
860 // %1:_(s8) = G_LOAD ...
861 // %2:_(s32) = G_SEXT %1(s8)
862 // %3:_(s64) = G_ANYEXT %1(s8)
863 // ... = ... %3(s64)
864 /// rewrites to:
865 // %2:_(s32) = G_SEXTLOAD ...
866 // %3:_(s64) = G_ANYEXT %2:_(s32)
867 // ... = ... %3(s64)
868 replaceRegOpWith(MRI, UseSrcMO, ChosenDstReg);
869 } else {
870 // If the preferred size is large, then insert a truncate. For
871 // example:
872 // %1:_(s8) = G_LOAD ...
873 // %2:_(s64) = G_SEXT %1(s8)
874 // %3:_(s32) = G_ZEXT %1(s8)
875 // ... = ... %3(s32)
876 /// rewrites to:
877 // %2:_(s64) = G_SEXTLOAD ...
878 // %4:_(s8) = G_TRUNC %2:_(s32)
879 // %3:_(s64) = G_ZEXT %2:_(s8)
880 // ... = ... %3(s64)
881 InsertInsnsWithoutSideEffectsBeforeUse(Builder, MI, *UseMO,
882 InsertTruncAt);
884 continue;
886 // The use is (one of) the uses of the preferred use we chose earlier.
887 // We're going to update the load to def this value later so just erase
888 // the old extend.
889 Observer.erasingInstr(*UseMO->getParent());
890 UseMO->getParent()->eraseFromParent();
891 continue;
894 // The use isn't an extend. Truncate back to the type we originally loaded.
895 // This is free on many targets.
896 InsertInsnsWithoutSideEffectsBeforeUse(Builder, MI, *UseMO, InsertTruncAt);
899 MI.getOperand(0).setReg(ChosenDstReg);
900 Observer.changedInstr(MI);
903 bool CombinerHelper::matchCombineLoadWithAndMask(MachineInstr &MI,
904 BuildFnTy &MatchInfo) const {
905 assert(MI.getOpcode() == TargetOpcode::G_AND);
907 // If we have the following code:
908 // %mask = G_CONSTANT 255
909 // %ld = G_LOAD %ptr, (load s16)
910 // %and = G_AND %ld, %mask
912 // Try to fold it into
913 // %ld = G_ZEXTLOAD %ptr, (load s8)
915 Register Dst = MI.getOperand(0).getReg();
916 if (MRI.getType(Dst).isVector())
917 return false;
919 auto MaybeMask =
920 getIConstantVRegValWithLookThrough(MI.getOperand(2).getReg(), MRI);
921 if (!MaybeMask)
922 return false;
924 APInt MaskVal = MaybeMask->Value;
926 if (!MaskVal.isMask())
927 return false;
929 Register SrcReg = MI.getOperand(1).getReg();
930 // Don't use getOpcodeDef() here since intermediate instructions may have
931 // multiple users.
932 GAnyLoad *LoadMI = dyn_cast<GAnyLoad>(MRI.getVRegDef(SrcReg));
933 if (!LoadMI || !MRI.hasOneNonDBGUse(LoadMI->getDstReg()))
934 return false;
936 Register LoadReg = LoadMI->getDstReg();
937 LLT RegTy = MRI.getType(LoadReg);
938 Register PtrReg = LoadMI->getPointerReg();
939 unsigned RegSize = RegTy.getSizeInBits();
940 LocationSize LoadSizeBits = LoadMI->getMemSizeInBits();
941 unsigned MaskSizeBits = MaskVal.countr_one();
943 // The mask may not be larger than the in-memory type, as it might cover sign
944 // extended bits
945 if (MaskSizeBits > LoadSizeBits.getValue())
946 return false;
948 // If the mask covers the whole destination register, there's nothing to
949 // extend
950 if (MaskSizeBits >= RegSize)
951 return false;
953 // Most targets cannot deal with loads of size < 8 and need to re-legalize to
954 // at least byte loads. Avoid creating such loads here
955 if (MaskSizeBits < 8 || !isPowerOf2_32(MaskSizeBits))
956 return false;
958 const MachineMemOperand &MMO = LoadMI->getMMO();
959 LegalityQuery::MemDesc MemDesc(MMO);
961 // Don't modify the memory access size if this is atomic/volatile, but we can
962 // still adjust the opcode to indicate the high bit behavior.
963 if (LoadMI->isSimple())
964 MemDesc.MemoryTy = LLT::scalar(MaskSizeBits);
965 else if (LoadSizeBits.getValue() > MaskSizeBits ||
966 LoadSizeBits.getValue() == RegSize)
967 return false;
969 // TODO: Could check if it's legal with the reduced or original memory size.
970 if (!isLegalOrBeforeLegalizer(
971 {TargetOpcode::G_ZEXTLOAD, {RegTy, MRI.getType(PtrReg)}, {MemDesc}}))
972 return false;
974 MatchInfo = [=](MachineIRBuilder &B) {
975 B.setInstrAndDebugLoc(*LoadMI);
976 auto &MF = B.getMF();
977 auto PtrInfo = MMO.getPointerInfo();
978 auto *NewMMO = MF.getMachineMemOperand(&MMO, PtrInfo, MemDesc.MemoryTy);
979 B.buildLoadInstr(TargetOpcode::G_ZEXTLOAD, Dst, PtrReg, *NewMMO);
980 LoadMI->eraseFromParent();
982 return true;
985 bool CombinerHelper::isPredecessor(const MachineInstr &DefMI,
986 const MachineInstr &UseMI) const {
987 assert(!DefMI.isDebugInstr() && !UseMI.isDebugInstr() &&
988 "shouldn't consider debug uses");
989 assert(DefMI.getParent() == UseMI.getParent());
990 if (&DefMI == &UseMI)
991 return true;
992 const MachineBasicBlock &MBB = *DefMI.getParent();
993 auto DefOrUse = find_if(MBB, [&DefMI, &UseMI](const MachineInstr &MI) {
994 return &MI == &DefMI || &MI == &UseMI;
996 if (DefOrUse == MBB.end())
997 llvm_unreachable("Block must contain both DefMI and UseMI!");
998 return &*DefOrUse == &DefMI;
1001 bool CombinerHelper::dominates(const MachineInstr &DefMI,
1002 const MachineInstr &UseMI) const {
1003 assert(!DefMI.isDebugInstr() && !UseMI.isDebugInstr() &&
1004 "shouldn't consider debug uses");
1005 if (MDT)
1006 return MDT->dominates(&DefMI, &UseMI);
1007 else if (DefMI.getParent() != UseMI.getParent())
1008 return false;
1010 return isPredecessor(DefMI, UseMI);
1013 bool CombinerHelper::matchSextTruncSextLoad(MachineInstr &MI) const {
1014 assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG);
1015 Register SrcReg = MI.getOperand(1).getReg();
1016 Register LoadUser = SrcReg;
1018 if (MRI.getType(SrcReg).isVector())
1019 return false;
1021 Register TruncSrc;
1022 if (mi_match(SrcReg, MRI, m_GTrunc(m_Reg(TruncSrc))))
1023 LoadUser = TruncSrc;
1025 uint64_t SizeInBits = MI.getOperand(2).getImm();
1026 // If the source is a G_SEXTLOAD from the same bit width, then we don't
1027 // need any extend at all, just a truncate.
1028 if (auto *LoadMI = getOpcodeDef<GSExtLoad>(LoadUser, MRI)) {
1029 // If truncating more than the original extended value, abort.
1030 auto LoadSizeBits = LoadMI->getMemSizeInBits();
1031 if (TruncSrc &&
1032 MRI.getType(TruncSrc).getSizeInBits() < LoadSizeBits.getValue())
1033 return false;
1034 if (LoadSizeBits == SizeInBits)
1035 return true;
1037 return false;
1040 void CombinerHelper::applySextTruncSextLoad(MachineInstr &MI) const {
1041 assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG);
1042 Builder.buildCopy(MI.getOperand(0).getReg(), MI.getOperand(1).getReg());
1043 MI.eraseFromParent();
1046 bool CombinerHelper::matchSextInRegOfLoad(
1047 MachineInstr &MI, std::tuple<Register, unsigned> &MatchInfo) const {
1048 assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG);
1050 Register DstReg = MI.getOperand(0).getReg();
1051 LLT RegTy = MRI.getType(DstReg);
1053 // Only supports scalars for now.
1054 if (RegTy.isVector())
1055 return false;
1057 Register SrcReg = MI.getOperand(1).getReg();
1058 auto *LoadDef = getOpcodeDef<GLoad>(SrcReg, MRI);
1059 if (!LoadDef || !MRI.hasOneNonDBGUse(SrcReg))
1060 return false;
1062 uint64_t MemBits = LoadDef->getMemSizeInBits().getValue();
1064 // If the sign extend extends from a narrower width than the load's width,
1065 // then we can narrow the load width when we combine to a G_SEXTLOAD.
1066 // Avoid widening the load at all.
1067 unsigned NewSizeBits = std::min((uint64_t)MI.getOperand(2).getImm(), MemBits);
1069 // Don't generate G_SEXTLOADs with a < 1 byte width.
1070 if (NewSizeBits < 8)
1071 return false;
1072 // Don't bother creating a non-power-2 sextload, it will likely be broken up
1073 // anyway for most targets.
1074 if (!isPowerOf2_32(NewSizeBits))
1075 return false;
1077 const MachineMemOperand &MMO = LoadDef->getMMO();
1078 LegalityQuery::MemDesc MMDesc(MMO);
1080 // Don't modify the memory access size if this is atomic/volatile, but we can
1081 // still adjust the opcode to indicate the high bit behavior.
1082 if (LoadDef->isSimple())
1083 MMDesc.MemoryTy = LLT::scalar(NewSizeBits);
1084 else if (MemBits > NewSizeBits || MemBits == RegTy.getSizeInBits())
1085 return false;
1087 // TODO: Could check if it's legal with the reduced or original memory size.
1088 if (!isLegalOrBeforeLegalizer({TargetOpcode::G_SEXTLOAD,
1089 {MRI.getType(LoadDef->getDstReg()),
1090 MRI.getType(LoadDef->getPointerReg())},
1091 {MMDesc}}))
1092 return false;
1094 MatchInfo = std::make_tuple(LoadDef->getDstReg(), NewSizeBits);
1095 return true;
1098 void CombinerHelper::applySextInRegOfLoad(
1099 MachineInstr &MI, std::tuple<Register, unsigned> &MatchInfo) const {
1100 assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG);
1101 Register LoadReg;
1102 unsigned ScalarSizeBits;
1103 std::tie(LoadReg, ScalarSizeBits) = MatchInfo;
1104 GLoad *LoadDef = cast<GLoad>(MRI.getVRegDef(LoadReg));
1106 // If we have the following:
1107 // %ld = G_LOAD %ptr, (load 2)
1108 // %ext = G_SEXT_INREG %ld, 8
1109 // ==>
1110 // %ld = G_SEXTLOAD %ptr (load 1)
1112 auto &MMO = LoadDef->getMMO();
1113 Builder.setInstrAndDebugLoc(*LoadDef);
1114 auto &MF = Builder.getMF();
1115 auto PtrInfo = MMO.getPointerInfo();
1116 auto *NewMMO = MF.getMachineMemOperand(&MMO, PtrInfo, ScalarSizeBits / 8);
1117 Builder.buildLoadInstr(TargetOpcode::G_SEXTLOAD, MI.getOperand(0).getReg(),
1118 LoadDef->getPointerReg(), *NewMMO);
1119 MI.eraseFromParent();
1121 // Not all loads can be deleted, so make sure the old one is removed.
1122 LoadDef->eraseFromParent();
1125 /// Return true if 'MI' is a load or a store that may be fold it's address
1126 /// operand into the load / store addressing mode.
1127 static bool canFoldInAddressingMode(GLoadStore *MI, const TargetLowering &TLI,
1128 MachineRegisterInfo &MRI) {
1129 TargetLowering::AddrMode AM;
1130 auto *MF = MI->getMF();
1131 auto *Addr = getOpcodeDef<GPtrAdd>(MI->getPointerReg(), MRI);
1132 if (!Addr)
1133 return false;
1135 AM.HasBaseReg = true;
1136 if (auto CstOff = getIConstantVRegVal(Addr->getOffsetReg(), MRI))
1137 AM.BaseOffs = CstOff->getSExtValue(); // [reg +/- imm]
1138 else
1139 AM.Scale = 1; // [reg +/- reg]
1141 return TLI.isLegalAddressingMode(
1142 MF->getDataLayout(), AM,
1143 getTypeForLLT(MI->getMMO().getMemoryType(),
1144 MF->getFunction().getContext()),
1145 MI->getMMO().getAddrSpace());
1148 static unsigned getIndexedOpc(unsigned LdStOpc) {
1149 switch (LdStOpc) {
1150 case TargetOpcode::G_LOAD:
1151 return TargetOpcode::G_INDEXED_LOAD;
1152 case TargetOpcode::G_STORE:
1153 return TargetOpcode::G_INDEXED_STORE;
1154 case TargetOpcode::G_ZEXTLOAD:
1155 return TargetOpcode::G_INDEXED_ZEXTLOAD;
1156 case TargetOpcode::G_SEXTLOAD:
1157 return TargetOpcode::G_INDEXED_SEXTLOAD;
1158 default:
1159 llvm_unreachable("Unexpected opcode");
1163 bool CombinerHelper::isIndexedLoadStoreLegal(GLoadStore &LdSt) const {
1164 // Check for legality.
1165 LLT PtrTy = MRI.getType(LdSt.getPointerReg());
1166 LLT Ty = MRI.getType(LdSt.getReg(0));
1167 LLT MemTy = LdSt.getMMO().getMemoryType();
1168 SmallVector<LegalityQuery::MemDesc, 2> MemDescrs(
1169 {{MemTy, MemTy.getSizeInBits().getKnownMinValue(),
1170 AtomicOrdering::NotAtomic}});
1171 unsigned IndexedOpc = getIndexedOpc(LdSt.getOpcode());
1172 SmallVector<LLT> OpTys;
1173 if (IndexedOpc == TargetOpcode::G_INDEXED_STORE)
1174 OpTys = {PtrTy, Ty, Ty};
1175 else
1176 OpTys = {Ty, PtrTy}; // For G_INDEXED_LOAD, G_INDEXED_[SZ]EXTLOAD
1178 LegalityQuery Q(IndexedOpc, OpTys, MemDescrs);
1179 return isLegal(Q);
1182 static cl::opt<unsigned> PostIndexUseThreshold(
1183 "post-index-use-threshold", cl::Hidden, cl::init(32),
1184 cl::desc("Number of uses of a base pointer to check before it is no longer "
1185 "considered for post-indexing."));
1187 bool CombinerHelper::findPostIndexCandidate(GLoadStore &LdSt, Register &Addr,
1188 Register &Base, Register &Offset,
1189 bool &RematOffset) const {
1190 // We're looking for the following pattern, for either load or store:
1191 // %baseptr:_(p0) = ...
1192 // G_STORE %val(s64), %baseptr(p0)
1193 // %offset:_(s64) = G_CONSTANT i64 -256
1194 // %new_addr:_(p0) = G_PTR_ADD %baseptr, %offset(s64)
1195 const auto &TLI = getTargetLowering();
1197 Register Ptr = LdSt.getPointerReg();
1198 // If the store is the only use, don't bother.
1199 if (MRI.hasOneNonDBGUse(Ptr))
1200 return false;
1202 if (!isIndexedLoadStoreLegal(LdSt))
1203 return false;
1205 if (getOpcodeDef(TargetOpcode::G_FRAME_INDEX, Ptr, MRI))
1206 return false;
1208 MachineInstr *StoredValDef = getDefIgnoringCopies(LdSt.getReg(0), MRI);
1209 auto *PtrDef = MRI.getVRegDef(Ptr);
1211 unsigned NumUsesChecked = 0;
1212 for (auto &Use : MRI.use_nodbg_instructions(Ptr)) {
1213 if (++NumUsesChecked > PostIndexUseThreshold)
1214 return false; // Try to avoid exploding compile time.
1216 auto *PtrAdd = dyn_cast<GPtrAdd>(&Use);
1217 // The use itself might be dead. This can happen during combines if DCE
1218 // hasn't had a chance to run yet. Don't allow it to form an indexed op.
1219 if (!PtrAdd || MRI.use_nodbg_empty(PtrAdd->getReg(0)))
1220 continue;
1222 // Check the user of this isn't the store, otherwise we'd be generate a
1223 // indexed store defining its own use.
1224 if (StoredValDef == &Use)
1225 continue;
1227 Offset = PtrAdd->getOffsetReg();
1228 if (!ForceLegalIndexing &&
1229 !TLI.isIndexingLegal(LdSt, PtrAdd->getBaseReg(), Offset,
1230 /*IsPre*/ false, MRI))
1231 continue;
1233 // Make sure the offset calculation is before the potentially indexed op.
1234 MachineInstr *OffsetDef = MRI.getVRegDef(Offset);
1235 RematOffset = false;
1236 if (!dominates(*OffsetDef, LdSt)) {
1237 // If the offset however is just a G_CONSTANT, we can always just
1238 // rematerialize it where we need it.
1239 if (OffsetDef->getOpcode() != TargetOpcode::G_CONSTANT)
1240 continue;
1241 RematOffset = true;
1244 for (auto &BasePtrUse : MRI.use_nodbg_instructions(PtrAdd->getBaseReg())) {
1245 if (&BasePtrUse == PtrDef)
1246 continue;
1248 // If the user is a later load/store that can be post-indexed, then don't
1249 // combine this one.
1250 auto *BasePtrLdSt = dyn_cast<GLoadStore>(&BasePtrUse);
1251 if (BasePtrLdSt && BasePtrLdSt != &LdSt &&
1252 dominates(LdSt, *BasePtrLdSt) &&
1253 isIndexedLoadStoreLegal(*BasePtrLdSt))
1254 return false;
1256 // Now we're looking for the key G_PTR_ADD instruction, which contains
1257 // the offset add that we want to fold.
1258 if (auto *BasePtrUseDef = dyn_cast<GPtrAdd>(&BasePtrUse)) {
1259 Register PtrAddDefReg = BasePtrUseDef->getReg(0);
1260 for (auto &BaseUseUse : MRI.use_nodbg_instructions(PtrAddDefReg)) {
1261 // If the use is in a different block, then we may produce worse code
1262 // due to the extra register pressure.
1263 if (BaseUseUse.getParent() != LdSt.getParent())
1264 return false;
1266 if (auto *UseUseLdSt = dyn_cast<GLoadStore>(&BaseUseUse))
1267 if (canFoldInAddressingMode(UseUseLdSt, TLI, MRI))
1268 return false;
1270 if (!dominates(LdSt, BasePtrUse))
1271 return false; // All use must be dominated by the load/store.
1275 Addr = PtrAdd->getReg(0);
1276 Base = PtrAdd->getBaseReg();
1277 return true;
1280 return false;
1283 bool CombinerHelper::findPreIndexCandidate(GLoadStore &LdSt, Register &Addr,
1284 Register &Base,
1285 Register &Offset) const {
1286 auto &MF = *LdSt.getParent()->getParent();
1287 const auto &TLI = *MF.getSubtarget().getTargetLowering();
1289 Addr = LdSt.getPointerReg();
1290 if (!mi_match(Addr, MRI, m_GPtrAdd(m_Reg(Base), m_Reg(Offset))) ||
1291 MRI.hasOneNonDBGUse(Addr))
1292 return false;
1294 if (!ForceLegalIndexing &&
1295 !TLI.isIndexingLegal(LdSt, Base, Offset, /*IsPre*/ true, MRI))
1296 return false;
1298 if (!isIndexedLoadStoreLegal(LdSt))
1299 return false;
1301 MachineInstr *BaseDef = getDefIgnoringCopies(Base, MRI);
1302 if (BaseDef->getOpcode() == TargetOpcode::G_FRAME_INDEX)
1303 return false;
1305 if (auto *St = dyn_cast<GStore>(&LdSt)) {
1306 // Would require a copy.
1307 if (Base == St->getValueReg())
1308 return false;
1310 // We're expecting one use of Addr in MI, but it could also be the
1311 // value stored, which isn't actually dominated by the instruction.
1312 if (St->getValueReg() == Addr)
1313 return false;
1316 // Avoid increasing cross-block register pressure.
1317 for (auto &AddrUse : MRI.use_nodbg_instructions(Addr))
1318 if (AddrUse.getParent() != LdSt.getParent())
1319 return false;
1321 // FIXME: check whether all uses of the base pointer are constant PtrAdds.
1322 // That might allow us to end base's liveness here by adjusting the constant.
1323 bool RealUse = false;
1324 for (auto &AddrUse : MRI.use_nodbg_instructions(Addr)) {
1325 if (!dominates(LdSt, AddrUse))
1326 return false; // All use must be dominated by the load/store.
1328 // If Ptr may be folded in addressing mode of other use, then it's
1329 // not profitable to do this transformation.
1330 if (auto *UseLdSt = dyn_cast<GLoadStore>(&AddrUse)) {
1331 if (!canFoldInAddressingMode(UseLdSt, TLI, MRI))
1332 RealUse = true;
1333 } else {
1334 RealUse = true;
1337 return RealUse;
1340 bool CombinerHelper::matchCombineExtractedVectorLoad(
1341 MachineInstr &MI, BuildFnTy &MatchInfo) const {
1342 assert(MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT);
1344 // Check if there is a load that defines the vector being extracted from.
1345 auto *LoadMI = getOpcodeDef<GLoad>(MI.getOperand(1).getReg(), MRI);
1346 if (!LoadMI)
1347 return false;
1349 Register Vector = MI.getOperand(1).getReg();
1350 LLT VecEltTy = MRI.getType(Vector).getElementType();
1352 assert(MRI.getType(MI.getOperand(0).getReg()) == VecEltTy);
1354 // Checking whether we should reduce the load width.
1355 if (!MRI.hasOneNonDBGUse(Vector))
1356 return false;
1358 // Check if the defining load is simple.
1359 if (!LoadMI->isSimple())
1360 return false;
1362 // If the vector element type is not a multiple of a byte then we are unable
1363 // to correctly compute an address to load only the extracted element as a
1364 // scalar.
1365 if (!VecEltTy.isByteSized())
1366 return false;
1368 // Check for load fold barriers between the extraction and the load.
1369 if (MI.getParent() != LoadMI->getParent())
1370 return false;
1371 const unsigned MaxIter = 20;
1372 unsigned Iter = 0;
1373 for (auto II = LoadMI->getIterator(), IE = MI.getIterator(); II != IE; ++II) {
1374 if (II->isLoadFoldBarrier())
1375 return false;
1376 if (Iter++ == MaxIter)
1377 return false;
1380 // Check if the new load that we are going to create is legal
1381 // if we are in the post-legalization phase.
1382 MachineMemOperand MMO = LoadMI->getMMO();
1383 Align Alignment = MMO.getAlign();
1384 MachinePointerInfo PtrInfo;
1385 uint64_t Offset;
1387 // Finding the appropriate PtrInfo if offset is a known constant.
1388 // This is required to create the memory operand for the narrowed load.
1389 // This machine memory operand object helps us infer about legality
1390 // before we proceed to combine the instruction.
1391 if (auto CVal = getIConstantVRegVal(Vector, MRI)) {
1392 int Elt = CVal->getZExtValue();
1393 // FIXME: should be (ABI size)*Elt.
1394 Offset = VecEltTy.getSizeInBits() * Elt / 8;
1395 PtrInfo = MMO.getPointerInfo().getWithOffset(Offset);
1396 } else {
1397 // Discard the pointer info except the address space because the memory
1398 // operand can't represent this new access since the offset is variable.
1399 Offset = VecEltTy.getSizeInBits() / 8;
1400 PtrInfo = MachinePointerInfo(MMO.getPointerInfo().getAddrSpace());
1403 Alignment = commonAlignment(Alignment, Offset);
1405 Register VecPtr = LoadMI->getPointerReg();
1406 LLT PtrTy = MRI.getType(VecPtr);
1408 MachineFunction &MF = *MI.getMF();
1409 auto *NewMMO = MF.getMachineMemOperand(&MMO, PtrInfo, VecEltTy);
1411 LegalityQuery::MemDesc MMDesc(*NewMMO);
1413 LegalityQuery Q = {TargetOpcode::G_LOAD, {VecEltTy, PtrTy}, {MMDesc}};
1415 if (!isLegalOrBeforeLegalizer(Q))
1416 return false;
1418 // Load must be allowed and fast on the target.
1419 LLVMContext &C = MF.getFunction().getContext();
1420 auto &DL = MF.getDataLayout();
1421 unsigned Fast = 0;
1422 if (!getTargetLowering().allowsMemoryAccess(C, DL, VecEltTy, *NewMMO,
1423 &Fast) ||
1424 !Fast)
1425 return false;
1427 Register Result = MI.getOperand(0).getReg();
1428 Register Index = MI.getOperand(2).getReg();
1430 MatchInfo = [=](MachineIRBuilder &B) {
1431 GISelObserverWrapper DummyObserver;
1432 LegalizerHelper Helper(B.getMF(), DummyObserver, B);
1433 //// Get pointer to the vector element.
1434 Register finalPtr = Helper.getVectorElementPointer(
1435 LoadMI->getPointerReg(), MRI.getType(LoadMI->getOperand(0).getReg()),
1436 Index);
1437 // New G_LOAD instruction.
1438 B.buildLoad(Result, finalPtr, PtrInfo, Alignment);
1439 // Remove original GLOAD instruction.
1440 LoadMI->eraseFromParent();
1443 return true;
1446 bool CombinerHelper::matchCombineIndexedLoadStore(
1447 MachineInstr &MI, IndexedLoadStoreMatchInfo &MatchInfo) const {
1448 auto &LdSt = cast<GLoadStore>(MI);
1450 if (LdSt.isAtomic())
1451 return false;
1453 MatchInfo.IsPre = findPreIndexCandidate(LdSt, MatchInfo.Addr, MatchInfo.Base,
1454 MatchInfo.Offset);
1455 if (!MatchInfo.IsPre &&
1456 !findPostIndexCandidate(LdSt, MatchInfo.Addr, MatchInfo.Base,
1457 MatchInfo.Offset, MatchInfo.RematOffset))
1458 return false;
1460 return true;
1463 void CombinerHelper::applyCombineIndexedLoadStore(
1464 MachineInstr &MI, IndexedLoadStoreMatchInfo &MatchInfo) const {
1465 MachineInstr &AddrDef = *MRI.getUniqueVRegDef(MatchInfo.Addr);
1466 unsigned Opcode = MI.getOpcode();
1467 bool IsStore = Opcode == TargetOpcode::G_STORE;
1468 unsigned NewOpcode = getIndexedOpc(Opcode);
1470 // If the offset constant didn't happen to dominate the load/store, we can
1471 // just clone it as needed.
1472 if (MatchInfo.RematOffset) {
1473 auto *OldCst = MRI.getVRegDef(MatchInfo.Offset);
1474 auto NewCst = Builder.buildConstant(MRI.getType(MatchInfo.Offset),
1475 *OldCst->getOperand(1).getCImm());
1476 MatchInfo.Offset = NewCst.getReg(0);
1479 auto MIB = Builder.buildInstr(NewOpcode);
1480 if (IsStore) {
1481 MIB.addDef(MatchInfo.Addr);
1482 MIB.addUse(MI.getOperand(0).getReg());
1483 } else {
1484 MIB.addDef(MI.getOperand(0).getReg());
1485 MIB.addDef(MatchInfo.Addr);
1488 MIB.addUse(MatchInfo.Base);
1489 MIB.addUse(MatchInfo.Offset);
1490 MIB.addImm(MatchInfo.IsPre);
1491 MIB->cloneMemRefs(*MI.getMF(), MI);
1492 MI.eraseFromParent();
1493 AddrDef.eraseFromParent();
1495 LLVM_DEBUG(dbgs() << " Combinined to indexed operation");
1498 bool CombinerHelper::matchCombineDivRem(MachineInstr &MI,
1499 MachineInstr *&OtherMI) const {
1500 unsigned Opcode = MI.getOpcode();
1501 bool IsDiv, IsSigned;
1503 switch (Opcode) {
1504 default:
1505 llvm_unreachable("Unexpected opcode!");
1506 case TargetOpcode::G_SDIV:
1507 case TargetOpcode::G_UDIV: {
1508 IsDiv = true;
1509 IsSigned = Opcode == TargetOpcode::G_SDIV;
1510 break;
1512 case TargetOpcode::G_SREM:
1513 case TargetOpcode::G_UREM: {
1514 IsDiv = false;
1515 IsSigned = Opcode == TargetOpcode::G_SREM;
1516 break;
1520 Register Src1 = MI.getOperand(1).getReg();
1521 unsigned DivOpcode, RemOpcode, DivremOpcode;
1522 if (IsSigned) {
1523 DivOpcode = TargetOpcode::G_SDIV;
1524 RemOpcode = TargetOpcode::G_SREM;
1525 DivremOpcode = TargetOpcode::G_SDIVREM;
1526 } else {
1527 DivOpcode = TargetOpcode::G_UDIV;
1528 RemOpcode = TargetOpcode::G_UREM;
1529 DivremOpcode = TargetOpcode::G_UDIVREM;
1532 if (!isLegalOrBeforeLegalizer({DivremOpcode, {MRI.getType(Src1)}}))
1533 return false;
1535 // Combine:
1536 // %div:_ = G_[SU]DIV %src1:_, %src2:_
1537 // %rem:_ = G_[SU]REM %src1:_, %src2:_
1538 // into:
1539 // %div:_, %rem:_ = G_[SU]DIVREM %src1:_, %src2:_
1541 // Combine:
1542 // %rem:_ = G_[SU]REM %src1:_, %src2:_
1543 // %div:_ = G_[SU]DIV %src1:_, %src2:_
1544 // into:
1545 // %div:_, %rem:_ = G_[SU]DIVREM %src1:_, %src2:_
1547 for (auto &UseMI : MRI.use_nodbg_instructions(Src1)) {
1548 if (MI.getParent() == UseMI.getParent() &&
1549 ((IsDiv && UseMI.getOpcode() == RemOpcode) ||
1550 (!IsDiv && UseMI.getOpcode() == DivOpcode)) &&
1551 matchEqualDefs(MI.getOperand(2), UseMI.getOperand(2)) &&
1552 matchEqualDefs(MI.getOperand(1), UseMI.getOperand(1))) {
1553 OtherMI = &UseMI;
1554 return true;
1558 return false;
1561 void CombinerHelper::applyCombineDivRem(MachineInstr &MI,
1562 MachineInstr *&OtherMI) const {
1563 unsigned Opcode = MI.getOpcode();
1564 assert(OtherMI && "OtherMI shouldn't be empty.");
1566 Register DestDivReg, DestRemReg;
1567 if (Opcode == TargetOpcode::G_SDIV || Opcode == TargetOpcode::G_UDIV) {
1568 DestDivReg = MI.getOperand(0).getReg();
1569 DestRemReg = OtherMI->getOperand(0).getReg();
1570 } else {
1571 DestDivReg = OtherMI->getOperand(0).getReg();
1572 DestRemReg = MI.getOperand(0).getReg();
1575 bool IsSigned =
1576 Opcode == TargetOpcode::G_SDIV || Opcode == TargetOpcode::G_SREM;
1578 // Check which instruction is first in the block so we don't break def-use
1579 // deps by "moving" the instruction incorrectly. Also keep track of which
1580 // instruction is first so we pick it's operands, avoiding use-before-def
1581 // bugs.
1582 MachineInstr *FirstInst = dominates(MI, *OtherMI) ? &MI : OtherMI;
1583 Builder.setInstrAndDebugLoc(*FirstInst);
1585 Builder.buildInstr(IsSigned ? TargetOpcode::G_SDIVREM
1586 : TargetOpcode::G_UDIVREM,
1587 {DestDivReg, DestRemReg},
1588 { FirstInst->getOperand(1), FirstInst->getOperand(2) });
1589 MI.eraseFromParent();
1590 OtherMI->eraseFromParent();
1593 bool CombinerHelper::matchOptBrCondByInvertingCond(
1594 MachineInstr &MI, MachineInstr *&BrCond) const {
1595 assert(MI.getOpcode() == TargetOpcode::G_BR);
1597 // Try to match the following:
1598 // bb1:
1599 // G_BRCOND %c1, %bb2
1600 // G_BR %bb3
1601 // bb2:
1602 // ...
1603 // bb3:
1605 // The above pattern does not have a fall through to the successor bb2, always
1606 // resulting in a branch no matter which path is taken. Here we try to find
1607 // and replace that pattern with conditional branch to bb3 and otherwise
1608 // fallthrough to bb2. This is generally better for branch predictors.
1610 MachineBasicBlock *MBB = MI.getParent();
1611 MachineBasicBlock::iterator BrIt(MI);
1612 if (BrIt == MBB->begin())
1613 return false;
1614 assert(std::next(BrIt) == MBB->end() && "expected G_BR to be a terminator");
1616 BrCond = &*std::prev(BrIt);
1617 if (BrCond->getOpcode() != TargetOpcode::G_BRCOND)
1618 return false;
1620 // Check that the next block is the conditional branch target. Also make sure
1621 // that it isn't the same as the G_BR's target (otherwise, this will loop.)
1622 MachineBasicBlock *BrCondTarget = BrCond->getOperand(1).getMBB();
1623 return BrCondTarget != MI.getOperand(0).getMBB() &&
1624 MBB->isLayoutSuccessor(BrCondTarget);
1627 void CombinerHelper::applyOptBrCondByInvertingCond(
1628 MachineInstr &MI, MachineInstr *&BrCond) const {
1629 MachineBasicBlock *BrTarget = MI.getOperand(0).getMBB();
1630 Builder.setInstrAndDebugLoc(*BrCond);
1631 LLT Ty = MRI.getType(BrCond->getOperand(0).getReg());
1632 // FIXME: Does int/fp matter for this? If so, we might need to restrict
1633 // this to i1 only since we might not know for sure what kind of
1634 // compare generated the condition value.
1635 auto True = Builder.buildConstant(
1636 Ty, getICmpTrueVal(getTargetLowering(), false, false));
1637 auto Xor = Builder.buildXor(Ty, BrCond->getOperand(0), True);
1639 auto *FallthroughBB = BrCond->getOperand(1).getMBB();
1640 Observer.changingInstr(MI);
1641 MI.getOperand(0).setMBB(FallthroughBB);
1642 Observer.changedInstr(MI);
1644 // Change the conditional branch to use the inverted condition and
1645 // new target block.
1646 Observer.changingInstr(*BrCond);
1647 BrCond->getOperand(0).setReg(Xor.getReg(0));
1648 BrCond->getOperand(1).setMBB(BrTarget);
1649 Observer.changedInstr(*BrCond);
1652 bool CombinerHelper::tryEmitMemcpyInline(MachineInstr &MI) const {
1653 MachineIRBuilder HelperBuilder(MI);
1654 GISelObserverWrapper DummyObserver;
1655 LegalizerHelper Helper(HelperBuilder.getMF(), DummyObserver, HelperBuilder);
1656 return Helper.lowerMemcpyInline(MI) ==
1657 LegalizerHelper::LegalizeResult::Legalized;
1660 bool CombinerHelper::tryCombineMemCpyFamily(MachineInstr &MI,
1661 unsigned MaxLen) const {
1662 MachineIRBuilder HelperBuilder(MI);
1663 GISelObserverWrapper DummyObserver;
1664 LegalizerHelper Helper(HelperBuilder.getMF(), DummyObserver, HelperBuilder);
1665 return Helper.lowerMemCpyFamily(MI, MaxLen) ==
1666 LegalizerHelper::LegalizeResult::Legalized;
1669 static APFloat constantFoldFpUnary(const MachineInstr &MI,
1670 const MachineRegisterInfo &MRI,
1671 const APFloat &Val) {
1672 APFloat Result(Val);
1673 switch (MI.getOpcode()) {
1674 default:
1675 llvm_unreachable("Unexpected opcode!");
1676 case TargetOpcode::G_FNEG: {
1677 Result.changeSign();
1678 return Result;
1680 case TargetOpcode::G_FABS: {
1681 Result.clearSign();
1682 return Result;
1684 case TargetOpcode::G_FPTRUNC: {
1685 bool Unused;
1686 LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
1687 Result.convert(getFltSemanticForLLT(DstTy), APFloat::rmNearestTiesToEven,
1688 &Unused);
1689 return Result;
1691 case TargetOpcode::G_FSQRT: {
1692 bool Unused;
1693 Result.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven,
1694 &Unused);
1695 Result = APFloat(sqrt(Result.convertToDouble()));
1696 break;
1698 case TargetOpcode::G_FLOG2: {
1699 bool Unused;
1700 Result.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven,
1701 &Unused);
1702 Result = APFloat(log2(Result.convertToDouble()));
1703 break;
1706 // Convert `APFloat` to appropriate IEEE type depending on `DstTy`. Otherwise,
1707 // `buildFConstant` will assert on size mismatch. Only `G_FSQRT`, and
1708 // `G_FLOG2` reach here.
1709 bool Unused;
1710 Result.convert(Val.getSemantics(), APFloat::rmNearestTiesToEven, &Unused);
1711 return Result;
1714 void CombinerHelper::applyCombineConstantFoldFpUnary(
1715 MachineInstr &MI, const ConstantFP *Cst) const {
1716 APFloat Folded = constantFoldFpUnary(MI, MRI, Cst->getValue());
1717 const ConstantFP *NewCst = ConstantFP::get(Builder.getContext(), Folded);
1718 Builder.buildFConstant(MI.getOperand(0), *NewCst);
1719 MI.eraseFromParent();
1722 bool CombinerHelper::matchPtrAddImmedChain(MachineInstr &MI,
1723 PtrAddChain &MatchInfo) const {
1724 // We're trying to match the following pattern:
1725 // %t1 = G_PTR_ADD %base, G_CONSTANT imm1
1726 // %root = G_PTR_ADD %t1, G_CONSTANT imm2
1727 // -->
1728 // %root = G_PTR_ADD %base, G_CONSTANT (imm1 + imm2)
1730 if (MI.getOpcode() != TargetOpcode::G_PTR_ADD)
1731 return false;
1733 Register Add2 = MI.getOperand(1).getReg();
1734 Register Imm1 = MI.getOperand(2).getReg();
1735 auto MaybeImmVal = getIConstantVRegValWithLookThrough(Imm1, MRI);
1736 if (!MaybeImmVal)
1737 return false;
1739 MachineInstr *Add2Def = MRI.getVRegDef(Add2);
1740 if (!Add2Def || Add2Def->getOpcode() != TargetOpcode::G_PTR_ADD)
1741 return false;
1743 Register Base = Add2Def->getOperand(1).getReg();
1744 Register Imm2 = Add2Def->getOperand(2).getReg();
1745 auto MaybeImm2Val = getIConstantVRegValWithLookThrough(Imm2, MRI);
1746 if (!MaybeImm2Val)
1747 return false;
1749 // Check if the new combined immediate forms an illegal addressing mode.
1750 // Do not combine if it was legal before but would get illegal.
1751 // To do so, we need to find a load/store user of the pointer to get
1752 // the access type.
1753 Type *AccessTy = nullptr;
1754 auto &MF = *MI.getMF();
1755 for (auto &UseMI : MRI.use_nodbg_instructions(MI.getOperand(0).getReg())) {
1756 if (auto *LdSt = dyn_cast<GLoadStore>(&UseMI)) {
1757 AccessTy = getTypeForLLT(MRI.getType(LdSt->getReg(0)),
1758 MF.getFunction().getContext());
1759 break;
1762 TargetLoweringBase::AddrMode AMNew;
1763 APInt CombinedImm = MaybeImmVal->Value + MaybeImm2Val->Value;
1764 AMNew.BaseOffs = CombinedImm.getSExtValue();
1765 if (AccessTy) {
1766 AMNew.HasBaseReg = true;
1767 TargetLoweringBase::AddrMode AMOld;
1768 AMOld.BaseOffs = MaybeImmVal->Value.getSExtValue();
1769 AMOld.HasBaseReg = true;
1770 unsigned AS = MRI.getType(Add2).getAddressSpace();
1771 const auto &TLI = *MF.getSubtarget().getTargetLowering();
1772 if (TLI.isLegalAddressingMode(MF.getDataLayout(), AMOld, AccessTy, AS) &&
1773 !TLI.isLegalAddressingMode(MF.getDataLayout(), AMNew, AccessTy, AS))
1774 return false;
1777 // Pass the combined immediate to the apply function.
1778 MatchInfo.Imm = AMNew.BaseOffs;
1779 MatchInfo.Base = Base;
1780 MatchInfo.Bank = getRegBank(Imm2);
1781 return true;
1784 void CombinerHelper::applyPtrAddImmedChain(MachineInstr &MI,
1785 PtrAddChain &MatchInfo) const {
1786 assert(MI.getOpcode() == TargetOpcode::G_PTR_ADD && "Expected G_PTR_ADD");
1787 MachineIRBuilder MIB(MI);
1788 LLT OffsetTy = MRI.getType(MI.getOperand(2).getReg());
1789 auto NewOffset = MIB.buildConstant(OffsetTy, MatchInfo.Imm);
1790 setRegBank(NewOffset.getReg(0), MatchInfo.Bank);
1791 Observer.changingInstr(MI);
1792 MI.getOperand(1).setReg(MatchInfo.Base);
1793 MI.getOperand(2).setReg(NewOffset.getReg(0));
1794 Observer.changedInstr(MI);
1797 bool CombinerHelper::matchShiftImmedChain(MachineInstr &MI,
1798 RegisterImmPair &MatchInfo) const {
1799 // We're trying to match the following pattern with any of
1800 // G_SHL/G_ASHR/G_LSHR/G_SSHLSAT/G_USHLSAT shift instructions:
1801 // %t1 = SHIFT %base, G_CONSTANT imm1
1802 // %root = SHIFT %t1, G_CONSTANT imm2
1803 // -->
1804 // %root = SHIFT %base, G_CONSTANT (imm1 + imm2)
1806 unsigned Opcode = MI.getOpcode();
1807 assert((Opcode == TargetOpcode::G_SHL || Opcode == TargetOpcode::G_ASHR ||
1808 Opcode == TargetOpcode::G_LSHR || Opcode == TargetOpcode::G_SSHLSAT ||
1809 Opcode == TargetOpcode::G_USHLSAT) &&
1810 "Expected G_SHL, G_ASHR, G_LSHR, G_SSHLSAT or G_USHLSAT");
1812 Register Shl2 = MI.getOperand(1).getReg();
1813 Register Imm1 = MI.getOperand(2).getReg();
1814 auto MaybeImmVal = getIConstantVRegValWithLookThrough(Imm1, MRI);
1815 if (!MaybeImmVal)
1816 return false;
1818 MachineInstr *Shl2Def = MRI.getUniqueVRegDef(Shl2);
1819 if (Shl2Def->getOpcode() != Opcode)
1820 return false;
1822 Register Base = Shl2Def->getOperand(1).getReg();
1823 Register Imm2 = Shl2Def->getOperand(2).getReg();
1824 auto MaybeImm2Val = getIConstantVRegValWithLookThrough(Imm2, MRI);
1825 if (!MaybeImm2Val)
1826 return false;
1828 // Pass the combined immediate to the apply function.
1829 MatchInfo.Imm =
1830 (MaybeImmVal->Value.getZExtValue() + MaybeImm2Val->Value).getZExtValue();
1831 MatchInfo.Reg = Base;
1833 // There is no simple replacement for a saturating unsigned left shift that
1834 // exceeds the scalar size.
1835 if (Opcode == TargetOpcode::G_USHLSAT &&
1836 MatchInfo.Imm >= MRI.getType(Shl2).getScalarSizeInBits())
1837 return false;
1839 return true;
1842 void CombinerHelper::applyShiftImmedChain(MachineInstr &MI,
1843 RegisterImmPair &MatchInfo) const {
1844 unsigned Opcode = MI.getOpcode();
1845 assert((Opcode == TargetOpcode::G_SHL || Opcode == TargetOpcode::G_ASHR ||
1846 Opcode == TargetOpcode::G_LSHR || Opcode == TargetOpcode::G_SSHLSAT ||
1847 Opcode == TargetOpcode::G_USHLSAT) &&
1848 "Expected G_SHL, G_ASHR, G_LSHR, G_SSHLSAT or G_USHLSAT");
1850 LLT Ty = MRI.getType(MI.getOperand(1).getReg());
1851 unsigned const ScalarSizeInBits = Ty.getScalarSizeInBits();
1852 auto Imm = MatchInfo.Imm;
1854 if (Imm >= ScalarSizeInBits) {
1855 // Any logical shift that exceeds scalar size will produce zero.
1856 if (Opcode == TargetOpcode::G_SHL || Opcode == TargetOpcode::G_LSHR) {
1857 Builder.buildConstant(MI.getOperand(0), 0);
1858 MI.eraseFromParent();
1859 return;
1861 // Arithmetic shift and saturating signed left shift have no effect beyond
1862 // scalar size.
1863 Imm = ScalarSizeInBits - 1;
1866 LLT ImmTy = MRI.getType(MI.getOperand(2).getReg());
1867 Register NewImm = Builder.buildConstant(ImmTy, Imm).getReg(0);
1868 Observer.changingInstr(MI);
1869 MI.getOperand(1).setReg(MatchInfo.Reg);
1870 MI.getOperand(2).setReg(NewImm);
1871 Observer.changedInstr(MI);
1874 bool CombinerHelper::matchShiftOfShiftedLogic(
1875 MachineInstr &MI, ShiftOfShiftedLogic &MatchInfo) const {
1876 // We're trying to match the following pattern with any of
1877 // G_SHL/G_ASHR/G_LSHR/G_USHLSAT/G_SSHLSAT shift instructions in combination
1878 // with any of G_AND/G_OR/G_XOR logic instructions.
1879 // %t1 = SHIFT %X, G_CONSTANT C0
1880 // %t2 = LOGIC %t1, %Y
1881 // %root = SHIFT %t2, G_CONSTANT C1
1882 // -->
1883 // %t3 = SHIFT %X, G_CONSTANT (C0+C1)
1884 // %t4 = SHIFT %Y, G_CONSTANT C1
1885 // %root = LOGIC %t3, %t4
1886 unsigned ShiftOpcode = MI.getOpcode();
1887 assert((ShiftOpcode == TargetOpcode::G_SHL ||
1888 ShiftOpcode == TargetOpcode::G_ASHR ||
1889 ShiftOpcode == TargetOpcode::G_LSHR ||
1890 ShiftOpcode == TargetOpcode::G_USHLSAT ||
1891 ShiftOpcode == TargetOpcode::G_SSHLSAT) &&
1892 "Expected G_SHL, G_ASHR, G_LSHR, G_USHLSAT and G_SSHLSAT");
1894 // Match a one-use bitwise logic op.
1895 Register LogicDest = MI.getOperand(1).getReg();
1896 if (!MRI.hasOneNonDBGUse(LogicDest))
1897 return false;
1899 MachineInstr *LogicMI = MRI.getUniqueVRegDef(LogicDest);
1900 unsigned LogicOpcode = LogicMI->getOpcode();
1901 if (LogicOpcode != TargetOpcode::G_AND && LogicOpcode != TargetOpcode::G_OR &&
1902 LogicOpcode != TargetOpcode::G_XOR)
1903 return false;
1905 // Find a matching one-use shift by constant.
1906 const Register C1 = MI.getOperand(2).getReg();
1907 auto MaybeImmVal = getIConstantVRegValWithLookThrough(C1, MRI);
1908 if (!MaybeImmVal || MaybeImmVal->Value == 0)
1909 return false;
1911 const uint64_t C1Val = MaybeImmVal->Value.getZExtValue();
1913 auto matchFirstShift = [&](const MachineInstr *MI, uint64_t &ShiftVal) {
1914 // Shift should match previous one and should be a one-use.
1915 if (MI->getOpcode() != ShiftOpcode ||
1916 !MRI.hasOneNonDBGUse(MI->getOperand(0).getReg()))
1917 return false;
1919 // Must be a constant.
1920 auto MaybeImmVal =
1921 getIConstantVRegValWithLookThrough(MI->getOperand(2).getReg(), MRI);
1922 if (!MaybeImmVal)
1923 return false;
1925 ShiftVal = MaybeImmVal->Value.getSExtValue();
1926 return true;
1929 // Logic ops are commutative, so check each operand for a match.
1930 Register LogicMIReg1 = LogicMI->getOperand(1).getReg();
1931 MachineInstr *LogicMIOp1 = MRI.getUniqueVRegDef(LogicMIReg1);
1932 Register LogicMIReg2 = LogicMI->getOperand(2).getReg();
1933 MachineInstr *LogicMIOp2 = MRI.getUniqueVRegDef(LogicMIReg2);
1934 uint64_t C0Val;
1936 if (matchFirstShift(LogicMIOp1, C0Val)) {
1937 MatchInfo.LogicNonShiftReg = LogicMIReg2;
1938 MatchInfo.Shift2 = LogicMIOp1;
1939 } else if (matchFirstShift(LogicMIOp2, C0Val)) {
1940 MatchInfo.LogicNonShiftReg = LogicMIReg1;
1941 MatchInfo.Shift2 = LogicMIOp2;
1942 } else
1943 return false;
1945 MatchInfo.ValSum = C0Val + C1Val;
1947 // The fold is not valid if the sum of the shift values exceeds bitwidth.
1948 if (MatchInfo.ValSum >= MRI.getType(LogicDest).getScalarSizeInBits())
1949 return false;
1951 MatchInfo.Logic = LogicMI;
1952 return true;
1955 void CombinerHelper::applyShiftOfShiftedLogic(
1956 MachineInstr &MI, ShiftOfShiftedLogic &MatchInfo) const {
1957 unsigned Opcode = MI.getOpcode();
1958 assert((Opcode == TargetOpcode::G_SHL || Opcode == TargetOpcode::G_ASHR ||
1959 Opcode == TargetOpcode::G_LSHR || Opcode == TargetOpcode::G_USHLSAT ||
1960 Opcode == TargetOpcode::G_SSHLSAT) &&
1961 "Expected G_SHL, G_ASHR, G_LSHR, G_USHLSAT and G_SSHLSAT");
1963 LLT ShlType = MRI.getType(MI.getOperand(2).getReg());
1964 LLT DestType = MRI.getType(MI.getOperand(0).getReg());
1966 Register Const = Builder.buildConstant(ShlType, MatchInfo.ValSum).getReg(0);
1968 Register Shift1Base = MatchInfo.Shift2->getOperand(1).getReg();
1969 Register Shift1 =
1970 Builder.buildInstr(Opcode, {DestType}, {Shift1Base, Const}).getReg(0);
1972 // If LogicNonShiftReg is the same to Shift1Base, and shift1 const is the same
1973 // to MatchInfo.Shift2 const, CSEMIRBuilder will reuse the old shift1 when
1974 // build shift2. So, if we erase MatchInfo.Shift2 at the end, actually we
1975 // remove old shift1. And it will cause crash later. So erase it earlier to
1976 // avoid the crash.
1977 MatchInfo.Shift2->eraseFromParent();
1979 Register Shift2Const = MI.getOperand(2).getReg();
1980 Register Shift2 = Builder
1981 .buildInstr(Opcode, {DestType},
1982 {MatchInfo.LogicNonShiftReg, Shift2Const})
1983 .getReg(0);
1985 Register Dest = MI.getOperand(0).getReg();
1986 Builder.buildInstr(MatchInfo.Logic->getOpcode(), {Dest}, {Shift1, Shift2});
1988 // This was one use so it's safe to remove it.
1989 MatchInfo.Logic->eraseFromParent();
1991 MI.eraseFromParent();
1994 bool CombinerHelper::matchCommuteShift(MachineInstr &MI,
1995 BuildFnTy &MatchInfo) const {
1996 assert(MI.getOpcode() == TargetOpcode::G_SHL && "Expected G_SHL");
1997 // Combine (shl (add x, c1), c2) -> (add (shl x, c2), c1 << c2)
1998 // Combine (shl (or x, c1), c2) -> (or (shl x, c2), c1 << c2)
1999 auto &Shl = cast<GenericMachineInstr>(MI);
2000 Register DstReg = Shl.getReg(0);
2001 Register SrcReg = Shl.getReg(1);
2002 Register ShiftReg = Shl.getReg(2);
2003 Register X, C1;
2005 if (!getTargetLowering().isDesirableToCommuteWithShift(MI, !isPreLegalize()))
2006 return false;
2008 if (!mi_match(SrcReg, MRI,
2009 m_OneNonDBGUse(m_any_of(m_GAdd(m_Reg(X), m_Reg(C1)),
2010 m_GOr(m_Reg(X), m_Reg(C1))))))
2011 return false;
2013 APInt C1Val, C2Val;
2014 if (!mi_match(C1, MRI, m_ICstOrSplat(C1Val)) ||
2015 !mi_match(ShiftReg, MRI, m_ICstOrSplat(C2Val)))
2016 return false;
2018 auto *SrcDef = MRI.getVRegDef(SrcReg);
2019 assert((SrcDef->getOpcode() == TargetOpcode::G_ADD ||
2020 SrcDef->getOpcode() == TargetOpcode::G_OR) && "Unexpected op");
2021 LLT SrcTy = MRI.getType(SrcReg);
2022 MatchInfo = [=](MachineIRBuilder &B) {
2023 auto S1 = B.buildShl(SrcTy, X, ShiftReg);
2024 auto S2 = B.buildShl(SrcTy, C1, ShiftReg);
2025 B.buildInstr(SrcDef->getOpcode(), {DstReg}, {S1, S2});
2027 return true;
2030 bool CombinerHelper::matchCombineMulToShl(MachineInstr &MI,
2031 unsigned &ShiftVal) const {
2032 assert(MI.getOpcode() == TargetOpcode::G_MUL && "Expected a G_MUL");
2033 auto MaybeImmVal =
2034 getIConstantVRegValWithLookThrough(MI.getOperand(2).getReg(), MRI);
2035 if (!MaybeImmVal)
2036 return false;
2038 ShiftVal = MaybeImmVal->Value.exactLogBase2();
2039 return (static_cast<int32_t>(ShiftVal) != -1);
2042 void CombinerHelper::applyCombineMulToShl(MachineInstr &MI,
2043 unsigned &ShiftVal) const {
2044 assert(MI.getOpcode() == TargetOpcode::G_MUL && "Expected a G_MUL");
2045 MachineIRBuilder MIB(MI);
2046 LLT ShiftTy = MRI.getType(MI.getOperand(0).getReg());
2047 auto ShiftCst = MIB.buildConstant(ShiftTy, ShiftVal);
2048 Observer.changingInstr(MI);
2049 MI.setDesc(MIB.getTII().get(TargetOpcode::G_SHL));
2050 MI.getOperand(2).setReg(ShiftCst.getReg(0));
2051 if (ShiftVal == ShiftTy.getScalarSizeInBits() - 1)
2052 MI.clearFlag(MachineInstr::MIFlag::NoSWrap);
2053 Observer.changedInstr(MI);
2056 bool CombinerHelper::matchCombineSubToAdd(MachineInstr &MI,
2057 BuildFnTy &MatchInfo) const {
2058 GSub &Sub = cast<GSub>(MI);
2060 LLT Ty = MRI.getType(Sub.getReg(0));
2062 if (!isLegalOrBeforeLegalizer({TargetOpcode::G_ADD, {Ty}}))
2063 return false;
2065 if (!isConstantLegalOrBeforeLegalizer(Ty))
2066 return false;
2068 APInt Imm = getIConstantFromReg(Sub.getRHSReg(), MRI);
2070 MatchInfo = [=, &MI](MachineIRBuilder &B) {
2071 auto NegCst = B.buildConstant(Ty, -Imm);
2072 Observer.changingInstr(MI);
2073 MI.setDesc(B.getTII().get(TargetOpcode::G_ADD));
2074 MI.getOperand(2).setReg(NegCst.getReg(0));
2075 MI.clearFlag(MachineInstr::MIFlag::NoUWrap);
2076 Observer.changedInstr(MI);
2078 return true;
2081 // shl ([sza]ext x), y => zext (shl x, y), if shift does not overflow source
2082 bool CombinerHelper::matchCombineShlOfExtend(MachineInstr &MI,
2083 RegisterImmPair &MatchData) const {
2084 assert(MI.getOpcode() == TargetOpcode::G_SHL && KB);
2085 if (!getTargetLowering().isDesirableToPullExtFromShl(MI))
2086 return false;
2088 Register LHS = MI.getOperand(1).getReg();
2090 Register ExtSrc;
2091 if (!mi_match(LHS, MRI, m_GAnyExt(m_Reg(ExtSrc))) &&
2092 !mi_match(LHS, MRI, m_GZExt(m_Reg(ExtSrc))) &&
2093 !mi_match(LHS, MRI, m_GSExt(m_Reg(ExtSrc))))
2094 return false;
2096 Register RHS = MI.getOperand(2).getReg();
2097 MachineInstr *MIShiftAmt = MRI.getVRegDef(RHS);
2098 auto MaybeShiftAmtVal = isConstantOrConstantSplatVector(*MIShiftAmt, MRI);
2099 if (!MaybeShiftAmtVal)
2100 return false;
2102 if (LI) {
2103 LLT SrcTy = MRI.getType(ExtSrc);
2105 // We only really care about the legality with the shifted value. We can
2106 // pick any type the constant shift amount, so ask the target what to
2107 // use. Otherwise we would have to guess and hope it is reported as legal.
2108 LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(SrcTy);
2109 if (!isLegalOrBeforeLegalizer({TargetOpcode::G_SHL, {SrcTy, ShiftAmtTy}}))
2110 return false;
2113 int64_t ShiftAmt = MaybeShiftAmtVal->getSExtValue();
2114 MatchData.Reg = ExtSrc;
2115 MatchData.Imm = ShiftAmt;
2117 unsigned MinLeadingZeros = KB->getKnownZeroes(ExtSrc).countl_one();
2118 unsigned SrcTySize = MRI.getType(ExtSrc).getScalarSizeInBits();
2119 return MinLeadingZeros >= ShiftAmt && ShiftAmt < SrcTySize;
2122 void CombinerHelper::applyCombineShlOfExtend(
2123 MachineInstr &MI, const RegisterImmPair &MatchData) const {
2124 Register ExtSrcReg = MatchData.Reg;
2125 int64_t ShiftAmtVal = MatchData.Imm;
2127 LLT ExtSrcTy = MRI.getType(ExtSrcReg);
2128 auto ShiftAmt = Builder.buildConstant(ExtSrcTy, ShiftAmtVal);
2129 auto NarrowShift =
2130 Builder.buildShl(ExtSrcTy, ExtSrcReg, ShiftAmt, MI.getFlags());
2131 Builder.buildZExt(MI.getOperand(0), NarrowShift);
2132 MI.eraseFromParent();
2135 bool CombinerHelper::matchCombineMergeUnmerge(MachineInstr &MI,
2136 Register &MatchInfo) const {
2137 GMerge &Merge = cast<GMerge>(MI);
2138 SmallVector<Register, 16> MergedValues;
2139 for (unsigned I = 0; I < Merge.getNumSources(); ++I)
2140 MergedValues.emplace_back(Merge.getSourceReg(I));
2142 auto *Unmerge = getOpcodeDef<GUnmerge>(MergedValues[0], MRI);
2143 if (!Unmerge || Unmerge->getNumDefs() != Merge.getNumSources())
2144 return false;
2146 for (unsigned I = 0; I < MergedValues.size(); ++I)
2147 if (MergedValues[I] != Unmerge->getReg(I))
2148 return false;
2150 MatchInfo = Unmerge->getSourceReg();
2151 return true;
2154 static Register peekThroughBitcast(Register Reg,
2155 const MachineRegisterInfo &MRI) {
2156 while (mi_match(Reg, MRI, m_GBitcast(m_Reg(Reg))))
2159 return Reg;
2162 bool CombinerHelper::matchCombineUnmergeMergeToPlainValues(
2163 MachineInstr &MI, SmallVectorImpl<Register> &Operands) const {
2164 assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES &&
2165 "Expected an unmerge");
2166 auto &Unmerge = cast<GUnmerge>(MI);
2167 Register SrcReg = peekThroughBitcast(Unmerge.getSourceReg(), MRI);
2169 auto *SrcInstr = getOpcodeDef<GMergeLikeInstr>(SrcReg, MRI);
2170 if (!SrcInstr)
2171 return false;
2173 // Check the source type of the merge.
2174 LLT SrcMergeTy = MRI.getType(SrcInstr->getSourceReg(0));
2175 LLT Dst0Ty = MRI.getType(Unmerge.getReg(0));
2176 bool SameSize = Dst0Ty.getSizeInBits() == SrcMergeTy.getSizeInBits();
2177 if (SrcMergeTy != Dst0Ty && !SameSize)
2178 return false;
2179 // They are the same now (modulo a bitcast).
2180 // We can collect all the src registers.
2181 for (unsigned Idx = 0; Idx < SrcInstr->getNumSources(); ++Idx)
2182 Operands.push_back(SrcInstr->getSourceReg(Idx));
2183 return true;
2186 void CombinerHelper::applyCombineUnmergeMergeToPlainValues(
2187 MachineInstr &MI, SmallVectorImpl<Register> &Operands) const {
2188 assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES &&
2189 "Expected an unmerge");
2190 assert((MI.getNumOperands() - 1 == Operands.size()) &&
2191 "Not enough operands to replace all defs");
2192 unsigned NumElems = MI.getNumOperands() - 1;
2194 LLT SrcTy = MRI.getType(Operands[0]);
2195 LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
2196 bool CanReuseInputDirectly = DstTy == SrcTy;
2197 for (unsigned Idx = 0; Idx < NumElems; ++Idx) {
2198 Register DstReg = MI.getOperand(Idx).getReg();
2199 Register SrcReg = Operands[Idx];
2201 // This combine may run after RegBankSelect, so we need to be aware of
2202 // register banks.
2203 const auto &DstCB = MRI.getRegClassOrRegBank(DstReg);
2204 if (!DstCB.isNull() && DstCB != MRI.getRegClassOrRegBank(SrcReg)) {
2205 SrcReg = Builder.buildCopy(MRI.getType(SrcReg), SrcReg).getReg(0);
2206 MRI.setRegClassOrRegBank(SrcReg, DstCB);
2209 if (CanReuseInputDirectly)
2210 replaceRegWith(MRI, DstReg, SrcReg);
2211 else
2212 Builder.buildCast(DstReg, SrcReg);
2214 MI.eraseFromParent();
2217 bool CombinerHelper::matchCombineUnmergeConstant(
2218 MachineInstr &MI, SmallVectorImpl<APInt> &Csts) const {
2219 unsigned SrcIdx = MI.getNumOperands() - 1;
2220 Register SrcReg = MI.getOperand(SrcIdx).getReg();
2221 MachineInstr *SrcInstr = MRI.getVRegDef(SrcReg);
2222 if (SrcInstr->getOpcode() != TargetOpcode::G_CONSTANT &&
2223 SrcInstr->getOpcode() != TargetOpcode::G_FCONSTANT)
2224 return false;
2225 // Break down the big constant in smaller ones.
2226 const MachineOperand &CstVal = SrcInstr->getOperand(1);
2227 APInt Val = SrcInstr->getOpcode() == TargetOpcode::G_CONSTANT
2228 ? CstVal.getCImm()->getValue()
2229 : CstVal.getFPImm()->getValueAPF().bitcastToAPInt();
2231 LLT Dst0Ty = MRI.getType(MI.getOperand(0).getReg());
2232 unsigned ShiftAmt = Dst0Ty.getSizeInBits();
2233 // Unmerge a constant.
2234 for (unsigned Idx = 0; Idx != SrcIdx; ++Idx) {
2235 Csts.emplace_back(Val.trunc(ShiftAmt));
2236 Val = Val.lshr(ShiftAmt);
2239 return true;
2242 void CombinerHelper::applyCombineUnmergeConstant(
2243 MachineInstr &MI, SmallVectorImpl<APInt> &Csts) const {
2244 assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES &&
2245 "Expected an unmerge");
2246 assert((MI.getNumOperands() - 1 == Csts.size()) &&
2247 "Not enough operands to replace all defs");
2248 unsigned NumElems = MI.getNumOperands() - 1;
2249 for (unsigned Idx = 0; Idx < NumElems; ++Idx) {
2250 Register DstReg = MI.getOperand(Idx).getReg();
2251 Builder.buildConstant(DstReg, Csts[Idx]);
2254 MI.eraseFromParent();
2257 bool CombinerHelper::matchCombineUnmergeUndef(
2258 MachineInstr &MI,
2259 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
2260 unsigned SrcIdx = MI.getNumOperands() - 1;
2261 Register SrcReg = MI.getOperand(SrcIdx).getReg();
2262 MatchInfo = [&MI](MachineIRBuilder &B) {
2263 unsigned NumElems = MI.getNumOperands() - 1;
2264 for (unsigned Idx = 0; Idx < NumElems; ++Idx) {
2265 Register DstReg = MI.getOperand(Idx).getReg();
2266 B.buildUndef(DstReg);
2269 return isa<GImplicitDef>(MRI.getVRegDef(SrcReg));
2272 bool CombinerHelper::matchCombineUnmergeWithDeadLanesToTrunc(
2273 MachineInstr &MI) const {
2274 assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES &&
2275 "Expected an unmerge");
2276 if (MRI.getType(MI.getOperand(0).getReg()).isVector() ||
2277 MRI.getType(MI.getOperand(MI.getNumDefs()).getReg()).isVector())
2278 return false;
2279 // Check that all the lanes are dead except the first one.
2280 for (unsigned Idx = 1, EndIdx = MI.getNumDefs(); Idx != EndIdx; ++Idx) {
2281 if (!MRI.use_nodbg_empty(MI.getOperand(Idx).getReg()))
2282 return false;
2284 return true;
2287 void CombinerHelper::applyCombineUnmergeWithDeadLanesToTrunc(
2288 MachineInstr &MI) const {
2289 Register SrcReg = MI.getOperand(MI.getNumDefs()).getReg();
2290 Register Dst0Reg = MI.getOperand(0).getReg();
2291 Builder.buildTrunc(Dst0Reg, SrcReg);
2292 MI.eraseFromParent();
2295 bool CombinerHelper::matchCombineUnmergeZExtToZExt(MachineInstr &MI) const {
2296 assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES &&
2297 "Expected an unmerge");
2298 Register Dst0Reg = MI.getOperand(0).getReg();
2299 LLT Dst0Ty = MRI.getType(Dst0Reg);
2300 // G_ZEXT on vector applies to each lane, so it will
2301 // affect all destinations. Therefore we won't be able
2302 // to simplify the unmerge to just the first definition.
2303 if (Dst0Ty.isVector())
2304 return false;
2305 Register SrcReg = MI.getOperand(MI.getNumDefs()).getReg();
2306 LLT SrcTy = MRI.getType(SrcReg);
2307 if (SrcTy.isVector())
2308 return false;
2310 Register ZExtSrcReg;
2311 if (!mi_match(SrcReg, MRI, m_GZExt(m_Reg(ZExtSrcReg))))
2312 return false;
2314 // Finally we can replace the first definition with
2315 // a zext of the source if the definition is big enough to hold
2316 // all of ZExtSrc bits.
2317 LLT ZExtSrcTy = MRI.getType(ZExtSrcReg);
2318 return ZExtSrcTy.getSizeInBits() <= Dst0Ty.getSizeInBits();
2321 void CombinerHelper::applyCombineUnmergeZExtToZExt(MachineInstr &MI) const {
2322 assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES &&
2323 "Expected an unmerge");
2325 Register Dst0Reg = MI.getOperand(0).getReg();
2327 MachineInstr *ZExtInstr =
2328 MRI.getVRegDef(MI.getOperand(MI.getNumDefs()).getReg());
2329 assert(ZExtInstr && ZExtInstr->getOpcode() == TargetOpcode::G_ZEXT &&
2330 "Expecting a G_ZEXT");
2332 Register ZExtSrcReg = ZExtInstr->getOperand(1).getReg();
2333 LLT Dst0Ty = MRI.getType(Dst0Reg);
2334 LLT ZExtSrcTy = MRI.getType(ZExtSrcReg);
2336 if (Dst0Ty.getSizeInBits() > ZExtSrcTy.getSizeInBits()) {
2337 Builder.buildZExt(Dst0Reg, ZExtSrcReg);
2338 } else {
2339 assert(Dst0Ty.getSizeInBits() == ZExtSrcTy.getSizeInBits() &&
2340 "ZExt src doesn't fit in destination");
2341 replaceRegWith(MRI, Dst0Reg, ZExtSrcReg);
2344 Register ZeroReg;
2345 for (unsigned Idx = 1, EndIdx = MI.getNumDefs(); Idx != EndIdx; ++Idx) {
2346 if (!ZeroReg)
2347 ZeroReg = Builder.buildConstant(Dst0Ty, 0).getReg(0);
2348 replaceRegWith(MRI, MI.getOperand(Idx).getReg(), ZeroReg);
2350 MI.eraseFromParent();
2353 bool CombinerHelper::matchCombineShiftToUnmerge(MachineInstr &MI,
2354 unsigned TargetShiftSize,
2355 unsigned &ShiftVal) const {
2356 assert((MI.getOpcode() == TargetOpcode::G_SHL ||
2357 MI.getOpcode() == TargetOpcode::G_LSHR ||
2358 MI.getOpcode() == TargetOpcode::G_ASHR) && "Expected a shift");
2360 LLT Ty = MRI.getType(MI.getOperand(0).getReg());
2361 if (Ty.isVector()) // TODO:
2362 return false;
2364 // Don't narrow further than the requested size.
2365 unsigned Size = Ty.getSizeInBits();
2366 if (Size <= TargetShiftSize)
2367 return false;
2369 auto MaybeImmVal =
2370 getIConstantVRegValWithLookThrough(MI.getOperand(2).getReg(), MRI);
2371 if (!MaybeImmVal)
2372 return false;
2374 ShiftVal = MaybeImmVal->Value.getSExtValue();
2375 return ShiftVal >= Size / 2 && ShiftVal < Size;
2378 void CombinerHelper::applyCombineShiftToUnmerge(
2379 MachineInstr &MI, const unsigned &ShiftVal) const {
2380 Register DstReg = MI.getOperand(0).getReg();
2381 Register SrcReg = MI.getOperand(1).getReg();
2382 LLT Ty = MRI.getType(SrcReg);
2383 unsigned Size = Ty.getSizeInBits();
2384 unsigned HalfSize = Size / 2;
2385 assert(ShiftVal >= HalfSize);
2387 LLT HalfTy = LLT::scalar(HalfSize);
2389 auto Unmerge = Builder.buildUnmerge(HalfTy, SrcReg);
2390 unsigned NarrowShiftAmt = ShiftVal - HalfSize;
2392 if (MI.getOpcode() == TargetOpcode::G_LSHR) {
2393 Register Narrowed = Unmerge.getReg(1);
2395 // dst = G_LSHR s64:x, C for C >= 32
2396 // =>
2397 // lo, hi = G_UNMERGE_VALUES x
2398 // dst = G_MERGE_VALUES (G_LSHR hi, C - 32), 0
2400 if (NarrowShiftAmt != 0) {
2401 Narrowed = Builder.buildLShr(HalfTy, Narrowed,
2402 Builder.buildConstant(HalfTy, NarrowShiftAmt)).getReg(0);
2405 auto Zero = Builder.buildConstant(HalfTy, 0);
2406 Builder.buildMergeLikeInstr(DstReg, {Narrowed, Zero});
2407 } else if (MI.getOpcode() == TargetOpcode::G_SHL) {
2408 Register Narrowed = Unmerge.getReg(0);
2409 // dst = G_SHL s64:x, C for C >= 32
2410 // =>
2411 // lo, hi = G_UNMERGE_VALUES x
2412 // dst = G_MERGE_VALUES 0, (G_SHL hi, C - 32)
2413 if (NarrowShiftAmt != 0) {
2414 Narrowed = Builder.buildShl(HalfTy, Narrowed,
2415 Builder.buildConstant(HalfTy, NarrowShiftAmt)).getReg(0);
2418 auto Zero = Builder.buildConstant(HalfTy, 0);
2419 Builder.buildMergeLikeInstr(DstReg, {Zero, Narrowed});
2420 } else {
2421 assert(MI.getOpcode() == TargetOpcode::G_ASHR);
2422 auto Hi = Builder.buildAShr(
2423 HalfTy, Unmerge.getReg(1),
2424 Builder.buildConstant(HalfTy, HalfSize - 1));
2426 if (ShiftVal == HalfSize) {
2427 // (G_ASHR i64:x, 32) ->
2428 // G_MERGE_VALUES hi_32(x), (G_ASHR hi_32(x), 31)
2429 Builder.buildMergeLikeInstr(DstReg, {Unmerge.getReg(1), Hi});
2430 } else if (ShiftVal == Size - 1) {
2431 // Don't need a second shift.
2432 // (G_ASHR i64:x, 63) ->
2433 // %narrowed = (G_ASHR hi_32(x), 31)
2434 // G_MERGE_VALUES %narrowed, %narrowed
2435 Builder.buildMergeLikeInstr(DstReg, {Hi, Hi});
2436 } else {
2437 auto Lo = Builder.buildAShr(
2438 HalfTy, Unmerge.getReg(1),
2439 Builder.buildConstant(HalfTy, ShiftVal - HalfSize));
2441 // (G_ASHR i64:x, C) ->, for C >= 32
2442 // G_MERGE_VALUES (G_ASHR hi_32(x), C - 32), (G_ASHR hi_32(x), 31)
2443 Builder.buildMergeLikeInstr(DstReg, {Lo, Hi});
2447 MI.eraseFromParent();
2450 bool CombinerHelper::tryCombineShiftToUnmerge(
2451 MachineInstr &MI, unsigned TargetShiftAmount) const {
2452 unsigned ShiftAmt;
2453 if (matchCombineShiftToUnmerge(MI, TargetShiftAmount, ShiftAmt)) {
2454 applyCombineShiftToUnmerge(MI, ShiftAmt);
2455 return true;
2458 return false;
2461 bool CombinerHelper::matchCombineI2PToP2I(MachineInstr &MI,
2462 Register &Reg) const {
2463 assert(MI.getOpcode() == TargetOpcode::G_INTTOPTR && "Expected a G_INTTOPTR");
2464 Register DstReg = MI.getOperand(0).getReg();
2465 LLT DstTy = MRI.getType(DstReg);
2466 Register SrcReg = MI.getOperand(1).getReg();
2467 return mi_match(SrcReg, MRI,
2468 m_GPtrToInt(m_all_of(m_SpecificType(DstTy), m_Reg(Reg))));
2471 void CombinerHelper::applyCombineI2PToP2I(MachineInstr &MI,
2472 Register &Reg) const {
2473 assert(MI.getOpcode() == TargetOpcode::G_INTTOPTR && "Expected a G_INTTOPTR");
2474 Register DstReg = MI.getOperand(0).getReg();
2475 Builder.buildCopy(DstReg, Reg);
2476 MI.eraseFromParent();
2479 void CombinerHelper::applyCombineP2IToI2P(MachineInstr &MI,
2480 Register &Reg) const {
2481 assert(MI.getOpcode() == TargetOpcode::G_PTRTOINT && "Expected a G_PTRTOINT");
2482 Register DstReg = MI.getOperand(0).getReg();
2483 Builder.buildZExtOrTrunc(DstReg, Reg);
2484 MI.eraseFromParent();
2487 bool CombinerHelper::matchCombineAddP2IToPtrAdd(
2488 MachineInstr &MI, std::pair<Register, bool> &PtrReg) const {
2489 assert(MI.getOpcode() == TargetOpcode::G_ADD);
2490 Register LHS = MI.getOperand(1).getReg();
2491 Register RHS = MI.getOperand(2).getReg();
2492 LLT IntTy = MRI.getType(LHS);
2494 // G_PTR_ADD always has the pointer in the LHS, so we may need to commute the
2495 // instruction.
2496 PtrReg.second = false;
2497 for (Register SrcReg : {LHS, RHS}) {
2498 if (mi_match(SrcReg, MRI, m_GPtrToInt(m_Reg(PtrReg.first)))) {
2499 // Don't handle cases where the integer is implicitly converted to the
2500 // pointer width.
2501 LLT PtrTy = MRI.getType(PtrReg.first);
2502 if (PtrTy.getScalarSizeInBits() == IntTy.getScalarSizeInBits())
2503 return true;
2506 PtrReg.second = true;
2509 return false;
2512 void CombinerHelper::applyCombineAddP2IToPtrAdd(
2513 MachineInstr &MI, std::pair<Register, bool> &PtrReg) const {
2514 Register Dst = MI.getOperand(0).getReg();
2515 Register LHS = MI.getOperand(1).getReg();
2516 Register RHS = MI.getOperand(2).getReg();
2518 const bool DoCommute = PtrReg.second;
2519 if (DoCommute)
2520 std::swap(LHS, RHS);
2521 LHS = PtrReg.first;
2523 LLT PtrTy = MRI.getType(LHS);
2525 auto PtrAdd = Builder.buildPtrAdd(PtrTy, LHS, RHS);
2526 Builder.buildPtrToInt(Dst, PtrAdd);
2527 MI.eraseFromParent();
2530 bool CombinerHelper::matchCombineConstPtrAddToI2P(MachineInstr &MI,
2531 APInt &NewCst) const {
2532 auto &PtrAdd = cast<GPtrAdd>(MI);
2533 Register LHS = PtrAdd.getBaseReg();
2534 Register RHS = PtrAdd.getOffsetReg();
2535 MachineRegisterInfo &MRI = Builder.getMF().getRegInfo();
2537 if (auto RHSCst = getIConstantVRegVal(RHS, MRI)) {
2538 APInt Cst;
2539 if (mi_match(LHS, MRI, m_GIntToPtr(m_ICst(Cst)))) {
2540 auto DstTy = MRI.getType(PtrAdd.getReg(0));
2541 // G_INTTOPTR uses zero-extension
2542 NewCst = Cst.zextOrTrunc(DstTy.getSizeInBits());
2543 NewCst += RHSCst->sextOrTrunc(DstTy.getSizeInBits());
2544 return true;
2548 return false;
2551 void CombinerHelper::applyCombineConstPtrAddToI2P(MachineInstr &MI,
2552 APInt &NewCst) const {
2553 auto &PtrAdd = cast<GPtrAdd>(MI);
2554 Register Dst = PtrAdd.getReg(0);
2556 Builder.buildConstant(Dst, NewCst);
2557 PtrAdd.eraseFromParent();
2560 bool CombinerHelper::matchCombineAnyExtTrunc(MachineInstr &MI,
2561 Register &Reg) const {
2562 assert(MI.getOpcode() == TargetOpcode::G_ANYEXT && "Expected a G_ANYEXT");
2563 Register DstReg = MI.getOperand(0).getReg();
2564 Register SrcReg = MI.getOperand(1).getReg();
2565 Register OriginalSrcReg = getSrcRegIgnoringCopies(SrcReg, MRI);
2566 if (OriginalSrcReg.isValid())
2567 SrcReg = OriginalSrcReg;
2568 LLT DstTy = MRI.getType(DstReg);
2569 return mi_match(SrcReg, MRI,
2570 m_GTrunc(m_all_of(m_Reg(Reg), m_SpecificType(DstTy))));
2573 bool CombinerHelper::matchCombineZextTrunc(MachineInstr &MI,
2574 Register &Reg) const {
2575 assert(MI.getOpcode() == TargetOpcode::G_ZEXT && "Expected a G_ZEXT");
2576 Register DstReg = MI.getOperand(0).getReg();
2577 Register SrcReg = MI.getOperand(1).getReg();
2578 LLT DstTy = MRI.getType(DstReg);
2579 if (mi_match(SrcReg, MRI,
2580 m_GTrunc(m_all_of(m_Reg(Reg), m_SpecificType(DstTy))))) {
2581 unsigned DstSize = DstTy.getScalarSizeInBits();
2582 unsigned SrcSize = MRI.getType(SrcReg).getScalarSizeInBits();
2583 return KB->getKnownBits(Reg).countMinLeadingZeros() >= DstSize - SrcSize;
2585 return false;
2588 static LLT getMidVTForTruncRightShiftCombine(LLT ShiftTy, LLT TruncTy) {
2589 const unsigned ShiftSize = ShiftTy.getScalarSizeInBits();
2590 const unsigned TruncSize = TruncTy.getScalarSizeInBits();
2592 // ShiftTy > 32 > TruncTy -> 32
2593 if (ShiftSize > 32 && TruncSize < 32)
2594 return ShiftTy.changeElementSize(32);
2596 // TODO: We could also reduce to 16 bits, but that's more target-dependent.
2597 // Some targets like it, some don't, some only like it under certain
2598 // conditions/processor versions, etc.
2599 // A TL hook might be needed for this.
2601 // Don't combine
2602 return ShiftTy;
2605 bool CombinerHelper::matchCombineTruncOfShift(
2606 MachineInstr &MI, std::pair<MachineInstr *, LLT> &MatchInfo) const {
2607 assert(MI.getOpcode() == TargetOpcode::G_TRUNC && "Expected a G_TRUNC");
2608 Register DstReg = MI.getOperand(0).getReg();
2609 Register SrcReg = MI.getOperand(1).getReg();
2611 if (!MRI.hasOneNonDBGUse(SrcReg))
2612 return false;
2614 LLT SrcTy = MRI.getType(SrcReg);
2615 LLT DstTy = MRI.getType(DstReg);
2617 MachineInstr *SrcMI = getDefIgnoringCopies(SrcReg, MRI);
2618 const auto &TL = getTargetLowering();
2620 LLT NewShiftTy;
2621 switch (SrcMI->getOpcode()) {
2622 default:
2623 return false;
2624 case TargetOpcode::G_SHL: {
2625 NewShiftTy = DstTy;
2627 // Make sure new shift amount is legal.
2628 KnownBits Known = KB->getKnownBits(SrcMI->getOperand(2).getReg());
2629 if (Known.getMaxValue().uge(NewShiftTy.getScalarSizeInBits()))
2630 return false;
2631 break;
2633 case TargetOpcode::G_LSHR:
2634 case TargetOpcode::G_ASHR: {
2635 // For right shifts, we conservatively do not do the transform if the TRUNC
2636 // has any STORE users. The reason is that if we change the type of the
2637 // shift, we may break the truncstore combine.
2639 // TODO: Fix truncstore combine to handle (trunc(lshr (trunc x), k)).
2640 for (auto &User : MRI.use_instructions(DstReg))
2641 if (User.getOpcode() == TargetOpcode::G_STORE)
2642 return false;
2644 NewShiftTy = getMidVTForTruncRightShiftCombine(SrcTy, DstTy);
2645 if (NewShiftTy == SrcTy)
2646 return false;
2648 // Make sure we won't lose information by truncating the high bits.
2649 KnownBits Known = KB->getKnownBits(SrcMI->getOperand(2).getReg());
2650 if (Known.getMaxValue().ugt(NewShiftTy.getScalarSizeInBits() -
2651 DstTy.getScalarSizeInBits()))
2652 return false;
2653 break;
2657 if (!isLegalOrBeforeLegalizer(
2658 {SrcMI->getOpcode(),
2659 {NewShiftTy, TL.getPreferredShiftAmountTy(NewShiftTy)}}))
2660 return false;
2662 MatchInfo = std::make_pair(SrcMI, NewShiftTy);
2663 return true;
2666 void CombinerHelper::applyCombineTruncOfShift(
2667 MachineInstr &MI, std::pair<MachineInstr *, LLT> &MatchInfo) const {
2668 MachineInstr *ShiftMI = MatchInfo.first;
2669 LLT NewShiftTy = MatchInfo.second;
2671 Register Dst = MI.getOperand(0).getReg();
2672 LLT DstTy = MRI.getType(Dst);
2674 Register ShiftAmt = ShiftMI->getOperand(2).getReg();
2675 Register ShiftSrc = ShiftMI->getOperand(1).getReg();
2676 ShiftSrc = Builder.buildTrunc(NewShiftTy, ShiftSrc).getReg(0);
2678 Register NewShift =
2679 Builder
2680 .buildInstr(ShiftMI->getOpcode(), {NewShiftTy}, {ShiftSrc, ShiftAmt})
2681 .getReg(0);
2683 if (NewShiftTy == DstTy)
2684 replaceRegWith(MRI, Dst, NewShift);
2685 else
2686 Builder.buildTrunc(Dst, NewShift);
2688 eraseInst(MI);
2691 bool CombinerHelper::matchAnyExplicitUseIsUndef(MachineInstr &MI) const {
2692 return any_of(MI.explicit_uses(), [this](const MachineOperand &MO) {
2693 return MO.isReg() &&
2694 getOpcodeDef(TargetOpcode::G_IMPLICIT_DEF, MO.getReg(), MRI);
2698 bool CombinerHelper::matchAllExplicitUsesAreUndef(MachineInstr &MI) const {
2699 return all_of(MI.explicit_uses(), [this](const MachineOperand &MO) {
2700 return !MO.isReg() ||
2701 getOpcodeDef(TargetOpcode::G_IMPLICIT_DEF, MO.getReg(), MRI);
2705 bool CombinerHelper::matchUndefShuffleVectorMask(MachineInstr &MI) const {
2706 assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
2707 ArrayRef<int> Mask = MI.getOperand(3).getShuffleMask();
2708 return all_of(Mask, [](int Elt) { return Elt < 0; });
2711 bool CombinerHelper::matchUndefStore(MachineInstr &MI) const {
2712 assert(MI.getOpcode() == TargetOpcode::G_STORE);
2713 return getOpcodeDef(TargetOpcode::G_IMPLICIT_DEF, MI.getOperand(0).getReg(),
2714 MRI);
2717 bool CombinerHelper::matchUndefSelectCmp(MachineInstr &MI) const {
2718 assert(MI.getOpcode() == TargetOpcode::G_SELECT);
2719 return getOpcodeDef(TargetOpcode::G_IMPLICIT_DEF, MI.getOperand(1).getReg(),
2720 MRI);
2723 bool CombinerHelper::matchInsertExtractVecEltOutOfBounds(
2724 MachineInstr &MI) const {
2725 assert((MI.getOpcode() == TargetOpcode::G_INSERT_VECTOR_ELT ||
2726 MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT) &&
2727 "Expected an insert/extract element op");
2728 LLT VecTy = MRI.getType(MI.getOperand(1).getReg());
2729 if (VecTy.isScalableVector())
2730 return false;
2732 unsigned IdxIdx =
2733 MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT ? 2 : 3;
2734 auto Idx = getIConstantVRegVal(MI.getOperand(IdxIdx).getReg(), MRI);
2735 if (!Idx)
2736 return false;
2737 return Idx->getZExtValue() >= VecTy.getNumElements();
2740 bool CombinerHelper::matchConstantSelectCmp(MachineInstr &MI,
2741 unsigned &OpIdx) const {
2742 GSelect &SelMI = cast<GSelect>(MI);
2743 auto Cst =
2744 isConstantOrConstantSplatVector(*MRI.getVRegDef(SelMI.getCondReg()), MRI);
2745 if (!Cst)
2746 return false;
2747 OpIdx = Cst->isZero() ? 3 : 2;
2748 return true;
2751 void CombinerHelper::eraseInst(MachineInstr &MI) const { MI.eraseFromParent(); }
2753 bool CombinerHelper::matchEqualDefs(const MachineOperand &MOP1,
2754 const MachineOperand &MOP2) const {
2755 if (!MOP1.isReg() || !MOP2.isReg())
2756 return false;
2757 auto InstAndDef1 = getDefSrcRegIgnoringCopies(MOP1.getReg(), MRI);
2758 if (!InstAndDef1)
2759 return false;
2760 auto InstAndDef2 = getDefSrcRegIgnoringCopies(MOP2.getReg(), MRI);
2761 if (!InstAndDef2)
2762 return false;
2763 MachineInstr *I1 = InstAndDef1->MI;
2764 MachineInstr *I2 = InstAndDef2->MI;
2766 // Handle a case like this:
2768 // %0:_(s64), %1:_(s64) = G_UNMERGE_VALUES %2:_(<2 x s64>)
2770 // Even though %0 and %1 are produced by the same instruction they are not
2771 // the same values.
2772 if (I1 == I2)
2773 return MOP1.getReg() == MOP2.getReg();
2775 // If we have an instruction which loads or stores, we can't guarantee that
2776 // it is identical.
2778 // For example, we may have
2780 // %x1 = G_LOAD %addr (load N from @somewhere)
2781 // ...
2782 // call @foo
2783 // ...
2784 // %x2 = G_LOAD %addr (load N from @somewhere)
2785 // ...
2786 // %or = G_OR %x1, %x2
2788 // It's possible that @foo will modify whatever lives at the address we're
2789 // loading from. To be safe, let's just assume that all loads and stores
2790 // are different (unless we have something which is guaranteed to not
2791 // change.)
2792 if (I1->mayLoadOrStore() && !I1->isDereferenceableInvariantLoad())
2793 return false;
2795 // If both instructions are loads or stores, they are equal only if both
2796 // are dereferenceable invariant loads with the same number of bits.
2797 if (I1->mayLoadOrStore() && I2->mayLoadOrStore()) {
2798 GLoadStore *LS1 = dyn_cast<GLoadStore>(I1);
2799 GLoadStore *LS2 = dyn_cast<GLoadStore>(I2);
2800 if (!LS1 || !LS2)
2801 return false;
2803 if (!I2->isDereferenceableInvariantLoad() ||
2804 (LS1->getMemSizeInBits() != LS2->getMemSizeInBits()))
2805 return false;
2808 // Check for physical registers on the instructions first to avoid cases
2809 // like this:
2811 // %a = COPY $physreg
2812 // ...
2813 // SOMETHING implicit-def $physreg
2814 // ...
2815 // %b = COPY $physreg
2817 // These copies are not equivalent.
2818 if (any_of(I1->uses(), [](const MachineOperand &MO) {
2819 return MO.isReg() && MO.getReg().isPhysical();
2820 })) {
2821 // Check if we have a case like this:
2823 // %a = COPY $physreg
2824 // %b = COPY %a
2826 // In this case, I1 and I2 will both be equal to %a = COPY $physreg.
2827 // From that, we know that they must have the same value, since they must
2828 // have come from the same COPY.
2829 return I1->isIdenticalTo(*I2);
2832 // We don't have any physical registers, so we don't necessarily need the
2833 // same vreg defs.
2835 // On the off-chance that there's some target instruction feeding into the
2836 // instruction, let's use produceSameValue instead of isIdenticalTo.
2837 if (Builder.getTII().produceSameValue(*I1, *I2, &MRI)) {
2838 // Handle instructions with multiple defs that produce same values. Values
2839 // are same for operands with same index.
2840 // %0:_(s8), %1:_(s8), %2:_(s8), %3:_(s8) = G_UNMERGE_VALUES %4:_(<4 x s8>)
2841 // %5:_(s8), %6:_(s8), %7:_(s8), %8:_(s8) = G_UNMERGE_VALUES %4:_(<4 x s8>)
2842 // I1 and I2 are different instructions but produce same values,
2843 // %1 and %6 are same, %1 and %7 are not the same value.
2844 return I1->findRegisterDefOperandIdx(InstAndDef1->Reg, /*TRI=*/nullptr) ==
2845 I2->findRegisterDefOperandIdx(InstAndDef2->Reg, /*TRI=*/nullptr);
2847 return false;
2850 bool CombinerHelper::matchConstantOp(const MachineOperand &MOP,
2851 int64_t C) const {
2852 if (!MOP.isReg())
2853 return false;
2854 auto *MI = MRI.getVRegDef(MOP.getReg());
2855 auto MaybeCst = isConstantOrConstantSplatVector(*MI, MRI);
2856 return MaybeCst && MaybeCst->getBitWidth() <= 64 &&
2857 MaybeCst->getSExtValue() == C;
2860 bool CombinerHelper::matchConstantFPOp(const MachineOperand &MOP,
2861 double C) const {
2862 if (!MOP.isReg())
2863 return false;
2864 std::optional<FPValueAndVReg> MaybeCst;
2865 if (!mi_match(MOP.getReg(), MRI, m_GFCstOrSplat(MaybeCst)))
2866 return false;
2868 return MaybeCst->Value.isExactlyValue(C);
2871 void CombinerHelper::replaceSingleDefInstWithOperand(MachineInstr &MI,
2872 unsigned OpIdx) const {
2873 assert(MI.getNumExplicitDefs() == 1 && "Expected one explicit def?");
2874 Register OldReg = MI.getOperand(0).getReg();
2875 Register Replacement = MI.getOperand(OpIdx).getReg();
2876 assert(canReplaceReg(OldReg, Replacement, MRI) && "Cannot replace register?");
2877 replaceRegWith(MRI, OldReg, Replacement);
2878 MI.eraseFromParent();
2881 void CombinerHelper::replaceSingleDefInstWithReg(MachineInstr &MI,
2882 Register Replacement) const {
2883 assert(MI.getNumExplicitDefs() == 1 && "Expected one explicit def?");
2884 Register OldReg = MI.getOperand(0).getReg();
2885 assert(canReplaceReg(OldReg, Replacement, MRI) && "Cannot replace register?");
2886 replaceRegWith(MRI, OldReg, Replacement);
2887 MI.eraseFromParent();
2890 bool CombinerHelper::matchConstantLargerBitWidth(MachineInstr &MI,
2891 unsigned ConstIdx) const {
2892 Register ConstReg = MI.getOperand(ConstIdx).getReg();
2893 LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
2895 // Get the shift amount
2896 auto VRegAndVal = getIConstantVRegValWithLookThrough(ConstReg, MRI);
2897 if (!VRegAndVal)
2898 return false;
2900 // Return true of shift amount >= Bitwidth
2901 return (VRegAndVal->Value.uge(DstTy.getSizeInBits()));
2904 void CombinerHelper::applyFunnelShiftConstantModulo(MachineInstr &MI) const {
2905 assert((MI.getOpcode() == TargetOpcode::G_FSHL ||
2906 MI.getOpcode() == TargetOpcode::G_FSHR) &&
2907 "This is not a funnel shift operation");
2909 Register ConstReg = MI.getOperand(3).getReg();
2910 LLT ConstTy = MRI.getType(ConstReg);
2911 LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
2913 auto VRegAndVal = getIConstantVRegValWithLookThrough(ConstReg, MRI);
2914 assert((VRegAndVal) && "Value is not a constant");
2916 // Calculate the new Shift Amount = Old Shift Amount % BitWidth
2917 APInt NewConst = VRegAndVal->Value.urem(
2918 APInt(ConstTy.getSizeInBits(), DstTy.getScalarSizeInBits()));
2920 auto NewConstInstr = Builder.buildConstant(ConstTy, NewConst.getZExtValue());
2921 Builder.buildInstr(
2922 MI.getOpcode(), {MI.getOperand(0)},
2923 {MI.getOperand(1), MI.getOperand(2), NewConstInstr.getReg(0)});
2925 MI.eraseFromParent();
2928 bool CombinerHelper::matchSelectSameVal(MachineInstr &MI) const {
2929 assert(MI.getOpcode() == TargetOpcode::G_SELECT);
2930 // Match (cond ? x : x)
2931 return matchEqualDefs(MI.getOperand(2), MI.getOperand(3)) &&
2932 canReplaceReg(MI.getOperand(0).getReg(), MI.getOperand(2).getReg(),
2933 MRI);
2936 bool CombinerHelper::matchBinOpSameVal(MachineInstr &MI) const {
2937 return matchEqualDefs(MI.getOperand(1), MI.getOperand(2)) &&
2938 canReplaceReg(MI.getOperand(0).getReg(), MI.getOperand(1).getReg(),
2939 MRI);
2942 bool CombinerHelper::matchOperandIsZero(MachineInstr &MI,
2943 unsigned OpIdx) const {
2944 return matchConstantOp(MI.getOperand(OpIdx), 0) &&
2945 canReplaceReg(MI.getOperand(0).getReg(), MI.getOperand(OpIdx).getReg(),
2946 MRI);
2949 bool CombinerHelper::matchOperandIsUndef(MachineInstr &MI,
2950 unsigned OpIdx) const {
2951 MachineOperand &MO = MI.getOperand(OpIdx);
2952 return MO.isReg() &&
2953 getOpcodeDef(TargetOpcode::G_IMPLICIT_DEF, MO.getReg(), MRI);
2956 bool CombinerHelper::matchOperandIsKnownToBeAPowerOfTwo(MachineInstr &MI,
2957 unsigned OpIdx) const {
2958 MachineOperand &MO = MI.getOperand(OpIdx);
2959 return isKnownToBeAPowerOfTwo(MO.getReg(), MRI, KB);
2962 void CombinerHelper::replaceInstWithFConstant(MachineInstr &MI,
2963 double C) const {
2964 assert(MI.getNumDefs() == 1 && "Expected only one def?");
2965 Builder.buildFConstant(MI.getOperand(0), C);
2966 MI.eraseFromParent();
2969 void CombinerHelper::replaceInstWithConstant(MachineInstr &MI,
2970 int64_t C) const {
2971 assert(MI.getNumDefs() == 1 && "Expected only one def?");
2972 Builder.buildConstant(MI.getOperand(0), C);
2973 MI.eraseFromParent();
2976 void CombinerHelper::replaceInstWithConstant(MachineInstr &MI, APInt C) const {
2977 assert(MI.getNumDefs() == 1 && "Expected only one def?");
2978 Builder.buildConstant(MI.getOperand(0), C);
2979 MI.eraseFromParent();
2982 void CombinerHelper::replaceInstWithFConstant(MachineInstr &MI,
2983 ConstantFP *CFP) const {
2984 assert(MI.getNumDefs() == 1 && "Expected only one def?");
2985 Builder.buildFConstant(MI.getOperand(0), CFP->getValueAPF());
2986 MI.eraseFromParent();
2989 void CombinerHelper::replaceInstWithUndef(MachineInstr &MI) const {
2990 assert(MI.getNumDefs() == 1 && "Expected only one def?");
2991 Builder.buildUndef(MI.getOperand(0));
2992 MI.eraseFromParent();
2995 bool CombinerHelper::matchSimplifyAddToSub(
2996 MachineInstr &MI, std::tuple<Register, Register> &MatchInfo) const {
2997 Register LHS = MI.getOperand(1).getReg();
2998 Register RHS = MI.getOperand(2).getReg();
2999 Register &NewLHS = std::get<0>(MatchInfo);
3000 Register &NewRHS = std::get<1>(MatchInfo);
3002 // Helper lambda to check for opportunities for
3003 // ((0-A) + B) -> B - A
3004 // (A + (0-B)) -> A - B
3005 auto CheckFold = [&](Register &MaybeSub, Register &MaybeNewLHS) {
3006 if (!mi_match(MaybeSub, MRI, m_Neg(m_Reg(NewRHS))))
3007 return false;
3008 NewLHS = MaybeNewLHS;
3009 return true;
3012 return CheckFold(LHS, RHS) || CheckFold(RHS, LHS);
3015 bool CombinerHelper::matchCombineInsertVecElts(
3016 MachineInstr &MI, SmallVectorImpl<Register> &MatchInfo) const {
3017 assert(MI.getOpcode() == TargetOpcode::G_INSERT_VECTOR_ELT &&
3018 "Invalid opcode");
3019 Register DstReg = MI.getOperand(0).getReg();
3020 LLT DstTy = MRI.getType(DstReg);
3021 assert(DstTy.isVector() && "Invalid G_INSERT_VECTOR_ELT?");
3023 if (DstTy.isScalableVector())
3024 return false;
3026 unsigned NumElts = DstTy.getNumElements();
3027 // If this MI is part of a sequence of insert_vec_elts, then
3028 // don't do the combine in the middle of the sequence.
3029 if (MRI.hasOneUse(DstReg) && MRI.use_instr_begin(DstReg)->getOpcode() ==
3030 TargetOpcode::G_INSERT_VECTOR_ELT)
3031 return false;
3032 MachineInstr *CurrInst = &MI;
3033 MachineInstr *TmpInst;
3034 int64_t IntImm;
3035 Register TmpReg;
3036 MatchInfo.resize(NumElts);
3037 while (mi_match(
3038 CurrInst->getOperand(0).getReg(), MRI,
3039 m_GInsertVecElt(m_MInstr(TmpInst), m_Reg(TmpReg), m_ICst(IntImm)))) {
3040 if (IntImm >= NumElts || IntImm < 0)
3041 return false;
3042 if (!MatchInfo[IntImm])
3043 MatchInfo[IntImm] = TmpReg;
3044 CurrInst = TmpInst;
3046 // Variable index.
3047 if (CurrInst->getOpcode() == TargetOpcode::G_INSERT_VECTOR_ELT)
3048 return false;
3049 if (TmpInst->getOpcode() == TargetOpcode::G_BUILD_VECTOR) {
3050 for (unsigned I = 1; I < TmpInst->getNumOperands(); ++I) {
3051 if (!MatchInfo[I - 1].isValid())
3052 MatchInfo[I - 1] = TmpInst->getOperand(I).getReg();
3054 return true;
3056 // If we didn't end in a G_IMPLICIT_DEF and the source is not fully
3057 // overwritten, bail out.
3058 return TmpInst->getOpcode() == TargetOpcode::G_IMPLICIT_DEF ||
3059 all_of(MatchInfo, [](Register Reg) { return !!Reg; });
3062 void CombinerHelper::applyCombineInsertVecElts(
3063 MachineInstr &MI, SmallVectorImpl<Register> &MatchInfo) const {
3064 Register UndefReg;
3065 auto GetUndef = [&]() {
3066 if (UndefReg)
3067 return UndefReg;
3068 LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
3069 UndefReg = Builder.buildUndef(DstTy.getScalarType()).getReg(0);
3070 return UndefReg;
3072 for (Register &Reg : MatchInfo) {
3073 if (!Reg)
3074 Reg = GetUndef();
3076 Builder.buildBuildVector(MI.getOperand(0).getReg(), MatchInfo);
3077 MI.eraseFromParent();
3080 void CombinerHelper::applySimplifyAddToSub(
3081 MachineInstr &MI, std::tuple<Register, Register> &MatchInfo) const {
3082 Register SubLHS, SubRHS;
3083 std::tie(SubLHS, SubRHS) = MatchInfo;
3084 Builder.buildSub(MI.getOperand(0).getReg(), SubLHS, SubRHS);
3085 MI.eraseFromParent();
3088 bool CombinerHelper::matchHoistLogicOpWithSameOpcodeHands(
3089 MachineInstr &MI, InstructionStepsMatchInfo &MatchInfo) const {
3090 // Matches: logic (hand x, ...), (hand y, ...) -> hand (logic x, y), ...
3092 // Creates the new hand + logic instruction (but does not insert them.)
3094 // On success, MatchInfo is populated with the new instructions. These are
3095 // inserted in applyHoistLogicOpWithSameOpcodeHands.
3096 unsigned LogicOpcode = MI.getOpcode();
3097 assert(LogicOpcode == TargetOpcode::G_AND ||
3098 LogicOpcode == TargetOpcode::G_OR ||
3099 LogicOpcode == TargetOpcode::G_XOR);
3100 MachineIRBuilder MIB(MI);
3101 Register Dst = MI.getOperand(0).getReg();
3102 Register LHSReg = MI.getOperand(1).getReg();
3103 Register RHSReg = MI.getOperand(2).getReg();
3105 // Don't recompute anything.
3106 if (!MRI.hasOneNonDBGUse(LHSReg) || !MRI.hasOneNonDBGUse(RHSReg))
3107 return false;
3109 // Make sure we have (hand x, ...), (hand y, ...)
3110 MachineInstr *LeftHandInst = getDefIgnoringCopies(LHSReg, MRI);
3111 MachineInstr *RightHandInst = getDefIgnoringCopies(RHSReg, MRI);
3112 if (!LeftHandInst || !RightHandInst)
3113 return false;
3114 unsigned HandOpcode = LeftHandInst->getOpcode();
3115 if (HandOpcode != RightHandInst->getOpcode())
3116 return false;
3117 if (LeftHandInst->getNumOperands() < 2 ||
3118 !LeftHandInst->getOperand(1).isReg() ||
3119 RightHandInst->getNumOperands() < 2 ||
3120 !RightHandInst->getOperand(1).isReg())
3121 return false;
3123 // Make sure the types match up, and if we're doing this post-legalization,
3124 // we end up with legal types.
3125 Register X = LeftHandInst->getOperand(1).getReg();
3126 Register Y = RightHandInst->getOperand(1).getReg();
3127 LLT XTy = MRI.getType(X);
3128 LLT YTy = MRI.getType(Y);
3129 if (!XTy.isValid() || XTy != YTy)
3130 return false;
3132 // Optional extra source register.
3133 Register ExtraHandOpSrcReg;
3134 switch (HandOpcode) {
3135 default:
3136 return false;
3137 case TargetOpcode::G_ANYEXT:
3138 case TargetOpcode::G_SEXT:
3139 case TargetOpcode::G_ZEXT: {
3140 // Match: logic (ext X), (ext Y) --> ext (logic X, Y)
3141 break;
3143 case TargetOpcode::G_TRUNC: {
3144 // Match: logic (trunc X), (trunc Y) -> trunc (logic X, Y)
3145 const MachineFunction *MF = MI.getMF();
3146 LLVMContext &Ctx = MF->getFunction().getContext();
3148 LLT DstTy = MRI.getType(Dst);
3149 const TargetLowering &TLI = getTargetLowering();
3151 // Be extra careful sinking truncate. If it's free, there's no benefit in
3152 // widening a binop.
3153 if (TLI.isZExtFree(DstTy, XTy, Ctx) && TLI.isTruncateFree(XTy, DstTy, Ctx))
3154 return false;
3155 break;
3157 case TargetOpcode::G_AND:
3158 case TargetOpcode::G_ASHR:
3159 case TargetOpcode::G_LSHR:
3160 case TargetOpcode::G_SHL: {
3161 // Match: logic (binop x, z), (binop y, z) -> binop (logic x, y), z
3162 MachineOperand &ZOp = LeftHandInst->getOperand(2);
3163 if (!matchEqualDefs(ZOp, RightHandInst->getOperand(2)))
3164 return false;
3165 ExtraHandOpSrcReg = ZOp.getReg();
3166 break;
3170 if (!isLegalOrBeforeLegalizer({LogicOpcode, {XTy, YTy}}))
3171 return false;
3173 // Record the steps to build the new instructions.
3175 // Steps to build (logic x, y)
3176 auto NewLogicDst = MRI.createGenericVirtualRegister(XTy);
3177 OperandBuildSteps LogicBuildSteps = {
3178 [=](MachineInstrBuilder &MIB) { MIB.addDef(NewLogicDst); },
3179 [=](MachineInstrBuilder &MIB) { MIB.addReg(X); },
3180 [=](MachineInstrBuilder &MIB) { MIB.addReg(Y); }};
3181 InstructionBuildSteps LogicSteps(LogicOpcode, LogicBuildSteps);
3183 // Steps to build hand (logic x, y), ...z
3184 OperandBuildSteps HandBuildSteps = {
3185 [=](MachineInstrBuilder &MIB) { MIB.addDef(Dst); },
3186 [=](MachineInstrBuilder &MIB) { MIB.addReg(NewLogicDst); }};
3187 if (ExtraHandOpSrcReg.isValid())
3188 HandBuildSteps.push_back(
3189 [=](MachineInstrBuilder &MIB) { MIB.addReg(ExtraHandOpSrcReg); });
3190 InstructionBuildSteps HandSteps(HandOpcode, HandBuildSteps);
3192 MatchInfo = InstructionStepsMatchInfo({LogicSteps, HandSteps});
3193 return true;
3196 void CombinerHelper::applyBuildInstructionSteps(
3197 MachineInstr &MI, InstructionStepsMatchInfo &MatchInfo) const {
3198 assert(MatchInfo.InstrsToBuild.size() &&
3199 "Expected at least one instr to build?");
3200 for (auto &InstrToBuild : MatchInfo.InstrsToBuild) {
3201 assert(InstrToBuild.Opcode && "Expected a valid opcode?");
3202 assert(InstrToBuild.OperandFns.size() && "Expected at least one operand?");
3203 MachineInstrBuilder Instr = Builder.buildInstr(InstrToBuild.Opcode);
3204 for (auto &OperandFn : InstrToBuild.OperandFns)
3205 OperandFn(Instr);
3207 MI.eraseFromParent();
3210 bool CombinerHelper::matchAshrShlToSextInreg(
3211 MachineInstr &MI, std::tuple<Register, int64_t> &MatchInfo) const {
3212 assert(MI.getOpcode() == TargetOpcode::G_ASHR);
3213 int64_t ShlCst, AshrCst;
3214 Register Src;
3215 if (!mi_match(MI.getOperand(0).getReg(), MRI,
3216 m_GAShr(m_GShl(m_Reg(Src), m_ICstOrSplat(ShlCst)),
3217 m_ICstOrSplat(AshrCst))))
3218 return false;
3219 if (ShlCst != AshrCst)
3220 return false;
3221 if (!isLegalOrBeforeLegalizer(
3222 {TargetOpcode::G_SEXT_INREG, {MRI.getType(Src)}}))
3223 return false;
3224 MatchInfo = std::make_tuple(Src, ShlCst);
3225 return true;
3228 void CombinerHelper::applyAshShlToSextInreg(
3229 MachineInstr &MI, std::tuple<Register, int64_t> &MatchInfo) const {
3230 assert(MI.getOpcode() == TargetOpcode::G_ASHR);
3231 Register Src;
3232 int64_t ShiftAmt;
3233 std::tie(Src, ShiftAmt) = MatchInfo;
3234 unsigned Size = MRI.getType(Src).getScalarSizeInBits();
3235 Builder.buildSExtInReg(MI.getOperand(0).getReg(), Src, Size - ShiftAmt);
3236 MI.eraseFromParent();
3239 /// and(and(x, C1), C2) -> C1&C2 ? and(x, C1&C2) : 0
3240 bool CombinerHelper::matchOverlappingAnd(
3241 MachineInstr &MI,
3242 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
3243 assert(MI.getOpcode() == TargetOpcode::G_AND);
3245 Register Dst = MI.getOperand(0).getReg();
3246 LLT Ty = MRI.getType(Dst);
3248 Register R;
3249 int64_t C1;
3250 int64_t C2;
3251 if (!mi_match(
3252 Dst, MRI,
3253 m_GAnd(m_GAnd(m_Reg(R), m_ICst(C1)), m_ICst(C2))))
3254 return false;
3256 MatchInfo = [=](MachineIRBuilder &B) {
3257 if (C1 & C2) {
3258 B.buildAnd(Dst, R, B.buildConstant(Ty, C1 & C2));
3259 return;
3261 auto Zero = B.buildConstant(Ty, 0);
3262 replaceRegWith(MRI, Dst, Zero->getOperand(0).getReg());
3264 return true;
3267 bool CombinerHelper::matchRedundantAnd(MachineInstr &MI,
3268 Register &Replacement) const {
3269 // Given
3271 // %y:_(sN) = G_SOMETHING
3272 // %x:_(sN) = G_SOMETHING
3273 // %res:_(sN) = G_AND %x, %y
3275 // Eliminate the G_AND when it is known that x & y == x or x & y == y.
3277 // Patterns like this can appear as a result of legalization. E.g.
3279 // %cmp:_(s32) = G_ICMP intpred(pred), %x(s32), %y
3280 // %one:_(s32) = G_CONSTANT i32 1
3281 // %and:_(s32) = G_AND %cmp, %one
3283 // In this case, G_ICMP only produces a single bit, so x & 1 == x.
3284 assert(MI.getOpcode() == TargetOpcode::G_AND);
3285 if (!KB)
3286 return false;
3288 Register AndDst = MI.getOperand(0).getReg();
3289 Register LHS = MI.getOperand(1).getReg();
3290 Register RHS = MI.getOperand(2).getReg();
3292 // Check the RHS (maybe a constant) first, and if we have no KnownBits there,
3293 // we can't do anything. If we do, then it depends on whether we have
3294 // KnownBits on the LHS.
3295 KnownBits RHSBits = KB->getKnownBits(RHS);
3296 if (RHSBits.isUnknown())
3297 return false;
3299 KnownBits LHSBits = KB->getKnownBits(LHS);
3301 // Check that x & Mask == x.
3302 // x & 1 == x, always
3303 // x & 0 == x, only if x is also 0
3304 // Meaning Mask has no effect if every bit is either one in Mask or zero in x.
3306 // Check if we can replace AndDst with the LHS of the G_AND
3307 if (canReplaceReg(AndDst, LHS, MRI) &&
3308 (LHSBits.Zero | RHSBits.One).isAllOnes()) {
3309 Replacement = LHS;
3310 return true;
3313 // Check if we can replace AndDst with the RHS of the G_AND
3314 if (canReplaceReg(AndDst, RHS, MRI) &&
3315 (LHSBits.One | RHSBits.Zero).isAllOnes()) {
3316 Replacement = RHS;
3317 return true;
3320 return false;
3323 bool CombinerHelper::matchRedundantOr(MachineInstr &MI,
3324 Register &Replacement) const {
3325 // Given
3327 // %y:_(sN) = G_SOMETHING
3328 // %x:_(sN) = G_SOMETHING
3329 // %res:_(sN) = G_OR %x, %y
3331 // Eliminate the G_OR when it is known that x | y == x or x | y == y.
3332 assert(MI.getOpcode() == TargetOpcode::G_OR);
3333 if (!KB)
3334 return false;
3336 Register OrDst = MI.getOperand(0).getReg();
3337 Register LHS = MI.getOperand(1).getReg();
3338 Register RHS = MI.getOperand(2).getReg();
3340 KnownBits LHSBits = KB->getKnownBits(LHS);
3341 KnownBits RHSBits = KB->getKnownBits(RHS);
3343 // Check that x | Mask == x.
3344 // x | 0 == x, always
3345 // x | 1 == x, only if x is also 1
3346 // Meaning Mask has no effect if every bit is either zero in Mask or one in x.
3348 // Check if we can replace OrDst with the LHS of the G_OR
3349 if (canReplaceReg(OrDst, LHS, MRI) &&
3350 (LHSBits.One | RHSBits.Zero).isAllOnes()) {
3351 Replacement = LHS;
3352 return true;
3355 // Check if we can replace OrDst with the RHS of the G_OR
3356 if (canReplaceReg(OrDst, RHS, MRI) &&
3357 (LHSBits.Zero | RHSBits.One).isAllOnes()) {
3358 Replacement = RHS;
3359 return true;
3362 return false;
3365 bool CombinerHelper::matchRedundantSExtInReg(MachineInstr &MI) const {
3366 // If the input is already sign extended, just drop the extension.
3367 Register Src = MI.getOperand(1).getReg();
3368 unsigned ExtBits = MI.getOperand(2).getImm();
3369 unsigned TypeSize = MRI.getType(Src).getScalarSizeInBits();
3370 return KB->computeNumSignBits(Src) >= (TypeSize - ExtBits + 1);
3373 static bool isConstValidTrue(const TargetLowering &TLI, unsigned ScalarSizeBits,
3374 int64_t Cst, bool IsVector, bool IsFP) {
3375 // For i1, Cst will always be -1 regardless of boolean contents.
3376 return (ScalarSizeBits == 1 && Cst == -1) ||
3377 isConstTrueVal(TLI, Cst, IsVector, IsFP);
3380 // This combine tries to reduce the number of scalarised G_TRUNC instructions by
3381 // using vector truncates instead
3383 // EXAMPLE:
3384 // %a(i32), %b(i32) = G_UNMERGE_VALUES %src(<2 x i32>)
3385 // %T_a(i16) = G_TRUNC %a(i32)
3386 // %T_b(i16) = G_TRUNC %b(i32)
3387 // %Undef(i16) = G_IMPLICIT_DEF(i16)
3388 // %dst(v4i16) = G_BUILD_VECTORS %T_a(i16), %T_b(i16), %Undef(i16), %Undef(i16)
3390 // ===>
3391 // %Undef(<2 x i32>) = G_IMPLICIT_DEF(<2 x i32>)
3392 // %Mid(<4 x s32>) = G_CONCAT_VECTORS %src(<2 x i32>), %Undef(<2 x i32>)
3393 // %dst(<4 x s16>) = G_TRUNC %Mid(<4 x s32>)
3395 // Only matches sources made up of G_TRUNCs followed by G_IMPLICIT_DEFs
3396 bool CombinerHelper::matchUseVectorTruncate(MachineInstr &MI,
3397 Register &MatchInfo) const {
3398 auto BuildMI = cast<GBuildVector>(&MI);
3399 unsigned NumOperands = BuildMI->getNumSources();
3400 LLT DstTy = MRI.getType(BuildMI->getReg(0));
3402 // Check the G_BUILD_VECTOR sources
3403 unsigned I;
3404 MachineInstr *UnmergeMI = nullptr;
3406 // Check all source TRUNCs come from the same UNMERGE instruction
3407 for (I = 0; I < NumOperands; ++I) {
3408 auto SrcMI = MRI.getVRegDef(BuildMI->getSourceReg(I));
3409 auto SrcMIOpc = SrcMI->getOpcode();
3411 // Check if the G_TRUNC instructions all come from the same MI
3412 if (SrcMIOpc == TargetOpcode::G_TRUNC) {
3413 if (!UnmergeMI) {
3414 UnmergeMI = MRI.getVRegDef(SrcMI->getOperand(1).getReg());
3415 if (UnmergeMI->getOpcode() != TargetOpcode::G_UNMERGE_VALUES)
3416 return false;
3417 } else {
3418 auto UnmergeSrcMI = MRI.getVRegDef(SrcMI->getOperand(1).getReg());
3419 if (UnmergeMI != UnmergeSrcMI)
3420 return false;
3422 } else {
3423 break;
3426 if (I < 2)
3427 return false;
3429 // Check the remaining source elements are only G_IMPLICIT_DEF
3430 for (; I < NumOperands; ++I) {
3431 auto SrcMI = MRI.getVRegDef(BuildMI->getSourceReg(I));
3432 auto SrcMIOpc = SrcMI->getOpcode();
3434 if (SrcMIOpc != TargetOpcode::G_IMPLICIT_DEF)
3435 return false;
3438 // Check the size of unmerge source
3439 MatchInfo = cast<GUnmerge>(UnmergeMI)->getSourceReg();
3440 LLT UnmergeSrcTy = MRI.getType(MatchInfo);
3441 if (!DstTy.getElementCount().isKnownMultipleOf(UnmergeSrcTy.getNumElements()))
3442 return false;
3444 // Only generate legal instructions post-legalizer
3445 if (!IsPreLegalize) {
3446 LLT MidTy = DstTy.changeElementType(UnmergeSrcTy.getScalarType());
3448 if (DstTy.getElementCount() != UnmergeSrcTy.getElementCount() &&
3449 !isLegal({TargetOpcode::G_CONCAT_VECTORS, {MidTy, UnmergeSrcTy}}))
3450 return false;
3452 if (!isLegal({TargetOpcode::G_TRUNC, {DstTy, MidTy}}))
3453 return false;
3456 return true;
3459 void CombinerHelper::applyUseVectorTruncate(MachineInstr &MI,
3460 Register &MatchInfo) const {
3461 Register MidReg;
3462 auto BuildMI = cast<GBuildVector>(&MI);
3463 Register DstReg = BuildMI->getReg(0);
3464 LLT DstTy = MRI.getType(DstReg);
3465 LLT UnmergeSrcTy = MRI.getType(MatchInfo);
3466 unsigned DstTyNumElt = DstTy.getNumElements();
3467 unsigned UnmergeSrcTyNumElt = UnmergeSrcTy.getNumElements();
3469 // No need to pad vector if only G_TRUNC is needed
3470 if (DstTyNumElt / UnmergeSrcTyNumElt == 1) {
3471 MidReg = MatchInfo;
3472 } else {
3473 Register UndefReg = Builder.buildUndef(UnmergeSrcTy).getReg(0);
3474 SmallVector<Register> ConcatRegs = {MatchInfo};
3475 for (unsigned I = 1; I < DstTyNumElt / UnmergeSrcTyNumElt; ++I)
3476 ConcatRegs.push_back(UndefReg);
3478 auto MidTy = DstTy.changeElementType(UnmergeSrcTy.getScalarType());
3479 MidReg = Builder.buildConcatVectors(MidTy, ConcatRegs).getReg(0);
3482 Builder.buildTrunc(DstReg, MidReg);
3483 MI.eraseFromParent();
3486 bool CombinerHelper::matchNotCmp(
3487 MachineInstr &MI, SmallVectorImpl<Register> &RegsToNegate) const {
3488 assert(MI.getOpcode() == TargetOpcode::G_XOR);
3489 LLT Ty = MRI.getType(MI.getOperand(0).getReg());
3490 const auto &TLI = *Builder.getMF().getSubtarget().getTargetLowering();
3491 Register XorSrc;
3492 Register CstReg;
3493 // We match xor(src, true) here.
3494 if (!mi_match(MI.getOperand(0).getReg(), MRI,
3495 m_GXor(m_Reg(XorSrc), m_Reg(CstReg))))
3496 return false;
3498 if (!MRI.hasOneNonDBGUse(XorSrc))
3499 return false;
3501 // Check that XorSrc is the root of a tree of comparisons combined with ANDs
3502 // and ORs. The suffix of RegsToNegate starting from index I is used a work
3503 // list of tree nodes to visit.
3504 RegsToNegate.push_back(XorSrc);
3505 // Remember whether the comparisons are all integer or all floating point.
3506 bool IsInt = false;
3507 bool IsFP = false;
3508 for (unsigned I = 0; I < RegsToNegate.size(); ++I) {
3509 Register Reg = RegsToNegate[I];
3510 if (!MRI.hasOneNonDBGUse(Reg))
3511 return false;
3512 MachineInstr *Def = MRI.getVRegDef(Reg);
3513 switch (Def->getOpcode()) {
3514 default:
3515 // Don't match if the tree contains anything other than ANDs, ORs and
3516 // comparisons.
3517 return false;
3518 case TargetOpcode::G_ICMP:
3519 if (IsFP)
3520 return false;
3521 IsInt = true;
3522 // When we apply the combine we will invert the predicate.
3523 break;
3524 case TargetOpcode::G_FCMP:
3525 if (IsInt)
3526 return false;
3527 IsFP = true;
3528 // When we apply the combine we will invert the predicate.
3529 break;
3530 case TargetOpcode::G_AND:
3531 case TargetOpcode::G_OR:
3532 // Implement De Morgan's laws:
3533 // ~(x & y) -> ~x | ~y
3534 // ~(x | y) -> ~x & ~y
3535 // When we apply the combine we will change the opcode and recursively
3536 // negate the operands.
3537 RegsToNegate.push_back(Def->getOperand(1).getReg());
3538 RegsToNegate.push_back(Def->getOperand(2).getReg());
3539 break;
3543 // Now we know whether the comparisons are integer or floating point, check
3544 // the constant in the xor.
3545 int64_t Cst;
3546 if (Ty.isVector()) {
3547 MachineInstr *CstDef = MRI.getVRegDef(CstReg);
3548 auto MaybeCst = getIConstantSplatSExtVal(*CstDef, MRI);
3549 if (!MaybeCst)
3550 return false;
3551 if (!isConstValidTrue(TLI, Ty.getScalarSizeInBits(), *MaybeCst, true, IsFP))
3552 return false;
3553 } else {
3554 if (!mi_match(CstReg, MRI, m_ICst(Cst)))
3555 return false;
3556 if (!isConstValidTrue(TLI, Ty.getSizeInBits(), Cst, false, IsFP))
3557 return false;
3560 return true;
3563 void CombinerHelper::applyNotCmp(
3564 MachineInstr &MI, SmallVectorImpl<Register> &RegsToNegate) const {
3565 for (Register Reg : RegsToNegate) {
3566 MachineInstr *Def = MRI.getVRegDef(Reg);
3567 Observer.changingInstr(*Def);
3568 // For each comparison, invert the opcode. For each AND and OR, change the
3569 // opcode.
3570 switch (Def->getOpcode()) {
3571 default:
3572 llvm_unreachable("Unexpected opcode");
3573 case TargetOpcode::G_ICMP:
3574 case TargetOpcode::G_FCMP: {
3575 MachineOperand &PredOp = Def->getOperand(1);
3576 CmpInst::Predicate NewP = CmpInst::getInversePredicate(
3577 (CmpInst::Predicate)PredOp.getPredicate());
3578 PredOp.setPredicate(NewP);
3579 break;
3581 case TargetOpcode::G_AND:
3582 Def->setDesc(Builder.getTII().get(TargetOpcode::G_OR));
3583 break;
3584 case TargetOpcode::G_OR:
3585 Def->setDesc(Builder.getTII().get(TargetOpcode::G_AND));
3586 break;
3588 Observer.changedInstr(*Def);
3591 replaceRegWith(MRI, MI.getOperand(0).getReg(), MI.getOperand(1).getReg());
3592 MI.eraseFromParent();
3595 bool CombinerHelper::matchXorOfAndWithSameReg(
3596 MachineInstr &MI, std::pair<Register, Register> &MatchInfo) const {
3597 // Match (xor (and x, y), y) (or any of its commuted cases)
3598 assert(MI.getOpcode() == TargetOpcode::G_XOR);
3599 Register &X = MatchInfo.first;
3600 Register &Y = MatchInfo.second;
3601 Register AndReg = MI.getOperand(1).getReg();
3602 Register SharedReg = MI.getOperand(2).getReg();
3604 // Find a G_AND on either side of the G_XOR.
3605 // Look for one of
3607 // (xor (and x, y), SharedReg)
3608 // (xor SharedReg, (and x, y))
3609 if (!mi_match(AndReg, MRI, m_GAnd(m_Reg(X), m_Reg(Y)))) {
3610 std::swap(AndReg, SharedReg);
3611 if (!mi_match(AndReg, MRI, m_GAnd(m_Reg(X), m_Reg(Y))))
3612 return false;
3615 // Only do this if we'll eliminate the G_AND.
3616 if (!MRI.hasOneNonDBGUse(AndReg))
3617 return false;
3619 // We can combine if SharedReg is the same as either the LHS or RHS of the
3620 // G_AND.
3621 if (Y != SharedReg)
3622 std::swap(X, Y);
3623 return Y == SharedReg;
3626 void CombinerHelper::applyXorOfAndWithSameReg(
3627 MachineInstr &MI, std::pair<Register, Register> &MatchInfo) const {
3628 // Fold (xor (and x, y), y) -> (and (not x), y)
3629 Register X, Y;
3630 std::tie(X, Y) = MatchInfo;
3631 auto Not = Builder.buildNot(MRI.getType(X), X);
3632 Observer.changingInstr(MI);
3633 MI.setDesc(Builder.getTII().get(TargetOpcode::G_AND));
3634 MI.getOperand(1).setReg(Not->getOperand(0).getReg());
3635 MI.getOperand(2).setReg(Y);
3636 Observer.changedInstr(MI);
3639 bool CombinerHelper::matchPtrAddZero(MachineInstr &MI) const {
3640 auto &PtrAdd = cast<GPtrAdd>(MI);
3641 Register DstReg = PtrAdd.getReg(0);
3642 LLT Ty = MRI.getType(DstReg);
3643 const DataLayout &DL = Builder.getMF().getDataLayout();
3645 if (DL.isNonIntegralAddressSpace(Ty.getScalarType().getAddressSpace()))
3646 return false;
3648 if (Ty.isPointer()) {
3649 auto ConstVal = getIConstantVRegVal(PtrAdd.getBaseReg(), MRI);
3650 return ConstVal && *ConstVal == 0;
3653 assert(Ty.isVector() && "Expecting a vector type");
3654 const MachineInstr *VecMI = MRI.getVRegDef(PtrAdd.getBaseReg());
3655 return isBuildVectorAllZeros(*VecMI, MRI);
3658 void CombinerHelper::applyPtrAddZero(MachineInstr &MI) const {
3659 auto &PtrAdd = cast<GPtrAdd>(MI);
3660 Builder.buildIntToPtr(PtrAdd.getReg(0), PtrAdd.getOffsetReg());
3661 PtrAdd.eraseFromParent();
3664 /// The second source operand is known to be a power of 2.
3665 void CombinerHelper::applySimplifyURemByPow2(MachineInstr &MI) const {
3666 Register DstReg = MI.getOperand(0).getReg();
3667 Register Src0 = MI.getOperand(1).getReg();
3668 Register Pow2Src1 = MI.getOperand(2).getReg();
3669 LLT Ty = MRI.getType(DstReg);
3671 // Fold (urem x, pow2) -> (and x, pow2-1)
3672 auto NegOne = Builder.buildConstant(Ty, -1);
3673 auto Add = Builder.buildAdd(Ty, Pow2Src1, NegOne);
3674 Builder.buildAnd(DstReg, Src0, Add);
3675 MI.eraseFromParent();
3678 bool CombinerHelper::matchFoldBinOpIntoSelect(MachineInstr &MI,
3679 unsigned &SelectOpNo) const {
3680 Register LHS = MI.getOperand(1).getReg();
3681 Register RHS = MI.getOperand(2).getReg();
3683 Register OtherOperandReg = RHS;
3684 SelectOpNo = 1;
3685 MachineInstr *Select = MRI.getVRegDef(LHS);
3687 // Don't do this unless the old select is going away. We want to eliminate the
3688 // binary operator, not replace a binop with a select.
3689 if (Select->getOpcode() != TargetOpcode::G_SELECT ||
3690 !MRI.hasOneNonDBGUse(LHS)) {
3691 OtherOperandReg = LHS;
3692 SelectOpNo = 2;
3693 Select = MRI.getVRegDef(RHS);
3694 if (Select->getOpcode() != TargetOpcode::G_SELECT ||
3695 !MRI.hasOneNonDBGUse(RHS))
3696 return false;
3699 MachineInstr *SelectLHS = MRI.getVRegDef(Select->getOperand(2).getReg());
3700 MachineInstr *SelectRHS = MRI.getVRegDef(Select->getOperand(3).getReg());
3702 if (!isConstantOrConstantVector(*SelectLHS, MRI,
3703 /*AllowFP*/ true,
3704 /*AllowOpaqueConstants*/ false))
3705 return false;
3706 if (!isConstantOrConstantVector(*SelectRHS, MRI,
3707 /*AllowFP*/ true,
3708 /*AllowOpaqueConstants*/ false))
3709 return false;
3711 unsigned BinOpcode = MI.getOpcode();
3713 // We know that one of the operands is a select of constants. Now verify that
3714 // the other binary operator operand is either a constant, or we can handle a
3715 // variable.
3716 bool CanFoldNonConst =
3717 (BinOpcode == TargetOpcode::G_AND || BinOpcode == TargetOpcode::G_OR) &&
3718 (isNullOrNullSplat(*SelectLHS, MRI) ||
3719 isAllOnesOrAllOnesSplat(*SelectLHS, MRI)) &&
3720 (isNullOrNullSplat(*SelectRHS, MRI) ||
3721 isAllOnesOrAllOnesSplat(*SelectRHS, MRI));
3722 if (CanFoldNonConst)
3723 return true;
3725 return isConstantOrConstantVector(*MRI.getVRegDef(OtherOperandReg), MRI,
3726 /*AllowFP*/ true,
3727 /*AllowOpaqueConstants*/ false);
3730 /// \p SelectOperand is the operand in binary operator \p MI that is the select
3731 /// to fold.
3732 void CombinerHelper::applyFoldBinOpIntoSelect(
3733 MachineInstr &MI, const unsigned &SelectOperand) const {
3734 Register Dst = MI.getOperand(0).getReg();
3735 Register LHS = MI.getOperand(1).getReg();
3736 Register RHS = MI.getOperand(2).getReg();
3737 MachineInstr *Select = MRI.getVRegDef(MI.getOperand(SelectOperand).getReg());
3739 Register SelectCond = Select->getOperand(1).getReg();
3740 Register SelectTrue = Select->getOperand(2).getReg();
3741 Register SelectFalse = Select->getOperand(3).getReg();
3743 LLT Ty = MRI.getType(Dst);
3744 unsigned BinOpcode = MI.getOpcode();
3746 Register FoldTrue, FoldFalse;
3748 // We have a select-of-constants followed by a binary operator with a
3749 // constant. Eliminate the binop by pulling the constant math into the select.
3750 // Example: add (select Cond, CT, CF), CBO --> select Cond, CT + CBO, CF + CBO
3751 if (SelectOperand == 1) {
3752 // TODO: SelectionDAG verifies this actually constant folds before
3753 // committing to the combine.
3755 FoldTrue = Builder.buildInstr(BinOpcode, {Ty}, {SelectTrue, RHS}).getReg(0);
3756 FoldFalse =
3757 Builder.buildInstr(BinOpcode, {Ty}, {SelectFalse, RHS}).getReg(0);
3758 } else {
3759 FoldTrue = Builder.buildInstr(BinOpcode, {Ty}, {LHS, SelectTrue}).getReg(0);
3760 FoldFalse =
3761 Builder.buildInstr(BinOpcode, {Ty}, {LHS, SelectFalse}).getReg(0);
3764 Builder.buildSelect(Dst, SelectCond, FoldTrue, FoldFalse, MI.getFlags());
3765 MI.eraseFromParent();
3768 std::optional<SmallVector<Register, 8>>
3769 CombinerHelper::findCandidatesForLoadOrCombine(const MachineInstr *Root) const {
3770 assert(Root->getOpcode() == TargetOpcode::G_OR && "Expected G_OR only!");
3771 // We want to detect if Root is part of a tree which represents a bunch
3772 // of loads being merged into a larger load. We'll try to recognize patterns
3773 // like, for example:
3775 // Reg Reg
3776 // \ /
3777 // OR_1 Reg
3778 // \ /
3779 // OR_2
3780 // \ Reg
3781 // .. /
3782 // Root
3784 // Reg Reg Reg Reg
3785 // \ / \ /
3786 // OR_1 OR_2
3787 // \ /
3788 // \ /
3789 // ...
3790 // Root
3792 // Each "Reg" may have been produced by a load + some arithmetic. This
3793 // function will save each of them.
3794 SmallVector<Register, 8> RegsToVisit;
3795 SmallVector<const MachineInstr *, 7> Ors = {Root};
3797 // In the "worst" case, we're dealing with a load for each byte. So, there
3798 // are at most #bytes - 1 ORs.
3799 const unsigned MaxIter =
3800 MRI.getType(Root->getOperand(0).getReg()).getSizeInBytes() - 1;
3801 for (unsigned Iter = 0; Iter < MaxIter; ++Iter) {
3802 if (Ors.empty())
3803 break;
3804 const MachineInstr *Curr = Ors.pop_back_val();
3805 Register OrLHS = Curr->getOperand(1).getReg();
3806 Register OrRHS = Curr->getOperand(2).getReg();
3808 // In the combine, we want to elimate the entire tree.
3809 if (!MRI.hasOneNonDBGUse(OrLHS) || !MRI.hasOneNonDBGUse(OrRHS))
3810 return std::nullopt;
3812 // If it's a G_OR, save it and continue to walk. If it's not, then it's
3813 // something that may be a load + arithmetic.
3814 if (const MachineInstr *Or = getOpcodeDef(TargetOpcode::G_OR, OrLHS, MRI))
3815 Ors.push_back(Or);
3816 else
3817 RegsToVisit.push_back(OrLHS);
3818 if (const MachineInstr *Or = getOpcodeDef(TargetOpcode::G_OR, OrRHS, MRI))
3819 Ors.push_back(Or);
3820 else
3821 RegsToVisit.push_back(OrRHS);
3824 // We're going to try and merge each register into a wider power-of-2 type,
3825 // so we ought to have an even number of registers.
3826 if (RegsToVisit.empty() || RegsToVisit.size() % 2 != 0)
3827 return std::nullopt;
3828 return RegsToVisit;
3831 /// Helper function for findLoadOffsetsForLoadOrCombine.
3833 /// Check if \p Reg is the result of loading a \p MemSizeInBits wide value,
3834 /// and then moving that value into a specific byte offset.
3836 /// e.g. x[i] << 24
3838 /// \returns The load instruction and the byte offset it is moved into.
3839 static std::optional<std::pair<GZExtLoad *, int64_t>>
3840 matchLoadAndBytePosition(Register Reg, unsigned MemSizeInBits,
3841 const MachineRegisterInfo &MRI) {
3842 assert(MRI.hasOneNonDBGUse(Reg) &&
3843 "Expected Reg to only have one non-debug use?");
3844 Register MaybeLoad;
3845 int64_t Shift;
3846 if (!mi_match(Reg, MRI,
3847 m_OneNonDBGUse(m_GShl(m_Reg(MaybeLoad), m_ICst(Shift))))) {
3848 Shift = 0;
3849 MaybeLoad = Reg;
3852 if (Shift % MemSizeInBits != 0)
3853 return std::nullopt;
3855 // TODO: Handle other types of loads.
3856 auto *Load = getOpcodeDef<GZExtLoad>(MaybeLoad, MRI);
3857 if (!Load)
3858 return std::nullopt;
3860 if (!Load->isUnordered() || Load->getMemSizeInBits() != MemSizeInBits)
3861 return std::nullopt;
3863 return std::make_pair(Load, Shift / MemSizeInBits);
3866 std::optional<std::tuple<GZExtLoad *, int64_t, GZExtLoad *>>
3867 CombinerHelper::findLoadOffsetsForLoadOrCombine(
3868 SmallDenseMap<int64_t, int64_t, 8> &MemOffset2Idx,
3869 const SmallVector<Register, 8> &RegsToVisit,
3870 const unsigned MemSizeInBits) const {
3872 // Each load found for the pattern. There should be one for each RegsToVisit.
3873 SmallSetVector<const MachineInstr *, 8> Loads;
3875 // The lowest index used in any load. (The lowest "i" for each x[i].)
3876 int64_t LowestIdx = INT64_MAX;
3878 // The load which uses the lowest index.
3879 GZExtLoad *LowestIdxLoad = nullptr;
3881 // Keeps track of the load indices we see. We shouldn't see any indices twice.
3882 SmallSet<int64_t, 8> SeenIdx;
3884 // Ensure each load is in the same MBB.
3885 // TODO: Support multiple MachineBasicBlocks.
3886 MachineBasicBlock *MBB = nullptr;
3887 const MachineMemOperand *MMO = nullptr;
3889 // Earliest instruction-order load in the pattern.
3890 GZExtLoad *EarliestLoad = nullptr;
3892 // Latest instruction-order load in the pattern.
3893 GZExtLoad *LatestLoad = nullptr;
3895 // Base pointer which every load should share.
3896 Register BasePtr;
3898 // We want to find a load for each register. Each load should have some
3899 // appropriate bit twiddling arithmetic. During this loop, we will also keep
3900 // track of the load which uses the lowest index. Later, we will check if we
3901 // can use its pointer in the final, combined load.
3902 for (auto Reg : RegsToVisit) {
3903 // Find the load, and find the position that it will end up in (e.g. a
3904 // shifted) value.
3905 auto LoadAndPos = matchLoadAndBytePosition(Reg, MemSizeInBits, MRI);
3906 if (!LoadAndPos)
3907 return std::nullopt;
3908 GZExtLoad *Load;
3909 int64_t DstPos;
3910 std::tie(Load, DstPos) = *LoadAndPos;
3912 // TODO: Handle multiple MachineBasicBlocks. Currently not handled because
3913 // it is difficult to check for stores/calls/etc between loads.
3914 MachineBasicBlock *LoadMBB = Load->getParent();
3915 if (!MBB)
3916 MBB = LoadMBB;
3917 if (LoadMBB != MBB)
3918 return std::nullopt;
3920 // Make sure that the MachineMemOperands of every seen load are compatible.
3921 auto &LoadMMO = Load->getMMO();
3922 if (!MMO)
3923 MMO = &LoadMMO;
3924 if (MMO->getAddrSpace() != LoadMMO.getAddrSpace())
3925 return std::nullopt;
3927 // Find out what the base pointer and index for the load is.
3928 Register LoadPtr;
3929 int64_t Idx;
3930 if (!mi_match(Load->getOperand(1).getReg(), MRI,
3931 m_GPtrAdd(m_Reg(LoadPtr), m_ICst(Idx)))) {
3932 LoadPtr = Load->getOperand(1).getReg();
3933 Idx = 0;
3936 // Don't combine things like a[i], a[i] -> a bigger load.
3937 if (!SeenIdx.insert(Idx).second)
3938 return std::nullopt;
3940 // Every load must share the same base pointer; don't combine things like:
3942 // a[i], b[i + 1] -> a bigger load.
3943 if (!BasePtr.isValid())
3944 BasePtr = LoadPtr;
3945 if (BasePtr != LoadPtr)
3946 return std::nullopt;
3948 if (Idx < LowestIdx) {
3949 LowestIdx = Idx;
3950 LowestIdxLoad = Load;
3953 // Keep track of the byte offset that this load ends up at. If we have seen
3954 // the byte offset, then stop here. We do not want to combine:
3956 // a[i] << 16, a[i + k] << 16 -> a bigger load.
3957 if (!MemOffset2Idx.try_emplace(DstPos, Idx).second)
3958 return std::nullopt;
3959 Loads.insert(Load);
3961 // Keep track of the position of the earliest/latest loads in the pattern.
3962 // We will check that there are no load fold barriers between them later
3963 // on.
3965 // FIXME: Is there a better way to check for load fold barriers?
3966 if (!EarliestLoad || dominates(*Load, *EarliestLoad))
3967 EarliestLoad = Load;
3968 if (!LatestLoad || dominates(*LatestLoad, *Load))
3969 LatestLoad = Load;
3972 // We found a load for each register. Let's check if each load satisfies the
3973 // pattern.
3974 assert(Loads.size() == RegsToVisit.size() &&
3975 "Expected to find a load for each register?");
3976 assert(EarliestLoad != LatestLoad && EarliestLoad &&
3977 LatestLoad && "Expected at least two loads?");
3979 // Check if there are any stores, calls, etc. between any of the loads. If
3980 // there are, then we can't safely perform the combine.
3982 // MaxIter is chosen based off the (worst case) number of iterations it
3983 // typically takes to succeed in the LLVM test suite plus some padding.
3985 // FIXME: Is there a better way to check for load fold barriers?
3986 const unsigned MaxIter = 20;
3987 unsigned Iter = 0;
3988 for (const auto &MI : instructionsWithoutDebug(EarliestLoad->getIterator(),
3989 LatestLoad->getIterator())) {
3990 if (Loads.count(&MI))
3991 continue;
3992 if (MI.isLoadFoldBarrier())
3993 return std::nullopt;
3994 if (Iter++ == MaxIter)
3995 return std::nullopt;
3998 return std::make_tuple(LowestIdxLoad, LowestIdx, LatestLoad);
4001 bool CombinerHelper::matchLoadOrCombine(
4002 MachineInstr &MI,
4003 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
4004 assert(MI.getOpcode() == TargetOpcode::G_OR);
4005 MachineFunction &MF = *MI.getMF();
4006 // Assuming a little-endian target, transform:
4007 // s8 *a = ...
4008 // s32 val = a[0] | (a[1] << 8) | (a[2] << 16) | (a[3] << 24)
4009 // =>
4010 // s32 val = *((i32)a)
4012 // s8 *a = ...
4013 // s32 val = (a[0] << 24) | (a[1] << 16) | (a[2] << 8) | a[3]
4014 // =>
4015 // s32 val = BSWAP(*((s32)a))
4016 Register Dst = MI.getOperand(0).getReg();
4017 LLT Ty = MRI.getType(Dst);
4018 if (Ty.isVector())
4019 return false;
4021 // We need to combine at least two loads into this type. Since the smallest
4022 // possible load is into a byte, we need at least a 16-bit wide type.
4023 const unsigned WideMemSizeInBits = Ty.getSizeInBits();
4024 if (WideMemSizeInBits < 16 || WideMemSizeInBits % 8 != 0)
4025 return false;
4027 // Match a collection of non-OR instructions in the pattern.
4028 auto RegsToVisit = findCandidatesForLoadOrCombine(&MI);
4029 if (!RegsToVisit)
4030 return false;
4032 // We have a collection of non-OR instructions. Figure out how wide each of
4033 // the small loads should be based off of the number of potential loads we
4034 // found.
4035 const unsigned NarrowMemSizeInBits = WideMemSizeInBits / RegsToVisit->size();
4036 if (NarrowMemSizeInBits % 8 != 0)
4037 return false;
4039 // Check if each register feeding into each OR is a load from the same
4040 // base pointer + some arithmetic.
4042 // e.g. a[0], a[1] << 8, a[2] << 16, etc.
4044 // Also verify that each of these ends up putting a[i] into the same memory
4045 // offset as a load into a wide type would.
4046 SmallDenseMap<int64_t, int64_t, 8> MemOffset2Idx;
4047 GZExtLoad *LowestIdxLoad, *LatestLoad;
4048 int64_t LowestIdx;
4049 auto MaybeLoadInfo = findLoadOffsetsForLoadOrCombine(
4050 MemOffset2Idx, *RegsToVisit, NarrowMemSizeInBits);
4051 if (!MaybeLoadInfo)
4052 return false;
4053 std::tie(LowestIdxLoad, LowestIdx, LatestLoad) = *MaybeLoadInfo;
4055 // We have a bunch of loads being OR'd together. Using the addresses + offsets
4056 // we found before, check if this corresponds to a big or little endian byte
4057 // pattern. If it does, then we can represent it using a load + possibly a
4058 // BSWAP.
4059 bool IsBigEndianTarget = MF.getDataLayout().isBigEndian();
4060 std::optional<bool> IsBigEndian = isBigEndian(MemOffset2Idx, LowestIdx);
4061 if (!IsBigEndian)
4062 return false;
4063 bool NeedsBSwap = IsBigEndianTarget != *IsBigEndian;
4064 if (NeedsBSwap && !isLegalOrBeforeLegalizer({TargetOpcode::G_BSWAP, {Ty}}))
4065 return false;
4067 // Make sure that the load from the lowest index produces offset 0 in the
4068 // final value.
4070 // This ensures that we won't combine something like this:
4072 // load x[i] -> byte 2
4073 // load x[i+1] -> byte 0 ---> wide_load x[i]
4074 // load x[i+2] -> byte 1
4075 const unsigned NumLoadsInTy = WideMemSizeInBits / NarrowMemSizeInBits;
4076 const unsigned ZeroByteOffset =
4077 *IsBigEndian
4078 ? bigEndianByteAt(NumLoadsInTy, 0)
4079 : littleEndianByteAt(NumLoadsInTy, 0);
4080 auto ZeroOffsetIdx = MemOffset2Idx.find(ZeroByteOffset);
4081 if (ZeroOffsetIdx == MemOffset2Idx.end() ||
4082 ZeroOffsetIdx->second != LowestIdx)
4083 return false;
4085 // We wil reuse the pointer from the load which ends up at byte offset 0. It
4086 // may not use index 0.
4087 Register Ptr = LowestIdxLoad->getPointerReg();
4088 const MachineMemOperand &MMO = LowestIdxLoad->getMMO();
4089 LegalityQuery::MemDesc MMDesc(MMO);
4090 MMDesc.MemoryTy = Ty;
4091 if (!isLegalOrBeforeLegalizer(
4092 {TargetOpcode::G_LOAD, {Ty, MRI.getType(Ptr)}, {MMDesc}}))
4093 return false;
4094 auto PtrInfo = MMO.getPointerInfo();
4095 auto *NewMMO = MF.getMachineMemOperand(&MMO, PtrInfo, WideMemSizeInBits / 8);
4097 // Load must be allowed and fast on the target.
4098 LLVMContext &C = MF.getFunction().getContext();
4099 auto &DL = MF.getDataLayout();
4100 unsigned Fast = 0;
4101 if (!getTargetLowering().allowsMemoryAccess(C, DL, Ty, *NewMMO, &Fast) ||
4102 !Fast)
4103 return false;
4105 MatchInfo = [=](MachineIRBuilder &MIB) {
4106 MIB.setInstrAndDebugLoc(*LatestLoad);
4107 Register LoadDst = NeedsBSwap ? MRI.cloneVirtualRegister(Dst) : Dst;
4108 MIB.buildLoad(LoadDst, Ptr, *NewMMO);
4109 if (NeedsBSwap)
4110 MIB.buildBSwap(Dst, LoadDst);
4112 return true;
4115 bool CombinerHelper::matchExtendThroughPhis(MachineInstr &MI,
4116 MachineInstr *&ExtMI) const {
4117 auto &PHI = cast<GPhi>(MI);
4118 Register DstReg = PHI.getReg(0);
4120 // TODO: Extending a vector may be expensive, don't do this until heuristics
4121 // are better.
4122 if (MRI.getType(DstReg).isVector())
4123 return false;
4125 // Try to match a phi, whose only use is an extend.
4126 if (!MRI.hasOneNonDBGUse(DstReg))
4127 return false;
4128 ExtMI = &*MRI.use_instr_nodbg_begin(DstReg);
4129 switch (ExtMI->getOpcode()) {
4130 case TargetOpcode::G_ANYEXT:
4131 return true; // G_ANYEXT is usually free.
4132 case TargetOpcode::G_ZEXT:
4133 case TargetOpcode::G_SEXT:
4134 break;
4135 default:
4136 return false;
4139 // If the target is likely to fold this extend away, don't propagate.
4140 if (Builder.getTII().isExtendLikelyToBeFolded(*ExtMI, MRI))
4141 return false;
4143 // We don't want to propagate the extends unless there's a good chance that
4144 // they'll be optimized in some way.
4145 // Collect the unique incoming values.
4146 SmallPtrSet<MachineInstr *, 4> InSrcs;
4147 for (unsigned I = 0; I < PHI.getNumIncomingValues(); ++I) {
4148 auto *DefMI = getDefIgnoringCopies(PHI.getIncomingValue(I), MRI);
4149 switch (DefMI->getOpcode()) {
4150 case TargetOpcode::G_LOAD:
4151 case TargetOpcode::G_TRUNC:
4152 case TargetOpcode::G_SEXT:
4153 case TargetOpcode::G_ZEXT:
4154 case TargetOpcode::G_ANYEXT:
4155 case TargetOpcode::G_CONSTANT:
4156 InSrcs.insert(DefMI);
4157 // Don't try to propagate if there are too many places to create new
4158 // extends, chances are it'll increase code size.
4159 if (InSrcs.size() > 2)
4160 return false;
4161 break;
4162 default:
4163 return false;
4166 return true;
4169 void CombinerHelper::applyExtendThroughPhis(MachineInstr &MI,
4170 MachineInstr *&ExtMI) const {
4171 auto &PHI = cast<GPhi>(MI);
4172 Register DstReg = ExtMI->getOperand(0).getReg();
4173 LLT ExtTy = MRI.getType(DstReg);
4175 // Propagate the extension into the block of each incoming reg's block.
4176 // Use a SetVector here because PHIs can have duplicate edges, and we want
4177 // deterministic iteration order.
4178 SmallSetVector<MachineInstr *, 8> SrcMIs;
4179 SmallDenseMap<MachineInstr *, MachineInstr *, 8> OldToNewSrcMap;
4180 for (unsigned I = 0; I < PHI.getNumIncomingValues(); ++I) {
4181 auto SrcReg = PHI.getIncomingValue(I);
4182 auto *SrcMI = MRI.getVRegDef(SrcReg);
4183 if (!SrcMIs.insert(SrcMI))
4184 continue;
4186 // Build an extend after each src inst.
4187 auto *MBB = SrcMI->getParent();
4188 MachineBasicBlock::iterator InsertPt = ++SrcMI->getIterator();
4189 if (InsertPt != MBB->end() && InsertPt->isPHI())
4190 InsertPt = MBB->getFirstNonPHI();
4192 Builder.setInsertPt(*SrcMI->getParent(), InsertPt);
4193 Builder.setDebugLoc(MI.getDebugLoc());
4194 auto NewExt = Builder.buildExtOrTrunc(ExtMI->getOpcode(), ExtTy, SrcReg);
4195 OldToNewSrcMap[SrcMI] = NewExt;
4198 // Create a new phi with the extended inputs.
4199 Builder.setInstrAndDebugLoc(MI);
4200 auto NewPhi = Builder.buildInstrNoInsert(TargetOpcode::G_PHI);
4201 NewPhi.addDef(DstReg);
4202 for (const MachineOperand &MO : llvm::drop_begin(MI.operands())) {
4203 if (!MO.isReg()) {
4204 NewPhi.addMBB(MO.getMBB());
4205 continue;
4207 auto *NewSrc = OldToNewSrcMap[MRI.getVRegDef(MO.getReg())];
4208 NewPhi.addUse(NewSrc->getOperand(0).getReg());
4210 Builder.insertInstr(NewPhi);
4211 ExtMI->eraseFromParent();
4214 bool CombinerHelper::matchExtractVecEltBuildVec(MachineInstr &MI,
4215 Register &Reg) const {
4216 assert(MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT);
4217 // If we have a constant index, look for a G_BUILD_VECTOR source
4218 // and find the source register that the index maps to.
4219 Register SrcVec = MI.getOperand(1).getReg();
4220 LLT SrcTy = MRI.getType(SrcVec);
4221 if (SrcTy.isScalableVector())
4222 return false;
4224 auto Cst = getIConstantVRegValWithLookThrough(MI.getOperand(2).getReg(), MRI);
4225 if (!Cst || Cst->Value.getZExtValue() >= SrcTy.getNumElements())
4226 return false;
4228 unsigned VecIdx = Cst->Value.getZExtValue();
4230 // Check if we have a build_vector or build_vector_trunc with an optional
4231 // trunc in front.
4232 MachineInstr *SrcVecMI = MRI.getVRegDef(SrcVec);
4233 if (SrcVecMI->getOpcode() == TargetOpcode::G_TRUNC) {
4234 SrcVecMI = MRI.getVRegDef(SrcVecMI->getOperand(1).getReg());
4237 if (SrcVecMI->getOpcode() != TargetOpcode::G_BUILD_VECTOR &&
4238 SrcVecMI->getOpcode() != TargetOpcode::G_BUILD_VECTOR_TRUNC)
4239 return false;
4241 EVT Ty(getMVTForLLT(SrcTy));
4242 if (!MRI.hasOneNonDBGUse(SrcVec) &&
4243 !getTargetLowering().aggressivelyPreferBuildVectorSources(Ty))
4244 return false;
4246 Reg = SrcVecMI->getOperand(VecIdx + 1).getReg();
4247 return true;
4250 void CombinerHelper::applyExtractVecEltBuildVec(MachineInstr &MI,
4251 Register &Reg) const {
4252 // Check the type of the register, since it may have come from a
4253 // G_BUILD_VECTOR_TRUNC.
4254 LLT ScalarTy = MRI.getType(Reg);
4255 Register DstReg = MI.getOperand(0).getReg();
4256 LLT DstTy = MRI.getType(DstReg);
4258 if (ScalarTy != DstTy) {
4259 assert(ScalarTy.getSizeInBits() > DstTy.getSizeInBits());
4260 Builder.buildTrunc(DstReg, Reg);
4261 MI.eraseFromParent();
4262 return;
4264 replaceSingleDefInstWithReg(MI, Reg);
4267 bool CombinerHelper::matchExtractAllEltsFromBuildVector(
4268 MachineInstr &MI,
4269 SmallVectorImpl<std::pair<Register, MachineInstr *>> &SrcDstPairs) const {
4270 assert(MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR);
4271 // This combine tries to find build_vector's which have every source element
4272 // extracted using G_EXTRACT_VECTOR_ELT. This can happen when transforms like
4273 // the masked load scalarization is run late in the pipeline. There's already
4274 // a combine for a similar pattern starting from the extract, but that
4275 // doesn't attempt to do it if there are multiple uses of the build_vector,
4276 // which in this case is true. Starting the combine from the build_vector
4277 // feels more natural than trying to find sibling nodes of extracts.
4278 // E.g.
4279 // %vec(<4 x s32>) = G_BUILD_VECTOR %s1(s32), %s2, %s3, %s4
4280 // %ext1 = G_EXTRACT_VECTOR_ELT %vec, 0
4281 // %ext2 = G_EXTRACT_VECTOR_ELT %vec, 1
4282 // %ext3 = G_EXTRACT_VECTOR_ELT %vec, 2
4283 // %ext4 = G_EXTRACT_VECTOR_ELT %vec, 3
4284 // ==>
4285 // replace ext{1,2,3,4} with %s{1,2,3,4}
4287 Register DstReg = MI.getOperand(0).getReg();
4288 LLT DstTy = MRI.getType(DstReg);
4289 unsigned NumElts = DstTy.getNumElements();
4291 SmallBitVector ExtractedElts(NumElts);
4292 for (MachineInstr &II : MRI.use_nodbg_instructions(DstReg)) {
4293 if (II.getOpcode() != TargetOpcode::G_EXTRACT_VECTOR_ELT)
4294 return false;
4295 auto Cst = getIConstantVRegVal(II.getOperand(2).getReg(), MRI);
4296 if (!Cst)
4297 return false;
4298 unsigned Idx = Cst->getZExtValue();
4299 if (Idx >= NumElts)
4300 return false; // Out of range.
4301 ExtractedElts.set(Idx);
4302 SrcDstPairs.emplace_back(
4303 std::make_pair(MI.getOperand(Idx + 1).getReg(), &II));
4305 // Match if every element was extracted.
4306 return ExtractedElts.all();
4309 void CombinerHelper::applyExtractAllEltsFromBuildVector(
4310 MachineInstr &MI,
4311 SmallVectorImpl<std::pair<Register, MachineInstr *>> &SrcDstPairs) const {
4312 assert(MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR);
4313 for (auto &Pair : SrcDstPairs) {
4314 auto *ExtMI = Pair.second;
4315 replaceRegWith(MRI, ExtMI->getOperand(0).getReg(), Pair.first);
4316 ExtMI->eraseFromParent();
4318 MI.eraseFromParent();
4321 void CombinerHelper::applyBuildFn(
4322 MachineInstr &MI,
4323 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
4324 applyBuildFnNoErase(MI, MatchInfo);
4325 MI.eraseFromParent();
4328 void CombinerHelper::applyBuildFnNoErase(
4329 MachineInstr &MI,
4330 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
4331 MatchInfo(Builder);
4334 bool CombinerHelper::matchOrShiftToFunnelShift(MachineInstr &MI,
4335 BuildFnTy &MatchInfo) const {
4336 assert(MI.getOpcode() == TargetOpcode::G_OR);
4338 Register Dst = MI.getOperand(0).getReg();
4339 LLT Ty = MRI.getType(Dst);
4340 unsigned BitWidth = Ty.getScalarSizeInBits();
4342 Register ShlSrc, ShlAmt, LShrSrc, LShrAmt, Amt;
4343 unsigned FshOpc = 0;
4345 // Match (or (shl ...), (lshr ...)).
4346 if (!mi_match(Dst, MRI,
4347 // m_GOr() handles the commuted version as well.
4348 m_GOr(m_GShl(m_Reg(ShlSrc), m_Reg(ShlAmt)),
4349 m_GLShr(m_Reg(LShrSrc), m_Reg(LShrAmt)))))
4350 return false;
4352 // Given constants C0 and C1 such that C0 + C1 is bit-width:
4353 // (or (shl x, C0), (lshr y, C1)) -> (fshl x, y, C0) or (fshr x, y, C1)
4354 int64_t CstShlAmt, CstLShrAmt;
4355 if (mi_match(ShlAmt, MRI, m_ICstOrSplat(CstShlAmt)) &&
4356 mi_match(LShrAmt, MRI, m_ICstOrSplat(CstLShrAmt)) &&
4357 CstShlAmt + CstLShrAmt == BitWidth) {
4358 FshOpc = TargetOpcode::G_FSHR;
4359 Amt = LShrAmt;
4361 } else if (mi_match(LShrAmt, MRI,
4362 m_GSub(m_SpecificICstOrSplat(BitWidth), m_Reg(Amt))) &&
4363 ShlAmt == Amt) {
4364 // (or (shl x, amt), (lshr y, (sub bw, amt))) -> (fshl x, y, amt)
4365 FshOpc = TargetOpcode::G_FSHL;
4367 } else if (mi_match(ShlAmt, MRI,
4368 m_GSub(m_SpecificICstOrSplat(BitWidth), m_Reg(Amt))) &&
4369 LShrAmt == Amt) {
4370 // (or (shl x, (sub bw, amt)), (lshr y, amt)) -> (fshr x, y, amt)
4371 FshOpc = TargetOpcode::G_FSHR;
4373 } else {
4374 return false;
4377 LLT AmtTy = MRI.getType(Amt);
4378 if (!isLegalOrBeforeLegalizer({FshOpc, {Ty, AmtTy}}))
4379 return false;
4381 MatchInfo = [=](MachineIRBuilder &B) {
4382 B.buildInstr(FshOpc, {Dst}, {ShlSrc, LShrSrc, Amt});
4384 return true;
4387 /// Match an FSHL or FSHR that can be combined to a ROTR or ROTL rotate.
4388 bool CombinerHelper::matchFunnelShiftToRotate(MachineInstr &MI) const {
4389 unsigned Opc = MI.getOpcode();
4390 assert(Opc == TargetOpcode::G_FSHL || Opc == TargetOpcode::G_FSHR);
4391 Register X = MI.getOperand(1).getReg();
4392 Register Y = MI.getOperand(2).getReg();
4393 if (X != Y)
4394 return false;
4395 unsigned RotateOpc =
4396 Opc == TargetOpcode::G_FSHL ? TargetOpcode::G_ROTL : TargetOpcode::G_ROTR;
4397 return isLegalOrBeforeLegalizer({RotateOpc, {MRI.getType(X), MRI.getType(Y)}});
4400 void CombinerHelper::applyFunnelShiftToRotate(MachineInstr &MI) const {
4401 unsigned Opc = MI.getOpcode();
4402 assert(Opc == TargetOpcode::G_FSHL || Opc == TargetOpcode::G_FSHR);
4403 bool IsFSHL = Opc == TargetOpcode::G_FSHL;
4404 Observer.changingInstr(MI);
4405 MI.setDesc(Builder.getTII().get(IsFSHL ? TargetOpcode::G_ROTL
4406 : TargetOpcode::G_ROTR));
4407 MI.removeOperand(2);
4408 Observer.changedInstr(MI);
4411 // Fold (rot x, c) -> (rot x, c % BitSize)
4412 bool CombinerHelper::matchRotateOutOfRange(MachineInstr &MI) const {
4413 assert(MI.getOpcode() == TargetOpcode::G_ROTL ||
4414 MI.getOpcode() == TargetOpcode::G_ROTR);
4415 unsigned Bitsize =
4416 MRI.getType(MI.getOperand(0).getReg()).getScalarSizeInBits();
4417 Register AmtReg = MI.getOperand(2).getReg();
4418 bool OutOfRange = false;
4419 auto MatchOutOfRange = [Bitsize, &OutOfRange](const Constant *C) {
4420 if (auto *CI = dyn_cast<ConstantInt>(C))
4421 OutOfRange |= CI->getValue().uge(Bitsize);
4422 return true;
4424 return matchUnaryPredicate(MRI, AmtReg, MatchOutOfRange) && OutOfRange;
4427 void CombinerHelper::applyRotateOutOfRange(MachineInstr &MI) const {
4428 assert(MI.getOpcode() == TargetOpcode::G_ROTL ||
4429 MI.getOpcode() == TargetOpcode::G_ROTR);
4430 unsigned Bitsize =
4431 MRI.getType(MI.getOperand(0).getReg()).getScalarSizeInBits();
4432 Register Amt = MI.getOperand(2).getReg();
4433 LLT AmtTy = MRI.getType(Amt);
4434 auto Bits = Builder.buildConstant(AmtTy, Bitsize);
4435 Amt = Builder.buildURem(AmtTy, MI.getOperand(2).getReg(), Bits).getReg(0);
4436 Observer.changingInstr(MI);
4437 MI.getOperand(2).setReg(Amt);
4438 Observer.changedInstr(MI);
4441 bool CombinerHelper::matchICmpToTrueFalseKnownBits(MachineInstr &MI,
4442 int64_t &MatchInfo) const {
4443 assert(MI.getOpcode() == TargetOpcode::G_ICMP);
4444 auto Pred = static_cast<CmpInst::Predicate>(MI.getOperand(1).getPredicate());
4446 // We want to avoid calling KnownBits on the LHS if possible, as this combine
4447 // has no filter and runs on every G_ICMP instruction. We can avoid calling
4448 // KnownBits on the LHS in two cases:
4450 // - The RHS is unknown: Constants are always on RHS. If the RHS is unknown
4451 // we cannot do any transforms so we can safely bail out early.
4452 // - The RHS is zero: we don't need to know the LHS to do unsigned <0 and
4453 // >=0.
4454 auto KnownRHS = KB->getKnownBits(MI.getOperand(3).getReg());
4455 if (KnownRHS.isUnknown())
4456 return false;
4458 std::optional<bool> KnownVal;
4459 if (KnownRHS.isZero()) {
4460 // ? uge 0 -> always true
4461 // ? ult 0 -> always false
4462 if (Pred == CmpInst::ICMP_UGE)
4463 KnownVal = true;
4464 else if (Pred == CmpInst::ICMP_ULT)
4465 KnownVal = false;
4468 if (!KnownVal) {
4469 auto KnownLHS = KB->getKnownBits(MI.getOperand(2).getReg());
4470 KnownVal = ICmpInst::compare(KnownLHS, KnownRHS, Pred);
4473 if (!KnownVal)
4474 return false;
4475 MatchInfo =
4476 *KnownVal
4477 ? getICmpTrueVal(getTargetLowering(),
4478 /*IsVector = */
4479 MRI.getType(MI.getOperand(0).getReg()).isVector(),
4480 /* IsFP = */ false)
4481 : 0;
4482 return true;
4485 bool CombinerHelper::matchICmpToLHSKnownBits(
4486 MachineInstr &MI,
4487 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
4488 assert(MI.getOpcode() == TargetOpcode::G_ICMP);
4489 // Given:
4491 // %x = G_WHATEVER (... x is known to be 0 or 1 ...)
4492 // %cmp = G_ICMP ne %x, 0
4494 // Or:
4496 // %x = G_WHATEVER (... x is known to be 0 or 1 ...)
4497 // %cmp = G_ICMP eq %x, 1
4499 // We can replace %cmp with %x assuming true is 1 on the target.
4500 auto Pred = static_cast<CmpInst::Predicate>(MI.getOperand(1).getPredicate());
4501 if (!CmpInst::isEquality(Pred))
4502 return false;
4503 Register Dst = MI.getOperand(0).getReg();
4504 LLT DstTy = MRI.getType(Dst);
4505 if (getICmpTrueVal(getTargetLowering(), DstTy.isVector(),
4506 /* IsFP = */ false) != 1)
4507 return false;
4508 int64_t OneOrZero = Pred == CmpInst::ICMP_EQ;
4509 if (!mi_match(MI.getOperand(3).getReg(), MRI, m_SpecificICst(OneOrZero)))
4510 return false;
4511 Register LHS = MI.getOperand(2).getReg();
4512 auto KnownLHS = KB->getKnownBits(LHS);
4513 if (KnownLHS.getMinValue() != 0 || KnownLHS.getMaxValue() != 1)
4514 return false;
4515 // Make sure replacing Dst with the LHS is a legal operation.
4516 LLT LHSTy = MRI.getType(LHS);
4517 unsigned LHSSize = LHSTy.getSizeInBits();
4518 unsigned DstSize = DstTy.getSizeInBits();
4519 unsigned Op = TargetOpcode::COPY;
4520 if (DstSize != LHSSize)
4521 Op = DstSize < LHSSize ? TargetOpcode::G_TRUNC : TargetOpcode::G_ZEXT;
4522 if (!isLegalOrBeforeLegalizer({Op, {DstTy, LHSTy}}))
4523 return false;
4524 MatchInfo = [=](MachineIRBuilder &B) { B.buildInstr(Op, {Dst}, {LHS}); };
4525 return true;
4528 // Replace (and (or x, c1), c2) with (and x, c2) iff c1 & c2 == 0
4529 bool CombinerHelper::matchAndOrDisjointMask(
4530 MachineInstr &MI,
4531 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
4532 assert(MI.getOpcode() == TargetOpcode::G_AND);
4534 // Ignore vector types to simplify matching the two constants.
4535 // TODO: do this for vectors and scalars via a demanded bits analysis.
4536 LLT Ty = MRI.getType(MI.getOperand(0).getReg());
4537 if (Ty.isVector())
4538 return false;
4540 Register Src;
4541 Register AndMaskReg;
4542 int64_t AndMaskBits;
4543 int64_t OrMaskBits;
4544 if (!mi_match(MI, MRI,
4545 m_GAnd(m_GOr(m_Reg(Src), m_ICst(OrMaskBits)),
4546 m_all_of(m_ICst(AndMaskBits), m_Reg(AndMaskReg)))))
4547 return false;
4549 // Check if OrMask could turn on any bits in Src.
4550 if (AndMaskBits & OrMaskBits)
4551 return false;
4553 MatchInfo = [=, &MI](MachineIRBuilder &B) {
4554 Observer.changingInstr(MI);
4555 // Canonicalize the result to have the constant on the RHS.
4556 if (MI.getOperand(1).getReg() == AndMaskReg)
4557 MI.getOperand(2).setReg(AndMaskReg);
4558 MI.getOperand(1).setReg(Src);
4559 Observer.changedInstr(MI);
4561 return true;
4564 /// Form a G_SBFX from a G_SEXT_INREG fed by a right shift.
4565 bool CombinerHelper::matchBitfieldExtractFromSExtInReg(
4566 MachineInstr &MI,
4567 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
4568 assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG);
4569 Register Dst = MI.getOperand(0).getReg();
4570 Register Src = MI.getOperand(1).getReg();
4571 LLT Ty = MRI.getType(Src);
4572 LLT ExtractTy = getTargetLowering().getPreferredShiftAmountTy(Ty);
4573 if (!LI || !LI->isLegalOrCustom({TargetOpcode::G_SBFX, {Ty, ExtractTy}}))
4574 return false;
4575 int64_t Width = MI.getOperand(2).getImm();
4576 Register ShiftSrc;
4577 int64_t ShiftImm;
4578 if (!mi_match(
4579 Src, MRI,
4580 m_OneNonDBGUse(m_any_of(m_GAShr(m_Reg(ShiftSrc), m_ICst(ShiftImm)),
4581 m_GLShr(m_Reg(ShiftSrc), m_ICst(ShiftImm))))))
4582 return false;
4583 if (ShiftImm < 0 || ShiftImm + Width > Ty.getScalarSizeInBits())
4584 return false;
4586 MatchInfo = [=](MachineIRBuilder &B) {
4587 auto Cst1 = B.buildConstant(ExtractTy, ShiftImm);
4588 auto Cst2 = B.buildConstant(ExtractTy, Width);
4589 B.buildSbfx(Dst, ShiftSrc, Cst1, Cst2);
4591 return true;
4594 /// Form a G_UBFX from "(a srl b) & mask", where b and mask are constants.
4595 bool CombinerHelper::matchBitfieldExtractFromAnd(MachineInstr &MI,
4596 BuildFnTy &MatchInfo) const {
4597 GAnd *And = cast<GAnd>(&MI);
4598 Register Dst = And->getReg(0);
4599 LLT Ty = MRI.getType(Dst);
4600 LLT ExtractTy = getTargetLowering().getPreferredShiftAmountTy(Ty);
4601 // Note that isLegalOrBeforeLegalizer is stricter and does not take custom
4602 // into account.
4603 if (LI && !LI->isLegalOrCustom({TargetOpcode::G_UBFX, {Ty, ExtractTy}}))
4604 return false;
4606 int64_t AndImm, LSBImm;
4607 Register ShiftSrc;
4608 const unsigned Size = Ty.getScalarSizeInBits();
4609 if (!mi_match(And->getReg(0), MRI,
4610 m_GAnd(m_OneNonDBGUse(m_GLShr(m_Reg(ShiftSrc), m_ICst(LSBImm))),
4611 m_ICst(AndImm))))
4612 return false;
4614 // The mask is a mask of the low bits iff imm & (imm+1) == 0.
4615 auto MaybeMask = static_cast<uint64_t>(AndImm);
4616 if (MaybeMask & (MaybeMask + 1))
4617 return false;
4619 // LSB must fit within the register.
4620 if (static_cast<uint64_t>(LSBImm) >= Size)
4621 return false;
4623 uint64_t Width = APInt(Size, AndImm).countr_one();
4624 MatchInfo = [=](MachineIRBuilder &B) {
4625 auto WidthCst = B.buildConstant(ExtractTy, Width);
4626 auto LSBCst = B.buildConstant(ExtractTy, LSBImm);
4627 B.buildInstr(TargetOpcode::G_UBFX, {Dst}, {ShiftSrc, LSBCst, WidthCst});
4629 return true;
4632 bool CombinerHelper::matchBitfieldExtractFromShr(
4633 MachineInstr &MI,
4634 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
4635 const unsigned Opcode = MI.getOpcode();
4636 assert(Opcode == TargetOpcode::G_ASHR || Opcode == TargetOpcode::G_LSHR);
4638 const Register Dst = MI.getOperand(0).getReg();
4640 const unsigned ExtrOpcode = Opcode == TargetOpcode::G_ASHR
4641 ? TargetOpcode::G_SBFX
4642 : TargetOpcode::G_UBFX;
4644 // Check if the type we would use for the extract is legal
4645 LLT Ty = MRI.getType(Dst);
4646 LLT ExtractTy = getTargetLowering().getPreferredShiftAmountTy(Ty);
4647 if (!LI || !LI->isLegalOrCustom({ExtrOpcode, {Ty, ExtractTy}}))
4648 return false;
4650 Register ShlSrc;
4651 int64_t ShrAmt;
4652 int64_t ShlAmt;
4653 const unsigned Size = Ty.getScalarSizeInBits();
4655 // Try to match shr (shl x, c1), c2
4656 if (!mi_match(Dst, MRI,
4657 m_BinOp(Opcode,
4658 m_OneNonDBGUse(m_GShl(m_Reg(ShlSrc), m_ICst(ShlAmt))),
4659 m_ICst(ShrAmt))))
4660 return false;
4662 // Make sure that the shift sizes can fit a bitfield extract
4663 if (ShlAmt < 0 || ShlAmt > ShrAmt || ShrAmt >= Size)
4664 return false;
4666 // Skip this combine if the G_SEXT_INREG combine could handle it
4667 if (Opcode == TargetOpcode::G_ASHR && ShlAmt == ShrAmt)
4668 return false;
4670 // Calculate start position and width of the extract
4671 const int64_t Pos = ShrAmt - ShlAmt;
4672 const int64_t Width = Size - ShrAmt;
4674 MatchInfo = [=](MachineIRBuilder &B) {
4675 auto WidthCst = B.buildConstant(ExtractTy, Width);
4676 auto PosCst = B.buildConstant(ExtractTy, Pos);
4677 B.buildInstr(ExtrOpcode, {Dst}, {ShlSrc, PosCst, WidthCst});
4679 return true;
4682 bool CombinerHelper::matchBitfieldExtractFromShrAnd(
4683 MachineInstr &MI,
4684 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
4685 const unsigned Opcode = MI.getOpcode();
4686 assert(Opcode == TargetOpcode::G_LSHR || Opcode == TargetOpcode::G_ASHR);
4688 const Register Dst = MI.getOperand(0).getReg();
4689 LLT Ty = MRI.getType(Dst);
4690 LLT ExtractTy = getTargetLowering().getPreferredShiftAmountTy(Ty);
4691 if (LI && !LI->isLegalOrCustom({TargetOpcode::G_UBFX, {Ty, ExtractTy}}))
4692 return false;
4694 // Try to match shr (and x, c1), c2
4695 Register AndSrc;
4696 int64_t ShrAmt;
4697 int64_t SMask;
4698 if (!mi_match(Dst, MRI,
4699 m_BinOp(Opcode,
4700 m_OneNonDBGUse(m_GAnd(m_Reg(AndSrc), m_ICst(SMask))),
4701 m_ICst(ShrAmt))))
4702 return false;
4704 const unsigned Size = Ty.getScalarSizeInBits();
4705 if (ShrAmt < 0 || ShrAmt >= Size)
4706 return false;
4708 // If the shift subsumes the mask, emit the 0 directly.
4709 if (0 == (SMask >> ShrAmt)) {
4710 MatchInfo = [=](MachineIRBuilder &B) {
4711 B.buildConstant(Dst, 0);
4713 return true;
4716 // Check that ubfx can do the extraction, with no holes in the mask.
4717 uint64_t UMask = SMask;
4718 UMask |= maskTrailingOnes<uint64_t>(ShrAmt);
4719 UMask &= maskTrailingOnes<uint64_t>(Size);
4720 if (!isMask_64(UMask))
4721 return false;
4723 // Calculate start position and width of the extract.
4724 const int64_t Pos = ShrAmt;
4725 const int64_t Width = llvm::countr_one(UMask) - ShrAmt;
4727 // It's preferable to keep the shift, rather than form G_SBFX.
4728 // TODO: remove the G_AND via demanded bits analysis.
4729 if (Opcode == TargetOpcode::G_ASHR && Width + ShrAmt == Size)
4730 return false;
4732 MatchInfo = [=](MachineIRBuilder &B) {
4733 auto WidthCst = B.buildConstant(ExtractTy, Width);
4734 auto PosCst = B.buildConstant(ExtractTy, Pos);
4735 B.buildInstr(TargetOpcode::G_UBFX, {Dst}, {AndSrc, PosCst, WidthCst});
4737 return true;
4740 bool CombinerHelper::reassociationCanBreakAddressingModePattern(
4741 MachineInstr &MI) const {
4742 auto &PtrAdd = cast<GPtrAdd>(MI);
4744 Register Src1Reg = PtrAdd.getBaseReg();
4745 auto *Src1Def = getOpcodeDef<GPtrAdd>(Src1Reg, MRI);
4746 if (!Src1Def)
4747 return false;
4749 Register Src2Reg = PtrAdd.getOffsetReg();
4751 if (MRI.hasOneNonDBGUse(Src1Reg))
4752 return false;
4754 auto C1 = getIConstantVRegVal(Src1Def->getOffsetReg(), MRI);
4755 if (!C1)
4756 return false;
4757 auto C2 = getIConstantVRegVal(Src2Reg, MRI);
4758 if (!C2)
4759 return false;
4761 const APInt &C1APIntVal = *C1;
4762 const APInt &C2APIntVal = *C2;
4763 const int64_t CombinedValue = (C1APIntVal + C2APIntVal).getSExtValue();
4765 for (auto &UseMI : MRI.use_nodbg_instructions(PtrAdd.getReg(0))) {
4766 // This combine may end up running before ptrtoint/inttoptr combines
4767 // manage to eliminate redundant conversions, so try to look through them.
4768 MachineInstr *ConvUseMI = &UseMI;
4769 unsigned ConvUseOpc = ConvUseMI->getOpcode();
4770 while (ConvUseOpc == TargetOpcode::G_INTTOPTR ||
4771 ConvUseOpc == TargetOpcode::G_PTRTOINT) {
4772 Register DefReg = ConvUseMI->getOperand(0).getReg();
4773 if (!MRI.hasOneNonDBGUse(DefReg))
4774 break;
4775 ConvUseMI = &*MRI.use_instr_nodbg_begin(DefReg);
4776 ConvUseOpc = ConvUseMI->getOpcode();
4778 auto *LdStMI = dyn_cast<GLoadStore>(ConvUseMI);
4779 if (!LdStMI)
4780 continue;
4781 // Is x[offset2] already not a legal addressing mode? If so then
4782 // reassociating the constants breaks nothing (we test offset2 because
4783 // that's the one we hope to fold into the load or store).
4784 TargetLoweringBase::AddrMode AM;
4785 AM.HasBaseReg = true;
4786 AM.BaseOffs = C2APIntVal.getSExtValue();
4787 unsigned AS = MRI.getType(LdStMI->getPointerReg()).getAddressSpace();
4788 Type *AccessTy = getTypeForLLT(LdStMI->getMMO().getMemoryType(),
4789 PtrAdd.getMF()->getFunction().getContext());
4790 const auto &TLI = *PtrAdd.getMF()->getSubtarget().getTargetLowering();
4791 if (!TLI.isLegalAddressingMode(PtrAdd.getMF()->getDataLayout(), AM,
4792 AccessTy, AS))
4793 continue;
4795 // Would x[offset1+offset2] still be a legal addressing mode?
4796 AM.BaseOffs = CombinedValue;
4797 if (!TLI.isLegalAddressingMode(PtrAdd.getMF()->getDataLayout(), AM,
4798 AccessTy, AS))
4799 return true;
4802 return false;
4805 bool CombinerHelper::matchReassocConstantInnerRHS(GPtrAdd &MI,
4806 MachineInstr *RHS,
4807 BuildFnTy &MatchInfo) const {
4808 // G_PTR_ADD(BASE, G_ADD(X, C)) -> G_PTR_ADD(G_PTR_ADD(BASE, X), C)
4809 Register Src1Reg = MI.getOperand(1).getReg();
4810 if (RHS->getOpcode() != TargetOpcode::G_ADD)
4811 return false;
4812 auto C2 = getIConstantVRegVal(RHS->getOperand(2).getReg(), MRI);
4813 if (!C2)
4814 return false;
4816 MatchInfo = [=, &MI](MachineIRBuilder &B) {
4817 LLT PtrTy = MRI.getType(MI.getOperand(0).getReg());
4819 auto NewBase =
4820 Builder.buildPtrAdd(PtrTy, Src1Reg, RHS->getOperand(1).getReg());
4821 Observer.changingInstr(MI);
4822 MI.getOperand(1).setReg(NewBase.getReg(0));
4823 MI.getOperand(2).setReg(RHS->getOperand(2).getReg());
4824 Observer.changedInstr(MI);
4826 return !reassociationCanBreakAddressingModePattern(MI);
4829 bool CombinerHelper::matchReassocConstantInnerLHS(GPtrAdd &MI,
4830 MachineInstr *LHS,
4831 MachineInstr *RHS,
4832 BuildFnTy &MatchInfo) const {
4833 // G_PTR_ADD (G_PTR_ADD X, C), Y) -> (G_PTR_ADD (G_PTR_ADD(X, Y), C)
4834 // if and only if (G_PTR_ADD X, C) has one use.
4835 Register LHSBase;
4836 std::optional<ValueAndVReg> LHSCstOff;
4837 if (!mi_match(MI.getBaseReg(), MRI,
4838 m_OneNonDBGUse(m_GPtrAdd(m_Reg(LHSBase), m_GCst(LHSCstOff)))))
4839 return false;
4841 auto *LHSPtrAdd = cast<GPtrAdd>(LHS);
4842 MatchInfo = [=, &MI](MachineIRBuilder &B) {
4843 // When we change LHSPtrAdd's offset register we might cause it to use a reg
4844 // before its def. Sink the instruction so the outer PTR_ADD to ensure this
4845 // doesn't happen.
4846 LHSPtrAdd->moveBefore(&MI);
4847 Register RHSReg = MI.getOffsetReg();
4848 // set VReg will cause type mismatch if it comes from extend/trunc
4849 auto NewCst = B.buildConstant(MRI.getType(RHSReg), LHSCstOff->Value);
4850 Observer.changingInstr(MI);
4851 MI.getOperand(2).setReg(NewCst.getReg(0));
4852 Observer.changedInstr(MI);
4853 Observer.changingInstr(*LHSPtrAdd);
4854 LHSPtrAdd->getOperand(2).setReg(RHSReg);
4855 Observer.changedInstr(*LHSPtrAdd);
4857 return !reassociationCanBreakAddressingModePattern(MI);
4860 bool CombinerHelper::matchReassocFoldConstantsInSubTree(
4861 GPtrAdd &MI, MachineInstr *LHS, MachineInstr *RHS,
4862 BuildFnTy &MatchInfo) const {
4863 // G_PTR_ADD(G_PTR_ADD(BASE, C1), C2) -> G_PTR_ADD(BASE, C1+C2)
4864 auto *LHSPtrAdd = dyn_cast<GPtrAdd>(LHS);
4865 if (!LHSPtrAdd)
4866 return false;
4868 Register Src2Reg = MI.getOperand(2).getReg();
4869 Register LHSSrc1 = LHSPtrAdd->getBaseReg();
4870 Register LHSSrc2 = LHSPtrAdd->getOffsetReg();
4871 auto C1 = getIConstantVRegVal(LHSSrc2, MRI);
4872 if (!C1)
4873 return false;
4874 auto C2 = getIConstantVRegVal(Src2Reg, MRI);
4875 if (!C2)
4876 return false;
4878 MatchInfo = [=, &MI](MachineIRBuilder &B) {
4879 auto NewCst = B.buildConstant(MRI.getType(Src2Reg), *C1 + *C2);
4880 Observer.changingInstr(MI);
4881 MI.getOperand(1).setReg(LHSSrc1);
4882 MI.getOperand(2).setReg(NewCst.getReg(0));
4883 Observer.changedInstr(MI);
4885 return !reassociationCanBreakAddressingModePattern(MI);
4888 bool CombinerHelper::matchReassocPtrAdd(MachineInstr &MI,
4889 BuildFnTy &MatchInfo) const {
4890 auto &PtrAdd = cast<GPtrAdd>(MI);
4891 // We're trying to match a few pointer computation patterns here for
4892 // re-association opportunities.
4893 // 1) Isolating a constant operand to be on the RHS, e.g.:
4894 // G_PTR_ADD(BASE, G_ADD(X, C)) -> G_PTR_ADD(G_PTR_ADD(BASE, X), C)
4896 // 2) Folding two constants in each sub-tree as long as such folding
4897 // doesn't break a legal addressing mode.
4898 // G_PTR_ADD(G_PTR_ADD(BASE, C1), C2) -> G_PTR_ADD(BASE, C1+C2)
4900 // 3) Move a constant from the LHS of an inner op to the RHS of the outer.
4901 // G_PTR_ADD (G_PTR_ADD X, C), Y) -> G_PTR_ADD (G_PTR_ADD(X, Y), C)
4902 // iif (G_PTR_ADD X, C) has one use.
4903 MachineInstr *LHS = MRI.getVRegDef(PtrAdd.getBaseReg());
4904 MachineInstr *RHS = MRI.getVRegDef(PtrAdd.getOffsetReg());
4906 // Try to match example 2.
4907 if (matchReassocFoldConstantsInSubTree(PtrAdd, LHS, RHS, MatchInfo))
4908 return true;
4910 // Try to match example 3.
4911 if (matchReassocConstantInnerLHS(PtrAdd, LHS, RHS, MatchInfo))
4912 return true;
4914 // Try to match example 1.
4915 if (matchReassocConstantInnerRHS(PtrAdd, RHS, MatchInfo))
4916 return true;
4918 return false;
4920 bool CombinerHelper::tryReassocBinOp(unsigned Opc, Register DstReg,
4921 Register OpLHS, Register OpRHS,
4922 BuildFnTy &MatchInfo) const {
4923 LLT OpRHSTy = MRI.getType(OpRHS);
4924 MachineInstr *OpLHSDef = MRI.getVRegDef(OpLHS);
4926 if (OpLHSDef->getOpcode() != Opc)
4927 return false;
4929 MachineInstr *OpRHSDef = MRI.getVRegDef(OpRHS);
4930 Register OpLHSLHS = OpLHSDef->getOperand(1).getReg();
4931 Register OpLHSRHS = OpLHSDef->getOperand(2).getReg();
4933 // If the inner op is (X op C), pull the constant out so it can be folded with
4934 // other constants in the expression tree. Folding is not guaranteed so we
4935 // might have (C1 op C2). In that case do not pull a constant out because it
4936 // won't help and can lead to infinite loops.
4937 if (isConstantOrConstantSplatVector(*MRI.getVRegDef(OpLHSRHS), MRI) &&
4938 !isConstantOrConstantSplatVector(*MRI.getVRegDef(OpLHSLHS), MRI)) {
4939 if (isConstantOrConstantSplatVector(*OpRHSDef, MRI)) {
4940 // (Opc (Opc X, C1), C2) -> (Opc X, (Opc C1, C2))
4941 MatchInfo = [=](MachineIRBuilder &B) {
4942 auto NewCst = B.buildInstr(Opc, {OpRHSTy}, {OpLHSRHS, OpRHS});
4943 B.buildInstr(Opc, {DstReg}, {OpLHSLHS, NewCst});
4945 return true;
4947 if (getTargetLowering().isReassocProfitable(MRI, OpLHS, OpRHS)) {
4948 // Reassociate: (op (op x, c1), y) -> (op (op x, y), c1)
4949 // iff (op x, c1) has one use
4950 MatchInfo = [=](MachineIRBuilder &B) {
4951 auto NewLHSLHS = B.buildInstr(Opc, {OpRHSTy}, {OpLHSLHS, OpRHS});
4952 B.buildInstr(Opc, {DstReg}, {NewLHSLHS, OpLHSRHS});
4954 return true;
4958 return false;
4961 bool CombinerHelper::matchReassocCommBinOp(MachineInstr &MI,
4962 BuildFnTy &MatchInfo) const {
4963 // We don't check if the reassociation will break a legal addressing mode
4964 // here since pointer arithmetic is handled by G_PTR_ADD.
4965 unsigned Opc = MI.getOpcode();
4966 Register DstReg = MI.getOperand(0).getReg();
4967 Register LHSReg = MI.getOperand(1).getReg();
4968 Register RHSReg = MI.getOperand(2).getReg();
4970 if (tryReassocBinOp(Opc, DstReg, LHSReg, RHSReg, MatchInfo))
4971 return true;
4972 if (tryReassocBinOp(Opc, DstReg, RHSReg, LHSReg, MatchInfo))
4973 return true;
4974 return false;
4977 bool CombinerHelper::matchConstantFoldCastOp(MachineInstr &MI,
4978 APInt &MatchInfo) const {
4979 LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
4980 Register SrcOp = MI.getOperand(1).getReg();
4982 if (auto MaybeCst = ConstantFoldCastOp(MI.getOpcode(), DstTy, SrcOp, MRI)) {
4983 MatchInfo = *MaybeCst;
4984 return true;
4987 return false;
4990 bool CombinerHelper::matchConstantFoldBinOp(MachineInstr &MI,
4991 APInt &MatchInfo) const {
4992 Register Op1 = MI.getOperand(1).getReg();
4993 Register Op2 = MI.getOperand(2).getReg();
4994 auto MaybeCst = ConstantFoldBinOp(MI.getOpcode(), Op1, Op2, MRI);
4995 if (!MaybeCst)
4996 return false;
4997 MatchInfo = *MaybeCst;
4998 return true;
5001 bool CombinerHelper::matchConstantFoldFPBinOp(MachineInstr &MI,
5002 ConstantFP *&MatchInfo) const {
5003 Register Op1 = MI.getOperand(1).getReg();
5004 Register Op2 = MI.getOperand(2).getReg();
5005 auto MaybeCst = ConstantFoldFPBinOp(MI.getOpcode(), Op1, Op2, MRI);
5006 if (!MaybeCst)
5007 return false;
5008 MatchInfo =
5009 ConstantFP::get(MI.getMF()->getFunction().getContext(), *MaybeCst);
5010 return true;
5013 bool CombinerHelper::matchConstantFoldFMA(MachineInstr &MI,
5014 ConstantFP *&MatchInfo) const {
5015 assert(MI.getOpcode() == TargetOpcode::G_FMA ||
5016 MI.getOpcode() == TargetOpcode::G_FMAD);
5017 auto [_, Op1, Op2, Op3] = MI.getFirst4Regs();
5019 const ConstantFP *Op3Cst = getConstantFPVRegVal(Op3, MRI);
5020 if (!Op3Cst)
5021 return false;
5023 const ConstantFP *Op2Cst = getConstantFPVRegVal(Op2, MRI);
5024 if (!Op2Cst)
5025 return false;
5027 const ConstantFP *Op1Cst = getConstantFPVRegVal(Op1, MRI);
5028 if (!Op1Cst)
5029 return false;
5031 APFloat Op1F = Op1Cst->getValueAPF();
5032 Op1F.fusedMultiplyAdd(Op2Cst->getValueAPF(), Op3Cst->getValueAPF(),
5033 APFloat::rmNearestTiesToEven);
5034 MatchInfo = ConstantFP::get(MI.getMF()->getFunction().getContext(), Op1F);
5035 return true;
5038 bool CombinerHelper::matchNarrowBinopFeedingAnd(
5039 MachineInstr &MI,
5040 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
5041 // Look for a binop feeding into an AND with a mask:
5043 // %add = G_ADD %lhs, %rhs
5044 // %and = G_AND %add, 000...11111111
5046 // Check if it's possible to perform the binop at a narrower width and zext
5047 // back to the original width like so:
5049 // %narrow_lhs = G_TRUNC %lhs
5050 // %narrow_rhs = G_TRUNC %rhs
5051 // %narrow_add = G_ADD %narrow_lhs, %narrow_rhs
5052 // %new_add = G_ZEXT %narrow_add
5053 // %and = G_AND %new_add, 000...11111111
5055 // This can allow later combines to eliminate the G_AND if it turns out
5056 // that the mask is irrelevant.
5057 assert(MI.getOpcode() == TargetOpcode::G_AND);
5058 Register Dst = MI.getOperand(0).getReg();
5059 Register AndLHS = MI.getOperand(1).getReg();
5060 Register AndRHS = MI.getOperand(2).getReg();
5061 LLT WideTy = MRI.getType(Dst);
5063 // If the potential binop has more than one use, then it's possible that one
5064 // of those uses will need its full width.
5065 if (!WideTy.isScalar() || !MRI.hasOneNonDBGUse(AndLHS))
5066 return false;
5068 // Check if the LHS feeding the AND is impacted by the high bits that we're
5069 // masking out.
5071 // e.g. for 64-bit x, y:
5073 // add_64(x, y) & 65535 == zext(add_16(trunc(x), trunc(y))) & 65535
5074 MachineInstr *LHSInst = getDefIgnoringCopies(AndLHS, MRI);
5075 if (!LHSInst)
5076 return false;
5077 unsigned LHSOpc = LHSInst->getOpcode();
5078 switch (LHSOpc) {
5079 default:
5080 return false;
5081 case TargetOpcode::G_ADD:
5082 case TargetOpcode::G_SUB:
5083 case TargetOpcode::G_MUL:
5084 case TargetOpcode::G_AND:
5085 case TargetOpcode::G_OR:
5086 case TargetOpcode::G_XOR:
5087 break;
5090 // Find the mask on the RHS.
5091 auto Cst = getIConstantVRegValWithLookThrough(AndRHS, MRI);
5092 if (!Cst)
5093 return false;
5094 auto Mask = Cst->Value;
5095 if (!Mask.isMask())
5096 return false;
5098 // No point in combining if there's nothing to truncate.
5099 unsigned NarrowWidth = Mask.countr_one();
5100 if (NarrowWidth == WideTy.getSizeInBits())
5101 return false;
5102 LLT NarrowTy = LLT::scalar(NarrowWidth);
5104 // Check if adding the zext + truncates could be harmful.
5105 auto &MF = *MI.getMF();
5106 const auto &TLI = getTargetLowering();
5107 LLVMContext &Ctx = MF.getFunction().getContext();
5108 if (!TLI.isTruncateFree(WideTy, NarrowTy, Ctx) ||
5109 !TLI.isZExtFree(NarrowTy, WideTy, Ctx))
5110 return false;
5111 if (!isLegalOrBeforeLegalizer({TargetOpcode::G_TRUNC, {NarrowTy, WideTy}}) ||
5112 !isLegalOrBeforeLegalizer({TargetOpcode::G_ZEXT, {WideTy, NarrowTy}}))
5113 return false;
5114 Register BinOpLHS = LHSInst->getOperand(1).getReg();
5115 Register BinOpRHS = LHSInst->getOperand(2).getReg();
5116 MatchInfo = [=, &MI](MachineIRBuilder &B) {
5117 auto NarrowLHS = Builder.buildTrunc(NarrowTy, BinOpLHS);
5118 auto NarrowRHS = Builder.buildTrunc(NarrowTy, BinOpRHS);
5119 auto NarrowBinOp =
5120 Builder.buildInstr(LHSOpc, {NarrowTy}, {NarrowLHS, NarrowRHS});
5121 auto Ext = Builder.buildZExt(WideTy, NarrowBinOp);
5122 Observer.changingInstr(MI);
5123 MI.getOperand(1).setReg(Ext.getReg(0));
5124 Observer.changedInstr(MI);
5126 return true;
5129 bool CombinerHelper::matchMulOBy2(MachineInstr &MI,
5130 BuildFnTy &MatchInfo) const {
5131 unsigned Opc = MI.getOpcode();
5132 assert(Opc == TargetOpcode::G_UMULO || Opc == TargetOpcode::G_SMULO);
5134 if (!mi_match(MI.getOperand(3).getReg(), MRI, m_SpecificICstOrSplat(2)))
5135 return false;
5137 MatchInfo = [=, &MI](MachineIRBuilder &B) {
5138 Observer.changingInstr(MI);
5139 unsigned NewOpc = Opc == TargetOpcode::G_UMULO ? TargetOpcode::G_UADDO
5140 : TargetOpcode::G_SADDO;
5141 MI.setDesc(Builder.getTII().get(NewOpc));
5142 MI.getOperand(3).setReg(MI.getOperand(2).getReg());
5143 Observer.changedInstr(MI);
5145 return true;
5148 bool CombinerHelper::matchMulOBy0(MachineInstr &MI,
5149 BuildFnTy &MatchInfo) const {
5150 // (G_*MULO x, 0) -> 0 + no carry out
5151 assert(MI.getOpcode() == TargetOpcode::G_UMULO ||
5152 MI.getOpcode() == TargetOpcode::G_SMULO);
5153 if (!mi_match(MI.getOperand(3).getReg(), MRI, m_SpecificICstOrSplat(0)))
5154 return false;
5155 Register Dst = MI.getOperand(0).getReg();
5156 Register Carry = MI.getOperand(1).getReg();
5157 if (!isConstantLegalOrBeforeLegalizer(MRI.getType(Dst)) ||
5158 !isConstantLegalOrBeforeLegalizer(MRI.getType(Carry)))
5159 return false;
5160 MatchInfo = [=](MachineIRBuilder &B) {
5161 B.buildConstant(Dst, 0);
5162 B.buildConstant(Carry, 0);
5164 return true;
5167 bool CombinerHelper::matchAddEToAddO(MachineInstr &MI,
5168 BuildFnTy &MatchInfo) const {
5169 // (G_*ADDE x, y, 0) -> (G_*ADDO x, y)
5170 // (G_*SUBE x, y, 0) -> (G_*SUBO x, y)
5171 assert(MI.getOpcode() == TargetOpcode::G_UADDE ||
5172 MI.getOpcode() == TargetOpcode::G_SADDE ||
5173 MI.getOpcode() == TargetOpcode::G_USUBE ||
5174 MI.getOpcode() == TargetOpcode::G_SSUBE);
5175 if (!mi_match(MI.getOperand(4).getReg(), MRI, m_SpecificICstOrSplat(0)))
5176 return false;
5177 MatchInfo = [&](MachineIRBuilder &B) {
5178 unsigned NewOpcode;
5179 switch (MI.getOpcode()) {
5180 case TargetOpcode::G_UADDE:
5181 NewOpcode = TargetOpcode::G_UADDO;
5182 break;
5183 case TargetOpcode::G_SADDE:
5184 NewOpcode = TargetOpcode::G_SADDO;
5185 break;
5186 case TargetOpcode::G_USUBE:
5187 NewOpcode = TargetOpcode::G_USUBO;
5188 break;
5189 case TargetOpcode::G_SSUBE:
5190 NewOpcode = TargetOpcode::G_SSUBO;
5191 break;
5193 Observer.changingInstr(MI);
5194 MI.setDesc(B.getTII().get(NewOpcode));
5195 MI.removeOperand(4);
5196 Observer.changedInstr(MI);
5198 return true;
5201 bool CombinerHelper::matchSubAddSameReg(MachineInstr &MI,
5202 BuildFnTy &MatchInfo) const {
5203 assert(MI.getOpcode() == TargetOpcode::G_SUB);
5204 Register Dst = MI.getOperand(0).getReg();
5205 // (x + y) - z -> x (if y == z)
5206 // (x + y) - z -> y (if x == z)
5207 Register X, Y, Z;
5208 if (mi_match(Dst, MRI, m_GSub(m_GAdd(m_Reg(X), m_Reg(Y)), m_Reg(Z)))) {
5209 Register ReplaceReg;
5210 int64_t CstX, CstY;
5211 if (Y == Z || (mi_match(Y, MRI, m_ICstOrSplat(CstY)) &&
5212 mi_match(Z, MRI, m_SpecificICstOrSplat(CstY))))
5213 ReplaceReg = X;
5214 else if (X == Z || (mi_match(X, MRI, m_ICstOrSplat(CstX)) &&
5215 mi_match(Z, MRI, m_SpecificICstOrSplat(CstX))))
5216 ReplaceReg = Y;
5217 if (ReplaceReg) {
5218 MatchInfo = [=](MachineIRBuilder &B) { B.buildCopy(Dst, ReplaceReg); };
5219 return true;
5223 // x - (y + z) -> 0 - y (if x == z)
5224 // x - (y + z) -> 0 - z (if x == y)
5225 if (mi_match(Dst, MRI, m_GSub(m_Reg(X), m_GAdd(m_Reg(Y), m_Reg(Z))))) {
5226 Register ReplaceReg;
5227 int64_t CstX;
5228 if (X == Z || (mi_match(X, MRI, m_ICstOrSplat(CstX)) &&
5229 mi_match(Z, MRI, m_SpecificICstOrSplat(CstX))))
5230 ReplaceReg = Y;
5231 else if (X == Y || (mi_match(X, MRI, m_ICstOrSplat(CstX)) &&
5232 mi_match(Y, MRI, m_SpecificICstOrSplat(CstX))))
5233 ReplaceReg = Z;
5234 if (ReplaceReg) {
5235 MatchInfo = [=](MachineIRBuilder &B) {
5236 auto Zero = B.buildConstant(MRI.getType(Dst), 0);
5237 B.buildSub(Dst, Zero, ReplaceReg);
5239 return true;
5242 return false;
5245 MachineInstr *CombinerHelper::buildUDivUsingMul(MachineInstr &MI) const {
5246 assert(MI.getOpcode() == TargetOpcode::G_UDIV);
5247 auto &UDiv = cast<GenericMachineInstr>(MI);
5248 Register Dst = UDiv.getReg(0);
5249 Register LHS = UDiv.getReg(1);
5250 Register RHS = UDiv.getReg(2);
5251 LLT Ty = MRI.getType(Dst);
5252 LLT ScalarTy = Ty.getScalarType();
5253 const unsigned EltBits = ScalarTy.getScalarSizeInBits();
5254 LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(Ty);
5255 LLT ScalarShiftAmtTy = ShiftAmtTy.getScalarType();
5257 auto &MIB = Builder;
5259 bool UseSRL = false;
5260 SmallVector<Register, 16> Shifts, Factors;
5261 auto *RHSDefInstr = cast<GenericMachineInstr>(getDefIgnoringCopies(RHS, MRI));
5262 bool IsSplat = getIConstantSplatVal(*RHSDefInstr, MRI).has_value();
5264 auto BuildExactUDIVPattern = [&](const Constant *C) {
5265 // Don't recompute inverses for each splat element.
5266 if (IsSplat && !Factors.empty()) {
5267 Shifts.push_back(Shifts[0]);
5268 Factors.push_back(Factors[0]);
5269 return true;
5272 auto *CI = cast<ConstantInt>(C);
5273 APInt Divisor = CI->getValue();
5274 unsigned Shift = Divisor.countr_zero();
5275 if (Shift) {
5276 Divisor.lshrInPlace(Shift);
5277 UseSRL = true;
5280 // Calculate the multiplicative inverse modulo BW.
5281 APInt Factor = Divisor.multiplicativeInverse();
5282 Shifts.push_back(MIB.buildConstant(ScalarShiftAmtTy, Shift).getReg(0));
5283 Factors.push_back(MIB.buildConstant(ScalarTy, Factor).getReg(0));
5284 return true;
5287 if (MI.getFlag(MachineInstr::MIFlag::IsExact)) {
5288 // Collect all magic values from the build vector.
5289 if (!matchUnaryPredicate(MRI, RHS, BuildExactUDIVPattern))
5290 llvm_unreachable("Expected unary predicate match to succeed");
5292 Register Shift, Factor;
5293 if (Ty.isVector()) {
5294 Shift = MIB.buildBuildVector(ShiftAmtTy, Shifts).getReg(0);
5295 Factor = MIB.buildBuildVector(Ty, Factors).getReg(0);
5296 } else {
5297 Shift = Shifts[0];
5298 Factor = Factors[0];
5301 Register Res = LHS;
5303 if (UseSRL)
5304 Res = MIB.buildLShr(Ty, Res, Shift, MachineInstr::IsExact).getReg(0);
5306 return MIB.buildMul(Ty, Res, Factor);
5309 unsigned KnownLeadingZeros =
5310 KB ? KB->getKnownBits(LHS).countMinLeadingZeros() : 0;
5312 bool UseNPQ = false;
5313 SmallVector<Register, 16> PreShifts, PostShifts, MagicFactors, NPQFactors;
5314 auto BuildUDIVPattern = [&](const Constant *C) {
5315 auto *CI = cast<ConstantInt>(C);
5316 const APInt &Divisor = CI->getValue();
5318 bool SelNPQ = false;
5319 APInt Magic(Divisor.getBitWidth(), 0);
5320 unsigned PreShift = 0, PostShift = 0;
5322 // Magic algorithm doesn't work for division by 1. We need to emit a select
5323 // at the end.
5324 // TODO: Use undef values for divisor of 1.
5325 if (!Divisor.isOne()) {
5327 // UnsignedDivisionByConstantInfo doesn't work correctly if leading zeros
5328 // in the dividend exceeds the leading zeros for the divisor.
5329 UnsignedDivisionByConstantInfo magics =
5330 UnsignedDivisionByConstantInfo::get(
5331 Divisor, std::min(KnownLeadingZeros, Divisor.countl_zero()));
5333 Magic = std::move(magics.Magic);
5335 assert(magics.PreShift < Divisor.getBitWidth() &&
5336 "We shouldn't generate an undefined shift!");
5337 assert(magics.PostShift < Divisor.getBitWidth() &&
5338 "We shouldn't generate an undefined shift!");
5339 assert((!magics.IsAdd || magics.PreShift == 0) && "Unexpected pre-shift");
5340 PreShift = magics.PreShift;
5341 PostShift = magics.PostShift;
5342 SelNPQ = magics.IsAdd;
5345 PreShifts.push_back(
5346 MIB.buildConstant(ScalarShiftAmtTy, PreShift).getReg(0));
5347 MagicFactors.push_back(MIB.buildConstant(ScalarTy, Magic).getReg(0));
5348 NPQFactors.push_back(
5349 MIB.buildConstant(ScalarTy,
5350 SelNPQ ? APInt::getOneBitSet(EltBits, EltBits - 1)
5351 : APInt::getZero(EltBits))
5352 .getReg(0));
5353 PostShifts.push_back(
5354 MIB.buildConstant(ScalarShiftAmtTy, PostShift).getReg(0));
5355 UseNPQ |= SelNPQ;
5356 return true;
5359 // Collect the shifts/magic values from each element.
5360 bool Matched = matchUnaryPredicate(MRI, RHS, BuildUDIVPattern);
5361 (void)Matched;
5362 assert(Matched && "Expected unary predicate match to succeed");
5364 Register PreShift, PostShift, MagicFactor, NPQFactor;
5365 auto *RHSDef = getOpcodeDef<GBuildVector>(RHS, MRI);
5366 if (RHSDef) {
5367 PreShift = MIB.buildBuildVector(ShiftAmtTy, PreShifts).getReg(0);
5368 MagicFactor = MIB.buildBuildVector(Ty, MagicFactors).getReg(0);
5369 NPQFactor = MIB.buildBuildVector(Ty, NPQFactors).getReg(0);
5370 PostShift = MIB.buildBuildVector(ShiftAmtTy, PostShifts).getReg(0);
5371 } else {
5372 assert(MRI.getType(RHS).isScalar() &&
5373 "Non-build_vector operation should have been a scalar");
5374 PreShift = PreShifts[0];
5375 MagicFactor = MagicFactors[0];
5376 PostShift = PostShifts[0];
5379 Register Q = LHS;
5380 Q = MIB.buildLShr(Ty, Q, PreShift).getReg(0);
5382 // Multiply the numerator (operand 0) by the magic value.
5383 Q = MIB.buildUMulH(Ty, Q, MagicFactor).getReg(0);
5385 if (UseNPQ) {
5386 Register NPQ = MIB.buildSub(Ty, LHS, Q).getReg(0);
5388 // For vectors we might have a mix of non-NPQ/NPQ paths, so use
5389 // G_UMULH to act as a SRL-by-1 for NPQ, else multiply by zero.
5390 if (Ty.isVector())
5391 NPQ = MIB.buildUMulH(Ty, NPQ, NPQFactor).getReg(0);
5392 else
5393 NPQ = MIB.buildLShr(Ty, NPQ, MIB.buildConstant(ShiftAmtTy, 1)).getReg(0);
5395 Q = MIB.buildAdd(Ty, NPQ, Q).getReg(0);
5398 Q = MIB.buildLShr(Ty, Q, PostShift).getReg(0);
5399 auto One = MIB.buildConstant(Ty, 1);
5400 auto IsOne = MIB.buildICmp(
5401 CmpInst::Predicate::ICMP_EQ,
5402 Ty.isScalar() ? LLT::scalar(1) : Ty.changeElementSize(1), RHS, One);
5403 return MIB.buildSelect(Ty, IsOne, LHS, Q);
5406 bool CombinerHelper::matchUDivByConst(MachineInstr &MI) const {
5407 assert(MI.getOpcode() == TargetOpcode::G_UDIV);
5408 Register Dst = MI.getOperand(0).getReg();
5409 Register RHS = MI.getOperand(2).getReg();
5410 LLT DstTy = MRI.getType(Dst);
5412 auto &MF = *MI.getMF();
5413 AttributeList Attr = MF.getFunction().getAttributes();
5414 const auto &TLI = getTargetLowering();
5415 LLVMContext &Ctx = MF.getFunction().getContext();
5416 if (TLI.isIntDivCheap(getApproximateEVTForLLT(DstTy, Ctx), Attr))
5417 return false;
5419 // Don't do this for minsize because the instruction sequence is usually
5420 // larger.
5421 if (MF.getFunction().hasMinSize())
5422 return false;
5424 if (MI.getFlag(MachineInstr::MIFlag::IsExact)) {
5425 return matchUnaryPredicate(
5426 MRI, RHS, [](const Constant *C) { return C && !C->isNullValue(); });
5429 auto *RHSDef = MRI.getVRegDef(RHS);
5430 if (!isConstantOrConstantVector(*RHSDef, MRI))
5431 return false;
5433 // Don't do this if the types are not going to be legal.
5434 if (LI) {
5435 if (!isLegalOrBeforeLegalizer({TargetOpcode::G_MUL, {DstTy, DstTy}}))
5436 return false;
5437 if (!isLegalOrBeforeLegalizer({TargetOpcode::G_UMULH, {DstTy}}))
5438 return false;
5439 if (!isLegalOrBeforeLegalizer(
5440 {TargetOpcode::G_ICMP,
5441 {DstTy.isVector() ? DstTy.changeElementSize(1) : LLT::scalar(1),
5442 DstTy}}))
5443 return false;
5446 return matchUnaryPredicate(
5447 MRI, RHS, [](const Constant *C) { return C && !C->isNullValue(); });
5450 void CombinerHelper::applyUDivByConst(MachineInstr &MI) const {
5451 auto *NewMI = buildUDivUsingMul(MI);
5452 replaceSingleDefInstWithReg(MI, NewMI->getOperand(0).getReg());
5455 bool CombinerHelper::matchSDivByConst(MachineInstr &MI) const {
5456 assert(MI.getOpcode() == TargetOpcode::G_SDIV && "Expected SDIV");
5457 Register Dst = MI.getOperand(0).getReg();
5458 Register RHS = MI.getOperand(2).getReg();
5459 LLT DstTy = MRI.getType(Dst);
5461 auto &MF = *MI.getMF();
5462 AttributeList Attr = MF.getFunction().getAttributes();
5463 const auto &TLI = getTargetLowering();
5464 LLVMContext &Ctx = MF.getFunction().getContext();
5465 if (TLI.isIntDivCheap(getApproximateEVTForLLT(DstTy, Ctx), Attr))
5466 return false;
5468 // Don't do this for minsize because the instruction sequence is usually
5469 // larger.
5470 if (MF.getFunction().hasMinSize())
5471 return false;
5473 // If the sdiv has an 'exact' flag we can use a simpler lowering.
5474 if (MI.getFlag(MachineInstr::MIFlag::IsExact)) {
5475 return matchUnaryPredicate(
5476 MRI, RHS, [](const Constant *C) { return C && !C->isNullValue(); });
5479 // Don't support the general case for now.
5480 return false;
5483 void CombinerHelper::applySDivByConst(MachineInstr &MI) const {
5484 auto *NewMI = buildSDivUsingMul(MI);
5485 replaceSingleDefInstWithReg(MI, NewMI->getOperand(0).getReg());
5488 MachineInstr *CombinerHelper::buildSDivUsingMul(MachineInstr &MI) const {
5489 assert(MI.getOpcode() == TargetOpcode::G_SDIV && "Expected SDIV");
5490 auto &SDiv = cast<GenericMachineInstr>(MI);
5491 Register Dst = SDiv.getReg(0);
5492 Register LHS = SDiv.getReg(1);
5493 Register RHS = SDiv.getReg(2);
5494 LLT Ty = MRI.getType(Dst);
5495 LLT ScalarTy = Ty.getScalarType();
5496 LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(Ty);
5497 LLT ScalarShiftAmtTy = ShiftAmtTy.getScalarType();
5498 auto &MIB = Builder;
5500 bool UseSRA = false;
5501 SmallVector<Register, 16> Shifts, Factors;
5503 auto *RHSDef = cast<GenericMachineInstr>(getDefIgnoringCopies(RHS, MRI));
5504 bool IsSplat = getIConstantSplatVal(*RHSDef, MRI).has_value();
5506 auto BuildSDIVPattern = [&](const Constant *C) {
5507 // Don't recompute inverses for each splat element.
5508 if (IsSplat && !Factors.empty()) {
5509 Shifts.push_back(Shifts[0]);
5510 Factors.push_back(Factors[0]);
5511 return true;
5514 auto *CI = cast<ConstantInt>(C);
5515 APInt Divisor = CI->getValue();
5516 unsigned Shift = Divisor.countr_zero();
5517 if (Shift) {
5518 Divisor.ashrInPlace(Shift);
5519 UseSRA = true;
5522 // Calculate the multiplicative inverse modulo BW.
5523 // 2^W requires W + 1 bits, so we have to extend and then truncate.
5524 APInt Factor = Divisor.multiplicativeInverse();
5525 Shifts.push_back(MIB.buildConstant(ScalarShiftAmtTy, Shift).getReg(0));
5526 Factors.push_back(MIB.buildConstant(ScalarTy, Factor).getReg(0));
5527 return true;
5530 // Collect all magic values from the build vector.
5531 bool Matched = matchUnaryPredicate(MRI, RHS, BuildSDIVPattern);
5532 (void)Matched;
5533 assert(Matched && "Expected unary predicate match to succeed");
5535 Register Shift, Factor;
5536 if (Ty.isVector()) {
5537 Shift = MIB.buildBuildVector(ShiftAmtTy, Shifts).getReg(0);
5538 Factor = MIB.buildBuildVector(Ty, Factors).getReg(0);
5539 } else {
5540 Shift = Shifts[0];
5541 Factor = Factors[0];
5544 Register Res = LHS;
5546 if (UseSRA)
5547 Res = MIB.buildAShr(Ty, Res, Shift, MachineInstr::IsExact).getReg(0);
5549 return MIB.buildMul(Ty, Res, Factor);
5552 bool CombinerHelper::matchDivByPow2(MachineInstr &MI, bool IsSigned) const {
5553 assert((MI.getOpcode() == TargetOpcode::G_SDIV ||
5554 MI.getOpcode() == TargetOpcode::G_UDIV) &&
5555 "Expected SDIV or UDIV");
5556 auto &Div = cast<GenericMachineInstr>(MI);
5557 Register RHS = Div.getReg(2);
5558 auto MatchPow2 = [&](const Constant *C) {
5559 auto *CI = dyn_cast<ConstantInt>(C);
5560 return CI && (CI->getValue().isPowerOf2() ||
5561 (IsSigned && CI->getValue().isNegatedPowerOf2()));
5563 return matchUnaryPredicate(MRI, RHS, MatchPow2, /*AllowUndefs=*/false);
5566 void CombinerHelper::applySDivByPow2(MachineInstr &MI) const {
5567 assert(MI.getOpcode() == TargetOpcode::G_SDIV && "Expected SDIV");
5568 auto &SDiv = cast<GenericMachineInstr>(MI);
5569 Register Dst = SDiv.getReg(0);
5570 Register LHS = SDiv.getReg(1);
5571 Register RHS = SDiv.getReg(2);
5572 LLT Ty = MRI.getType(Dst);
5573 LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(Ty);
5574 LLT CCVT =
5575 Ty.isVector() ? LLT::vector(Ty.getElementCount(), 1) : LLT::scalar(1);
5577 // Effectively we want to lower G_SDIV %lhs, %rhs, where %rhs is a power of 2,
5578 // to the following version:
5580 // %c1 = G_CTTZ %rhs
5581 // %inexact = G_SUB $bitwidth, %c1
5582 // %sign = %G_ASHR %lhs, $(bitwidth - 1)
5583 // %lshr = G_LSHR %sign, %inexact
5584 // %add = G_ADD %lhs, %lshr
5585 // %ashr = G_ASHR %add, %c1
5586 // %ashr = G_SELECT, %isoneorallones, %lhs, %ashr
5587 // %zero = G_CONSTANT $0
5588 // %neg = G_NEG %ashr
5589 // %isneg = G_ICMP SLT %rhs, %zero
5590 // %res = G_SELECT %isneg, %neg, %ashr
5592 unsigned BitWidth = Ty.getScalarSizeInBits();
5593 auto Zero = Builder.buildConstant(Ty, 0);
5595 auto Bits = Builder.buildConstant(ShiftAmtTy, BitWidth);
5596 auto C1 = Builder.buildCTTZ(ShiftAmtTy, RHS);
5597 auto Inexact = Builder.buildSub(ShiftAmtTy, Bits, C1);
5598 // Splat the sign bit into the register
5599 auto Sign = Builder.buildAShr(
5600 Ty, LHS, Builder.buildConstant(ShiftAmtTy, BitWidth - 1));
5602 // Add (LHS < 0) ? abs2 - 1 : 0;
5603 auto LSrl = Builder.buildLShr(Ty, Sign, Inexact);
5604 auto Add = Builder.buildAdd(Ty, LHS, LSrl);
5605 auto AShr = Builder.buildAShr(Ty, Add, C1);
5607 // Special case: (sdiv X, 1) -> X
5608 // Special Case: (sdiv X, -1) -> 0-X
5609 auto One = Builder.buildConstant(Ty, 1);
5610 auto MinusOne = Builder.buildConstant(Ty, -1);
5611 auto IsOne = Builder.buildICmp(CmpInst::Predicate::ICMP_EQ, CCVT, RHS, One);
5612 auto IsMinusOne =
5613 Builder.buildICmp(CmpInst::Predicate::ICMP_EQ, CCVT, RHS, MinusOne);
5614 auto IsOneOrMinusOne = Builder.buildOr(CCVT, IsOne, IsMinusOne);
5615 AShr = Builder.buildSelect(Ty, IsOneOrMinusOne, LHS, AShr);
5617 // If divided by a positive value, we're done. Otherwise, the result must be
5618 // negated.
5619 auto Neg = Builder.buildNeg(Ty, AShr);
5620 auto IsNeg = Builder.buildICmp(CmpInst::Predicate::ICMP_SLT, CCVT, RHS, Zero);
5621 Builder.buildSelect(MI.getOperand(0).getReg(), IsNeg, Neg, AShr);
5622 MI.eraseFromParent();
5625 void CombinerHelper::applyUDivByPow2(MachineInstr &MI) const {
5626 assert(MI.getOpcode() == TargetOpcode::G_UDIV && "Expected UDIV");
5627 auto &UDiv = cast<GenericMachineInstr>(MI);
5628 Register Dst = UDiv.getReg(0);
5629 Register LHS = UDiv.getReg(1);
5630 Register RHS = UDiv.getReg(2);
5631 LLT Ty = MRI.getType(Dst);
5632 LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(Ty);
5634 auto C1 = Builder.buildCTTZ(ShiftAmtTy, RHS);
5635 Builder.buildLShr(MI.getOperand(0).getReg(), LHS, C1);
5636 MI.eraseFromParent();
5639 bool CombinerHelper::matchUMulHToLShr(MachineInstr &MI) const {
5640 assert(MI.getOpcode() == TargetOpcode::G_UMULH);
5641 Register RHS = MI.getOperand(2).getReg();
5642 Register Dst = MI.getOperand(0).getReg();
5643 LLT Ty = MRI.getType(Dst);
5644 LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(Ty);
5645 auto MatchPow2ExceptOne = [&](const Constant *C) {
5646 if (auto *CI = dyn_cast<ConstantInt>(C))
5647 return CI->getValue().isPowerOf2() && !CI->getValue().isOne();
5648 return false;
5650 if (!matchUnaryPredicate(MRI, RHS, MatchPow2ExceptOne, false))
5651 return false;
5652 return isLegalOrBeforeLegalizer({TargetOpcode::G_LSHR, {Ty, ShiftAmtTy}});
5655 void CombinerHelper::applyUMulHToLShr(MachineInstr &MI) const {
5656 Register LHS = MI.getOperand(1).getReg();
5657 Register RHS = MI.getOperand(2).getReg();
5658 Register Dst = MI.getOperand(0).getReg();
5659 LLT Ty = MRI.getType(Dst);
5660 LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(Ty);
5661 unsigned NumEltBits = Ty.getScalarSizeInBits();
5663 auto LogBase2 = buildLogBase2(RHS, Builder);
5664 auto ShiftAmt =
5665 Builder.buildSub(Ty, Builder.buildConstant(Ty, NumEltBits), LogBase2);
5666 auto Trunc = Builder.buildZExtOrTrunc(ShiftAmtTy, ShiftAmt);
5667 Builder.buildLShr(Dst, LHS, Trunc);
5668 MI.eraseFromParent();
5671 bool CombinerHelper::matchRedundantNegOperands(MachineInstr &MI,
5672 BuildFnTy &MatchInfo) const {
5673 unsigned Opc = MI.getOpcode();
5674 assert(Opc == TargetOpcode::G_FADD || Opc == TargetOpcode::G_FSUB ||
5675 Opc == TargetOpcode::G_FMUL || Opc == TargetOpcode::G_FDIV ||
5676 Opc == TargetOpcode::G_FMAD || Opc == TargetOpcode::G_FMA);
5678 Register Dst = MI.getOperand(0).getReg();
5679 Register X = MI.getOperand(1).getReg();
5680 Register Y = MI.getOperand(2).getReg();
5681 LLT Type = MRI.getType(Dst);
5683 // fold (fadd x, fneg(y)) -> (fsub x, y)
5684 // fold (fadd fneg(y), x) -> (fsub x, y)
5685 // G_ADD is commutative so both cases are checked by m_GFAdd
5686 if (mi_match(Dst, MRI, m_GFAdd(m_Reg(X), m_GFNeg(m_Reg(Y)))) &&
5687 isLegalOrBeforeLegalizer({TargetOpcode::G_FSUB, {Type}})) {
5688 Opc = TargetOpcode::G_FSUB;
5690 /// fold (fsub x, fneg(y)) -> (fadd x, y)
5691 else if (mi_match(Dst, MRI, m_GFSub(m_Reg(X), m_GFNeg(m_Reg(Y)))) &&
5692 isLegalOrBeforeLegalizer({TargetOpcode::G_FADD, {Type}})) {
5693 Opc = TargetOpcode::G_FADD;
5695 // fold (fmul fneg(x), fneg(y)) -> (fmul x, y)
5696 // fold (fdiv fneg(x), fneg(y)) -> (fdiv x, y)
5697 // fold (fmad fneg(x), fneg(y), z) -> (fmad x, y, z)
5698 // fold (fma fneg(x), fneg(y), z) -> (fma x, y, z)
5699 else if ((Opc == TargetOpcode::G_FMUL || Opc == TargetOpcode::G_FDIV ||
5700 Opc == TargetOpcode::G_FMAD || Opc == TargetOpcode::G_FMA) &&
5701 mi_match(X, MRI, m_GFNeg(m_Reg(X))) &&
5702 mi_match(Y, MRI, m_GFNeg(m_Reg(Y)))) {
5703 // no opcode change
5704 } else
5705 return false;
5707 MatchInfo = [=, &MI](MachineIRBuilder &B) {
5708 Observer.changingInstr(MI);
5709 MI.setDesc(B.getTII().get(Opc));
5710 MI.getOperand(1).setReg(X);
5711 MI.getOperand(2).setReg(Y);
5712 Observer.changedInstr(MI);
5714 return true;
5717 bool CombinerHelper::matchFsubToFneg(MachineInstr &MI,
5718 Register &MatchInfo) const {
5719 assert(MI.getOpcode() == TargetOpcode::G_FSUB);
5721 Register LHS = MI.getOperand(1).getReg();
5722 MatchInfo = MI.getOperand(2).getReg();
5723 LLT Ty = MRI.getType(MI.getOperand(0).getReg());
5725 const auto LHSCst = Ty.isVector()
5726 ? getFConstantSplat(LHS, MRI, /* allowUndef */ true)
5727 : getFConstantVRegValWithLookThrough(LHS, MRI);
5728 if (!LHSCst)
5729 return false;
5731 // -0.0 is always allowed
5732 if (LHSCst->Value.isNegZero())
5733 return true;
5735 // +0.0 is only allowed if nsz is set.
5736 if (LHSCst->Value.isPosZero())
5737 return MI.getFlag(MachineInstr::FmNsz);
5739 return false;
5742 void CombinerHelper::applyFsubToFneg(MachineInstr &MI,
5743 Register &MatchInfo) const {
5744 Register Dst = MI.getOperand(0).getReg();
5745 Builder.buildFNeg(
5746 Dst, Builder.buildFCanonicalize(MRI.getType(Dst), MatchInfo).getReg(0));
5747 eraseInst(MI);
5750 /// Checks if \p MI is TargetOpcode::G_FMUL and contractable either
5751 /// due to global flags or MachineInstr flags.
5752 static bool isContractableFMul(MachineInstr &MI, bool AllowFusionGlobally) {
5753 if (MI.getOpcode() != TargetOpcode::G_FMUL)
5754 return false;
5755 return AllowFusionGlobally || MI.getFlag(MachineInstr::MIFlag::FmContract);
5758 static bool hasMoreUses(const MachineInstr &MI0, const MachineInstr &MI1,
5759 const MachineRegisterInfo &MRI) {
5760 return std::distance(MRI.use_instr_nodbg_begin(MI0.getOperand(0).getReg()),
5761 MRI.use_instr_nodbg_end()) >
5762 std::distance(MRI.use_instr_nodbg_begin(MI1.getOperand(0).getReg()),
5763 MRI.use_instr_nodbg_end());
5766 bool CombinerHelper::canCombineFMadOrFMA(MachineInstr &MI,
5767 bool &AllowFusionGlobally,
5768 bool &HasFMAD, bool &Aggressive,
5769 bool CanReassociate) const {
5771 auto *MF = MI.getMF();
5772 const auto &TLI = *MF->getSubtarget().getTargetLowering();
5773 const TargetOptions &Options = MF->getTarget().Options;
5774 LLT DstType = MRI.getType(MI.getOperand(0).getReg());
5776 if (CanReassociate &&
5777 !(Options.UnsafeFPMath || MI.getFlag(MachineInstr::MIFlag::FmReassoc)))
5778 return false;
5780 // Floating-point multiply-add with intermediate rounding.
5781 HasFMAD = (!isPreLegalize() && TLI.isFMADLegal(MI, DstType));
5782 // Floating-point multiply-add without intermediate rounding.
5783 bool HasFMA = TLI.isFMAFasterThanFMulAndFAdd(*MF, DstType) &&
5784 isLegalOrBeforeLegalizer({TargetOpcode::G_FMA, {DstType}});
5785 // No valid opcode, do not combine.
5786 if (!HasFMAD && !HasFMA)
5787 return false;
5789 AllowFusionGlobally = Options.AllowFPOpFusion == FPOpFusion::Fast ||
5790 Options.UnsafeFPMath || HasFMAD;
5791 // If the addition is not contractable, do not combine.
5792 if (!AllowFusionGlobally && !MI.getFlag(MachineInstr::MIFlag::FmContract))
5793 return false;
5795 Aggressive = TLI.enableAggressiveFMAFusion(DstType);
5796 return true;
5799 bool CombinerHelper::matchCombineFAddFMulToFMadOrFMA(
5800 MachineInstr &MI,
5801 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
5802 assert(MI.getOpcode() == TargetOpcode::G_FADD);
5804 bool AllowFusionGlobally, HasFMAD, Aggressive;
5805 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive))
5806 return false;
5808 Register Op1 = MI.getOperand(1).getReg();
5809 Register Op2 = MI.getOperand(2).getReg();
5810 DefinitionAndSourceRegister LHS = {MRI.getVRegDef(Op1), Op1};
5811 DefinitionAndSourceRegister RHS = {MRI.getVRegDef(Op2), Op2};
5812 unsigned PreferredFusedOpcode =
5813 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
5815 // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)),
5816 // prefer to fold the multiply with fewer uses.
5817 if (Aggressive && isContractableFMul(*LHS.MI, AllowFusionGlobally) &&
5818 isContractableFMul(*RHS.MI, AllowFusionGlobally)) {
5819 if (hasMoreUses(*LHS.MI, *RHS.MI, MRI))
5820 std::swap(LHS, RHS);
5823 // fold (fadd (fmul x, y), z) -> (fma x, y, z)
5824 if (isContractableFMul(*LHS.MI, AllowFusionGlobally) &&
5825 (Aggressive || MRI.hasOneNonDBGUse(LHS.Reg))) {
5826 MatchInfo = [=, &MI](MachineIRBuilder &B) {
5827 B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()},
5828 {LHS.MI->getOperand(1).getReg(),
5829 LHS.MI->getOperand(2).getReg(), RHS.Reg});
5831 return true;
5834 // fold (fadd x, (fmul y, z)) -> (fma y, z, x)
5835 if (isContractableFMul(*RHS.MI, AllowFusionGlobally) &&
5836 (Aggressive || MRI.hasOneNonDBGUse(RHS.Reg))) {
5837 MatchInfo = [=, &MI](MachineIRBuilder &B) {
5838 B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()},
5839 {RHS.MI->getOperand(1).getReg(),
5840 RHS.MI->getOperand(2).getReg(), LHS.Reg});
5842 return true;
5845 return false;
5848 bool CombinerHelper::matchCombineFAddFpExtFMulToFMadOrFMA(
5849 MachineInstr &MI,
5850 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
5851 assert(MI.getOpcode() == TargetOpcode::G_FADD);
5853 bool AllowFusionGlobally, HasFMAD, Aggressive;
5854 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive))
5855 return false;
5857 const auto &TLI = *MI.getMF()->getSubtarget().getTargetLowering();
5858 Register Op1 = MI.getOperand(1).getReg();
5859 Register Op2 = MI.getOperand(2).getReg();
5860 DefinitionAndSourceRegister LHS = {MRI.getVRegDef(Op1), Op1};
5861 DefinitionAndSourceRegister RHS = {MRI.getVRegDef(Op2), Op2};
5862 LLT DstType = MRI.getType(MI.getOperand(0).getReg());
5864 unsigned PreferredFusedOpcode =
5865 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
5867 // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)),
5868 // prefer to fold the multiply with fewer uses.
5869 if (Aggressive && isContractableFMul(*LHS.MI, AllowFusionGlobally) &&
5870 isContractableFMul(*RHS.MI, AllowFusionGlobally)) {
5871 if (hasMoreUses(*LHS.MI, *RHS.MI, MRI))
5872 std::swap(LHS, RHS);
5875 // fold (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z)
5876 MachineInstr *FpExtSrc;
5877 if (mi_match(LHS.Reg, MRI, m_GFPExt(m_MInstr(FpExtSrc))) &&
5878 isContractableFMul(*FpExtSrc, AllowFusionGlobally) &&
5879 TLI.isFPExtFoldable(MI, PreferredFusedOpcode, DstType,
5880 MRI.getType(FpExtSrc->getOperand(1).getReg()))) {
5881 MatchInfo = [=, &MI](MachineIRBuilder &B) {
5882 auto FpExtX = B.buildFPExt(DstType, FpExtSrc->getOperand(1).getReg());
5883 auto FpExtY = B.buildFPExt(DstType, FpExtSrc->getOperand(2).getReg());
5884 B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()},
5885 {FpExtX.getReg(0), FpExtY.getReg(0), RHS.Reg});
5887 return true;
5890 // fold (fadd z, (fpext (fmul x, y))) -> (fma (fpext x), (fpext y), z)
5891 // Note: Commutes FADD operands.
5892 if (mi_match(RHS.Reg, MRI, m_GFPExt(m_MInstr(FpExtSrc))) &&
5893 isContractableFMul(*FpExtSrc, AllowFusionGlobally) &&
5894 TLI.isFPExtFoldable(MI, PreferredFusedOpcode, DstType,
5895 MRI.getType(FpExtSrc->getOperand(1).getReg()))) {
5896 MatchInfo = [=, &MI](MachineIRBuilder &B) {
5897 auto FpExtX = B.buildFPExt(DstType, FpExtSrc->getOperand(1).getReg());
5898 auto FpExtY = B.buildFPExt(DstType, FpExtSrc->getOperand(2).getReg());
5899 B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()},
5900 {FpExtX.getReg(0), FpExtY.getReg(0), LHS.Reg});
5902 return true;
5905 return false;
5908 bool CombinerHelper::matchCombineFAddFMAFMulToFMadOrFMA(
5909 MachineInstr &MI,
5910 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
5911 assert(MI.getOpcode() == TargetOpcode::G_FADD);
5913 bool AllowFusionGlobally, HasFMAD, Aggressive;
5914 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive, true))
5915 return false;
5917 Register Op1 = MI.getOperand(1).getReg();
5918 Register Op2 = MI.getOperand(2).getReg();
5919 DefinitionAndSourceRegister LHS = {MRI.getVRegDef(Op1), Op1};
5920 DefinitionAndSourceRegister RHS = {MRI.getVRegDef(Op2), Op2};
5921 LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
5923 unsigned PreferredFusedOpcode =
5924 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
5926 // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)),
5927 // prefer to fold the multiply with fewer uses.
5928 if (Aggressive && isContractableFMul(*LHS.MI, AllowFusionGlobally) &&
5929 isContractableFMul(*RHS.MI, AllowFusionGlobally)) {
5930 if (hasMoreUses(*LHS.MI, *RHS.MI, MRI))
5931 std::swap(LHS, RHS);
5934 MachineInstr *FMA = nullptr;
5935 Register Z;
5936 // fold (fadd (fma x, y, (fmul u, v)), z) -> (fma x, y, (fma u, v, z))
5937 if (LHS.MI->getOpcode() == PreferredFusedOpcode &&
5938 (MRI.getVRegDef(LHS.MI->getOperand(3).getReg())->getOpcode() ==
5939 TargetOpcode::G_FMUL) &&
5940 MRI.hasOneNonDBGUse(LHS.MI->getOperand(0).getReg()) &&
5941 MRI.hasOneNonDBGUse(LHS.MI->getOperand(3).getReg())) {
5942 FMA = LHS.MI;
5943 Z = RHS.Reg;
5945 // fold (fadd z, (fma x, y, (fmul u, v))) -> (fma x, y, (fma u, v, z))
5946 else if (RHS.MI->getOpcode() == PreferredFusedOpcode &&
5947 (MRI.getVRegDef(RHS.MI->getOperand(3).getReg())->getOpcode() ==
5948 TargetOpcode::G_FMUL) &&
5949 MRI.hasOneNonDBGUse(RHS.MI->getOperand(0).getReg()) &&
5950 MRI.hasOneNonDBGUse(RHS.MI->getOperand(3).getReg())) {
5951 Z = LHS.Reg;
5952 FMA = RHS.MI;
5955 if (FMA) {
5956 MachineInstr *FMulMI = MRI.getVRegDef(FMA->getOperand(3).getReg());
5957 Register X = FMA->getOperand(1).getReg();
5958 Register Y = FMA->getOperand(2).getReg();
5959 Register U = FMulMI->getOperand(1).getReg();
5960 Register V = FMulMI->getOperand(2).getReg();
5962 MatchInfo = [=, &MI](MachineIRBuilder &B) {
5963 Register InnerFMA = MRI.createGenericVirtualRegister(DstTy);
5964 B.buildInstr(PreferredFusedOpcode, {InnerFMA}, {U, V, Z});
5965 B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()},
5966 {X, Y, InnerFMA});
5968 return true;
5971 return false;
5974 bool CombinerHelper::matchCombineFAddFpExtFMulToFMadOrFMAAggressive(
5975 MachineInstr &MI,
5976 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
5977 assert(MI.getOpcode() == TargetOpcode::G_FADD);
5979 bool AllowFusionGlobally, HasFMAD, Aggressive;
5980 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive))
5981 return false;
5983 if (!Aggressive)
5984 return false;
5986 const auto &TLI = *MI.getMF()->getSubtarget().getTargetLowering();
5987 LLT DstType = MRI.getType(MI.getOperand(0).getReg());
5988 Register Op1 = MI.getOperand(1).getReg();
5989 Register Op2 = MI.getOperand(2).getReg();
5990 DefinitionAndSourceRegister LHS = {MRI.getVRegDef(Op1), Op1};
5991 DefinitionAndSourceRegister RHS = {MRI.getVRegDef(Op2), Op2};
5993 unsigned PreferredFusedOpcode =
5994 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
5996 // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)),
5997 // prefer to fold the multiply with fewer uses.
5998 if (Aggressive && isContractableFMul(*LHS.MI, AllowFusionGlobally) &&
5999 isContractableFMul(*RHS.MI, AllowFusionGlobally)) {
6000 if (hasMoreUses(*LHS.MI, *RHS.MI, MRI))
6001 std::swap(LHS, RHS);
6004 // Builds: (fma x, y, (fma (fpext u), (fpext v), z))
6005 auto buildMatchInfo = [=, &MI](Register U, Register V, Register Z, Register X,
6006 Register Y, MachineIRBuilder &B) {
6007 Register FpExtU = B.buildFPExt(DstType, U).getReg(0);
6008 Register FpExtV = B.buildFPExt(DstType, V).getReg(0);
6009 Register InnerFMA =
6010 B.buildInstr(PreferredFusedOpcode, {DstType}, {FpExtU, FpExtV, Z})
6011 .getReg(0);
6012 B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()},
6013 {X, Y, InnerFMA});
6016 MachineInstr *FMulMI, *FMAMI;
6017 // fold (fadd (fma x, y, (fpext (fmul u, v))), z)
6018 // -> (fma x, y, (fma (fpext u), (fpext v), z))
6019 if (LHS.MI->getOpcode() == PreferredFusedOpcode &&
6020 mi_match(LHS.MI->getOperand(3).getReg(), MRI,
6021 m_GFPExt(m_MInstr(FMulMI))) &&
6022 isContractableFMul(*FMulMI, AllowFusionGlobally) &&
6023 TLI.isFPExtFoldable(MI, PreferredFusedOpcode, DstType,
6024 MRI.getType(FMulMI->getOperand(0).getReg()))) {
6025 MatchInfo = [=](MachineIRBuilder &B) {
6026 buildMatchInfo(FMulMI->getOperand(1).getReg(),
6027 FMulMI->getOperand(2).getReg(), RHS.Reg,
6028 LHS.MI->getOperand(1).getReg(),
6029 LHS.MI->getOperand(2).getReg(), B);
6031 return true;
6034 // fold (fadd (fpext (fma x, y, (fmul u, v))), z)
6035 // -> (fma (fpext x), (fpext y), (fma (fpext u), (fpext v), z))
6036 // FIXME: This turns two single-precision and one double-precision
6037 // operation into two double-precision operations, which might not be
6038 // interesting for all targets, especially GPUs.
6039 if (mi_match(LHS.Reg, MRI, m_GFPExt(m_MInstr(FMAMI))) &&
6040 FMAMI->getOpcode() == PreferredFusedOpcode) {
6041 MachineInstr *FMulMI = MRI.getVRegDef(FMAMI->getOperand(3).getReg());
6042 if (isContractableFMul(*FMulMI, AllowFusionGlobally) &&
6043 TLI.isFPExtFoldable(MI, PreferredFusedOpcode, DstType,
6044 MRI.getType(FMAMI->getOperand(0).getReg()))) {
6045 MatchInfo = [=](MachineIRBuilder &B) {
6046 Register X = FMAMI->getOperand(1).getReg();
6047 Register Y = FMAMI->getOperand(2).getReg();
6048 X = B.buildFPExt(DstType, X).getReg(0);
6049 Y = B.buildFPExt(DstType, Y).getReg(0);
6050 buildMatchInfo(FMulMI->getOperand(1).getReg(),
6051 FMulMI->getOperand(2).getReg(), RHS.Reg, X, Y, B);
6054 return true;
6058 // fold (fadd z, (fma x, y, (fpext (fmul u, v)))
6059 // -> (fma x, y, (fma (fpext u), (fpext v), z))
6060 if (RHS.MI->getOpcode() == PreferredFusedOpcode &&
6061 mi_match(RHS.MI->getOperand(3).getReg(), MRI,
6062 m_GFPExt(m_MInstr(FMulMI))) &&
6063 isContractableFMul(*FMulMI, AllowFusionGlobally) &&
6064 TLI.isFPExtFoldable(MI, PreferredFusedOpcode, DstType,
6065 MRI.getType(FMulMI->getOperand(0).getReg()))) {
6066 MatchInfo = [=](MachineIRBuilder &B) {
6067 buildMatchInfo(FMulMI->getOperand(1).getReg(),
6068 FMulMI->getOperand(2).getReg(), LHS.Reg,
6069 RHS.MI->getOperand(1).getReg(),
6070 RHS.MI->getOperand(2).getReg(), B);
6072 return true;
6075 // fold (fadd z, (fpext (fma x, y, (fmul u, v)))
6076 // -> (fma (fpext x), (fpext y), (fma (fpext u), (fpext v), z))
6077 // FIXME: This turns two single-precision and one double-precision
6078 // operation into two double-precision operations, which might not be
6079 // interesting for all targets, especially GPUs.
6080 if (mi_match(RHS.Reg, MRI, m_GFPExt(m_MInstr(FMAMI))) &&
6081 FMAMI->getOpcode() == PreferredFusedOpcode) {
6082 MachineInstr *FMulMI = MRI.getVRegDef(FMAMI->getOperand(3).getReg());
6083 if (isContractableFMul(*FMulMI, AllowFusionGlobally) &&
6084 TLI.isFPExtFoldable(MI, PreferredFusedOpcode, DstType,
6085 MRI.getType(FMAMI->getOperand(0).getReg()))) {
6086 MatchInfo = [=](MachineIRBuilder &B) {
6087 Register X = FMAMI->getOperand(1).getReg();
6088 Register Y = FMAMI->getOperand(2).getReg();
6089 X = B.buildFPExt(DstType, X).getReg(0);
6090 Y = B.buildFPExt(DstType, Y).getReg(0);
6091 buildMatchInfo(FMulMI->getOperand(1).getReg(),
6092 FMulMI->getOperand(2).getReg(), LHS.Reg, X, Y, B);
6094 return true;
6098 return false;
6101 bool CombinerHelper::matchCombineFSubFMulToFMadOrFMA(
6102 MachineInstr &MI,
6103 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
6104 assert(MI.getOpcode() == TargetOpcode::G_FSUB);
6106 bool AllowFusionGlobally, HasFMAD, Aggressive;
6107 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive))
6108 return false;
6110 Register Op1 = MI.getOperand(1).getReg();
6111 Register Op2 = MI.getOperand(2).getReg();
6112 DefinitionAndSourceRegister LHS = {MRI.getVRegDef(Op1), Op1};
6113 DefinitionAndSourceRegister RHS = {MRI.getVRegDef(Op2), Op2};
6114 LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
6116 // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)),
6117 // prefer to fold the multiply with fewer uses.
6118 int FirstMulHasFewerUses = true;
6119 if (isContractableFMul(*LHS.MI, AllowFusionGlobally) &&
6120 isContractableFMul(*RHS.MI, AllowFusionGlobally) &&
6121 hasMoreUses(*LHS.MI, *RHS.MI, MRI))
6122 FirstMulHasFewerUses = false;
6124 unsigned PreferredFusedOpcode =
6125 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
6127 // fold (fsub (fmul x, y), z) -> (fma x, y, -z)
6128 if (FirstMulHasFewerUses &&
6129 (isContractableFMul(*LHS.MI, AllowFusionGlobally) &&
6130 (Aggressive || MRI.hasOneNonDBGUse(LHS.Reg)))) {
6131 MatchInfo = [=, &MI](MachineIRBuilder &B) {
6132 Register NegZ = B.buildFNeg(DstTy, RHS.Reg).getReg(0);
6133 B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()},
6134 {LHS.MI->getOperand(1).getReg(),
6135 LHS.MI->getOperand(2).getReg(), NegZ});
6137 return true;
6139 // fold (fsub x, (fmul y, z)) -> (fma -y, z, x)
6140 else if ((isContractableFMul(*RHS.MI, AllowFusionGlobally) &&
6141 (Aggressive || MRI.hasOneNonDBGUse(RHS.Reg)))) {
6142 MatchInfo = [=, &MI](MachineIRBuilder &B) {
6143 Register NegY =
6144 B.buildFNeg(DstTy, RHS.MI->getOperand(1).getReg()).getReg(0);
6145 B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()},
6146 {NegY, RHS.MI->getOperand(2).getReg(), LHS.Reg});
6148 return true;
6151 return false;
6154 bool CombinerHelper::matchCombineFSubFNegFMulToFMadOrFMA(
6155 MachineInstr &MI,
6156 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
6157 assert(MI.getOpcode() == TargetOpcode::G_FSUB);
6159 bool AllowFusionGlobally, HasFMAD, Aggressive;
6160 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive))
6161 return false;
6163 Register LHSReg = MI.getOperand(1).getReg();
6164 Register RHSReg = MI.getOperand(2).getReg();
6165 LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
6167 unsigned PreferredFusedOpcode =
6168 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
6170 MachineInstr *FMulMI;
6171 // fold (fsub (fneg (fmul x, y)), z) -> (fma (fneg x), y, (fneg z))
6172 if (mi_match(LHSReg, MRI, m_GFNeg(m_MInstr(FMulMI))) &&
6173 (Aggressive || (MRI.hasOneNonDBGUse(LHSReg) &&
6174 MRI.hasOneNonDBGUse(FMulMI->getOperand(0).getReg()))) &&
6175 isContractableFMul(*FMulMI, AllowFusionGlobally)) {
6176 MatchInfo = [=, &MI](MachineIRBuilder &B) {
6177 Register NegX =
6178 B.buildFNeg(DstTy, FMulMI->getOperand(1).getReg()).getReg(0);
6179 Register NegZ = B.buildFNeg(DstTy, RHSReg).getReg(0);
6180 B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()},
6181 {NegX, FMulMI->getOperand(2).getReg(), NegZ});
6183 return true;
6186 // fold (fsub x, (fneg (fmul, y, z))) -> (fma y, z, x)
6187 if (mi_match(RHSReg, MRI, m_GFNeg(m_MInstr(FMulMI))) &&
6188 (Aggressive || (MRI.hasOneNonDBGUse(RHSReg) &&
6189 MRI.hasOneNonDBGUse(FMulMI->getOperand(0).getReg()))) &&
6190 isContractableFMul(*FMulMI, AllowFusionGlobally)) {
6191 MatchInfo = [=, &MI](MachineIRBuilder &B) {
6192 B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()},
6193 {FMulMI->getOperand(1).getReg(),
6194 FMulMI->getOperand(2).getReg(), LHSReg});
6196 return true;
6199 return false;
6202 bool CombinerHelper::matchCombineFSubFpExtFMulToFMadOrFMA(
6203 MachineInstr &MI,
6204 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
6205 assert(MI.getOpcode() == TargetOpcode::G_FSUB);
6207 bool AllowFusionGlobally, HasFMAD, Aggressive;
6208 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive))
6209 return false;
6211 Register LHSReg = MI.getOperand(1).getReg();
6212 Register RHSReg = MI.getOperand(2).getReg();
6213 LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
6215 unsigned PreferredFusedOpcode =
6216 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
6218 MachineInstr *FMulMI;
6219 // fold (fsub (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), (fneg z))
6220 if (mi_match(LHSReg, MRI, m_GFPExt(m_MInstr(FMulMI))) &&
6221 isContractableFMul(*FMulMI, AllowFusionGlobally) &&
6222 (Aggressive || MRI.hasOneNonDBGUse(LHSReg))) {
6223 MatchInfo = [=, &MI](MachineIRBuilder &B) {
6224 Register FpExtX =
6225 B.buildFPExt(DstTy, FMulMI->getOperand(1).getReg()).getReg(0);
6226 Register FpExtY =
6227 B.buildFPExt(DstTy, FMulMI->getOperand(2).getReg()).getReg(0);
6228 Register NegZ = B.buildFNeg(DstTy, RHSReg).getReg(0);
6229 B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()},
6230 {FpExtX, FpExtY, NegZ});
6232 return true;
6235 // fold (fsub x, (fpext (fmul y, z))) -> (fma (fneg (fpext y)), (fpext z), x)
6236 if (mi_match(RHSReg, MRI, m_GFPExt(m_MInstr(FMulMI))) &&
6237 isContractableFMul(*FMulMI, AllowFusionGlobally) &&
6238 (Aggressive || MRI.hasOneNonDBGUse(RHSReg))) {
6239 MatchInfo = [=, &MI](MachineIRBuilder &B) {
6240 Register FpExtY =
6241 B.buildFPExt(DstTy, FMulMI->getOperand(1).getReg()).getReg(0);
6242 Register NegY = B.buildFNeg(DstTy, FpExtY).getReg(0);
6243 Register FpExtZ =
6244 B.buildFPExt(DstTy, FMulMI->getOperand(2).getReg()).getReg(0);
6245 B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()},
6246 {NegY, FpExtZ, LHSReg});
6248 return true;
6251 return false;
6254 bool CombinerHelper::matchCombineFSubFpExtFNegFMulToFMadOrFMA(
6255 MachineInstr &MI,
6256 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
6257 assert(MI.getOpcode() == TargetOpcode::G_FSUB);
6259 bool AllowFusionGlobally, HasFMAD, Aggressive;
6260 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive))
6261 return false;
6263 const auto &TLI = *MI.getMF()->getSubtarget().getTargetLowering();
6264 LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
6265 Register LHSReg = MI.getOperand(1).getReg();
6266 Register RHSReg = MI.getOperand(2).getReg();
6268 unsigned PreferredFusedOpcode =
6269 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
6271 auto buildMatchInfo = [=](Register Dst, Register X, Register Y, Register Z,
6272 MachineIRBuilder &B) {
6273 Register FpExtX = B.buildFPExt(DstTy, X).getReg(0);
6274 Register FpExtY = B.buildFPExt(DstTy, Y).getReg(0);
6275 B.buildInstr(PreferredFusedOpcode, {Dst}, {FpExtX, FpExtY, Z});
6278 MachineInstr *FMulMI;
6279 // fold (fsub (fpext (fneg (fmul x, y))), z) ->
6280 // (fneg (fma (fpext x), (fpext y), z))
6281 // fold (fsub (fneg (fpext (fmul x, y))), z) ->
6282 // (fneg (fma (fpext x), (fpext y), z))
6283 if ((mi_match(LHSReg, MRI, m_GFPExt(m_GFNeg(m_MInstr(FMulMI)))) ||
6284 mi_match(LHSReg, MRI, m_GFNeg(m_GFPExt(m_MInstr(FMulMI))))) &&
6285 isContractableFMul(*FMulMI, AllowFusionGlobally) &&
6286 TLI.isFPExtFoldable(MI, PreferredFusedOpcode, DstTy,
6287 MRI.getType(FMulMI->getOperand(0).getReg()))) {
6288 MatchInfo = [=, &MI](MachineIRBuilder &B) {
6289 Register FMAReg = MRI.createGenericVirtualRegister(DstTy);
6290 buildMatchInfo(FMAReg, FMulMI->getOperand(1).getReg(),
6291 FMulMI->getOperand(2).getReg(), RHSReg, B);
6292 B.buildFNeg(MI.getOperand(0).getReg(), FMAReg);
6294 return true;
6297 // fold (fsub x, (fpext (fneg (fmul y, z)))) -> (fma (fpext y), (fpext z), x)
6298 // fold (fsub x, (fneg (fpext (fmul y, z)))) -> (fma (fpext y), (fpext z), x)
6299 if ((mi_match(RHSReg, MRI, m_GFPExt(m_GFNeg(m_MInstr(FMulMI)))) ||
6300 mi_match(RHSReg, MRI, m_GFNeg(m_GFPExt(m_MInstr(FMulMI))))) &&
6301 isContractableFMul(*FMulMI, AllowFusionGlobally) &&
6302 TLI.isFPExtFoldable(MI, PreferredFusedOpcode, DstTy,
6303 MRI.getType(FMulMI->getOperand(0).getReg()))) {
6304 MatchInfo = [=, &MI](MachineIRBuilder &B) {
6305 buildMatchInfo(MI.getOperand(0).getReg(), FMulMI->getOperand(1).getReg(),
6306 FMulMI->getOperand(2).getReg(), LHSReg, B);
6308 return true;
6311 return false;
6314 bool CombinerHelper::matchCombineFMinMaxNaN(MachineInstr &MI,
6315 unsigned &IdxToPropagate) const {
6316 bool PropagateNaN;
6317 switch (MI.getOpcode()) {
6318 default:
6319 return false;
6320 case TargetOpcode::G_FMINNUM:
6321 case TargetOpcode::G_FMAXNUM:
6322 PropagateNaN = false;
6323 break;
6324 case TargetOpcode::G_FMINIMUM:
6325 case TargetOpcode::G_FMAXIMUM:
6326 PropagateNaN = true;
6327 break;
6330 auto MatchNaN = [&](unsigned Idx) {
6331 Register MaybeNaNReg = MI.getOperand(Idx).getReg();
6332 const ConstantFP *MaybeCst = getConstantFPVRegVal(MaybeNaNReg, MRI);
6333 if (!MaybeCst || !MaybeCst->getValueAPF().isNaN())
6334 return false;
6335 IdxToPropagate = PropagateNaN ? Idx : (Idx == 1 ? 2 : 1);
6336 return true;
6339 return MatchNaN(1) || MatchNaN(2);
6342 bool CombinerHelper::matchAddSubSameReg(MachineInstr &MI, Register &Src) const {
6343 assert(MI.getOpcode() == TargetOpcode::G_ADD && "Expected a G_ADD");
6344 Register LHS = MI.getOperand(1).getReg();
6345 Register RHS = MI.getOperand(2).getReg();
6347 // Helper lambda to check for opportunities for
6348 // A + (B - A) -> B
6349 // (B - A) + A -> B
6350 auto CheckFold = [&](Register MaybeSub, Register MaybeSameReg) {
6351 Register Reg;
6352 return mi_match(MaybeSub, MRI, m_GSub(m_Reg(Src), m_Reg(Reg))) &&
6353 Reg == MaybeSameReg;
6355 return CheckFold(LHS, RHS) || CheckFold(RHS, LHS);
6358 bool CombinerHelper::matchBuildVectorIdentityFold(MachineInstr &MI,
6359 Register &MatchInfo) const {
6360 // This combine folds the following patterns:
6362 // G_BUILD_VECTOR_TRUNC (G_BITCAST(x), G_LSHR(G_BITCAST(x), k))
6363 // G_BUILD_VECTOR(G_TRUNC(G_BITCAST(x)), G_TRUNC(G_LSHR(G_BITCAST(x), k)))
6364 // into
6365 // x
6366 // if
6367 // k == sizeof(VecEltTy)/2
6368 // type(x) == type(dst)
6370 // G_BUILD_VECTOR(G_TRUNC(G_BITCAST(x)), undef)
6371 // into
6372 // x
6373 // if
6374 // type(x) == type(dst)
6376 LLT DstVecTy = MRI.getType(MI.getOperand(0).getReg());
6377 LLT DstEltTy = DstVecTy.getElementType();
6379 Register Lo, Hi;
6381 if (mi_match(
6382 MI, MRI,
6383 m_GBuildVector(m_GTrunc(m_GBitcast(m_Reg(Lo))), m_GImplicitDef()))) {
6384 MatchInfo = Lo;
6385 return MRI.getType(MatchInfo) == DstVecTy;
6388 std::optional<ValueAndVReg> ShiftAmount;
6389 const auto LoPattern = m_GBitcast(m_Reg(Lo));
6390 const auto HiPattern = m_GLShr(m_GBitcast(m_Reg(Hi)), m_GCst(ShiftAmount));
6391 if (mi_match(
6392 MI, MRI,
6393 m_any_of(m_GBuildVectorTrunc(LoPattern, HiPattern),
6394 m_GBuildVector(m_GTrunc(LoPattern), m_GTrunc(HiPattern))))) {
6395 if (Lo == Hi && ShiftAmount->Value == DstEltTy.getSizeInBits()) {
6396 MatchInfo = Lo;
6397 return MRI.getType(MatchInfo) == DstVecTy;
6401 return false;
6404 bool CombinerHelper::matchTruncBuildVectorFold(MachineInstr &MI,
6405 Register &MatchInfo) const {
6406 // Replace (G_TRUNC (G_BITCAST (G_BUILD_VECTOR x, y)) with just x
6407 // if type(x) == type(G_TRUNC)
6408 if (!mi_match(MI.getOperand(1).getReg(), MRI,
6409 m_GBitcast(m_GBuildVector(m_Reg(MatchInfo), m_Reg()))))
6410 return false;
6412 return MRI.getType(MatchInfo) == MRI.getType(MI.getOperand(0).getReg());
6415 bool CombinerHelper::matchTruncLshrBuildVectorFold(MachineInstr &MI,
6416 Register &MatchInfo) const {
6417 // Replace (G_TRUNC (G_LSHR (G_BITCAST (G_BUILD_VECTOR x, y)), K)) with
6418 // y if K == size of vector element type
6419 std::optional<ValueAndVReg> ShiftAmt;
6420 if (!mi_match(MI.getOperand(1).getReg(), MRI,
6421 m_GLShr(m_GBitcast(m_GBuildVector(m_Reg(), m_Reg(MatchInfo))),
6422 m_GCst(ShiftAmt))))
6423 return false;
6425 LLT MatchTy = MRI.getType(MatchInfo);
6426 return ShiftAmt->Value.getZExtValue() == MatchTy.getSizeInBits() &&
6427 MatchTy == MRI.getType(MI.getOperand(0).getReg());
6430 unsigned CombinerHelper::getFPMinMaxOpcForSelect(
6431 CmpInst::Predicate Pred, LLT DstTy,
6432 SelectPatternNaNBehaviour VsNaNRetVal) const {
6433 assert(VsNaNRetVal != SelectPatternNaNBehaviour::NOT_APPLICABLE &&
6434 "Expected a NaN behaviour?");
6435 // Choose an opcode based off of legality or the behaviour when one of the
6436 // LHS/RHS may be NaN.
6437 switch (Pred) {
6438 default:
6439 return 0;
6440 case CmpInst::FCMP_UGT:
6441 case CmpInst::FCMP_UGE:
6442 case CmpInst::FCMP_OGT:
6443 case CmpInst::FCMP_OGE:
6444 if (VsNaNRetVal == SelectPatternNaNBehaviour::RETURNS_OTHER)
6445 return TargetOpcode::G_FMAXNUM;
6446 if (VsNaNRetVal == SelectPatternNaNBehaviour::RETURNS_NAN)
6447 return TargetOpcode::G_FMAXIMUM;
6448 if (isLegal({TargetOpcode::G_FMAXNUM, {DstTy}}))
6449 return TargetOpcode::G_FMAXNUM;
6450 if (isLegal({TargetOpcode::G_FMAXIMUM, {DstTy}}))
6451 return TargetOpcode::G_FMAXIMUM;
6452 return 0;
6453 case CmpInst::FCMP_ULT:
6454 case CmpInst::FCMP_ULE:
6455 case CmpInst::FCMP_OLT:
6456 case CmpInst::FCMP_OLE:
6457 if (VsNaNRetVal == SelectPatternNaNBehaviour::RETURNS_OTHER)
6458 return TargetOpcode::G_FMINNUM;
6459 if (VsNaNRetVal == SelectPatternNaNBehaviour::RETURNS_NAN)
6460 return TargetOpcode::G_FMINIMUM;
6461 if (isLegal({TargetOpcode::G_FMINNUM, {DstTy}}))
6462 return TargetOpcode::G_FMINNUM;
6463 if (!isLegal({TargetOpcode::G_FMINIMUM, {DstTy}}))
6464 return 0;
6465 return TargetOpcode::G_FMINIMUM;
6469 CombinerHelper::SelectPatternNaNBehaviour
6470 CombinerHelper::computeRetValAgainstNaN(Register LHS, Register RHS,
6471 bool IsOrderedComparison) const {
6472 bool LHSSafe = isKnownNeverNaN(LHS, MRI);
6473 bool RHSSafe = isKnownNeverNaN(RHS, MRI);
6474 // Completely unsafe.
6475 if (!LHSSafe && !RHSSafe)
6476 return SelectPatternNaNBehaviour::NOT_APPLICABLE;
6477 if (LHSSafe && RHSSafe)
6478 return SelectPatternNaNBehaviour::RETURNS_ANY;
6479 // An ordered comparison will return false when given a NaN, so it
6480 // returns the RHS.
6481 if (IsOrderedComparison)
6482 return LHSSafe ? SelectPatternNaNBehaviour::RETURNS_NAN
6483 : SelectPatternNaNBehaviour::RETURNS_OTHER;
6484 // An unordered comparison will return true when given a NaN, so it
6485 // returns the LHS.
6486 return LHSSafe ? SelectPatternNaNBehaviour::RETURNS_OTHER
6487 : SelectPatternNaNBehaviour::RETURNS_NAN;
6490 bool CombinerHelper::matchFPSelectToMinMax(Register Dst, Register Cond,
6491 Register TrueVal, Register FalseVal,
6492 BuildFnTy &MatchInfo) const {
6493 // Match: select (fcmp cond x, y) x, y
6494 // select (fcmp cond x, y) y, x
6495 // And turn it into fminnum/fmaxnum or fmin/fmax based off of the condition.
6496 LLT DstTy = MRI.getType(Dst);
6497 // Bail out early on pointers, since we'll never want to fold to a min/max.
6498 if (DstTy.isPointer())
6499 return false;
6500 // Match a floating point compare with a less-than/greater-than predicate.
6501 // TODO: Allow multiple users of the compare if they are all selects.
6502 CmpInst::Predicate Pred;
6503 Register CmpLHS, CmpRHS;
6504 if (!mi_match(Cond, MRI,
6505 m_OneNonDBGUse(
6506 m_GFCmp(m_Pred(Pred), m_Reg(CmpLHS), m_Reg(CmpRHS)))) ||
6507 CmpInst::isEquality(Pred))
6508 return false;
6509 SelectPatternNaNBehaviour ResWithKnownNaNInfo =
6510 computeRetValAgainstNaN(CmpLHS, CmpRHS, CmpInst::isOrdered(Pred));
6511 if (ResWithKnownNaNInfo == SelectPatternNaNBehaviour::NOT_APPLICABLE)
6512 return false;
6513 if (TrueVal == CmpRHS && FalseVal == CmpLHS) {
6514 std::swap(CmpLHS, CmpRHS);
6515 Pred = CmpInst::getSwappedPredicate(Pred);
6516 if (ResWithKnownNaNInfo == SelectPatternNaNBehaviour::RETURNS_NAN)
6517 ResWithKnownNaNInfo = SelectPatternNaNBehaviour::RETURNS_OTHER;
6518 else if (ResWithKnownNaNInfo == SelectPatternNaNBehaviour::RETURNS_OTHER)
6519 ResWithKnownNaNInfo = SelectPatternNaNBehaviour::RETURNS_NAN;
6521 if (TrueVal != CmpLHS || FalseVal != CmpRHS)
6522 return false;
6523 // Decide what type of max/min this should be based off of the predicate.
6524 unsigned Opc = getFPMinMaxOpcForSelect(Pred, DstTy, ResWithKnownNaNInfo);
6525 if (!Opc || !isLegal({Opc, {DstTy}}))
6526 return false;
6527 // Comparisons between signed zero and zero may have different results...
6528 // unless we have fmaximum/fminimum. In that case, we know -0 < 0.
6529 if (Opc != TargetOpcode::G_FMAXIMUM && Opc != TargetOpcode::G_FMINIMUM) {
6530 // We don't know if a comparison between two 0s will give us a consistent
6531 // result. Be conservative and only proceed if at least one side is
6532 // non-zero.
6533 auto KnownNonZeroSide = getFConstantVRegValWithLookThrough(CmpLHS, MRI);
6534 if (!KnownNonZeroSide || !KnownNonZeroSide->Value.isNonZero()) {
6535 KnownNonZeroSide = getFConstantVRegValWithLookThrough(CmpRHS, MRI);
6536 if (!KnownNonZeroSide || !KnownNonZeroSide->Value.isNonZero())
6537 return false;
6540 MatchInfo = [=](MachineIRBuilder &B) {
6541 B.buildInstr(Opc, {Dst}, {CmpLHS, CmpRHS});
6543 return true;
6546 bool CombinerHelper::matchSimplifySelectToMinMax(MachineInstr &MI,
6547 BuildFnTy &MatchInfo) const {
6548 // TODO: Handle integer cases.
6549 assert(MI.getOpcode() == TargetOpcode::G_SELECT);
6550 // Condition may be fed by a truncated compare.
6551 Register Cond = MI.getOperand(1).getReg();
6552 Register MaybeTrunc;
6553 if (mi_match(Cond, MRI, m_OneNonDBGUse(m_GTrunc(m_Reg(MaybeTrunc)))))
6554 Cond = MaybeTrunc;
6555 Register Dst = MI.getOperand(0).getReg();
6556 Register TrueVal = MI.getOperand(2).getReg();
6557 Register FalseVal = MI.getOperand(3).getReg();
6558 return matchFPSelectToMinMax(Dst, Cond, TrueVal, FalseVal, MatchInfo);
6561 bool CombinerHelper::matchRedundantBinOpInEquality(MachineInstr &MI,
6562 BuildFnTy &MatchInfo) const {
6563 assert(MI.getOpcode() == TargetOpcode::G_ICMP);
6564 // (X + Y) == X --> Y == 0
6565 // (X + Y) != X --> Y != 0
6566 // (X - Y) == X --> Y == 0
6567 // (X - Y) != X --> Y != 0
6568 // (X ^ Y) == X --> Y == 0
6569 // (X ^ Y) != X --> Y != 0
6570 Register Dst = MI.getOperand(0).getReg();
6571 CmpInst::Predicate Pred;
6572 Register X, Y, OpLHS, OpRHS;
6573 bool MatchedSub = mi_match(
6574 Dst, MRI,
6575 m_c_GICmp(m_Pred(Pred), m_Reg(X), m_GSub(m_Reg(OpLHS), m_Reg(Y))));
6576 if (MatchedSub && X != OpLHS)
6577 return false;
6578 if (!MatchedSub) {
6579 if (!mi_match(Dst, MRI,
6580 m_c_GICmp(m_Pred(Pred), m_Reg(X),
6581 m_any_of(m_GAdd(m_Reg(OpLHS), m_Reg(OpRHS)),
6582 m_GXor(m_Reg(OpLHS), m_Reg(OpRHS))))))
6583 return false;
6584 Y = X == OpLHS ? OpRHS : X == OpRHS ? OpLHS : Register();
6586 MatchInfo = [=](MachineIRBuilder &B) {
6587 auto Zero = B.buildConstant(MRI.getType(Y), 0);
6588 B.buildICmp(Pred, Dst, Y, Zero);
6590 return CmpInst::isEquality(Pred) && Y.isValid();
6593 bool CombinerHelper::matchShiftsTooBig(MachineInstr &MI) const {
6594 Register ShiftReg = MI.getOperand(2).getReg();
6595 LLT ResTy = MRI.getType(MI.getOperand(0).getReg());
6596 auto IsShiftTooBig = [&](const Constant *C) {
6597 auto *CI = dyn_cast<ConstantInt>(C);
6598 return CI && CI->uge(ResTy.getScalarSizeInBits());
6600 return matchUnaryPredicate(MRI, ShiftReg, IsShiftTooBig);
6603 bool CombinerHelper::matchCommuteConstantToRHS(MachineInstr &MI) const {
6604 unsigned LHSOpndIdx = 1;
6605 unsigned RHSOpndIdx = 2;
6606 switch (MI.getOpcode()) {
6607 case TargetOpcode::G_UADDO:
6608 case TargetOpcode::G_SADDO:
6609 case TargetOpcode::G_UMULO:
6610 case TargetOpcode::G_SMULO:
6611 LHSOpndIdx = 2;
6612 RHSOpndIdx = 3;
6613 break;
6614 default:
6615 break;
6617 Register LHS = MI.getOperand(LHSOpndIdx).getReg();
6618 Register RHS = MI.getOperand(RHSOpndIdx).getReg();
6619 if (!getIConstantVRegVal(LHS, MRI)) {
6620 // Skip commuting if LHS is not a constant. But, LHS may be a
6621 // G_CONSTANT_FOLD_BARRIER. If so we commute as long as we don't already
6622 // have a constant on the RHS.
6623 if (MRI.getVRegDef(LHS)->getOpcode() !=
6624 TargetOpcode::G_CONSTANT_FOLD_BARRIER)
6625 return false;
6627 // Commute as long as RHS is not a constant or G_CONSTANT_FOLD_BARRIER.
6628 return MRI.getVRegDef(RHS)->getOpcode() !=
6629 TargetOpcode::G_CONSTANT_FOLD_BARRIER &&
6630 !getIConstantVRegVal(RHS, MRI);
6633 bool CombinerHelper::matchCommuteFPConstantToRHS(MachineInstr &MI) const {
6634 Register LHS = MI.getOperand(1).getReg();
6635 Register RHS = MI.getOperand(2).getReg();
6636 std::optional<FPValueAndVReg> ValAndVReg;
6637 if (!mi_match(LHS, MRI, m_GFCstOrSplat(ValAndVReg)))
6638 return false;
6639 return !mi_match(RHS, MRI, m_GFCstOrSplat(ValAndVReg));
6642 void CombinerHelper::applyCommuteBinOpOperands(MachineInstr &MI) const {
6643 Observer.changingInstr(MI);
6644 unsigned LHSOpndIdx = 1;
6645 unsigned RHSOpndIdx = 2;
6646 switch (MI.getOpcode()) {
6647 case TargetOpcode::G_UADDO:
6648 case TargetOpcode::G_SADDO:
6649 case TargetOpcode::G_UMULO:
6650 case TargetOpcode::G_SMULO:
6651 LHSOpndIdx = 2;
6652 RHSOpndIdx = 3;
6653 break;
6654 default:
6655 break;
6657 Register LHSReg = MI.getOperand(LHSOpndIdx).getReg();
6658 Register RHSReg = MI.getOperand(RHSOpndIdx).getReg();
6659 MI.getOperand(LHSOpndIdx).setReg(RHSReg);
6660 MI.getOperand(RHSOpndIdx).setReg(LHSReg);
6661 Observer.changedInstr(MI);
6664 bool CombinerHelper::isOneOrOneSplat(Register Src, bool AllowUndefs) const {
6665 LLT SrcTy = MRI.getType(Src);
6666 if (SrcTy.isFixedVector())
6667 return isConstantSplatVector(Src, 1, AllowUndefs);
6668 if (SrcTy.isScalar()) {
6669 if (AllowUndefs && getOpcodeDef<GImplicitDef>(Src, MRI) != nullptr)
6670 return true;
6671 auto IConstant = getIConstantVRegValWithLookThrough(Src, MRI);
6672 return IConstant && IConstant->Value == 1;
6674 return false; // scalable vector
6677 bool CombinerHelper::isZeroOrZeroSplat(Register Src, bool AllowUndefs) const {
6678 LLT SrcTy = MRI.getType(Src);
6679 if (SrcTy.isFixedVector())
6680 return isConstantSplatVector(Src, 0, AllowUndefs);
6681 if (SrcTy.isScalar()) {
6682 if (AllowUndefs && getOpcodeDef<GImplicitDef>(Src, MRI) != nullptr)
6683 return true;
6684 auto IConstant = getIConstantVRegValWithLookThrough(Src, MRI);
6685 return IConstant && IConstant->Value == 0;
6687 return false; // scalable vector
6690 // Ignores COPYs during conformance checks.
6691 // FIXME scalable vectors.
6692 bool CombinerHelper::isConstantSplatVector(Register Src, int64_t SplatValue,
6693 bool AllowUndefs) const {
6694 GBuildVector *BuildVector = getOpcodeDef<GBuildVector>(Src, MRI);
6695 if (!BuildVector)
6696 return false;
6697 unsigned NumSources = BuildVector->getNumSources();
6699 for (unsigned I = 0; I < NumSources; ++I) {
6700 GImplicitDef *ImplicitDef =
6701 getOpcodeDef<GImplicitDef>(BuildVector->getSourceReg(I), MRI);
6702 if (ImplicitDef && AllowUndefs)
6703 continue;
6704 if (ImplicitDef && !AllowUndefs)
6705 return false;
6706 std::optional<ValueAndVReg> IConstant =
6707 getIConstantVRegValWithLookThrough(BuildVector->getSourceReg(I), MRI);
6708 if (IConstant && IConstant->Value == SplatValue)
6709 continue;
6710 return false;
6712 return true;
6715 // Ignores COPYs during lookups.
6716 // FIXME scalable vectors
6717 std::optional<APInt>
6718 CombinerHelper::getConstantOrConstantSplatVector(Register Src) const {
6719 auto IConstant = getIConstantVRegValWithLookThrough(Src, MRI);
6720 if (IConstant)
6721 return IConstant->Value;
6723 GBuildVector *BuildVector = getOpcodeDef<GBuildVector>(Src, MRI);
6724 if (!BuildVector)
6725 return std::nullopt;
6726 unsigned NumSources = BuildVector->getNumSources();
6728 std::optional<APInt> Value = std::nullopt;
6729 for (unsigned I = 0; I < NumSources; ++I) {
6730 std::optional<ValueAndVReg> IConstant =
6731 getIConstantVRegValWithLookThrough(BuildVector->getSourceReg(I), MRI);
6732 if (!IConstant)
6733 return std::nullopt;
6734 if (!Value)
6735 Value = IConstant->Value;
6736 else if (*Value != IConstant->Value)
6737 return std::nullopt;
6739 return Value;
6742 // FIXME G_SPLAT_VECTOR
6743 bool CombinerHelper::isConstantOrConstantVectorI(Register Src) const {
6744 auto IConstant = getIConstantVRegValWithLookThrough(Src, MRI);
6745 if (IConstant)
6746 return true;
6748 GBuildVector *BuildVector = getOpcodeDef<GBuildVector>(Src, MRI);
6749 if (!BuildVector)
6750 return false;
6752 unsigned NumSources = BuildVector->getNumSources();
6753 for (unsigned I = 0; I < NumSources; ++I) {
6754 std::optional<ValueAndVReg> IConstant =
6755 getIConstantVRegValWithLookThrough(BuildVector->getSourceReg(I), MRI);
6756 if (!IConstant)
6757 return false;
6759 return true;
6762 // TODO: use knownbits to determine zeros
6763 bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select,
6764 BuildFnTy &MatchInfo) const {
6765 uint32_t Flags = Select->getFlags();
6766 Register Dest = Select->getReg(0);
6767 Register Cond = Select->getCondReg();
6768 Register True = Select->getTrueReg();
6769 Register False = Select->getFalseReg();
6770 LLT CondTy = MRI.getType(Select->getCondReg());
6771 LLT TrueTy = MRI.getType(Select->getTrueReg());
6773 // We only do this combine for scalar boolean conditions.
6774 if (CondTy != LLT::scalar(1))
6775 return false;
6777 if (TrueTy.isPointer())
6778 return false;
6780 // Both are scalars.
6781 std::optional<ValueAndVReg> TrueOpt =
6782 getIConstantVRegValWithLookThrough(True, MRI);
6783 std::optional<ValueAndVReg> FalseOpt =
6784 getIConstantVRegValWithLookThrough(False, MRI);
6786 if (!TrueOpt || !FalseOpt)
6787 return false;
6789 APInt TrueValue = TrueOpt->Value;
6790 APInt FalseValue = FalseOpt->Value;
6792 // select Cond, 1, 0 --> zext (Cond)
6793 if (TrueValue.isOne() && FalseValue.isZero()) {
6794 MatchInfo = [=](MachineIRBuilder &B) {
6795 B.setInstrAndDebugLoc(*Select);
6796 B.buildZExtOrTrunc(Dest, Cond);
6798 return true;
6801 // select Cond, -1, 0 --> sext (Cond)
6802 if (TrueValue.isAllOnes() && FalseValue.isZero()) {
6803 MatchInfo = [=](MachineIRBuilder &B) {
6804 B.setInstrAndDebugLoc(*Select);
6805 B.buildSExtOrTrunc(Dest, Cond);
6807 return true;
6810 // select Cond, 0, 1 --> zext (!Cond)
6811 if (TrueValue.isZero() && FalseValue.isOne()) {
6812 MatchInfo = [=](MachineIRBuilder &B) {
6813 B.setInstrAndDebugLoc(*Select);
6814 Register Inner = MRI.createGenericVirtualRegister(CondTy);
6815 B.buildNot(Inner, Cond);
6816 B.buildZExtOrTrunc(Dest, Inner);
6818 return true;
6821 // select Cond, 0, -1 --> sext (!Cond)
6822 if (TrueValue.isZero() && FalseValue.isAllOnes()) {
6823 MatchInfo = [=](MachineIRBuilder &B) {
6824 B.setInstrAndDebugLoc(*Select);
6825 Register Inner = MRI.createGenericVirtualRegister(CondTy);
6826 B.buildNot(Inner, Cond);
6827 B.buildSExtOrTrunc(Dest, Inner);
6829 return true;
6832 // select Cond, C1, C1-1 --> add (zext Cond), C1-1
6833 if (TrueValue - 1 == FalseValue) {
6834 MatchInfo = [=](MachineIRBuilder &B) {
6835 B.setInstrAndDebugLoc(*Select);
6836 Register Inner = MRI.createGenericVirtualRegister(TrueTy);
6837 B.buildZExtOrTrunc(Inner, Cond);
6838 B.buildAdd(Dest, Inner, False);
6840 return true;
6843 // select Cond, C1, C1+1 --> add (sext Cond), C1+1
6844 if (TrueValue + 1 == FalseValue) {
6845 MatchInfo = [=](MachineIRBuilder &B) {
6846 B.setInstrAndDebugLoc(*Select);
6847 Register Inner = MRI.createGenericVirtualRegister(TrueTy);
6848 B.buildSExtOrTrunc(Inner, Cond);
6849 B.buildAdd(Dest, Inner, False);
6851 return true;
6854 // select Cond, Pow2, 0 --> (zext Cond) << log2(Pow2)
6855 if (TrueValue.isPowerOf2() && FalseValue.isZero()) {
6856 MatchInfo = [=](MachineIRBuilder &B) {
6857 B.setInstrAndDebugLoc(*Select);
6858 Register Inner = MRI.createGenericVirtualRegister(TrueTy);
6859 B.buildZExtOrTrunc(Inner, Cond);
6860 // The shift amount must be scalar.
6861 LLT ShiftTy = TrueTy.isVector() ? TrueTy.getElementType() : TrueTy;
6862 auto ShAmtC = B.buildConstant(ShiftTy, TrueValue.exactLogBase2());
6863 B.buildShl(Dest, Inner, ShAmtC, Flags);
6865 return true;
6868 // select Cond, 0, Pow2 --> (zext (!Cond)) << log2(Pow2)
6869 if (FalseValue.isPowerOf2() && TrueValue.isZero()) {
6870 MatchInfo = [=](MachineIRBuilder &B) {
6871 B.setInstrAndDebugLoc(*Select);
6872 Register Not = MRI.createGenericVirtualRegister(CondTy);
6873 B.buildNot(Not, Cond);
6874 Register Inner = MRI.createGenericVirtualRegister(TrueTy);
6875 B.buildZExtOrTrunc(Inner, Not);
6876 // The shift amount must be scalar.
6877 LLT ShiftTy = TrueTy.isVector() ? TrueTy.getElementType() : TrueTy;
6878 auto ShAmtC = B.buildConstant(ShiftTy, FalseValue.exactLogBase2());
6879 B.buildShl(Dest, Inner, ShAmtC, Flags);
6881 return true;
6884 // select Cond, -1, C --> or (sext Cond), C
6885 if (TrueValue.isAllOnes()) {
6886 MatchInfo = [=](MachineIRBuilder &B) {
6887 B.setInstrAndDebugLoc(*Select);
6888 Register Inner = MRI.createGenericVirtualRegister(TrueTy);
6889 B.buildSExtOrTrunc(Inner, Cond);
6890 B.buildOr(Dest, Inner, False, Flags);
6892 return true;
6895 // select Cond, C, -1 --> or (sext (not Cond)), C
6896 if (FalseValue.isAllOnes()) {
6897 MatchInfo = [=](MachineIRBuilder &B) {
6898 B.setInstrAndDebugLoc(*Select);
6899 Register Not = MRI.createGenericVirtualRegister(CondTy);
6900 B.buildNot(Not, Cond);
6901 Register Inner = MRI.createGenericVirtualRegister(TrueTy);
6902 B.buildSExtOrTrunc(Inner, Not);
6903 B.buildOr(Dest, Inner, True, Flags);
6905 return true;
6908 return false;
6911 // TODO: use knownbits to determine zeros
6912 bool CombinerHelper::tryFoldBoolSelectToLogic(GSelect *Select,
6913 BuildFnTy &MatchInfo) const {
6914 uint32_t Flags = Select->getFlags();
6915 Register DstReg = Select->getReg(0);
6916 Register Cond = Select->getCondReg();
6917 Register True = Select->getTrueReg();
6918 Register False = Select->getFalseReg();
6919 LLT CondTy = MRI.getType(Select->getCondReg());
6920 LLT TrueTy = MRI.getType(Select->getTrueReg());
6922 // Boolean or fixed vector of booleans.
6923 if (CondTy.isScalableVector() ||
6924 (CondTy.isFixedVector() &&
6925 CondTy.getElementType().getScalarSizeInBits() != 1) ||
6926 CondTy.getScalarSizeInBits() != 1)
6927 return false;
6929 if (CondTy != TrueTy)
6930 return false;
6932 // select Cond, Cond, F --> or Cond, F
6933 // select Cond, 1, F --> or Cond, F
6934 if ((Cond == True) || isOneOrOneSplat(True, /* AllowUndefs */ true)) {
6935 MatchInfo = [=](MachineIRBuilder &B) {
6936 B.setInstrAndDebugLoc(*Select);
6937 Register Ext = MRI.createGenericVirtualRegister(TrueTy);
6938 B.buildZExtOrTrunc(Ext, Cond);
6939 auto FreezeFalse = B.buildFreeze(TrueTy, False);
6940 B.buildOr(DstReg, Ext, FreezeFalse, Flags);
6942 return true;
6945 // select Cond, T, Cond --> and Cond, T
6946 // select Cond, T, 0 --> and Cond, T
6947 if ((Cond == False) || isZeroOrZeroSplat(False, /* AllowUndefs */ true)) {
6948 MatchInfo = [=](MachineIRBuilder &B) {
6949 B.setInstrAndDebugLoc(*Select);
6950 Register Ext = MRI.createGenericVirtualRegister(TrueTy);
6951 B.buildZExtOrTrunc(Ext, Cond);
6952 auto FreezeTrue = B.buildFreeze(TrueTy, True);
6953 B.buildAnd(DstReg, Ext, FreezeTrue);
6955 return true;
6958 // select Cond, T, 1 --> or (not Cond), T
6959 if (isOneOrOneSplat(False, /* AllowUndefs */ true)) {
6960 MatchInfo = [=](MachineIRBuilder &B) {
6961 B.setInstrAndDebugLoc(*Select);
6962 // First the not.
6963 Register Inner = MRI.createGenericVirtualRegister(CondTy);
6964 B.buildNot(Inner, Cond);
6965 // Then an ext to match the destination register.
6966 Register Ext = MRI.createGenericVirtualRegister(TrueTy);
6967 B.buildZExtOrTrunc(Ext, Inner);
6968 auto FreezeTrue = B.buildFreeze(TrueTy, True);
6969 B.buildOr(DstReg, Ext, FreezeTrue, Flags);
6971 return true;
6974 // select Cond, 0, F --> and (not Cond), F
6975 if (isZeroOrZeroSplat(True, /* AllowUndefs */ true)) {
6976 MatchInfo = [=](MachineIRBuilder &B) {
6977 B.setInstrAndDebugLoc(*Select);
6978 // First the not.
6979 Register Inner = MRI.createGenericVirtualRegister(CondTy);
6980 B.buildNot(Inner, Cond);
6981 // Then an ext to match the destination register.
6982 Register Ext = MRI.createGenericVirtualRegister(TrueTy);
6983 B.buildZExtOrTrunc(Ext, Inner);
6984 auto FreezeFalse = B.buildFreeze(TrueTy, False);
6985 B.buildAnd(DstReg, Ext, FreezeFalse);
6987 return true;
6990 return false;
6993 bool CombinerHelper::matchSelectIMinMax(const MachineOperand &MO,
6994 BuildFnTy &MatchInfo) const {
6995 GSelect *Select = cast<GSelect>(MRI.getVRegDef(MO.getReg()));
6996 GICmp *Cmp = cast<GICmp>(MRI.getVRegDef(Select->getCondReg()));
6998 Register DstReg = Select->getReg(0);
6999 Register True = Select->getTrueReg();
7000 Register False = Select->getFalseReg();
7001 LLT DstTy = MRI.getType(DstReg);
7003 if (DstTy.isPointer())
7004 return false;
7006 // We want to fold the icmp and replace the select.
7007 if (!MRI.hasOneNonDBGUse(Cmp->getReg(0)))
7008 return false;
7010 CmpInst::Predicate Pred = Cmp->getCond();
7011 // We need a larger or smaller predicate for
7012 // canonicalization.
7013 if (CmpInst::isEquality(Pred))
7014 return false;
7016 Register CmpLHS = Cmp->getLHSReg();
7017 Register CmpRHS = Cmp->getRHSReg();
7019 // We can swap CmpLHS and CmpRHS for higher hitrate.
7020 if (True == CmpRHS && False == CmpLHS) {
7021 std::swap(CmpLHS, CmpRHS);
7022 Pred = CmpInst::getSwappedPredicate(Pred);
7025 // (icmp X, Y) ? X : Y -> integer minmax.
7026 // see matchSelectPattern in ValueTracking.
7027 // Legality between G_SELECT and integer minmax can differ.
7028 if (True != CmpLHS || False != CmpRHS)
7029 return false;
7031 switch (Pred) {
7032 case ICmpInst::ICMP_UGT:
7033 case ICmpInst::ICMP_UGE: {
7034 if (!isLegalOrBeforeLegalizer({TargetOpcode::G_UMAX, DstTy}))
7035 return false;
7036 MatchInfo = [=](MachineIRBuilder &B) { B.buildUMax(DstReg, True, False); };
7037 return true;
7039 case ICmpInst::ICMP_SGT:
7040 case ICmpInst::ICMP_SGE: {
7041 if (!isLegalOrBeforeLegalizer({TargetOpcode::G_SMAX, DstTy}))
7042 return false;
7043 MatchInfo = [=](MachineIRBuilder &B) { B.buildSMax(DstReg, True, False); };
7044 return true;
7046 case ICmpInst::ICMP_ULT:
7047 case ICmpInst::ICMP_ULE: {
7048 if (!isLegalOrBeforeLegalizer({TargetOpcode::G_UMIN, DstTy}))
7049 return false;
7050 MatchInfo = [=](MachineIRBuilder &B) { B.buildUMin(DstReg, True, False); };
7051 return true;
7053 case ICmpInst::ICMP_SLT:
7054 case ICmpInst::ICMP_SLE: {
7055 if (!isLegalOrBeforeLegalizer({TargetOpcode::G_SMIN, DstTy}))
7056 return false;
7057 MatchInfo = [=](MachineIRBuilder &B) { B.buildSMin(DstReg, True, False); };
7058 return true;
7060 default:
7061 return false;
7065 // (neg (min/max x, (neg x))) --> (max/min x, (neg x))
7066 bool CombinerHelper::matchSimplifyNegMinMax(MachineInstr &MI,
7067 BuildFnTy &MatchInfo) const {
7068 assert(MI.getOpcode() == TargetOpcode::G_SUB);
7069 Register DestReg = MI.getOperand(0).getReg();
7070 LLT DestTy = MRI.getType(DestReg);
7072 Register X;
7073 Register Sub0;
7074 auto NegPattern = m_all_of(m_Neg(m_DeferredReg(X)), m_Reg(Sub0));
7075 if (mi_match(DestReg, MRI,
7076 m_Neg(m_OneUse(m_any_of(m_GSMin(m_Reg(X), NegPattern),
7077 m_GSMax(m_Reg(X), NegPattern),
7078 m_GUMin(m_Reg(X), NegPattern),
7079 m_GUMax(m_Reg(X), NegPattern)))))) {
7080 MachineInstr *MinMaxMI = MRI.getVRegDef(MI.getOperand(2).getReg());
7081 unsigned NewOpc = getInverseGMinMaxOpcode(MinMaxMI->getOpcode());
7082 if (isLegal({NewOpc, {DestTy}})) {
7083 MatchInfo = [=](MachineIRBuilder &B) {
7084 B.buildInstr(NewOpc, {DestReg}, {X, Sub0});
7086 return true;
7090 return false;
7093 bool CombinerHelper::matchSelect(MachineInstr &MI, BuildFnTy &MatchInfo) const {
7094 GSelect *Select = cast<GSelect>(&MI);
7096 if (tryFoldSelectOfConstants(Select, MatchInfo))
7097 return true;
7099 if (tryFoldBoolSelectToLogic(Select, MatchInfo))
7100 return true;
7102 return false;
7105 /// Fold (icmp Pred1 V1, C1) && (icmp Pred2 V2, C2)
7106 /// or (icmp Pred1 V1, C1) || (icmp Pred2 V2, C2)
7107 /// into a single comparison using range-based reasoning.
7108 /// see InstCombinerImpl::foldAndOrOfICmpsUsingRanges.
7109 bool CombinerHelper::tryFoldAndOrOrICmpsUsingRanges(
7110 GLogicalBinOp *Logic, BuildFnTy &MatchInfo) const {
7111 assert(Logic->getOpcode() != TargetOpcode::G_XOR && "unexpected xor");
7112 bool IsAnd = Logic->getOpcode() == TargetOpcode::G_AND;
7113 Register DstReg = Logic->getReg(0);
7114 Register LHS = Logic->getLHSReg();
7115 Register RHS = Logic->getRHSReg();
7116 unsigned Flags = Logic->getFlags();
7118 // We need an G_ICMP on the LHS register.
7119 GICmp *Cmp1 = getOpcodeDef<GICmp>(LHS, MRI);
7120 if (!Cmp1)
7121 return false;
7123 // We need an G_ICMP on the RHS register.
7124 GICmp *Cmp2 = getOpcodeDef<GICmp>(RHS, MRI);
7125 if (!Cmp2)
7126 return false;
7128 // We want to fold the icmps.
7129 if (!MRI.hasOneNonDBGUse(Cmp1->getReg(0)) ||
7130 !MRI.hasOneNonDBGUse(Cmp2->getReg(0)))
7131 return false;
7133 APInt C1;
7134 APInt C2;
7135 std::optional<ValueAndVReg> MaybeC1 =
7136 getIConstantVRegValWithLookThrough(Cmp1->getRHSReg(), MRI);
7137 if (!MaybeC1)
7138 return false;
7139 C1 = MaybeC1->Value;
7141 std::optional<ValueAndVReg> MaybeC2 =
7142 getIConstantVRegValWithLookThrough(Cmp2->getRHSReg(), MRI);
7143 if (!MaybeC2)
7144 return false;
7145 C2 = MaybeC2->Value;
7147 Register R1 = Cmp1->getLHSReg();
7148 Register R2 = Cmp2->getLHSReg();
7149 CmpInst::Predicate Pred1 = Cmp1->getCond();
7150 CmpInst::Predicate Pred2 = Cmp2->getCond();
7151 LLT CmpTy = MRI.getType(Cmp1->getReg(0));
7152 LLT CmpOperandTy = MRI.getType(R1);
7154 if (CmpOperandTy.isPointer())
7155 return false;
7157 // We build ands, adds, and constants of type CmpOperandTy.
7158 // They must be legal to build.
7159 if (!isLegalOrBeforeLegalizer({TargetOpcode::G_AND, CmpOperandTy}) ||
7160 !isLegalOrBeforeLegalizer({TargetOpcode::G_ADD, CmpOperandTy}) ||
7161 !isConstantLegalOrBeforeLegalizer(CmpOperandTy))
7162 return false;
7164 // Look through add of a constant offset on R1, R2, or both operands. This
7165 // allows us to interpret the R + C' < C'' range idiom into a proper range.
7166 std::optional<APInt> Offset1;
7167 std::optional<APInt> Offset2;
7168 if (R1 != R2) {
7169 if (GAdd *Add = getOpcodeDef<GAdd>(R1, MRI)) {
7170 std::optional<ValueAndVReg> MaybeOffset1 =
7171 getIConstantVRegValWithLookThrough(Add->getRHSReg(), MRI);
7172 if (MaybeOffset1) {
7173 R1 = Add->getLHSReg();
7174 Offset1 = MaybeOffset1->Value;
7177 if (GAdd *Add = getOpcodeDef<GAdd>(R2, MRI)) {
7178 std::optional<ValueAndVReg> MaybeOffset2 =
7179 getIConstantVRegValWithLookThrough(Add->getRHSReg(), MRI);
7180 if (MaybeOffset2) {
7181 R2 = Add->getLHSReg();
7182 Offset2 = MaybeOffset2->Value;
7187 if (R1 != R2)
7188 return false;
7190 // We calculate the icmp ranges including maybe offsets.
7191 ConstantRange CR1 = ConstantRange::makeExactICmpRegion(
7192 IsAnd ? ICmpInst::getInversePredicate(Pred1) : Pred1, C1);
7193 if (Offset1)
7194 CR1 = CR1.subtract(*Offset1);
7196 ConstantRange CR2 = ConstantRange::makeExactICmpRegion(
7197 IsAnd ? ICmpInst::getInversePredicate(Pred2) : Pred2, C2);
7198 if (Offset2)
7199 CR2 = CR2.subtract(*Offset2);
7201 bool CreateMask = false;
7202 APInt LowerDiff;
7203 std::optional<ConstantRange> CR = CR1.exactUnionWith(CR2);
7204 if (!CR) {
7205 // We need non-wrapping ranges.
7206 if (CR1.isWrappedSet() || CR2.isWrappedSet())
7207 return false;
7209 // Check whether we have equal-size ranges that only differ by one bit.
7210 // In that case we can apply a mask to map one range onto the other.
7211 LowerDiff = CR1.getLower() ^ CR2.getLower();
7212 APInt UpperDiff = (CR1.getUpper() - 1) ^ (CR2.getUpper() - 1);
7213 APInt CR1Size = CR1.getUpper() - CR1.getLower();
7214 if (!LowerDiff.isPowerOf2() || LowerDiff != UpperDiff ||
7215 CR1Size != CR2.getUpper() - CR2.getLower())
7216 return false;
7218 CR = CR1.getLower().ult(CR2.getLower()) ? CR1 : CR2;
7219 CreateMask = true;
7222 if (IsAnd)
7223 CR = CR->inverse();
7225 CmpInst::Predicate NewPred;
7226 APInt NewC, Offset;
7227 CR->getEquivalentICmp(NewPred, NewC, Offset);
7229 // We take the result type of one of the original icmps, CmpTy, for
7230 // the to be build icmp. The operand type, CmpOperandTy, is used for
7231 // the other instructions and constants to be build. The types of
7232 // the parameters and output are the same for add and and. CmpTy
7233 // and the type of DstReg might differ. That is why we zext or trunc
7234 // the icmp into the destination register.
7236 MatchInfo = [=](MachineIRBuilder &B) {
7237 if (CreateMask && Offset != 0) {
7238 auto TildeLowerDiff = B.buildConstant(CmpOperandTy, ~LowerDiff);
7239 auto And = B.buildAnd(CmpOperandTy, R1, TildeLowerDiff); // the mask.
7240 auto OffsetC = B.buildConstant(CmpOperandTy, Offset);
7241 auto Add = B.buildAdd(CmpOperandTy, And, OffsetC, Flags);
7242 auto NewCon = B.buildConstant(CmpOperandTy, NewC);
7243 auto ICmp = B.buildICmp(NewPred, CmpTy, Add, NewCon);
7244 B.buildZExtOrTrunc(DstReg, ICmp);
7245 } else if (CreateMask && Offset == 0) {
7246 auto TildeLowerDiff = B.buildConstant(CmpOperandTy, ~LowerDiff);
7247 auto And = B.buildAnd(CmpOperandTy, R1, TildeLowerDiff); // the mask.
7248 auto NewCon = B.buildConstant(CmpOperandTy, NewC);
7249 auto ICmp = B.buildICmp(NewPred, CmpTy, And, NewCon);
7250 B.buildZExtOrTrunc(DstReg, ICmp);
7251 } else if (!CreateMask && Offset != 0) {
7252 auto OffsetC = B.buildConstant(CmpOperandTy, Offset);
7253 auto Add = B.buildAdd(CmpOperandTy, R1, OffsetC, Flags);
7254 auto NewCon = B.buildConstant(CmpOperandTy, NewC);
7255 auto ICmp = B.buildICmp(NewPred, CmpTy, Add, NewCon);
7256 B.buildZExtOrTrunc(DstReg, ICmp);
7257 } else if (!CreateMask && Offset == 0) {
7258 auto NewCon = B.buildConstant(CmpOperandTy, NewC);
7259 auto ICmp = B.buildICmp(NewPred, CmpTy, R1, NewCon);
7260 B.buildZExtOrTrunc(DstReg, ICmp);
7261 } else {
7262 llvm_unreachable("unexpected configuration of CreateMask and Offset");
7265 return true;
7268 bool CombinerHelper::tryFoldLogicOfFCmps(GLogicalBinOp *Logic,
7269 BuildFnTy &MatchInfo) const {
7270 assert(Logic->getOpcode() != TargetOpcode::G_XOR && "unexpecte xor");
7271 Register DestReg = Logic->getReg(0);
7272 Register LHS = Logic->getLHSReg();
7273 Register RHS = Logic->getRHSReg();
7274 bool IsAnd = Logic->getOpcode() == TargetOpcode::G_AND;
7276 // We need a compare on the LHS register.
7277 GFCmp *Cmp1 = getOpcodeDef<GFCmp>(LHS, MRI);
7278 if (!Cmp1)
7279 return false;
7281 // We need a compare on the RHS register.
7282 GFCmp *Cmp2 = getOpcodeDef<GFCmp>(RHS, MRI);
7283 if (!Cmp2)
7284 return false;
7286 LLT CmpTy = MRI.getType(Cmp1->getReg(0));
7287 LLT CmpOperandTy = MRI.getType(Cmp1->getLHSReg());
7289 // We build one fcmp, want to fold the fcmps, replace the logic op,
7290 // and the fcmps must have the same shape.
7291 if (!isLegalOrBeforeLegalizer(
7292 {TargetOpcode::G_FCMP, {CmpTy, CmpOperandTy}}) ||
7293 !MRI.hasOneNonDBGUse(Logic->getReg(0)) ||
7294 !MRI.hasOneNonDBGUse(Cmp1->getReg(0)) ||
7295 !MRI.hasOneNonDBGUse(Cmp2->getReg(0)) ||
7296 MRI.getType(Cmp1->getLHSReg()) != MRI.getType(Cmp2->getLHSReg()))
7297 return false;
7299 CmpInst::Predicate PredL = Cmp1->getCond();
7300 CmpInst::Predicate PredR = Cmp2->getCond();
7301 Register LHS0 = Cmp1->getLHSReg();
7302 Register LHS1 = Cmp1->getRHSReg();
7303 Register RHS0 = Cmp2->getLHSReg();
7304 Register RHS1 = Cmp2->getRHSReg();
7306 if (LHS0 == RHS1 && LHS1 == RHS0) {
7307 // Swap RHS operands to match LHS.
7308 PredR = CmpInst::getSwappedPredicate(PredR);
7309 std::swap(RHS0, RHS1);
7312 if (LHS0 == RHS0 && LHS1 == RHS1) {
7313 // We determine the new predicate.
7314 unsigned CmpCodeL = getFCmpCode(PredL);
7315 unsigned CmpCodeR = getFCmpCode(PredR);
7316 unsigned NewPred = IsAnd ? CmpCodeL & CmpCodeR : CmpCodeL | CmpCodeR;
7317 unsigned Flags = Cmp1->getFlags() | Cmp2->getFlags();
7318 MatchInfo = [=](MachineIRBuilder &B) {
7319 // The fcmp predicates fill the lower part of the enum.
7320 FCmpInst::Predicate Pred = static_cast<FCmpInst::Predicate>(NewPred);
7321 if (Pred == FCmpInst::FCMP_FALSE &&
7322 isConstantLegalOrBeforeLegalizer(CmpTy)) {
7323 auto False = B.buildConstant(CmpTy, 0);
7324 B.buildZExtOrTrunc(DestReg, False);
7325 } else if (Pred == FCmpInst::FCMP_TRUE &&
7326 isConstantLegalOrBeforeLegalizer(CmpTy)) {
7327 auto True =
7328 B.buildConstant(CmpTy, getICmpTrueVal(getTargetLowering(),
7329 CmpTy.isVector() /*isVector*/,
7330 true /*isFP*/));
7331 B.buildZExtOrTrunc(DestReg, True);
7332 } else { // We take the predicate without predicate optimizations.
7333 auto Cmp = B.buildFCmp(Pred, CmpTy, LHS0, LHS1, Flags);
7334 B.buildZExtOrTrunc(DestReg, Cmp);
7337 return true;
7340 return false;
7343 bool CombinerHelper::matchAnd(MachineInstr &MI, BuildFnTy &MatchInfo) const {
7344 GAnd *And = cast<GAnd>(&MI);
7346 if (tryFoldAndOrOrICmpsUsingRanges(And, MatchInfo))
7347 return true;
7349 if (tryFoldLogicOfFCmps(And, MatchInfo))
7350 return true;
7352 return false;
7355 bool CombinerHelper::matchOr(MachineInstr &MI, BuildFnTy &MatchInfo) const {
7356 GOr *Or = cast<GOr>(&MI);
7358 if (tryFoldAndOrOrICmpsUsingRanges(Or, MatchInfo))
7359 return true;
7361 if (tryFoldLogicOfFCmps(Or, MatchInfo))
7362 return true;
7364 return false;
7367 bool CombinerHelper::matchAddOverflow(MachineInstr &MI,
7368 BuildFnTy &MatchInfo) const {
7369 GAddCarryOut *Add = cast<GAddCarryOut>(&MI);
7371 // Addo has no flags
7372 Register Dst = Add->getReg(0);
7373 Register Carry = Add->getReg(1);
7374 Register LHS = Add->getLHSReg();
7375 Register RHS = Add->getRHSReg();
7376 bool IsSigned = Add->isSigned();
7377 LLT DstTy = MRI.getType(Dst);
7378 LLT CarryTy = MRI.getType(Carry);
7380 // Fold addo, if the carry is dead -> add, undef.
7381 if (MRI.use_nodbg_empty(Carry) &&
7382 isLegalOrBeforeLegalizer({TargetOpcode::G_ADD, {DstTy}})) {
7383 MatchInfo = [=](MachineIRBuilder &B) {
7384 B.buildAdd(Dst, LHS, RHS);
7385 B.buildUndef(Carry);
7387 return true;
7390 // Canonicalize constant to RHS.
7391 if (isConstantOrConstantVectorI(LHS) && !isConstantOrConstantVectorI(RHS)) {
7392 if (IsSigned) {
7393 MatchInfo = [=](MachineIRBuilder &B) {
7394 B.buildSAddo(Dst, Carry, RHS, LHS);
7396 return true;
7398 // !IsSigned
7399 MatchInfo = [=](MachineIRBuilder &B) {
7400 B.buildUAddo(Dst, Carry, RHS, LHS);
7402 return true;
7405 std::optional<APInt> MaybeLHS = getConstantOrConstantSplatVector(LHS);
7406 std::optional<APInt> MaybeRHS = getConstantOrConstantSplatVector(RHS);
7408 // Fold addo(c1, c2) -> c3, carry.
7409 if (MaybeLHS && MaybeRHS && isConstantLegalOrBeforeLegalizer(DstTy) &&
7410 isConstantLegalOrBeforeLegalizer(CarryTy)) {
7411 bool Overflow;
7412 APInt Result = IsSigned ? MaybeLHS->sadd_ov(*MaybeRHS, Overflow)
7413 : MaybeLHS->uadd_ov(*MaybeRHS, Overflow);
7414 MatchInfo = [=](MachineIRBuilder &B) {
7415 B.buildConstant(Dst, Result);
7416 B.buildConstant(Carry, Overflow);
7418 return true;
7421 // Fold (addo x, 0) -> x, no carry
7422 if (MaybeRHS && *MaybeRHS == 0 && isConstantLegalOrBeforeLegalizer(CarryTy)) {
7423 MatchInfo = [=](MachineIRBuilder &B) {
7424 B.buildCopy(Dst, LHS);
7425 B.buildConstant(Carry, 0);
7427 return true;
7430 // Given 2 constant operands whose sum does not overflow:
7431 // uaddo (X +nuw C0), C1 -> uaddo X, C0 + C1
7432 // saddo (X +nsw C0), C1 -> saddo X, C0 + C1
7433 GAdd *AddLHS = getOpcodeDef<GAdd>(LHS, MRI);
7434 if (MaybeRHS && AddLHS && MRI.hasOneNonDBGUse(Add->getReg(0)) &&
7435 ((IsSigned && AddLHS->getFlag(MachineInstr::MIFlag::NoSWrap)) ||
7436 (!IsSigned && AddLHS->getFlag(MachineInstr::MIFlag::NoUWrap)))) {
7437 std::optional<APInt> MaybeAddRHS =
7438 getConstantOrConstantSplatVector(AddLHS->getRHSReg());
7439 if (MaybeAddRHS) {
7440 bool Overflow;
7441 APInt NewC = IsSigned ? MaybeAddRHS->sadd_ov(*MaybeRHS, Overflow)
7442 : MaybeAddRHS->uadd_ov(*MaybeRHS, Overflow);
7443 if (!Overflow && isConstantLegalOrBeforeLegalizer(DstTy)) {
7444 if (IsSigned) {
7445 MatchInfo = [=](MachineIRBuilder &B) {
7446 auto ConstRHS = B.buildConstant(DstTy, NewC);
7447 B.buildSAddo(Dst, Carry, AddLHS->getLHSReg(), ConstRHS);
7449 return true;
7451 // !IsSigned
7452 MatchInfo = [=](MachineIRBuilder &B) {
7453 auto ConstRHS = B.buildConstant(DstTy, NewC);
7454 B.buildUAddo(Dst, Carry, AddLHS->getLHSReg(), ConstRHS);
7456 return true;
7461 // We try to combine addo to non-overflowing add.
7462 if (!isLegalOrBeforeLegalizer({TargetOpcode::G_ADD, {DstTy}}) ||
7463 !isConstantLegalOrBeforeLegalizer(CarryTy))
7464 return false;
7466 // We try to combine uaddo to non-overflowing add.
7467 if (!IsSigned) {
7468 ConstantRange CRLHS =
7469 ConstantRange::fromKnownBits(KB->getKnownBits(LHS), /*IsSigned=*/false);
7470 ConstantRange CRRHS =
7471 ConstantRange::fromKnownBits(KB->getKnownBits(RHS), /*IsSigned=*/false);
7473 switch (CRLHS.unsignedAddMayOverflow(CRRHS)) {
7474 case ConstantRange::OverflowResult::MayOverflow:
7475 return false;
7476 case ConstantRange::OverflowResult::NeverOverflows: {
7477 MatchInfo = [=](MachineIRBuilder &B) {
7478 B.buildAdd(Dst, LHS, RHS, MachineInstr::MIFlag::NoUWrap);
7479 B.buildConstant(Carry, 0);
7481 return true;
7483 case ConstantRange::OverflowResult::AlwaysOverflowsLow:
7484 case ConstantRange::OverflowResult::AlwaysOverflowsHigh: {
7485 MatchInfo = [=](MachineIRBuilder &B) {
7486 B.buildAdd(Dst, LHS, RHS);
7487 B.buildConstant(Carry, 1);
7489 return true;
7492 return false;
7495 // We try to combine saddo to non-overflowing add.
7497 // If LHS and RHS each have at least two sign bits, then there is no signed
7498 // overflow.
7499 if (KB->computeNumSignBits(RHS) > 1 && KB->computeNumSignBits(LHS) > 1) {
7500 MatchInfo = [=](MachineIRBuilder &B) {
7501 B.buildAdd(Dst, LHS, RHS, MachineInstr::MIFlag::NoSWrap);
7502 B.buildConstant(Carry, 0);
7504 return true;
7507 ConstantRange CRLHS =
7508 ConstantRange::fromKnownBits(KB->getKnownBits(LHS), /*IsSigned=*/true);
7509 ConstantRange CRRHS =
7510 ConstantRange::fromKnownBits(KB->getKnownBits(RHS), /*IsSigned=*/true);
7512 switch (CRLHS.signedAddMayOverflow(CRRHS)) {
7513 case ConstantRange::OverflowResult::MayOverflow:
7514 return false;
7515 case ConstantRange::OverflowResult::NeverOverflows: {
7516 MatchInfo = [=](MachineIRBuilder &B) {
7517 B.buildAdd(Dst, LHS, RHS, MachineInstr::MIFlag::NoSWrap);
7518 B.buildConstant(Carry, 0);
7520 return true;
7522 case ConstantRange::OverflowResult::AlwaysOverflowsLow:
7523 case ConstantRange::OverflowResult::AlwaysOverflowsHigh: {
7524 MatchInfo = [=](MachineIRBuilder &B) {
7525 B.buildAdd(Dst, LHS, RHS);
7526 B.buildConstant(Carry, 1);
7528 return true;
7532 return false;
7535 void CombinerHelper::applyBuildFnMO(const MachineOperand &MO,
7536 BuildFnTy &MatchInfo) const {
7537 MachineInstr *Root = getDefIgnoringCopies(MO.getReg(), MRI);
7538 MatchInfo(Builder);
7539 Root->eraseFromParent();
7542 bool CombinerHelper::matchFPowIExpansion(MachineInstr &MI,
7543 int64_t Exponent) const {
7544 bool OptForSize = MI.getMF()->getFunction().hasOptSize();
7545 return getTargetLowering().isBeneficialToExpandPowI(Exponent, OptForSize);
7548 void CombinerHelper::applyExpandFPowI(MachineInstr &MI,
7549 int64_t Exponent) const {
7550 auto [Dst, Base] = MI.getFirst2Regs();
7551 LLT Ty = MRI.getType(Dst);
7552 int64_t ExpVal = Exponent;
7554 if (ExpVal == 0) {
7555 Builder.buildFConstant(Dst, 1.0);
7556 MI.removeFromParent();
7557 return;
7560 if (ExpVal < 0)
7561 ExpVal = -ExpVal;
7563 // We use the simple binary decomposition method from SelectionDAG ExpandPowI
7564 // to generate the multiply sequence. There are more optimal ways to do this
7565 // (for example, powi(x,15) generates one more multiply than it should), but
7566 // this has the benefit of being both really simple and much better than a
7567 // libcall.
7568 std::optional<SrcOp> Res;
7569 SrcOp CurSquare = Base;
7570 while (ExpVal > 0) {
7571 if (ExpVal & 1) {
7572 if (!Res)
7573 Res = CurSquare;
7574 else
7575 Res = Builder.buildFMul(Ty, *Res, CurSquare);
7578 CurSquare = Builder.buildFMul(Ty, CurSquare, CurSquare);
7579 ExpVal >>= 1;
7582 // If the original exponent was negative, invert the result, producing
7583 // 1/(x*x*x).
7584 if (Exponent < 0)
7585 Res = Builder.buildFDiv(Ty, Builder.buildFConstant(Ty, 1.0), *Res,
7586 MI.getFlags());
7588 Builder.buildCopy(Dst, *Res);
7589 MI.eraseFromParent();
7592 bool CombinerHelper::matchFoldAPlusC1MinusC2(const MachineInstr &MI,
7593 BuildFnTy &MatchInfo) const {
7594 // fold (A+C1)-C2 -> A+(C1-C2)
7595 const GSub *Sub = cast<GSub>(&MI);
7596 GAdd *Add = cast<GAdd>(MRI.getVRegDef(Sub->getLHSReg()));
7598 if (!MRI.hasOneNonDBGUse(Add->getReg(0)))
7599 return false;
7601 APInt C2 = getIConstantFromReg(Sub->getRHSReg(), MRI);
7602 APInt C1 = getIConstantFromReg(Add->getRHSReg(), MRI);
7604 Register Dst = Sub->getReg(0);
7605 LLT DstTy = MRI.getType(Dst);
7607 MatchInfo = [=](MachineIRBuilder &B) {
7608 auto Const = B.buildConstant(DstTy, C1 - C2);
7609 B.buildAdd(Dst, Add->getLHSReg(), Const);
7612 return true;
7615 bool CombinerHelper::matchFoldC2MinusAPlusC1(const MachineInstr &MI,
7616 BuildFnTy &MatchInfo) const {
7617 // fold C2-(A+C1) -> (C2-C1)-A
7618 const GSub *Sub = cast<GSub>(&MI);
7619 GAdd *Add = cast<GAdd>(MRI.getVRegDef(Sub->getRHSReg()));
7621 if (!MRI.hasOneNonDBGUse(Add->getReg(0)))
7622 return false;
7624 APInt C2 = getIConstantFromReg(Sub->getLHSReg(), MRI);
7625 APInt C1 = getIConstantFromReg(Add->getRHSReg(), MRI);
7627 Register Dst = Sub->getReg(0);
7628 LLT DstTy = MRI.getType(Dst);
7630 MatchInfo = [=](MachineIRBuilder &B) {
7631 auto Const = B.buildConstant(DstTy, C2 - C1);
7632 B.buildSub(Dst, Const, Add->getLHSReg());
7635 return true;
7638 bool CombinerHelper::matchFoldAMinusC1MinusC2(const MachineInstr &MI,
7639 BuildFnTy &MatchInfo) const {
7640 // fold (A-C1)-C2 -> A-(C1+C2)
7641 const GSub *Sub1 = cast<GSub>(&MI);
7642 GSub *Sub2 = cast<GSub>(MRI.getVRegDef(Sub1->getLHSReg()));
7644 if (!MRI.hasOneNonDBGUse(Sub2->getReg(0)))
7645 return false;
7647 APInt C2 = getIConstantFromReg(Sub1->getRHSReg(), MRI);
7648 APInt C1 = getIConstantFromReg(Sub2->getRHSReg(), MRI);
7650 Register Dst = Sub1->getReg(0);
7651 LLT DstTy = MRI.getType(Dst);
7653 MatchInfo = [=](MachineIRBuilder &B) {
7654 auto Const = B.buildConstant(DstTy, C1 + C2);
7655 B.buildSub(Dst, Sub2->getLHSReg(), Const);
7658 return true;
7661 bool CombinerHelper::matchFoldC1Minus2MinusC2(const MachineInstr &MI,
7662 BuildFnTy &MatchInfo) const {
7663 // fold (C1-A)-C2 -> (C1-C2)-A
7664 const GSub *Sub1 = cast<GSub>(&MI);
7665 GSub *Sub2 = cast<GSub>(MRI.getVRegDef(Sub1->getLHSReg()));
7667 if (!MRI.hasOneNonDBGUse(Sub2->getReg(0)))
7668 return false;
7670 APInt C2 = getIConstantFromReg(Sub1->getRHSReg(), MRI);
7671 APInt C1 = getIConstantFromReg(Sub2->getLHSReg(), MRI);
7673 Register Dst = Sub1->getReg(0);
7674 LLT DstTy = MRI.getType(Dst);
7676 MatchInfo = [=](MachineIRBuilder &B) {
7677 auto Const = B.buildConstant(DstTy, C1 - C2);
7678 B.buildSub(Dst, Const, Sub2->getRHSReg());
7681 return true;
7684 bool CombinerHelper::matchFoldAMinusC1PlusC2(const MachineInstr &MI,
7685 BuildFnTy &MatchInfo) const {
7686 // fold ((A-C1)+C2) -> (A+(C2-C1))
7687 const GAdd *Add = cast<GAdd>(&MI);
7688 GSub *Sub = cast<GSub>(MRI.getVRegDef(Add->getLHSReg()));
7690 if (!MRI.hasOneNonDBGUse(Sub->getReg(0)))
7691 return false;
7693 APInt C2 = getIConstantFromReg(Add->getRHSReg(), MRI);
7694 APInt C1 = getIConstantFromReg(Sub->getRHSReg(), MRI);
7696 Register Dst = Add->getReg(0);
7697 LLT DstTy = MRI.getType(Dst);
7699 MatchInfo = [=](MachineIRBuilder &B) {
7700 auto Const = B.buildConstant(DstTy, C2 - C1);
7701 B.buildAdd(Dst, Sub->getLHSReg(), Const);
7704 return true;
7707 bool CombinerHelper::matchUnmergeValuesAnyExtBuildVector(
7708 const MachineInstr &MI, BuildFnTy &MatchInfo) const {
7709 const GUnmerge *Unmerge = cast<GUnmerge>(&MI);
7711 if (!MRI.hasOneNonDBGUse(Unmerge->getSourceReg()))
7712 return false;
7714 const MachineInstr *Source = MRI.getVRegDef(Unmerge->getSourceReg());
7716 LLT DstTy = MRI.getType(Unmerge->getReg(0));
7718 // $bv:_(<8 x s8>) = G_BUILD_VECTOR ....
7719 // $any:_(<8 x s16>) = G_ANYEXT $bv
7720 // $uv:_(<4 x s16>), $uv1:_(<4 x s16>) = G_UNMERGE_VALUES $any
7722 // ->
7724 // $any:_(s16) = G_ANYEXT $bv[0]
7725 // $any1:_(s16) = G_ANYEXT $bv[1]
7726 // $any2:_(s16) = G_ANYEXT $bv[2]
7727 // $any3:_(s16) = G_ANYEXT $bv[3]
7728 // $any4:_(s16) = G_ANYEXT $bv[4]
7729 // $any5:_(s16) = G_ANYEXT $bv[5]
7730 // $any6:_(s16) = G_ANYEXT $bv[6]
7731 // $any7:_(s16) = G_ANYEXT $bv[7]
7732 // $uv:_(<4 x s16>) = G_BUILD_VECTOR $any, $any1, $any2, $any3
7733 // $uv1:_(<4 x s16>) = G_BUILD_VECTOR $any4, $any5, $any6, $any7
7735 // We want to unmerge into vectors.
7736 if (!DstTy.isFixedVector())
7737 return false;
7739 const GAnyExt *Any = dyn_cast<GAnyExt>(Source);
7740 if (!Any)
7741 return false;
7743 const MachineInstr *NextSource = MRI.getVRegDef(Any->getSrcReg());
7745 if (const GBuildVector *BV = dyn_cast<GBuildVector>(NextSource)) {
7746 // G_UNMERGE_VALUES G_ANYEXT G_BUILD_VECTOR
7748 if (!MRI.hasOneNonDBGUse(BV->getReg(0)))
7749 return false;
7751 // FIXME: check element types?
7752 if (BV->getNumSources() % Unmerge->getNumDefs() != 0)
7753 return false;
7755 LLT BigBvTy = MRI.getType(BV->getReg(0));
7756 LLT SmallBvTy = DstTy;
7757 LLT SmallBvElemenTy = SmallBvTy.getElementType();
7759 if (!isLegalOrBeforeLegalizer(
7760 {TargetOpcode::G_BUILD_VECTOR, {SmallBvTy, SmallBvElemenTy}}))
7761 return false;
7763 // We check the legality of scalar anyext.
7764 if (!isLegalOrBeforeLegalizer(
7765 {TargetOpcode::G_ANYEXT,
7766 {SmallBvElemenTy, BigBvTy.getElementType()}}))
7767 return false;
7769 MatchInfo = [=](MachineIRBuilder &B) {
7770 // Build into each G_UNMERGE_VALUES def
7771 // a small build vector with anyext from the source build vector.
7772 for (unsigned I = 0; I < Unmerge->getNumDefs(); ++I) {
7773 SmallVector<Register> Ops;
7774 for (unsigned J = 0; J < SmallBvTy.getNumElements(); ++J) {
7775 Register SourceArray =
7776 BV->getSourceReg(I * SmallBvTy.getNumElements() + J);
7777 auto AnyExt = B.buildAnyExt(SmallBvElemenTy, SourceArray);
7778 Ops.push_back(AnyExt.getReg(0));
7780 B.buildBuildVector(Unmerge->getOperand(I).getReg(), Ops);
7783 return true;
7786 return false;
7789 bool CombinerHelper::matchShuffleUndefRHS(MachineInstr &MI,
7790 BuildFnTy &MatchInfo) const {
7792 bool Changed = false;
7793 auto &Shuffle = cast<GShuffleVector>(MI);
7794 ArrayRef<int> OrigMask = Shuffle.getMask();
7795 SmallVector<int, 16> NewMask;
7796 const LLT SrcTy = MRI.getType(Shuffle.getSrc1Reg());
7797 const unsigned NumSrcElems = SrcTy.isVector() ? SrcTy.getNumElements() : 1;
7798 const unsigned NumDstElts = OrigMask.size();
7799 for (unsigned i = 0; i != NumDstElts; ++i) {
7800 int Idx = OrigMask[i];
7801 if (Idx >= (int)NumSrcElems) {
7802 Idx = -1;
7803 Changed = true;
7805 NewMask.push_back(Idx);
7808 if (!Changed)
7809 return false;
7811 MatchInfo = [&, NewMask = std::move(NewMask)](MachineIRBuilder &B) {
7812 B.buildShuffleVector(MI.getOperand(0), MI.getOperand(1), MI.getOperand(2),
7813 std::move(NewMask));
7816 return true;
7819 static void commuteMask(MutableArrayRef<int> Mask, const unsigned NumElems) {
7820 const unsigned MaskSize = Mask.size();
7821 for (unsigned I = 0; I < MaskSize; ++I) {
7822 int Idx = Mask[I];
7823 if (Idx < 0)
7824 continue;
7826 if (Idx < (int)NumElems)
7827 Mask[I] = Idx + NumElems;
7828 else
7829 Mask[I] = Idx - NumElems;
7833 bool CombinerHelper::matchShuffleDisjointMask(MachineInstr &MI,
7834 BuildFnTy &MatchInfo) const {
7836 auto &Shuffle = cast<GShuffleVector>(MI);
7837 // If any of the two inputs is already undef, don't check the mask again to
7838 // prevent infinite loop
7839 if (getOpcodeDef(TargetOpcode::G_IMPLICIT_DEF, Shuffle.getSrc1Reg(), MRI))
7840 return false;
7842 if (getOpcodeDef(TargetOpcode::G_IMPLICIT_DEF, Shuffle.getSrc2Reg(), MRI))
7843 return false;
7845 const LLT DstTy = MRI.getType(Shuffle.getReg(0));
7846 const LLT Src1Ty = MRI.getType(Shuffle.getSrc1Reg());
7847 if (!isLegalOrBeforeLegalizer(
7848 {TargetOpcode::G_SHUFFLE_VECTOR, {DstTy, Src1Ty}}))
7849 return false;
7851 ArrayRef<int> Mask = Shuffle.getMask();
7852 const unsigned NumSrcElems = Src1Ty.isVector() ? Src1Ty.getNumElements() : 1;
7854 bool TouchesSrc1 = false;
7855 bool TouchesSrc2 = false;
7856 const unsigned NumElems = Mask.size();
7857 for (unsigned Idx = 0; Idx < NumElems; ++Idx) {
7858 if (Mask[Idx] < 0)
7859 continue;
7861 if (Mask[Idx] < (int)NumSrcElems)
7862 TouchesSrc1 = true;
7863 else
7864 TouchesSrc2 = true;
7867 if (TouchesSrc1 == TouchesSrc2)
7868 return false;
7870 Register NewSrc1 = Shuffle.getSrc1Reg();
7871 SmallVector<int, 16> NewMask(Mask);
7872 if (TouchesSrc2) {
7873 NewSrc1 = Shuffle.getSrc2Reg();
7874 commuteMask(NewMask, NumSrcElems);
7877 MatchInfo = [=, &Shuffle](MachineIRBuilder &B) {
7878 auto Undef = B.buildUndef(Src1Ty);
7879 B.buildShuffleVector(Shuffle.getReg(0), NewSrc1, Undef, NewMask);
7882 return true;
7885 bool CombinerHelper::matchSuboCarryOut(const MachineInstr &MI,
7886 BuildFnTy &MatchInfo) const {
7887 const GSubCarryOut *Subo = cast<GSubCarryOut>(&MI);
7889 Register Dst = Subo->getReg(0);
7890 Register LHS = Subo->getLHSReg();
7891 Register RHS = Subo->getRHSReg();
7892 Register Carry = Subo->getCarryOutReg();
7893 LLT DstTy = MRI.getType(Dst);
7894 LLT CarryTy = MRI.getType(Carry);
7896 // Check legality before known bits.
7897 if (!isLegalOrBeforeLegalizer({TargetOpcode::G_SUB, {DstTy}}) ||
7898 !isConstantLegalOrBeforeLegalizer(CarryTy))
7899 return false;
7901 ConstantRange KBLHS =
7902 ConstantRange::fromKnownBits(KB->getKnownBits(LHS),
7903 /* IsSigned=*/Subo->isSigned());
7904 ConstantRange KBRHS =
7905 ConstantRange::fromKnownBits(KB->getKnownBits(RHS),
7906 /* IsSigned=*/Subo->isSigned());
7908 if (Subo->isSigned()) {
7909 // G_SSUBO
7910 switch (KBLHS.signedSubMayOverflow(KBRHS)) {
7911 case ConstantRange::OverflowResult::MayOverflow:
7912 return false;
7913 case ConstantRange::OverflowResult::NeverOverflows: {
7914 MatchInfo = [=](MachineIRBuilder &B) {
7915 B.buildSub(Dst, LHS, RHS, MachineInstr::MIFlag::NoSWrap);
7916 B.buildConstant(Carry, 0);
7918 return true;
7920 case ConstantRange::OverflowResult::AlwaysOverflowsLow:
7921 case ConstantRange::OverflowResult::AlwaysOverflowsHigh: {
7922 MatchInfo = [=](MachineIRBuilder &B) {
7923 B.buildSub(Dst, LHS, RHS);
7924 B.buildConstant(Carry, getICmpTrueVal(getTargetLowering(),
7925 /*isVector=*/CarryTy.isVector(),
7926 /*isFP=*/false));
7928 return true;
7931 return false;
7934 // G_USUBO
7935 switch (KBLHS.unsignedSubMayOverflow(KBRHS)) {
7936 case ConstantRange::OverflowResult::MayOverflow:
7937 return false;
7938 case ConstantRange::OverflowResult::NeverOverflows: {
7939 MatchInfo = [=](MachineIRBuilder &B) {
7940 B.buildSub(Dst, LHS, RHS, MachineInstr::MIFlag::NoUWrap);
7941 B.buildConstant(Carry, 0);
7943 return true;
7945 case ConstantRange::OverflowResult::AlwaysOverflowsLow:
7946 case ConstantRange::OverflowResult::AlwaysOverflowsHigh: {
7947 MatchInfo = [=](MachineIRBuilder &B) {
7948 B.buildSub(Dst, LHS, RHS);
7949 B.buildConstant(Carry, getICmpTrueVal(getTargetLowering(),
7950 /*isVector=*/CarryTy.isVector(),
7951 /*isFP=*/false));
7953 return true;
7957 return false;