1 //===-- RISCVInsertReadWriteCSR.cpp - Insert Read/Write of RISC-V CSR -----===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
8 // This file implements the machine function pass to insert read/write of CSR-s
9 // of the RISC-V instructions.
11 // Currently the pass implements:
12 // -Naive insertion of a write to vxrm before an RVV fixed-point instruction.
13 // -Writing and saving frm before an RVV floating-point instruction with a
14 // static rounding mode and restores the value after.
16 //===----------------------------------------------------------------------===//
18 #include "MCTargetDesc/RISCVBaseInfo.h"
20 #include "RISCVSubtarget.h"
21 #include "llvm/CodeGen/MachineFunctionPass.h"
24 #define DEBUG_TYPE "riscv-insert-read-write-csr"
25 #define RISCV_INSERT_READ_WRITE_CSR_NAME "RISC-V Insert Read/Write CSR Pass"
29 class RISCVInsertReadWriteCSR
: public MachineFunctionPass
{
30 const TargetInstrInfo
*TII
;
35 RISCVInsertReadWriteCSR() : MachineFunctionPass(ID
) {
36 initializeRISCVInsertReadWriteCSRPass(*PassRegistry::getPassRegistry());
39 bool runOnMachineFunction(MachineFunction
&MF
) override
;
41 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
43 MachineFunctionPass::getAnalysisUsage(AU
);
46 StringRef
getPassName() const override
{
47 return RISCV_INSERT_READ_WRITE_CSR_NAME
;
51 bool emitWriteRoundingMode(MachineBasicBlock
&MBB
);
54 } // end anonymous namespace
56 char RISCVInsertReadWriteCSR::ID
= 0;
58 INITIALIZE_PASS(RISCVInsertReadWriteCSR
, DEBUG_TYPE
,
59 RISCV_INSERT_READ_WRITE_CSR_NAME
, false, false)
61 // This function inserts a write to vxrm when encountering an RVV fixed-point
62 // instruction. This function also swaps frm and restores it when encountering
63 // an RVV floating point instruction with a static rounding mode.
64 bool RISCVInsertReadWriteCSR::emitWriteRoundingMode(MachineBasicBlock
&MBB
) {
66 for (MachineInstr
&MI
: MBB
) {
67 int VXRMIdx
= RISCVII::getVXRMOpNum(MI
.getDesc());
69 unsigned VXRMImm
= MI
.getOperand(VXRMIdx
).getImm();
73 BuildMI(MBB
, MI
, MI
.getDebugLoc(), TII
->get(RISCV::WriteVXRMImm
))
75 MI
.addOperand(MachineOperand::CreateReg(RISCV::VXRM
, /*IsDef*/ false,
80 int FRMIdx
= RISCVII::getFRMOpNum(MI
.getDesc());
84 unsigned FRMImm
= MI
.getOperand(FRMIdx
).getImm();
86 // The value is a hint to this pass to not alter the frm value.
87 if (FRMImm
== RISCVFPRndMode::DYN
)
93 MachineRegisterInfo
*MRI
= &MBB
.getParent()->getRegInfo();
94 Register SavedFRM
= MRI
->createVirtualRegister(&RISCV::GPRRegClass
);
95 BuildMI(MBB
, MI
, MI
.getDebugLoc(), TII
->get(RISCV::SwapFRMImm
),
98 MI
.addOperand(MachineOperand::CreateReg(RISCV::FRM
, /*IsDef*/ false,
101 MachineInstrBuilder MIB
=
102 BuildMI(*MBB
.getParent(), {}, TII
->get(RISCV::WriteFRM
))
104 MBB
.insertAfter(MI
, MIB
);
109 bool RISCVInsertReadWriteCSR::runOnMachineFunction(MachineFunction
&MF
) {
110 // Skip if the vector extension is not enabled.
111 const RISCVSubtarget
&ST
= MF
.getSubtarget
<RISCVSubtarget
>();
112 if (!ST
.hasVInstructions())
115 TII
= ST
.getInstrInfo();
117 bool Changed
= false;
119 for (MachineBasicBlock
&MBB
: MF
)
120 Changed
|= emitWriteRoundingMode(MBB
);
125 FunctionPass
*llvm::createRISCVInsertReadWriteCSRPass() {
126 return new RISCVInsertReadWriteCSR();