Run DCE after a LoopFlatten test to reduce spurious output [nfc]
[llvm-project.git] / llvm / lib / Target / RISCV / RISCVInsertReadWriteCSR.cpp
blob75f5ac3fbe0dd5511063e800a961771079a80aa8
1 //===-- RISCVInsertReadWriteCSR.cpp - Insert Read/Write of RISC-V CSR -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // This file implements the machine function pass to insert read/write of CSR-s
9 // of the RISC-V instructions.
11 // Currently the pass implements:
12 // -Naive insertion of a write to vxrm before an RVV fixed-point instruction.
13 // -Writing and saving frm before an RVV floating-point instruction with a
14 // static rounding mode and restores the value after.
16 //===----------------------------------------------------------------------===//
18 #include "MCTargetDesc/RISCVBaseInfo.h"
19 #include "RISCV.h"
20 #include "RISCVSubtarget.h"
21 #include "llvm/CodeGen/MachineFunctionPass.h"
22 using namespace llvm;
24 #define DEBUG_TYPE "riscv-insert-read-write-csr"
25 #define RISCV_INSERT_READ_WRITE_CSR_NAME "RISC-V Insert Read/Write CSR Pass"
27 namespace {
29 class RISCVInsertReadWriteCSR : public MachineFunctionPass {
30 const TargetInstrInfo *TII;
32 public:
33 static char ID;
35 RISCVInsertReadWriteCSR() : MachineFunctionPass(ID) {
36 initializeRISCVInsertReadWriteCSRPass(*PassRegistry::getPassRegistry());
39 bool runOnMachineFunction(MachineFunction &MF) override;
41 void getAnalysisUsage(AnalysisUsage &AU) const override {
42 AU.setPreservesCFG();
43 MachineFunctionPass::getAnalysisUsage(AU);
46 StringRef getPassName() const override {
47 return RISCV_INSERT_READ_WRITE_CSR_NAME;
50 private:
51 bool emitWriteRoundingMode(MachineBasicBlock &MBB);
54 } // end anonymous namespace
56 char RISCVInsertReadWriteCSR::ID = 0;
58 INITIALIZE_PASS(RISCVInsertReadWriteCSR, DEBUG_TYPE,
59 RISCV_INSERT_READ_WRITE_CSR_NAME, false, false)
61 // This function inserts a write to vxrm when encountering an RVV fixed-point
62 // instruction. This function also swaps frm and restores it when encountering
63 // an RVV floating point instruction with a static rounding mode.
64 bool RISCVInsertReadWriteCSR::emitWriteRoundingMode(MachineBasicBlock &MBB) {
65 bool Changed = false;
66 for (MachineInstr &MI : MBB) {
67 int VXRMIdx = RISCVII::getVXRMOpNum(MI.getDesc());
68 if (VXRMIdx >= 0) {
69 unsigned VXRMImm = MI.getOperand(VXRMIdx).getImm();
71 Changed = true;
73 BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(RISCV::WriteVXRMImm))
74 .addImm(VXRMImm);
75 MI.addOperand(MachineOperand::CreateReg(RISCV::VXRM, /*IsDef*/ false,
76 /*IsImp*/ true));
77 continue;
80 int FRMIdx = RISCVII::getFRMOpNum(MI.getDesc());
81 if (FRMIdx < 0)
82 continue;
84 unsigned FRMImm = MI.getOperand(FRMIdx).getImm();
86 // The value is a hint to this pass to not alter the frm value.
87 if (FRMImm == RISCVFPRndMode::DYN)
88 continue;
90 Changed = true;
92 // Save
93 MachineRegisterInfo *MRI = &MBB.getParent()->getRegInfo();
94 Register SavedFRM = MRI->createVirtualRegister(&RISCV::GPRRegClass);
95 BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(RISCV::SwapFRMImm),
96 SavedFRM)
97 .addImm(FRMImm);
98 MI.addOperand(MachineOperand::CreateReg(RISCV::FRM, /*IsDef*/ false,
99 /*IsImp*/ true));
100 // Restore
101 MachineInstrBuilder MIB =
102 BuildMI(*MBB.getParent(), {}, TII->get(RISCV::WriteFRM))
103 .addReg(SavedFRM);
104 MBB.insertAfter(MI, MIB);
106 return Changed;
109 bool RISCVInsertReadWriteCSR::runOnMachineFunction(MachineFunction &MF) {
110 // Skip if the vector extension is not enabled.
111 const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>();
112 if (!ST.hasVInstructions())
113 return false;
115 TII = ST.getInstrInfo();
117 bool Changed = false;
119 for (MachineBasicBlock &MBB : MF)
120 Changed |= emitWriteRoundingMode(MBB);
122 return Changed;
125 FunctionPass *llvm::createRISCVInsertReadWriteCSRPass() {
126 return new RISCVInsertReadWriteCSR();