1 //===- SMEPeepholeOpt.cpp - SME peephole optimization pass-----------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
8 // This pass tries to remove back-to-back (smstart, smstop) and
9 // (smstop, smstart) sequences. The pass is conservative when it cannot
10 // determine that it is safe to remove these sequences.
11 //===----------------------------------------------------------------------===//
13 #include "AArch64InstrInfo.h"
14 #include "AArch64MachineFunctionInfo.h"
15 #include "AArch64Subtarget.h"
16 #include "llvm/ADT/SmallVector.h"
17 #include "llvm/CodeGen/MachineBasicBlock.h"
18 #include "llvm/CodeGen/MachineFunctionPass.h"
19 #include "llvm/CodeGen/MachineRegisterInfo.h"
20 #include "llvm/CodeGen/TargetRegisterInfo.h"
24 #define DEBUG_TYPE "aarch64-sme-peephole-opt"
28 struct SMEPeepholeOpt
: public MachineFunctionPass
{
31 SMEPeepholeOpt() : MachineFunctionPass(ID
) {
32 initializeSMEPeepholeOptPass(*PassRegistry::getPassRegistry());
35 bool runOnMachineFunction(MachineFunction
&MF
) override
;
37 StringRef
getPassName() const override
{
38 return "SME Peephole Optimization pass";
41 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
43 MachineFunctionPass::getAnalysisUsage(AU
);
46 bool optimizeStartStopPairs(MachineBasicBlock
&MBB
,
47 bool &HasRemovedAllSMChanges
) const;
50 char SMEPeepholeOpt::ID
= 0;
52 } // end anonymous namespace
54 static bool isConditionalStartStop(const MachineInstr
*MI
) {
55 return MI
->getOpcode() == AArch64::MSRpstatePseudo
;
58 static bool isMatchingStartStopPair(const MachineInstr
*MI1
,
59 const MachineInstr
*MI2
) {
60 // We only consider the same type of streaming mode change here, i.e.
61 // start/stop SM, or start/stop ZA pairs.
62 if (MI1
->getOperand(0).getImm() != MI2
->getOperand(0).getImm())
65 // One must be 'start', the other must be 'stop'
66 if (MI1
->getOperand(1).getImm() == MI2
->getOperand(1).getImm())
69 bool IsConditional
= isConditionalStartStop(MI2
);
70 if (isConditionalStartStop(MI1
) != IsConditional
)
76 // Check to make sure the conditional start/stop pairs are identical.
77 if (MI1
->getOperand(2).getImm() != MI2
->getOperand(2).getImm())
80 // Ensure reg masks are identical.
81 if (MI1
->getOperand(4).getRegMask() != MI2
->getOperand(4).getRegMask())
84 // This optimisation is unlikely to happen in practice for conditional
85 // smstart/smstop pairs as the virtual registers for pstate.sm will always
87 // TODO: For this optimisation to apply to conditional smstart/smstop,
88 // this pass will need to do more work to remove redundant calls to
91 // Only consider conditional start/stop pairs which read the same register
92 // holding the original value of pstate.sm, as some conditional start/stops
93 // require the state on entry to the function.
94 if (MI1
->getOperand(3).isReg() && MI2
->getOperand(3).isReg()) {
95 Register Reg1
= MI1
->getOperand(3).getReg();
96 Register Reg2
= MI2
->getOperand(3).getReg();
97 if (Reg1
.isPhysical() || Reg2
.isPhysical() || Reg1
!= Reg2
)
104 static bool ChangesStreamingMode(const MachineInstr
*MI
) {
105 assert((MI
->getOpcode() == AArch64::MSRpstatesvcrImm1
||
106 MI
->getOpcode() == AArch64::MSRpstatePseudo
) &&
107 "Expected MI to be a smstart/smstop instruction");
108 return MI
->getOperand(0).getImm() == AArch64SVCR::SVCRSM
||
109 MI
->getOperand(0).getImm() == AArch64SVCR::SVCRSMZA
;
112 static bool isSVERegOp(const TargetRegisterInfo
&TRI
,
113 const MachineRegisterInfo
&MRI
,
114 const MachineOperand
&MO
) {
118 Register R
= MO
.getReg();
120 return llvm::any_of(TRI
.subregs_inclusive(R
), [](const MCPhysReg
&SR
) {
121 return AArch64::ZPRRegClass
.contains(SR
) ||
122 AArch64::PPRRegClass
.contains(SR
);
125 const TargetRegisterClass
*RC
= MRI
.getRegClass(R
);
126 return TRI
.getCommonSubClass(&AArch64::ZPRRegClass
, RC
) ||
127 TRI
.getCommonSubClass(&AArch64::PPRRegClass
, RC
);
130 bool SMEPeepholeOpt::optimizeStartStopPairs(
131 MachineBasicBlock
&MBB
, bool &HasRemovedAllSMChanges
) const {
132 const MachineRegisterInfo
&MRI
= MBB
.getParent()->getRegInfo();
133 const TargetRegisterInfo
&TRI
=
134 *MBB
.getParent()->getSubtarget().getRegisterInfo();
136 bool Changed
= false;
137 MachineInstr
*Prev
= nullptr;
138 SmallVector
<MachineInstr
*, 4> ToBeRemoved
;
140 // Convenience function to reset the matching of a sequence.
146 // Walk through instructions in the block trying to find pairs of smstart
147 // and smstop nodes that cancel each other out. We only permit a limited
148 // set of instructions to appear between them, otherwise we reset our
150 unsigned NumSMChanges
= 0;
151 unsigned NumSMChangesRemoved
= 0;
152 for (MachineInstr
&MI
: make_early_inc_range(MBB
)) {
153 switch (MI
.getOpcode()) {
154 case AArch64::MSRpstatesvcrImm1
:
155 case AArch64::MSRpstatePseudo
: {
156 if (ChangesStreamingMode(&MI
))
161 else if (isMatchingStartStopPair(Prev
, &MI
)) {
162 // If they match, we can remove them, and possibly any instructions
163 // that we marked for deletion in between.
164 Prev
->eraseFromParent();
165 MI
.eraseFromParent();
166 for (MachineInstr
*TBR
: ToBeRemoved
)
167 TBR
->eraseFromParent();
171 NumSMChangesRemoved
+= 2;
180 // Avoid doing expensive checks when Prev is nullptr.
185 // Test if the instructions in between the start/stop sequence are agnostic
186 // of streaming mode. If not, the algorithm should reset.
187 switch (MI
.getOpcode()) {
191 case AArch64::COALESCER_BARRIER_FPR16
:
192 case AArch64::COALESCER_BARRIER_FPR32
:
193 case AArch64::COALESCER_BARRIER_FPR64
:
194 case AArch64::COALESCER_BARRIER_FPR128
:
196 // These instructions should be safe when executed on their own, but
197 // the code remains conservative when SVE registers are used. There may
198 // exist subtle cases where executing a COPY in a different mode results
199 // in different behaviour, even if we can't yet come up with any
200 // concrete example/test-case.
201 if (isSVERegOp(TRI
, MRI
, MI
.getOperand(0)) ||
202 isSVERegOp(TRI
, MRI
, MI
.getOperand(1)))
205 case AArch64::ADJCALLSTACKDOWN
:
206 case AArch64::ADJCALLSTACKUP
:
207 case AArch64::ANDXri
:
208 case AArch64::ADDXri
:
209 // We permit these as they don't generate SVE/NEON instructions.
211 case AArch64::VGRestorePseudo
:
212 case AArch64::VGSavePseudo
:
213 // When the smstart/smstop are removed, we should also remove
214 // the pseudos that save/restore the VG value for CFI info.
215 ToBeRemoved
.push_back(&MI
);
217 case AArch64::MSRpstatesvcrImm1
:
218 case AArch64::MSRpstatePseudo
:
219 llvm_unreachable("Should have been handled");
223 HasRemovedAllSMChanges
=
224 NumSMChanges
&& (NumSMChanges
== NumSMChangesRemoved
);
228 INITIALIZE_PASS(SMEPeepholeOpt
, "aarch64-sme-peephole-opt",
229 "SME Peephole Optimization", false, false)
231 bool SMEPeepholeOpt::runOnMachineFunction(MachineFunction
&MF
) {
232 if (skipFunction(MF
.getFunction()))
235 if (!MF
.getSubtarget
<AArch64Subtarget
>().hasSME())
238 assert(MF
.getRegInfo().isSSA() && "Expected to be run on SSA form!");
240 bool Changed
= false;
241 bool FunctionHasAllSMChangesRemoved
= false;
243 // Even if the block lives in a function with no SME attributes attached we
244 // still have to analyze all the blocks because we may call a streaming
245 // function that requires smstart/smstop pairs.
246 for (MachineBasicBlock
&MBB
: MF
) {
247 bool BlockHasAllSMChangesRemoved
;
248 Changed
|= optimizeStartStopPairs(MBB
, BlockHasAllSMChangesRemoved
);
249 FunctionHasAllSMChangesRemoved
|= BlockHasAllSMChangesRemoved
;
252 AArch64FunctionInfo
*AFI
= MF
.getInfo
<AArch64FunctionInfo
>();
253 if (FunctionHasAllSMChangesRemoved
)
254 AFI
->setHasStreamingModeChanges(false);
259 FunctionPass
*llvm::createSMEPeepholeOptPass() { return new SMEPeepholeOpt(); }