1 //===-- RISCVInsertReadWriteCSR.cpp - Insert Read/Write of RISC-V CSR -----===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
8 // This file implements the machine function pass to insert read/write of CSR-s
9 // of the RISC-V instructions.
11 // Currently the pass implements:
12 // -Writing and saving frm before an RVV floating-point instruction with a
13 // static rounding mode and restores the value after.
15 //===----------------------------------------------------------------------===//
17 #include "MCTargetDesc/RISCVBaseInfo.h"
19 #include "RISCVSubtarget.h"
20 #include "llvm/CodeGen/MachineFunctionPass.h"
23 #define DEBUG_TYPE "riscv-insert-read-write-csr"
24 #define RISCV_INSERT_READ_WRITE_CSR_NAME "RISC-V Insert Read/Write CSR Pass"
27 DisableFRMInsertOpt("riscv-disable-frm-insert-opt", cl::init(false),
29 cl::desc("Disable optimized frm insertion."));
33 class RISCVInsertReadWriteCSR
: public MachineFunctionPass
{
34 const TargetInstrInfo
*TII
;
39 RISCVInsertReadWriteCSR() : MachineFunctionPass(ID
) {}
41 bool runOnMachineFunction(MachineFunction
&MF
) override
;
43 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
45 MachineFunctionPass::getAnalysisUsage(AU
);
48 StringRef
getPassName() const override
{
49 return RISCV_INSERT_READ_WRITE_CSR_NAME
;
53 bool emitWriteRoundingMode(MachineBasicBlock
&MBB
);
54 bool emitWriteRoundingModeOpt(MachineBasicBlock
&MBB
);
57 } // end anonymous namespace
59 char RISCVInsertReadWriteCSR::ID
= 0;
61 INITIALIZE_PASS(RISCVInsertReadWriteCSR
, DEBUG_TYPE
,
62 RISCV_INSERT_READ_WRITE_CSR_NAME
, false, false)
64 // TODO: Use more accurate rounding mode at the start of MBB.
65 bool RISCVInsertReadWriteCSR::emitWriteRoundingModeOpt(MachineBasicBlock
&MBB
) {
67 MachineInstr
*LastFRMChanger
= nullptr;
68 unsigned CurrentRM
= RISCVFPRndMode::DYN
;
71 for (MachineInstr
&MI
: MBB
) {
72 if (MI
.getOpcode() == RISCV::SwapFRMImm
||
73 MI
.getOpcode() == RISCV::WriteFRMImm
) {
74 CurrentRM
= MI
.getOperand(0).getImm();
75 SavedFRM
= Register();
79 if (MI
.getOpcode() == RISCV::WriteFRM
) {
80 CurrentRM
= RISCVFPRndMode::DYN
;
81 SavedFRM
= Register();
85 if (MI
.isCall() || MI
.isInlineAsm() ||
86 MI
.readsRegister(RISCV::FRM
, /*TRI=*/nullptr)) {
87 // Restore FRM before unknown operations.
88 if (SavedFRM
.isValid())
89 BuildMI(MBB
, MI
, MI
.getDebugLoc(), TII
->get(RISCV::WriteFRM
))
91 CurrentRM
= RISCVFPRndMode::DYN
;
92 SavedFRM
= Register();
96 assert(!MI
.modifiesRegister(RISCV::FRM
, /*TRI=*/nullptr) &&
97 "Expected that MI could not modify FRM.");
99 int FRMIdx
= RISCVII::getFRMOpNum(MI
.getDesc());
102 unsigned InstrRM
= MI
.getOperand(FRMIdx
).getImm();
104 LastFRMChanger
= &MI
;
106 // Make MI implicit use FRM.
107 MI
.addOperand(MachineOperand::CreateReg(RISCV::FRM
, /*IsDef*/ false,
111 // Skip if MI uses same rounding mode as FRM.
112 if (InstrRM
== CurrentRM
)
115 if (!SavedFRM
.isValid()) {
116 // Save current FRM value to SavedFRM.
117 MachineRegisterInfo
*MRI
= &MBB
.getParent()->getRegInfo();
118 SavedFRM
= MRI
->createVirtualRegister(&RISCV::GPRRegClass
);
119 BuildMI(MBB
, MI
, MI
.getDebugLoc(), TII
->get(RISCV::SwapFRMImm
), SavedFRM
)
122 // Don't need to save current FRM when SavedFRM having value.
123 BuildMI(MBB
, MI
, MI
.getDebugLoc(), TII
->get(RISCV::WriteFRMImm
))
129 // Restore FRM if needed.
130 if (SavedFRM
.isValid()) {
131 assert(LastFRMChanger
&& "Expected valid pointer.");
132 MachineInstrBuilder MIB
=
133 BuildMI(*MBB
.getParent(), {}, TII
->get(RISCV::WriteFRM
))
135 MBB
.insertAfter(LastFRMChanger
, MIB
);
141 // This function also swaps frm and restores it when encountering an RVV
142 // floating point instruction with a static rounding mode.
143 bool RISCVInsertReadWriteCSR::emitWriteRoundingMode(MachineBasicBlock
&MBB
) {
144 bool Changed
= false;
145 for (MachineInstr
&MI
: MBB
) {
146 int FRMIdx
= RISCVII::getFRMOpNum(MI
.getDesc());
150 unsigned FRMImm
= MI
.getOperand(FRMIdx
).getImm();
152 // The value is a hint to this pass to not alter the frm value.
153 if (FRMImm
== RISCVFPRndMode::DYN
)
159 MachineRegisterInfo
*MRI
= &MBB
.getParent()->getRegInfo();
160 Register SavedFRM
= MRI
->createVirtualRegister(&RISCV::GPRRegClass
);
161 BuildMI(MBB
, MI
, MI
.getDebugLoc(), TII
->get(RISCV::SwapFRMImm
),
164 MI
.addOperand(MachineOperand::CreateReg(RISCV::FRM
, /*IsDef*/ false,
167 MachineInstrBuilder MIB
=
168 BuildMI(*MBB
.getParent(), {}, TII
->get(RISCV::WriteFRM
))
170 MBB
.insertAfter(MI
, MIB
);
175 bool RISCVInsertReadWriteCSR::runOnMachineFunction(MachineFunction
&MF
) {
176 // Skip if the vector extension is not enabled.
177 const RISCVSubtarget
&ST
= MF
.getSubtarget
<RISCVSubtarget
>();
178 if (!ST
.hasVInstructions())
181 TII
= ST
.getInstrInfo();
183 bool Changed
= false;
185 for (MachineBasicBlock
&MBB
: MF
) {
186 if (DisableFRMInsertOpt
)
187 Changed
|= emitWriteRoundingMode(MBB
);
189 Changed
|= emitWriteRoundingModeOpt(MBB
);
195 FunctionPass
*llvm::createRISCVInsertReadWriteCSRPass() {
196 return new RISCVInsertReadWriteCSR();