[gn build] Port 077cc3deeebe
[llvm-project.git] / llvm / lib / Target / RISCV / RISCVInsertReadWriteCSR.cpp
blob7b9e9fb988bc68e25cf6256e247c718adfcc0236
1 //===-- RISCVInsertReadWriteCSR.cpp - Insert Read/Write of RISC-V CSR -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // This file implements the machine function pass to insert read/write of CSR-s
9 // of the RISC-V instructions.
11 // Currently the pass implements:
12 // -Writing and saving frm before an RVV floating-point instruction with a
13 // static rounding mode and restores the value after.
15 //===----------------------------------------------------------------------===//
17 #include "MCTargetDesc/RISCVBaseInfo.h"
18 #include "RISCV.h"
19 #include "RISCVSubtarget.h"
20 #include "llvm/CodeGen/MachineFunctionPass.h"
21 using namespace llvm;
23 #define DEBUG_TYPE "riscv-insert-read-write-csr"
24 #define RISCV_INSERT_READ_WRITE_CSR_NAME "RISC-V Insert Read/Write CSR Pass"
26 static cl::opt<bool>
27 DisableFRMInsertOpt("riscv-disable-frm-insert-opt", cl::init(false),
28 cl::Hidden,
29 cl::desc("Disable optimized frm insertion."));
31 namespace {
33 class RISCVInsertReadWriteCSR : public MachineFunctionPass {
34 const TargetInstrInfo *TII;
36 public:
37 static char ID;
39 RISCVInsertReadWriteCSR() : MachineFunctionPass(ID) {}
41 bool runOnMachineFunction(MachineFunction &MF) override;
43 void getAnalysisUsage(AnalysisUsage &AU) const override {
44 AU.setPreservesCFG();
45 MachineFunctionPass::getAnalysisUsage(AU);
48 StringRef getPassName() const override {
49 return RISCV_INSERT_READ_WRITE_CSR_NAME;
52 private:
53 bool emitWriteRoundingMode(MachineBasicBlock &MBB);
54 bool emitWriteRoundingModeOpt(MachineBasicBlock &MBB);
57 } // end anonymous namespace
59 char RISCVInsertReadWriteCSR::ID = 0;
61 INITIALIZE_PASS(RISCVInsertReadWriteCSR, DEBUG_TYPE,
62 RISCV_INSERT_READ_WRITE_CSR_NAME, false, false)
64 // TODO: Use more accurate rounding mode at the start of MBB.
65 bool RISCVInsertReadWriteCSR::emitWriteRoundingModeOpt(MachineBasicBlock &MBB) {
66 bool Changed = false;
67 MachineInstr *LastFRMChanger = nullptr;
68 unsigned CurrentRM = RISCVFPRndMode::DYN;
69 Register SavedFRM;
71 for (MachineInstr &MI : MBB) {
72 if (MI.getOpcode() == RISCV::SwapFRMImm ||
73 MI.getOpcode() == RISCV::WriteFRMImm) {
74 CurrentRM = MI.getOperand(0).getImm();
75 SavedFRM = Register();
76 continue;
79 if (MI.getOpcode() == RISCV::WriteFRM) {
80 CurrentRM = RISCVFPRndMode::DYN;
81 SavedFRM = Register();
82 continue;
85 if (MI.isCall() || MI.isInlineAsm() ||
86 MI.readsRegister(RISCV::FRM, /*TRI=*/nullptr)) {
87 // Restore FRM before unknown operations.
88 if (SavedFRM.isValid())
89 BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(RISCV::WriteFRM))
90 .addReg(SavedFRM);
91 CurrentRM = RISCVFPRndMode::DYN;
92 SavedFRM = Register();
93 continue;
96 assert(!MI.modifiesRegister(RISCV::FRM, /*TRI=*/nullptr) &&
97 "Expected that MI could not modify FRM.");
99 int FRMIdx = RISCVII::getFRMOpNum(MI.getDesc());
100 if (FRMIdx < 0)
101 continue;
102 unsigned InstrRM = MI.getOperand(FRMIdx).getImm();
104 LastFRMChanger = &MI;
106 // Make MI implicit use FRM.
107 MI.addOperand(MachineOperand::CreateReg(RISCV::FRM, /*IsDef*/ false,
108 /*IsImp*/ true));
109 Changed = true;
111 // Skip if MI uses same rounding mode as FRM.
112 if (InstrRM == CurrentRM)
113 continue;
115 if (!SavedFRM.isValid()) {
116 // Save current FRM value to SavedFRM.
117 MachineRegisterInfo *MRI = &MBB.getParent()->getRegInfo();
118 SavedFRM = MRI->createVirtualRegister(&RISCV::GPRRegClass);
119 BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(RISCV::SwapFRMImm), SavedFRM)
120 .addImm(InstrRM);
121 } else {
122 // Don't need to save current FRM when SavedFRM having value.
123 BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(RISCV::WriteFRMImm))
124 .addImm(InstrRM);
126 CurrentRM = InstrRM;
129 // Restore FRM if needed.
130 if (SavedFRM.isValid()) {
131 assert(LastFRMChanger && "Expected valid pointer.");
132 MachineInstrBuilder MIB =
133 BuildMI(*MBB.getParent(), {}, TII->get(RISCV::WriteFRM))
134 .addReg(SavedFRM);
135 MBB.insertAfter(LastFRMChanger, MIB);
138 return Changed;
141 // This function also swaps frm and restores it when encountering an RVV
142 // floating point instruction with a static rounding mode.
143 bool RISCVInsertReadWriteCSR::emitWriteRoundingMode(MachineBasicBlock &MBB) {
144 bool Changed = false;
145 for (MachineInstr &MI : MBB) {
146 int FRMIdx = RISCVII::getFRMOpNum(MI.getDesc());
147 if (FRMIdx < 0)
148 continue;
150 unsigned FRMImm = MI.getOperand(FRMIdx).getImm();
152 // The value is a hint to this pass to not alter the frm value.
153 if (FRMImm == RISCVFPRndMode::DYN)
154 continue;
156 Changed = true;
158 // Save
159 MachineRegisterInfo *MRI = &MBB.getParent()->getRegInfo();
160 Register SavedFRM = MRI->createVirtualRegister(&RISCV::GPRRegClass);
161 BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(RISCV::SwapFRMImm),
162 SavedFRM)
163 .addImm(FRMImm);
164 MI.addOperand(MachineOperand::CreateReg(RISCV::FRM, /*IsDef*/ false,
165 /*IsImp*/ true));
166 // Restore
167 MachineInstrBuilder MIB =
168 BuildMI(*MBB.getParent(), {}, TII->get(RISCV::WriteFRM))
169 .addReg(SavedFRM);
170 MBB.insertAfter(MI, MIB);
172 return Changed;
175 bool RISCVInsertReadWriteCSR::runOnMachineFunction(MachineFunction &MF) {
176 // Skip if the vector extension is not enabled.
177 const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>();
178 if (!ST.hasVInstructions())
179 return false;
181 TII = ST.getInstrInfo();
183 bool Changed = false;
185 for (MachineBasicBlock &MBB : MF) {
186 if (DisableFRMInsertOpt)
187 Changed |= emitWriteRoundingMode(MBB);
188 else
189 Changed |= emitWriteRoundingModeOpt(MBB);
192 return Changed;
195 FunctionPass *llvm::createRISCVInsertReadWriteCSRPass() {
196 return new RISCVInsertReadWriteCSR();