[mlir] More fixes for 9fddaf6b14102963f12dbb9730f101fc52e662c1
[llvm-project.git] / llvm / lib / Target / AArch64 / SMEPeepholeOpt.cpp
blob4a0312d5b276f391318d9e4cb3fc332733dbf5da
1 //===- SMEPeepholeOpt.cpp - SME peephole optimization pass-----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // This pass tries to remove back-to-back (smstart, smstop) and
9 // (smstop, smstart) sequences. The pass is conservative when it cannot
10 // determine that it is safe to remove these sequences.
11 //===----------------------------------------------------------------------===//
13 #include "AArch64InstrInfo.h"
14 #include "AArch64MachineFunctionInfo.h"
15 #include "AArch64Subtarget.h"
16 #include "llvm/ADT/SmallVector.h"
17 #include "llvm/CodeGen/MachineBasicBlock.h"
18 #include "llvm/CodeGen/MachineFunctionPass.h"
19 #include "llvm/CodeGen/MachineRegisterInfo.h"
20 #include "llvm/CodeGen/TargetRegisterInfo.h"
22 using namespace llvm;
24 #define DEBUG_TYPE "aarch64-sme-peephole-opt"
26 namespace {
28 struct SMEPeepholeOpt : public MachineFunctionPass {
29 static char ID;
31 SMEPeepholeOpt() : MachineFunctionPass(ID) {
32 initializeSMEPeepholeOptPass(*PassRegistry::getPassRegistry());
35 bool runOnMachineFunction(MachineFunction &MF) override;
37 StringRef getPassName() const override {
38 return "SME Peephole Optimization pass";
41 void getAnalysisUsage(AnalysisUsage &AU) const override {
42 AU.setPreservesCFG();
43 MachineFunctionPass::getAnalysisUsage(AU);
46 bool optimizeStartStopPairs(MachineBasicBlock &MBB,
47 bool &HasRemovedAllSMChanges) const;
50 char SMEPeepholeOpt::ID = 0;
52 } // end anonymous namespace
54 static bool isConditionalStartStop(const MachineInstr *MI) {
55 return MI->getOpcode() == AArch64::MSRpstatePseudo;
58 static bool isMatchingStartStopPair(const MachineInstr *MI1,
59 const MachineInstr *MI2) {
60 // We only consider the same type of streaming mode change here, i.e.
61 // start/stop SM, or start/stop ZA pairs.
62 if (MI1->getOperand(0).getImm() != MI2->getOperand(0).getImm())
63 return false;
65 // One must be 'start', the other must be 'stop'
66 if (MI1->getOperand(1).getImm() == MI2->getOperand(1).getImm())
67 return false;
69 bool IsConditional = isConditionalStartStop(MI2);
70 if (isConditionalStartStop(MI1) != IsConditional)
71 return false;
73 if (!IsConditional)
74 return true;
76 // Check to make sure the conditional start/stop pairs are identical.
77 if (MI1->getOperand(2).getImm() != MI2->getOperand(2).getImm())
78 return false;
80 // Ensure reg masks are identical.
81 if (MI1->getOperand(4).getRegMask() != MI2->getOperand(4).getRegMask())
82 return false;
84 // This optimisation is unlikely to happen in practice for conditional
85 // smstart/smstop pairs as the virtual registers for pstate.sm will always
86 // be different.
87 // TODO: For this optimisation to apply to conditional smstart/smstop,
88 // this pass will need to do more work to remove redundant calls to
89 // __arm_sme_state.
91 // Only consider conditional start/stop pairs which read the same register
92 // holding the original value of pstate.sm, as some conditional start/stops
93 // require the state on entry to the function.
94 if (MI1->getOperand(3).isReg() && MI2->getOperand(3).isReg()) {
95 Register Reg1 = MI1->getOperand(3).getReg();
96 Register Reg2 = MI2->getOperand(3).getReg();
97 if (Reg1.isPhysical() || Reg2.isPhysical() || Reg1 != Reg2)
98 return false;
101 return true;
104 static bool ChangesStreamingMode(const MachineInstr *MI) {
105 assert((MI->getOpcode() == AArch64::MSRpstatesvcrImm1 ||
106 MI->getOpcode() == AArch64::MSRpstatePseudo) &&
107 "Expected MI to be a smstart/smstop instruction");
108 return MI->getOperand(0).getImm() == AArch64SVCR::SVCRSM ||
109 MI->getOperand(0).getImm() == AArch64SVCR::SVCRSMZA;
112 static bool isSVERegOp(const TargetRegisterInfo &TRI,
113 const MachineRegisterInfo &MRI,
114 const MachineOperand &MO) {
115 if (!MO.isReg())
116 return false;
118 Register R = MO.getReg();
119 if (R.isPhysical())
120 return llvm::any_of(TRI.subregs_inclusive(R), [](const MCPhysReg &SR) {
121 return AArch64::ZPRRegClass.contains(SR) ||
122 AArch64::PPRRegClass.contains(SR);
125 const TargetRegisterClass *RC = MRI.getRegClass(R);
126 return TRI.getCommonSubClass(&AArch64::ZPRRegClass, RC) ||
127 TRI.getCommonSubClass(&AArch64::PPRRegClass, RC);
130 bool SMEPeepholeOpt::optimizeStartStopPairs(
131 MachineBasicBlock &MBB, bool &HasRemovedAllSMChanges) const {
132 const MachineRegisterInfo &MRI = MBB.getParent()->getRegInfo();
133 const TargetRegisterInfo &TRI =
134 *MBB.getParent()->getSubtarget().getRegisterInfo();
136 bool Changed = false;
137 MachineInstr *Prev = nullptr;
138 SmallVector<MachineInstr *, 4> ToBeRemoved;
140 // Convenience function to reset the matching of a sequence.
141 auto Reset = [&]() {
142 Prev = nullptr;
143 ToBeRemoved.clear();
146 // Walk through instructions in the block trying to find pairs of smstart
147 // and smstop nodes that cancel each other out. We only permit a limited
148 // set of instructions to appear between them, otherwise we reset our
149 // tracking.
150 unsigned NumSMChanges = 0;
151 unsigned NumSMChangesRemoved = 0;
152 for (MachineInstr &MI : make_early_inc_range(MBB)) {
153 switch (MI.getOpcode()) {
154 case AArch64::MSRpstatesvcrImm1:
155 case AArch64::MSRpstatePseudo: {
156 if (ChangesStreamingMode(&MI))
157 NumSMChanges++;
159 if (!Prev)
160 Prev = &MI;
161 else if (isMatchingStartStopPair(Prev, &MI)) {
162 // If they match, we can remove them, and possibly any instructions
163 // that we marked for deletion in between.
164 Prev->eraseFromParent();
165 MI.eraseFromParent();
166 for (MachineInstr *TBR : ToBeRemoved)
167 TBR->eraseFromParent();
168 ToBeRemoved.clear();
169 Prev = nullptr;
170 Changed = true;
171 NumSMChangesRemoved += 2;
172 } else {
173 Reset();
174 Prev = &MI;
176 continue;
178 default:
179 if (!Prev)
180 // Avoid doing expensive checks when Prev is nullptr.
181 continue;
182 break;
185 // Test if the instructions in between the start/stop sequence are agnostic
186 // of streaming mode. If not, the algorithm should reset.
187 switch (MI.getOpcode()) {
188 default:
189 Reset();
190 break;
191 case AArch64::COALESCER_BARRIER_FPR16:
192 case AArch64::COALESCER_BARRIER_FPR32:
193 case AArch64::COALESCER_BARRIER_FPR64:
194 case AArch64::COALESCER_BARRIER_FPR128:
195 case AArch64::COPY:
196 // These instructions should be safe when executed on their own, but
197 // the code remains conservative when SVE registers are used. There may
198 // exist subtle cases where executing a COPY in a different mode results
199 // in different behaviour, even if we can't yet come up with any
200 // concrete example/test-case.
201 if (isSVERegOp(TRI, MRI, MI.getOperand(0)) ||
202 isSVERegOp(TRI, MRI, MI.getOperand(1)))
203 Reset();
204 break;
205 case AArch64::ADJCALLSTACKDOWN:
206 case AArch64::ADJCALLSTACKUP:
207 case AArch64::ANDXri:
208 case AArch64::ADDXri:
209 // We permit these as they don't generate SVE/NEON instructions.
210 break;
211 case AArch64::VGRestorePseudo:
212 case AArch64::VGSavePseudo:
213 // When the smstart/smstop are removed, we should also remove
214 // the pseudos that save/restore the VG value for CFI info.
215 ToBeRemoved.push_back(&MI);
216 break;
217 case AArch64::MSRpstatesvcrImm1:
218 case AArch64::MSRpstatePseudo:
219 llvm_unreachable("Should have been handled");
223 HasRemovedAllSMChanges =
224 NumSMChanges && (NumSMChanges == NumSMChangesRemoved);
225 return Changed;
228 INITIALIZE_PASS(SMEPeepholeOpt, "aarch64-sme-peephole-opt",
229 "SME Peephole Optimization", false, false)
231 bool SMEPeepholeOpt::runOnMachineFunction(MachineFunction &MF) {
232 if (skipFunction(MF.getFunction()))
233 return false;
235 if (!MF.getSubtarget<AArch64Subtarget>().hasSME())
236 return false;
238 assert(MF.getRegInfo().isSSA() && "Expected to be run on SSA form!");
240 bool Changed = false;
241 bool FunctionHasAllSMChangesRemoved = false;
243 // Even if the block lives in a function with no SME attributes attached we
244 // still have to analyze all the blocks because we may call a streaming
245 // function that requires smstart/smstop pairs.
246 for (MachineBasicBlock &MBB : MF) {
247 bool BlockHasAllSMChangesRemoved;
248 Changed |= optimizeStartStopPairs(MBB, BlockHasAllSMChangesRemoved);
249 FunctionHasAllSMChangesRemoved |= BlockHasAllSMChangesRemoved;
252 AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
253 if (FunctionHasAllSMChangesRemoved)
254 AFI->setHasStreamingModeChanges(false);
256 return Changed;
259 FunctionPass *llvm::createSMEPeepholeOptPass() { return new SMEPeepholeOpt(); }