1 //===- RISCVOptWInstrs.cpp - MI W instruction optimizations ---------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===---------------------------------------------------------------------===//
9 // This pass does some optimizations for *W instructions at the MI level.
11 // First it removes unneeded sext.w instructions. Either because the sign
12 // extended bits aren't consumed or because the input was already sign extended
13 // by an earlier instruction.
16 // 1. Unless explicit disabled or the target prefers instructions with W suffix,
17 // it removes the -w suffix from opw instructions whenever all users are
18 // dependent only on the lower word of the result of the instruction.
19 // The cases handled are:
20 // * addw because c.add has a larger register encoding than c.addw.
21 // * addiw because it helps reduce test differences between RV32 and RV64
22 // w/o being a pessimization.
23 // * mulw because c.mulw doesn't exist but c.mul does (w/ zcb)
24 // * slliw because c.slliw doesn't exist and c.slli does
26 // 2. Or if explicit enabled or the target prefers instructions with W suffix,
27 // it adds the W suffix to the instruction whenever all users are dependent
28 // only on the lower word of the result of the instruction.
29 // The cases handled are:
30 // * add/addi/sub/mul.
31 // * slli with imm < 32.
33 //===---------------------------------------------------------------------===//
36 #include "RISCVMachineFunctionInfo.h"
37 #include "RISCVSubtarget.h"
38 #include "llvm/ADT/SmallSet.h"
39 #include "llvm/ADT/Statistic.h"
40 #include "llvm/CodeGen/MachineFunctionPass.h"
41 #include "llvm/CodeGen/TargetInstrInfo.h"
45 #define DEBUG_TYPE "riscv-opt-w-instrs"
46 #define RISCV_OPT_W_INSTRS_NAME "RISC-V Optimize W Instructions"
48 STATISTIC(NumRemovedSExtW
, "Number of removed sign-extensions");
49 STATISTIC(NumTransformedToWInstrs
,
50 "Number of instructions transformed to W-ops");
52 static cl::opt
<bool> DisableSExtWRemoval("riscv-disable-sextw-removal",
53 cl::desc("Disable removal of sext.w"),
54 cl::init(false), cl::Hidden
);
55 static cl::opt
<bool> DisableStripWSuffix("riscv-disable-strip-w-suffix",
56 cl::desc("Disable strip W suffix"),
57 cl::init(false), cl::Hidden
);
61 class RISCVOptWInstrs
: public MachineFunctionPass
{
65 RISCVOptWInstrs() : MachineFunctionPass(ID
) {}
67 bool runOnMachineFunction(MachineFunction
&MF
) override
;
68 bool removeSExtWInstrs(MachineFunction
&MF
, const RISCVInstrInfo
&TII
,
69 const RISCVSubtarget
&ST
, MachineRegisterInfo
&MRI
);
70 bool stripWSuffixes(MachineFunction
&MF
, const RISCVInstrInfo
&TII
,
71 const RISCVSubtarget
&ST
, MachineRegisterInfo
&MRI
);
72 bool appendWSuffixes(MachineFunction
&MF
, const RISCVInstrInfo
&TII
,
73 const RISCVSubtarget
&ST
, MachineRegisterInfo
&MRI
);
75 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
77 MachineFunctionPass::getAnalysisUsage(AU
);
80 StringRef
getPassName() const override
{ return RISCV_OPT_W_INSTRS_NAME
; }
83 } // end anonymous namespace
85 char RISCVOptWInstrs::ID
= 0;
86 INITIALIZE_PASS(RISCVOptWInstrs
, DEBUG_TYPE
, RISCV_OPT_W_INSTRS_NAME
, false,
89 FunctionPass
*llvm::createRISCVOptWInstrsPass() {
90 return new RISCVOptWInstrs();
93 static bool vectorPseudoHasAllNBitUsers(const MachineOperand
&UserOp
,
95 const MachineInstr
&MI
= *UserOp
.getParent();
96 unsigned MCOpcode
= RISCV::getRVVMCOpcode(MI
.getOpcode());
101 const MCInstrDesc
&MCID
= MI
.getDesc();
102 const uint64_t TSFlags
= MCID
.TSFlags
;
103 if (!RISCVII::hasSEWOp(TSFlags
))
105 assert(RISCVII::hasVLOp(TSFlags
));
106 const unsigned Log2SEW
= MI
.getOperand(RISCVII::getSEWOpNum(MCID
)).getImm();
108 if (UserOp
.getOperandNo() == RISCVII::getVLOpNum(MCID
))
111 auto NumDemandedBits
=
112 RISCV::getVectorLowDemandedScalarBits(MCOpcode
, Log2SEW
);
113 return NumDemandedBits
&& Bits
>= *NumDemandedBits
;
116 // Checks if all users only demand the lower \p OrigBits of the original
117 // instruction's result.
118 // TODO: handle multiple interdependent transformations
119 static bool hasAllNBitUsers(const MachineInstr
&OrigMI
,
120 const RISCVSubtarget
&ST
,
121 const MachineRegisterInfo
&MRI
, unsigned OrigBits
) {
123 SmallSet
<std::pair
<const MachineInstr
*, unsigned>, 4> Visited
;
124 SmallVector
<std::pair
<const MachineInstr
*, unsigned>, 4> Worklist
;
126 Worklist
.push_back(std::make_pair(&OrigMI
, OrigBits
));
128 while (!Worklist
.empty()) {
129 auto P
= Worklist
.pop_back_val();
130 const MachineInstr
*MI
= P
.first
;
131 unsigned Bits
= P
.second
;
133 if (!Visited
.insert(P
).second
)
136 // Only handle instructions with one def.
137 if (MI
->getNumExplicitDefs() != 1)
140 Register DestReg
= MI
->getOperand(0).getReg();
141 if (!DestReg
.isVirtual())
144 for (auto &UserOp
: MRI
.use_nodbg_operands(DestReg
)) {
145 const MachineInstr
*UserMI
= UserOp
.getParent();
146 unsigned OpIdx
= UserOp
.getOperandNo();
148 switch (UserMI
->getOpcode()) {
150 if (vectorPseudoHasAllNBitUsers(UserOp
, Bits
))
176 case RISCV::FCVT_H_W
:
177 case RISCV::FCVT_H_W_INX
:
178 case RISCV::FCVT_H_WU
:
179 case RISCV::FCVT_H_WU_INX
:
180 case RISCV::FCVT_S_W
:
181 case RISCV::FCVT_S_W_INX
:
182 case RISCV::FCVT_S_WU
:
183 case RISCV::FCVT_S_WU_INX
:
184 case RISCV::FCVT_D_W
:
185 case RISCV::FCVT_D_W_INX
:
186 case RISCV::FCVT_D_WU
:
187 case RISCV::FCVT_D_WU_INX
:
198 case RISCV::ZEXT_H_RV32
:
199 case RISCV::ZEXT_H_RV64
:
206 if (Bits
>= (ST
.getXLen() / 2))
211 // If we are shifting right by less than Bits, and users don't demand
212 // any bits that were shifted into [Bits-1:0], then we can consider this
214 unsigned ShAmt
= UserMI
->getOperand(2).getImm();
216 Worklist
.push_back(std::make_pair(UserMI
, Bits
- ShAmt
));
222 // these overwrite higher input bits, otherwise the lower word of output
223 // depends only on the lower word of input. So check their uses read W.
225 unsigned ShAmt
= UserMI
->getOperand(2).getImm();
226 if (Bits
>= (ST
.getXLen() - ShAmt
))
228 Worklist
.push_back(std::make_pair(UserMI
, Bits
+ ShAmt
));
232 uint64_t Imm
= UserMI
->getOperand(2).getImm();
233 if (Bits
>= (unsigned)llvm::bit_width(Imm
))
235 Worklist
.push_back(std::make_pair(UserMI
, Bits
));
239 uint64_t Imm
= UserMI
->getOperand(2).getImm();
240 if (Bits
>= (unsigned)llvm::bit_width
<uint64_t>(~Imm
))
242 Worklist
.push_back(std::make_pair(UserMI
, Bits
));
250 // Operand 2 is the shift amount which uses log2(xlen) bits.
252 if (Bits
>= Log2_32(ST
.getXLen()))
256 Worklist
.push_back(std::make_pair(UserMI
, Bits
));
263 // Operand 2 is the shift amount which uses 6 bits.
264 if (OpIdx
== 2 && Bits
>= Log2_32(ST
.getXLen()))
269 case RISCV::SH1ADD_UW
:
270 case RISCV::SH2ADD_UW
:
271 case RISCV::SH3ADD_UW
:
272 // Operand 1 is implicitly zero extended.
273 if (OpIdx
== 1 && Bits
>= 32)
275 Worklist
.push_back(std::make_pair(UserMI
, Bits
));
279 if (UserMI
->getOperand(2).getImm() >= Bits
)
284 // The first argument is the value to store.
285 if (OpIdx
== 0 && Bits
>= 8)
289 // The first argument is the value to store.
290 if (OpIdx
== 0 && Bits
>= 16)
294 // The first argument is the value to store.
295 if (OpIdx
== 0 && Bits
>= 32)
299 // For these, lower word of output in these operations, depends only on
300 // the lower word of input. So, we check all uses only read lower word.
325 Worklist
.push_back(std::make_pair(UserMI
, Bits
));
328 case RISCV::PseudoCCMOVGPR
:
329 // Either operand 4 or operand 5 is returned by this instruction. If
330 // only the lower word of the result is used, then only the lower word
331 // of operand 4 and 5 is used.
332 if (OpIdx
!= 4 && OpIdx
!= 5)
334 Worklist
.push_back(std::make_pair(UserMI
, Bits
));
337 case RISCV::CZERO_EQZ
:
338 case RISCV::CZERO_NEZ
:
339 case RISCV::VT_MASKC
:
340 case RISCV::VT_MASKCN
:
343 Worklist
.push_back(std::make_pair(UserMI
, Bits
));
352 static bool hasAllWUsers(const MachineInstr
&OrigMI
, const RISCVSubtarget
&ST
,
353 const MachineRegisterInfo
&MRI
) {
354 return hasAllNBitUsers(OrigMI
, ST
, MRI
, 32);
357 // This function returns true if the machine instruction always outputs a value
358 // where bits 63:32 match bit 31.
359 static bool isSignExtendingOpW(const MachineInstr
&MI
, unsigned OpNo
) {
360 uint64_t TSFlags
= MI
.getDesc().TSFlags
;
362 // Instructions that can be determined from opcode are marked in tablegen.
363 if (TSFlags
& RISCVII::IsSignExtendingOpWMask
)
366 // Special cases that require checking operands.
367 switch (MI
.getOpcode()) {
368 // shifting right sufficiently makes the value 32-bit sign-extended
370 return MI
.getOperand(2).getImm() >= 32;
372 return MI
.getOperand(2).getImm() > 32;
373 // The LI pattern ADDI rd, X0, imm is sign extended.
375 return MI
.getOperand(1).isReg() && MI
.getOperand(1).getReg() == RISCV::X0
;
376 // An ANDI with an 11 bit immediate will zero bits 63:11.
378 return isUInt
<11>(MI
.getOperand(2).getImm());
379 // An ORI with an >11 bit immediate (negative 12-bit) will set bits 63:11.
381 return !isUInt
<11>(MI
.getOperand(2).getImm());
382 // A bseti with X0 is sign extended if the immediate is less than 31.
384 return MI
.getOperand(2).getImm() < 31 &&
385 MI
.getOperand(1).getReg() == RISCV::X0
;
386 // Copying from X0 produces zero.
388 return MI
.getOperand(1).getReg() == RISCV::X0
;
389 // Ignore the scratch register destination.
390 case RISCV::PseudoAtomicLoadNand32
:
392 case RISCV::PseudoVMV_X_S
: {
393 // vmv.x.s has at least 33 sign bits if log2(sew) <= 5.
394 int64_t Log2SEW
= MI
.getOperand(2).getImm();
395 assert(Log2SEW
>= 3 && Log2SEW
<= 6 && "Unexpected Log2SEW");
403 static bool isSignExtendedW(Register SrcReg
, const RISCVSubtarget
&ST
,
404 const MachineRegisterInfo
&MRI
,
405 SmallPtrSetImpl
<MachineInstr
*> &FixableDef
) {
406 SmallSet
<Register
, 4> Visited
;
407 SmallVector
<Register
, 4> Worklist
;
409 auto AddRegToWorkList
= [&](Register SrcReg
) {
410 if (!SrcReg
.isVirtual())
412 Worklist
.push_back(SrcReg
);
416 if (!AddRegToWorkList(SrcReg
))
419 while (!Worklist
.empty()) {
420 Register Reg
= Worklist
.pop_back_val();
422 // If we already visited this register, we don't need to check it again.
423 if (!Visited
.insert(Reg
).second
)
426 MachineInstr
*MI
= MRI
.getVRegDef(Reg
);
430 int OpNo
= MI
->findRegisterDefOperandIdx(Reg
, /*TRI=*/nullptr);
431 assert(OpNo
!= -1 && "Couldn't find register");
433 // If this is a sign extending operation we don't need to look any further.
434 if (isSignExtendingOpW(*MI
, OpNo
))
437 // Is this an instruction that propagates sign extend?
438 switch (MI
->getOpcode()) {
440 // Unknown opcode, give up.
443 const MachineFunction
*MF
= MI
->getMF();
444 const RISCVMachineFunctionInfo
*RVFI
=
445 MF
->getInfo
<RISCVMachineFunctionInfo
>();
447 // If this is the entry block and the register is livein, see if we know
448 // it is sign extended.
449 if (MI
->getParent() == &MF
->front()) {
450 Register VReg
= MI
->getOperand(0).getReg();
451 if (MF
->getRegInfo().isLiveIn(VReg
) && RVFI
->isSExt32Register(VReg
))
455 Register CopySrcReg
= MI
->getOperand(1).getReg();
456 if (CopySrcReg
== RISCV::X10
) {
457 // For a method return value, we check the ZExt/SExt flags in attribute.
458 // We assume the following code sequence for method call.
459 // PseudoCALL @bar, ...
460 // ADJCALLSTACKUP 0, 0, implicit-def dead $x2, implicit $x2
461 // %0:gpr = COPY $x10
463 // We use the PseudoCall to look up the IR function being called to find
464 // its return attributes.
465 const MachineBasicBlock
*MBB
= MI
->getParent();
466 auto II
= MI
->getIterator();
467 if (II
== MBB
->instr_begin() ||
468 (--II
)->getOpcode() != RISCV::ADJCALLSTACKUP
)
471 const MachineInstr
&CallMI
= *(--II
);
472 if (!CallMI
.isCall() || !CallMI
.getOperand(0).isGlobal())
476 dyn_cast_if_present
<Function
>(CallMI
.getOperand(0).getGlobal());
480 auto *IntTy
= dyn_cast
<IntegerType
>(CalleeFn
->getReturnType());
484 const AttributeSet
&Attrs
= CalleeFn
->getAttributes().getRetAttrs();
485 unsigned BitWidth
= IntTy
->getBitWidth();
486 if ((BitWidth
<= 32 && Attrs
.hasAttribute(Attribute::SExt
)) ||
487 (BitWidth
< 32 && Attrs
.hasAttribute(Attribute::ZExt
)))
491 if (!AddRegToWorkList(CopySrcReg
))
497 // For these, we just need to check if the 1st operand is sign extended.
501 if (MI
->getOperand(2).getImm() >= 31)
508 // |Remainder| is always <= |Dividend|. If D is 32-bit, then so is R.
509 // DIV doesn't work because of the edge case 0xf..f 8000 0000 / (long)-1
510 // Logical operations use a sign extended 12-bit immediate.
511 if (!AddRegToWorkList(MI
->getOperand(1).getReg()))
515 case RISCV::PseudoCCADDW
:
516 case RISCV::PseudoCCADDIW
:
517 case RISCV::PseudoCCSUBW
:
518 case RISCV::PseudoCCSLLW
:
519 case RISCV::PseudoCCSRLW
:
520 case RISCV::PseudoCCSRAW
:
521 case RISCV::PseudoCCSLLIW
:
522 case RISCV::PseudoCCSRLIW
:
523 case RISCV::PseudoCCSRAIW
:
524 // Returns operand 4 or an ADDW/SUBW/etc. of operands 5 and 6. We only
525 // need to check if operand 4 is sign extended.
526 if (!AddRegToWorkList(MI
->getOperand(4).getReg()))
540 case RISCV::PseudoCCMOVGPR
:
541 case RISCV::PseudoCCAND
:
542 case RISCV::PseudoCCOR
:
543 case RISCV::PseudoCCXOR
:
545 // If all incoming values are sign-extended, the output of AND, OR, XOR,
546 // MIN, MAX, or PHI is also sign-extended.
548 // The input registers for PHI are operand 1, 3, ...
549 // The input registers for PseudoCCMOVGPR are 4 and 5.
550 // The input registers for PseudoCCAND/OR/XOR are 4, 5, and 6.
551 // The input registers for others are operand 1 and 2.
552 unsigned B
= 1, E
= 3, D
= 1;
553 switch (MI
->getOpcode()) {
555 E
= MI
->getNumOperands();
558 case RISCV::PseudoCCMOVGPR
:
562 case RISCV::PseudoCCAND
:
563 case RISCV::PseudoCCOR
:
564 case RISCV::PseudoCCXOR
:
570 for (unsigned I
= B
; I
!= E
; I
+= D
) {
571 if (!MI
->getOperand(I
).isReg())
574 if (!AddRegToWorkList(MI
->getOperand(I
).getReg()))
581 case RISCV::CZERO_EQZ
:
582 case RISCV::CZERO_NEZ
:
583 case RISCV::VT_MASKC
:
584 case RISCV::VT_MASKCN
:
585 // Instructions return zero or operand 1. Result is sign extended if
586 // operand 1 is sign extended.
587 if (!AddRegToWorkList(MI
->getOperand(1).getReg()))
591 // With these opcode, we can "fix" them with the W-version
592 // if we know all users of the result only rely on bits 31:0
594 // SLLIW reads the lowest 5 bits, while SLLI reads lowest 6 bits
595 if (MI
->getOperand(2).getImm() >= 32)
604 if (hasAllWUsers(*MI
, ST
, MRI
)) {
605 FixableDef
.insert(MI
);
612 // If we get here, then every node we visited produces a sign extended value
613 // or propagated sign extended values. So the result must be sign extended.
617 static unsigned getWOp(unsigned Opcode
) {
633 llvm_unreachable("Unexpected opcode for replacement with W variant");
637 bool RISCVOptWInstrs::removeSExtWInstrs(MachineFunction
&MF
,
638 const RISCVInstrInfo
&TII
,
639 const RISCVSubtarget
&ST
,
640 MachineRegisterInfo
&MRI
) {
641 if (DisableSExtWRemoval
)
644 bool MadeChange
= false;
645 for (MachineBasicBlock
&MBB
: MF
) {
646 for (MachineInstr
&MI
: llvm::make_early_inc_range(MBB
)) {
647 // We're looking for the sext.w pattern ADDIW rd, rs1, 0.
648 if (!RISCV::isSEXT_W(MI
))
651 Register SrcReg
= MI
.getOperand(1).getReg();
653 SmallPtrSet
<MachineInstr
*, 4> FixableDefs
;
655 // If all users only use the lower bits, this sext.w is redundant.
656 // Or if all definitions reaching MI sign-extend their output,
657 // then sext.w is redundant.
658 if (!hasAllWUsers(MI
, ST
, MRI
) &&
659 !isSignExtendedW(SrcReg
, ST
, MRI
, FixableDefs
))
662 Register DstReg
= MI
.getOperand(0).getReg();
663 if (!MRI
.constrainRegClass(SrcReg
, MRI
.getRegClass(DstReg
)))
666 // Convert Fixable instructions to their W versions.
667 for (MachineInstr
*Fixable
: FixableDefs
) {
668 LLVM_DEBUG(dbgs() << "Replacing " << *Fixable
);
669 Fixable
->setDesc(TII
.get(getWOp(Fixable
->getOpcode())));
670 Fixable
->clearFlag(MachineInstr::MIFlag::NoSWrap
);
671 Fixable
->clearFlag(MachineInstr::MIFlag::NoUWrap
);
672 Fixable
->clearFlag(MachineInstr::MIFlag::IsExact
);
673 LLVM_DEBUG(dbgs() << " with " << *Fixable
);
674 ++NumTransformedToWInstrs
;
677 LLVM_DEBUG(dbgs() << "Removing redundant sign-extension\n");
678 MRI
.replaceRegWith(DstReg
, SrcReg
);
679 MRI
.clearKillFlags(SrcReg
);
680 MI
.eraseFromParent();
689 bool RISCVOptWInstrs::stripWSuffixes(MachineFunction
&MF
,
690 const RISCVInstrInfo
&TII
,
691 const RISCVSubtarget
&ST
,
692 MachineRegisterInfo
&MRI
) {
693 bool MadeChange
= false;
694 for (MachineBasicBlock
&MBB
: MF
) {
695 for (MachineInstr
&MI
: MBB
) {
697 switch (MI
.getOpcode()) {
700 case RISCV::ADDW
: Opc
= RISCV::ADD
; break;
701 case RISCV::ADDIW
: Opc
= RISCV::ADDI
; break;
702 case RISCV::MULW
: Opc
= RISCV::MUL
; break;
703 case RISCV::SLLIW
: Opc
= RISCV::SLLI
; break;
706 if (hasAllWUsers(MI
, ST
, MRI
)) {
707 MI
.setDesc(TII
.get(Opc
));
716 bool RISCVOptWInstrs::appendWSuffixes(MachineFunction
&MF
,
717 const RISCVInstrInfo
&TII
,
718 const RISCVSubtarget
&ST
,
719 MachineRegisterInfo
&MRI
) {
720 bool MadeChange
= false;
721 for (MachineBasicBlock
&MBB
: MF
) {
722 for (MachineInstr
&MI
: MBB
) {
725 switch (MI
.getOpcode()) {
741 // SLLIW reads the lowest 5 bits, while SLLI reads lowest 6 bits
742 if (MI
.getOperand(2).getImm() >= 32)
752 if (hasAllWUsers(MI
, ST
, MRI
)) {
753 LLVM_DEBUG(dbgs() << "Replacing " << MI
);
754 MI
.setDesc(TII
.get(WOpc
));
755 MI
.clearFlag(MachineInstr::MIFlag::NoSWrap
);
756 MI
.clearFlag(MachineInstr::MIFlag::NoUWrap
);
757 MI
.clearFlag(MachineInstr::MIFlag::IsExact
);
758 LLVM_DEBUG(dbgs() << " with " << MI
);
759 ++NumTransformedToWInstrs
;
768 bool RISCVOptWInstrs::runOnMachineFunction(MachineFunction
&MF
) {
769 if (skipFunction(MF
.getFunction()))
772 MachineRegisterInfo
&MRI
= MF
.getRegInfo();
773 const RISCVSubtarget
&ST
= MF
.getSubtarget
<RISCVSubtarget
>();
774 const RISCVInstrInfo
&TII
= *ST
.getInstrInfo();
779 bool MadeChange
= false;
780 MadeChange
|= removeSExtWInstrs(MF
, TII
, ST
, MRI
);
782 if (!(DisableStripWSuffix
|| ST
.preferWInst()))
783 MadeChange
|= stripWSuffixes(MF
, TII
, ST
, MRI
);
785 if (ST
.preferWInst())
786 MadeChange
|= appendWSuffixes(MF
, TII
, ST
, MRI
);