[SampleProfileLoader] Fix integer overflow in generateMDProfMetadata (#90217)
[llvm-project.git] / llvm / lib / Target / RISCV / RISCVExpandAtomicPseudoInsts.cpp
blobbb772fc5da9224492dec76b161b856c6932a8fcf
1 //===-- RISCVExpandAtomicPseudoInsts.cpp - Expand atomic pseudo instrs. ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains a pass that expands atomic pseudo instructions into
10 // target instructions. This pass should be run at the last possible moment,
11 // avoiding the possibility for other passes to break the requirements for
12 // forward progress in the LR/SC block.
14 //===----------------------------------------------------------------------===//
16 #include "RISCV.h"
17 #include "RISCVInstrInfo.h"
18 #include "RISCVTargetMachine.h"
20 #include "llvm/CodeGen/LivePhysRegs.h"
21 #include "llvm/CodeGen/MachineFunctionPass.h"
22 #include "llvm/CodeGen/MachineInstrBuilder.h"
24 using namespace llvm;
26 #define RISCV_EXPAND_ATOMIC_PSEUDO_NAME \
27 "RISC-V atomic pseudo instruction expansion pass"
29 namespace {
31 class RISCVExpandAtomicPseudo : public MachineFunctionPass {
32 public:
33 const RISCVSubtarget *STI;
34 const RISCVInstrInfo *TII;
35 static char ID;
37 RISCVExpandAtomicPseudo() : MachineFunctionPass(ID) {
38 initializeRISCVExpandAtomicPseudoPass(*PassRegistry::getPassRegistry());
41 bool runOnMachineFunction(MachineFunction &MF) override;
43 StringRef getPassName() const override {
44 return RISCV_EXPAND_ATOMIC_PSEUDO_NAME;
47 private:
48 bool expandMBB(MachineBasicBlock &MBB);
49 bool expandMI(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
50 MachineBasicBlock::iterator &NextMBBI);
51 bool expandAtomicBinOp(MachineBasicBlock &MBB,
52 MachineBasicBlock::iterator MBBI, AtomicRMWInst::BinOp,
53 bool IsMasked, int Width,
54 MachineBasicBlock::iterator &NextMBBI);
55 bool expandAtomicMinMaxOp(MachineBasicBlock &MBB,
56 MachineBasicBlock::iterator MBBI,
57 AtomicRMWInst::BinOp, bool IsMasked, int Width,
58 MachineBasicBlock::iterator &NextMBBI);
59 bool expandAtomicCmpXchg(MachineBasicBlock &MBB,
60 MachineBasicBlock::iterator MBBI, bool IsMasked,
61 int Width, MachineBasicBlock::iterator &NextMBBI);
62 #ifndef NDEBUG
63 unsigned getInstSizeInBytes(const MachineFunction &MF) const {
64 unsigned Size = 0;
65 for (auto &MBB : MF)
66 for (auto &MI : MBB)
67 Size += TII->getInstSizeInBytes(MI);
68 return Size;
70 #endif
73 char RISCVExpandAtomicPseudo::ID = 0;
75 bool RISCVExpandAtomicPseudo::runOnMachineFunction(MachineFunction &MF) {
76 STI = &MF.getSubtarget<RISCVSubtarget>();
77 TII = STI->getInstrInfo();
79 #ifndef NDEBUG
80 const unsigned OldSize = getInstSizeInBytes(MF);
81 #endif
83 bool Modified = false;
84 for (auto &MBB : MF)
85 Modified |= expandMBB(MBB);
87 #ifndef NDEBUG
88 const unsigned NewSize = getInstSizeInBytes(MF);
89 assert(OldSize >= NewSize);
90 #endif
91 return Modified;
94 bool RISCVExpandAtomicPseudo::expandMBB(MachineBasicBlock &MBB) {
95 bool Modified = false;
97 MachineBasicBlock::iterator MBBI = MBB.begin(), E = MBB.end();
98 while (MBBI != E) {
99 MachineBasicBlock::iterator NMBBI = std::next(MBBI);
100 Modified |= expandMI(MBB, MBBI, NMBBI);
101 MBBI = NMBBI;
104 return Modified;
107 bool RISCVExpandAtomicPseudo::expandMI(MachineBasicBlock &MBB,
108 MachineBasicBlock::iterator MBBI,
109 MachineBasicBlock::iterator &NextMBBI) {
110 // RISCVInstrInfo::getInstSizeInBytes expects that the total size of the
111 // expanded instructions for each pseudo is correct in the Size field of the
112 // tablegen definition for the pseudo.
113 switch (MBBI->getOpcode()) {
114 case RISCV::PseudoAtomicLoadNand32:
115 return expandAtomicBinOp(MBB, MBBI, AtomicRMWInst::Nand, false, 32,
116 NextMBBI);
117 case RISCV::PseudoAtomicLoadNand64:
118 return expandAtomicBinOp(MBB, MBBI, AtomicRMWInst::Nand, false, 64,
119 NextMBBI);
120 case RISCV::PseudoMaskedAtomicSwap32:
121 return expandAtomicBinOp(MBB, MBBI, AtomicRMWInst::Xchg, true, 32,
122 NextMBBI);
123 case RISCV::PseudoMaskedAtomicLoadAdd32:
124 return expandAtomicBinOp(MBB, MBBI, AtomicRMWInst::Add, true, 32, NextMBBI);
125 case RISCV::PseudoMaskedAtomicLoadSub32:
126 return expandAtomicBinOp(MBB, MBBI, AtomicRMWInst::Sub, true, 32, NextMBBI);
127 case RISCV::PseudoMaskedAtomicLoadNand32:
128 return expandAtomicBinOp(MBB, MBBI, AtomicRMWInst::Nand, true, 32,
129 NextMBBI);
130 case RISCV::PseudoMaskedAtomicLoadMax32:
131 return expandAtomicMinMaxOp(MBB, MBBI, AtomicRMWInst::Max, true, 32,
132 NextMBBI);
133 case RISCV::PseudoMaskedAtomicLoadMin32:
134 return expandAtomicMinMaxOp(MBB, MBBI, AtomicRMWInst::Min, true, 32,
135 NextMBBI);
136 case RISCV::PseudoMaskedAtomicLoadUMax32:
137 return expandAtomicMinMaxOp(MBB, MBBI, AtomicRMWInst::UMax, true, 32,
138 NextMBBI);
139 case RISCV::PseudoMaskedAtomicLoadUMin32:
140 return expandAtomicMinMaxOp(MBB, MBBI, AtomicRMWInst::UMin, true, 32,
141 NextMBBI);
142 case RISCV::PseudoCmpXchg32:
143 return expandAtomicCmpXchg(MBB, MBBI, false, 32, NextMBBI);
144 case RISCV::PseudoCmpXchg64:
145 return expandAtomicCmpXchg(MBB, MBBI, false, 64, NextMBBI);
146 case RISCV::PseudoMaskedCmpXchg32:
147 return expandAtomicCmpXchg(MBB, MBBI, true, 32, NextMBBI);
150 return false;
153 static unsigned getLRForRMW32(AtomicOrdering Ordering,
154 const RISCVSubtarget *Subtarget) {
155 switch (Ordering) {
156 default:
157 llvm_unreachable("Unexpected AtomicOrdering");
158 case AtomicOrdering::Monotonic:
159 return RISCV::LR_W;
160 case AtomicOrdering::Acquire:
161 if (Subtarget->hasStdExtZtso())
162 return RISCV::LR_W;
163 return RISCV::LR_W_AQ;
164 case AtomicOrdering::Release:
165 return RISCV::LR_W;
166 case AtomicOrdering::AcquireRelease:
167 if (Subtarget->hasStdExtZtso())
168 return RISCV::LR_W;
169 return RISCV::LR_W_AQ;
170 case AtomicOrdering::SequentiallyConsistent:
171 return RISCV::LR_W_AQ_RL;
175 static unsigned getSCForRMW32(AtomicOrdering Ordering,
176 const RISCVSubtarget *Subtarget) {
177 switch (Ordering) {
178 default:
179 llvm_unreachable("Unexpected AtomicOrdering");
180 case AtomicOrdering::Monotonic:
181 return RISCV::SC_W;
182 case AtomicOrdering::Acquire:
183 return RISCV::SC_W;
184 case AtomicOrdering::Release:
185 if (Subtarget->hasStdExtZtso())
186 return RISCV::SC_W;
187 return RISCV::SC_W_RL;
188 case AtomicOrdering::AcquireRelease:
189 if (Subtarget->hasStdExtZtso())
190 return RISCV::SC_W;
191 return RISCV::SC_W_RL;
192 case AtomicOrdering::SequentiallyConsistent:
193 return RISCV::SC_W_RL;
197 static unsigned getLRForRMW64(AtomicOrdering Ordering,
198 const RISCVSubtarget *Subtarget) {
199 switch (Ordering) {
200 default:
201 llvm_unreachable("Unexpected AtomicOrdering");
202 case AtomicOrdering::Monotonic:
203 return RISCV::LR_D;
204 case AtomicOrdering::Acquire:
205 if (Subtarget->hasStdExtZtso())
206 return RISCV::LR_D;
207 return RISCV::LR_D_AQ;
208 case AtomicOrdering::Release:
209 return RISCV::LR_D;
210 case AtomicOrdering::AcquireRelease:
211 if (Subtarget->hasStdExtZtso())
212 return RISCV::LR_D;
213 return RISCV::LR_D_AQ;
214 case AtomicOrdering::SequentiallyConsistent:
215 return RISCV::LR_D_AQ_RL;
219 static unsigned getSCForRMW64(AtomicOrdering Ordering,
220 const RISCVSubtarget *Subtarget) {
221 switch (Ordering) {
222 default:
223 llvm_unreachable("Unexpected AtomicOrdering");
224 case AtomicOrdering::Monotonic:
225 return RISCV::SC_D;
226 case AtomicOrdering::Acquire:
227 return RISCV::SC_D;
228 case AtomicOrdering::Release:
229 if (Subtarget->hasStdExtZtso())
230 return RISCV::SC_D;
231 return RISCV::SC_D_RL;
232 case AtomicOrdering::AcquireRelease:
233 if (Subtarget->hasStdExtZtso())
234 return RISCV::SC_D;
235 return RISCV::SC_D_RL;
236 case AtomicOrdering::SequentiallyConsistent:
237 return RISCV::SC_D_RL;
241 static unsigned getLRForRMW(AtomicOrdering Ordering, int Width,
242 const RISCVSubtarget *Subtarget) {
243 if (Width == 32)
244 return getLRForRMW32(Ordering, Subtarget);
245 if (Width == 64)
246 return getLRForRMW64(Ordering, Subtarget);
247 llvm_unreachable("Unexpected LR width\n");
250 static unsigned getSCForRMW(AtomicOrdering Ordering, int Width,
251 const RISCVSubtarget *Subtarget) {
252 if (Width == 32)
253 return getSCForRMW32(Ordering, Subtarget);
254 if (Width == 64)
255 return getSCForRMW64(Ordering, Subtarget);
256 llvm_unreachable("Unexpected SC width\n");
259 static void doAtomicBinOpExpansion(const RISCVInstrInfo *TII, MachineInstr &MI,
260 DebugLoc DL, MachineBasicBlock *ThisMBB,
261 MachineBasicBlock *LoopMBB,
262 MachineBasicBlock *DoneMBB,
263 AtomicRMWInst::BinOp BinOp, int Width,
264 const RISCVSubtarget *STI) {
265 Register DestReg = MI.getOperand(0).getReg();
266 Register ScratchReg = MI.getOperand(1).getReg();
267 Register AddrReg = MI.getOperand(2).getReg();
268 Register IncrReg = MI.getOperand(3).getReg();
269 AtomicOrdering Ordering =
270 static_cast<AtomicOrdering>(MI.getOperand(4).getImm());
272 // .loop:
273 // lr.[w|d] dest, (addr)
274 // binop scratch, dest, val
275 // sc.[w|d] scratch, scratch, (addr)
276 // bnez scratch, loop
277 BuildMI(LoopMBB, DL, TII->get(getLRForRMW(Ordering, Width, STI)), DestReg)
278 .addReg(AddrReg);
279 switch (BinOp) {
280 default:
281 llvm_unreachable("Unexpected AtomicRMW BinOp");
282 case AtomicRMWInst::Nand:
283 BuildMI(LoopMBB, DL, TII->get(RISCV::AND), ScratchReg)
284 .addReg(DestReg)
285 .addReg(IncrReg);
286 BuildMI(LoopMBB, DL, TII->get(RISCV::XORI), ScratchReg)
287 .addReg(ScratchReg)
288 .addImm(-1);
289 break;
291 BuildMI(LoopMBB, DL, TII->get(getSCForRMW(Ordering, Width, STI)), ScratchReg)
292 .addReg(AddrReg)
293 .addReg(ScratchReg);
294 BuildMI(LoopMBB, DL, TII->get(RISCV::BNE))
295 .addReg(ScratchReg)
296 .addReg(RISCV::X0)
297 .addMBB(LoopMBB);
300 static void insertMaskedMerge(const RISCVInstrInfo *TII, DebugLoc DL,
301 MachineBasicBlock *MBB, Register DestReg,
302 Register OldValReg, Register NewValReg,
303 Register MaskReg, Register ScratchReg) {
304 assert(OldValReg != ScratchReg && "OldValReg and ScratchReg must be unique");
305 assert(OldValReg != MaskReg && "OldValReg and MaskReg must be unique");
306 assert(ScratchReg != MaskReg && "ScratchReg and MaskReg must be unique");
308 // We select bits from newval and oldval using:
309 // https://graphics.stanford.edu/~seander/bithacks.html#MaskedMerge
310 // r = oldval ^ ((oldval ^ newval) & masktargetdata);
311 BuildMI(MBB, DL, TII->get(RISCV::XOR), ScratchReg)
312 .addReg(OldValReg)
313 .addReg(NewValReg);
314 BuildMI(MBB, DL, TII->get(RISCV::AND), ScratchReg)
315 .addReg(ScratchReg)
316 .addReg(MaskReg);
317 BuildMI(MBB, DL, TII->get(RISCV::XOR), DestReg)
318 .addReg(OldValReg)
319 .addReg(ScratchReg);
322 static void doMaskedAtomicBinOpExpansion(const RISCVInstrInfo *TII,
323 MachineInstr &MI, DebugLoc DL,
324 MachineBasicBlock *ThisMBB,
325 MachineBasicBlock *LoopMBB,
326 MachineBasicBlock *DoneMBB,
327 AtomicRMWInst::BinOp BinOp, int Width,
328 const RISCVSubtarget *STI) {
329 assert(Width == 32 && "Should never need to expand masked 64-bit operations");
330 Register DestReg = MI.getOperand(0).getReg();
331 Register ScratchReg = MI.getOperand(1).getReg();
332 Register AddrReg = MI.getOperand(2).getReg();
333 Register IncrReg = MI.getOperand(3).getReg();
334 Register MaskReg = MI.getOperand(4).getReg();
335 AtomicOrdering Ordering =
336 static_cast<AtomicOrdering>(MI.getOperand(5).getImm());
338 // .loop:
339 // lr.w destreg, (alignedaddr)
340 // binop scratch, destreg, incr
341 // xor scratch, destreg, scratch
342 // and scratch, scratch, masktargetdata
343 // xor scratch, destreg, scratch
344 // sc.w scratch, scratch, (alignedaddr)
345 // bnez scratch, loop
346 BuildMI(LoopMBB, DL, TII->get(getLRForRMW32(Ordering, STI)), DestReg)
347 .addReg(AddrReg);
348 switch (BinOp) {
349 default:
350 llvm_unreachable("Unexpected AtomicRMW BinOp");
351 case AtomicRMWInst::Xchg:
352 BuildMI(LoopMBB, DL, TII->get(RISCV::ADDI), ScratchReg)
353 .addReg(IncrReg)
354 .addImm(0);
355 break;
356 case AtomicRMWInst::Add:
357 BuildMI(LoopMBB, DL, TII->get(RISCV::ADD), ScratchReg)
358 .addReg(DestReg)
359 .addReg(IncrReg);
360 break;
361 case AtomicRMWInst::Sub:
362 BuildMI(LoopMBB, DL, TII->get(RISCV::SUB), ScratchReg)
363 .addReg(DestReg)
364 .addReg(IncrReg);
365 break;
366 case AtomicRMWInst::Nand:
367 BuildMI(LoopMBB, DL, TII->get(RISCV::AND), ScratchReg)
368 .addReg(DestReg)
369 .addReg(IncrReg);
370 BuildMI(LoopMBB, DL, TII->get(RISCV::XORI), ScratchReg)
371 .addReg(ScratchReg)
372 .addImm(-1);
373 break;
376 insertMaskedMerge(TII, DL, LoopMBB, ScratchReg, DestReg, ScratchReg, MaskReg,
377 ScratchReg);
379 BuildMI(LoopMBB, DL, TII->get(getSCForRMW32(Ordering, STI)), ScratchReg)
380 .addReg(AddrReg)
381 .addReg(ScratchReg);
382 BuildMI(LoopMBB, DL, TII->get(RISCV::BNE))
383 .addReg(ScratchReg)
384 .addReg(RISCV::X0)
385 .addMBB(LoopMBB);
388 bool RISCVExpandAtomicPseudo::expandAtomicBinOp(
389 MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
390 AtomicRMWInst::BinOp BinOp, bool IsMasked, int Width,
391 MachineBasicBlock::iterator &NextMBBI) {
392 MachineInstr &MI = *MBBI;
393 DebugLoc DL = MI.getDebugLoc();
395 MachineFunction *MF = MBB.getParent();
396 auto LoopMBB = MF->CreateMachineBasicBlock(MBB.getBasicBlock());
397 auto DoneMBB = MF->CreateMachineBasicBlock(MBB.getBasicBlock());
399 // Insert new MBBs.
400 MF->insert(++MBB.getIterator(), LoopMBB);
401 MF->insert(++LoopMBB->getIterator(), DoneMBB);
403 // Set up successors and transfer remaining instructions to DoneMBB.
404 LoopMBB->addSuccessor(LoopMBB);
405 LoopMBB->addSuccessor(DoneMBB);
406 DoneMBB->splice(DoneMBB->end(), &MBB, MI, MBB.end());
407 DoneMBB->transferSuccessors(&MBB);
408 MBB.addSuccessor(LoopMBB);
410 if (!IsMasked)
411 doAtomicBinOpExpansion(TII, MI, DL, &MBB, LoopMBB, DoneMBB, BinOp, Width,
412 STI);
413 else
414 doMaskedAtomicBinOpExpansion(TII, MI, DL, &MBB, LoopMBB, DoneMBB, BinOp,
415 Width, STI);
417 NextMBBI = MBB.end();
418 MI.eraseFromParent();
420 LivePhysRegs LiveRegs;
421 computeAndAddLiveIns(LiveRegs, *LoopMBB);
422 computeAndAddLiveIns(LiveRegs, *DoneMBB);
424 return true;
427 static void insertSext(const RISCVInstrInfo *TII, DebugLoc DL,
428 MachineBasicBlock *MBB, Register ValReg,
429 Register ShamtReg) {
430 BuildMI(MBB, DL, TII->get(RISCV::SLL), ValReg)
431 .addReg(ValReg)
432 .addReg(ShamtReg);
433 BuildMI(MBB, DL, TII->get(RISCV::SRA), ValReg)
434 .addReg(ValReg)
435 .addReg(ShamtReg);
438 bool RISCVExpandAtomicPseudo::expandAtomicMinMaxOp(
439 MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
440 AtomicRMWInst::BinOp BinOp, bool IsMasked, int Width,
441 MachineBasicBlock::iterator &NextMBBI) {
442 assert(IsMasked == true &&
443 "Should only need to expand masked atomic max/min");
444 assert(Width == 32 && "Should never need to expand masked 64-bit operations");
446 MachineInstr &MI = *MBBI;
447 DebugLoc DL = MI.getDebugLoc();
448 MachineFunction *MF = MBB.getParent();
449 auto LoopHeadMBB = MF->CreateMachineBasicBlock(MBB.getBasicBlock());
450 auto LoopIfBodyMBB = MF->CreateMachineBasicBlock(MBB.getBasicBlock());
451 auto LoopTailMBB = MF->CreateMachineBasicBlock(MBB.getBasicBlock());
452 auto DoneMBB = MF->CreateMachineBasicBlock(MBB.getBasicBlock());
454 // Insert new MBBs.
455 MF->insert(++MBB.getIterator(), LoopHeadMBB);
456 MF->insert(++LoopHeadMBB->getIterator(), LoopIfBodyMBB);
457 MF->insert(++LoopIfBodyMBB->getIterator(), LoopTailMBB);
458 MF->insert(++LoopTailMBB->getIterator(), DoneMBB);
460 // Set up successors and transfer remaining instructions to DoneMBB.
461 LoopHeadMBB->addSuccessor(LoopIfBodyMBB);
462 LoopHeadMBB->addSuccessor(LoopTailMBB);
463 LoopIfBodyMBB->addSuccessor(LoopTailMBB);
464 LoopTailMBB->addSuccessor(LoopHeadMBB);
465 LoopTailMBB->addSuccessor(DoneMBB);
466 DoneMBB->splice(DoneMBB->end(), &MBB, MI, MBB.end());
467 DoneMBB->transferSuccessors(&MBB);
468 MBB.addSuccessor(LoopHeadMBB);
470 Register DestReg = MI.getOperand(0).getReg();
471 Register Scratch1Reg = MI.getOperand(1).getReg();
472 Register Scratch2Reg = MI.getOperand(2).getReg();
473 Register AddrReg = MI.getOperand(3).getReg();
474 Register IncrReg = MI.getOperand(4).getReg();
475 Register MaskReg = MI.getOperand(5).getReg();
476 bool IsSigned = BinOp == AtomicRMWInst::Min || BinOp == AtomicRMWInst::Max;
477 AtomicOrdering Ordering =
478 static_cast<AtomicOrdering>(MI.getOperand(IsSigned ? 7 : 6).getImm());
481 // .loophead:
482 // lr.w destreg, (alignedaddr)
483 // and scratch2, destreg, mask
484 // mv scratch1, destreg
485 // [sext scratch2 if signed min/max]
486 // ifnochangeneeded scratch2, incr, .looptail
487 BuildMI(LoopHeadMBB, DL, TII->get(getLRForRMW32(Ordering, STI)), DestReg)
488 .addReg(AddrReg);
489 BuildMI(LoopHeadMBB, DL, TII->get(RISCV::AND), Scratch2Reg)
490 .addReg(DestReg)
491 .addReg(MaskReg);
492 BuildMI(LoopHeadMBB, DL, TII->get(RISCV::ADDI), Scratch1Reg)
493 .addReg(DestReg)
494 .addImm(0);
496 switch (BinOp) {
497 default:
498 llvm_unreachable("Unexpected AtomicRMW BinOp");
499 case AtomicRMWInst::Max: {
500 insertSext(TII, DL, LoopHeadMBB, Scratch2Reg, MI.getOperand(6).getReg());
501 BuildMI(LoopHeadMBB, DL, TII->get(RISCV::BGE))
502 .addReg(Scratch2Reg)
503 .addReg(IncrReg)
504 .addMBB(LoopTailMBB);
505 break;
507 case AtomicRMWInst::Min: {
508 insertSext(TII, DL, LoopHeadMBB, Scratch2Reg, MI.getOperand(6).getReg());
509 BuildMI(LoopHeadMBB, DL, TII->get(RISCV::BGE))
510 .addReg(IncrReg)
511 .addReg(Scratch2Reg)
512 .addMBB(LoopTailMBB);
513 break;
515 case AtomicRMWInst::UMax:
516 BuildMI(LoopHeadMBB, DL, TII->get(RISCV::BGEU))
517 .addReg(Scratch2Reg)
518 .addReg(IncrReg)
519 .addMBB(LoopTailMBB);
520 break;
521 case AtomicRMWInst::UMin:
522 BuildMI(LoopHeadMBB, DL, TII->get(RISCV::BGEU))
523 .addReg(IncrReg)
524 .addReg(Scratch2Reg)
525 .addMBB(LoopTailMBB);
526 break;
529 // .loopifbody:
530 // xor scratch1, destreg, incr
531 // and scratch1, scratch1, mask
532 // xor scratch1, destreg, scratch1
533 insertMaskedMerge(TII, DL, LoopIfBodyMBB, Scratch1Reg, DestReg, IncrReg,
534 MaskReg, Scratch1Reg);
536 // .looptail:
537 // sc.w scratch1, scratch1, (addr)
538 // bnez scratch1, loop
539 BuildMI(LoopTailMBB, DL, TII->get(getSCForRMW32(Ordering, STI)), Scratch1Reg)
540 .addReg(AddrReg)
541 .addReg(Scratch1Reg);
542 BuildMI(LoopTailMBB, DL, TII->get(RISCV::BNE))
543 .addReg(Scratch1Reg)
544 .addReg(RISCV::X0)
545 .addMBB(LoopHeadMBB);
547 NextMBBI = MBB.end();
548 MI.eraseFromParent();
550 LivePhysRegs LiveRegs;
551 computeAndAddLiveIns(LiveRegs, *LoopHeadMBB);
552 computeAndAddLiveIns(LiveRegs, *LoopIfBodyMBB);
553 computeAndAddLiveIns(LiveRegs, *LoopTailMBB);
554 computeAndAddLiveIns(LiveRegs, *DoneMBB);
556 return true;
559 // If a BNE on the cmpxchg comparison result immediately follows the cmpxchg
560 // operation, it can be folded into the cmpxchg expansion by
561 // modifying the branch within 'LoopHead' (which performs the same
562 // comparison). This is a valid transformation because after altering the
563 // LoopHead's BNE destination, the BNE following the cmpxchg becomes
564 // redundant and and be deleted. In the case of a masked cmpxchg, an
565 // appropriate AND and BNE must be matched.
567 // On success, returns true and deletes the matching BNE or AND+BNE, sets the
568 // LoopHeadBNETarget argument to the target that should be used within the
569 // loop head, and removes that block as a successor to MBB.
570 bool tryToFoldBNEOnCmpXchgResult(MachineBasicBlock &MBB,
571 MachineBasicBlock::iterator MBBI,
572 Register DestReg, Register CmpValReg,
573 Register MaskReg,
574 MachineBasicBlock *&LoopHeadBNETarget) {
575 SmallVector<MachineInstr *> ToErase;
576 auto E = MBB.end();
577 if (MBBI == E)
578 return false;
579 MBBI = skipDebugInstructionsForward(MBBI, E);
581 // If we have a masked cmpxchg, match AND dst, DestReg, MaskReg.
582 if (MaskReg.isValid()) {
583 if (MBBI == E || MBBI->getOpcode() != RISCV::AND)
584 return false;
585 Register ANDOp1 = MBBI->getOperand(1).getReg();
586 Register ANDOp2 = MBBI->getOperand(2).getReg();
587 if (!(ANDOp1 == DestReg && ANDOp2 == MaskReg) &&
588 !(ANDOp1 == MaskReg && ANDOp2 == DestReg))
589 return false;
590 // We now expect the BNE to use the result of the AND as an operand.
591 DestReg = MBBI->getOperand(0).getReg();
592 ToErase.push_back(&*MBBI);
593 MBBI = skipDebugInstructionsForward(std::next(MBBI), E);
596 // Match BNE DestReg, MaskReg.
597 if (MBBI == E || MBBI->getOpcode() != RISCV::BNE)
598 return false;
599 Register BNEOp0 = MBBI->getOperand(0).getReg();
600 Register BNEOp1 = MBBI->getOperand(1).getReg();
601 if (!(BNEOp0 == DestReg && BNEOp1 == CmpValReg) &&
602 !(BNEOp0 == CmpValReg && BNEOp1 == DestReg))
603 return false;
605 // Make sure the branch is the only user of the AND.
606 if (MaskReg.isValid()) {
607 if (BNEOp0 == DestReg && !MBBI->getOperand(0).isKill())
608 return false;
609 if (BNEOp1 == DestReg && !MBBI->getOperand(1).isKill())
610 return false;
613 ToErase.push_back(&*MBBI);
614 LoopHeadBNETarget = MBBI->getOperand(2).getMBB();
615 MBBI = skipDebugInstructionsForward(std::next(MBBI), E);
616 if (MBBI != E)
617 return false;
619 MBB.removeSuccessor(LoopHeadBNETarget);
620 for (auto *MI : ToErase)
621 MI->eraseFromParent();
622 return true;
625 bool RISCVExpandAtomicPseudo::expandAtomicCmpXchg(
626 MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI, bool IsMasked,
627 int Width, MachineBasicBlock::iterator &NextMBBI) {
628 MachineInstr &MI = *MBBI;
629 DebugLoc DL = MI.getDebugLoc();
630 MachineFunction *MF = MBB.getParent();
631 auto LoopHeadMBB = MF->CreateMachineBasicBlock(MBB.getBasicBlock());
632 auto LoopTailMBB = MF->CreateMachineBasicBlock(MBB.getBasicBlock());
633 auto DoneMBB = MF->CreateMachineBasicBlock(MBB.getBasicBlock());
635 Register DestReg = MI.getOperand(0).getReg();
636 Register ScratchReg = MI.getOperand(1).getReg();
637 Register AddrReg = MI.getOperand(2).getReg();
638 Register CmpValReg = MI.getOperand(3).getReg();
639 Register NewValReg = MI.getOperand(4).getReg();
640 Register MaskReg = IsMasked ? MI.getOperand(5).getReg() : Register();
642 MachineBasicBlock *LoopHeadBNETarget = DoneMBB;
643 tryToFoldBNEOnCmpXchgResult(MBB, std::next(MBBI), DestReg, CmpValReg, MaskReg,
644 LoopHeadBNETarget);
646 // Insert new MBBs.
647 MF->insert(++MBB.getIterator(), LoopHeadMBB);
648 MF->insert(++LoopHeadMBB->getIterator(), LoopTailMBB);
649 MF->insert(++LoopTailMBB->getIterator(), DoneMBB);
651 // Set up successors and transfer remaining instructions to DoneMBB.
652 LoopHeadMBB->addSuccessor(LoopTailMBB);
653 LoopHeadMBB->addSuccessor(LoopHeadBNETarget);
654 LoopTailMBB->addSuccessor(DoneMBB);
655 LoopTailMBB->addSuccessor(LoopHeadMBB);
656 DoneMBB->splice(DoneMBB->end(), &MBB, MI, MBB.end());
657 DoneMBB->transferSuccessors(&MBB);
658 MBB.addSuccessor(LoopHeadMBB);
660 AtomicOrdering Ordering =
661 static_cast<AtomicOrdering>(MI.getOperand(IsMasked ? 6 : 5).getImm());
663 if (!IsMasked) {
664 // .loophead:
665 // lr.[w|d] dest, (addr)
666 // bne dest, cmpval, done
667 BuildMI(LoopHeadMBB, DL, TII->get(getLRForRMW(Ordering, Width, STI)),
668 DestReg)
669 .addReg(AddrReg);
670 BuildMI(LoopHeadMBB, DL, TII->get(RISCV::BNE))
671 .addReg(DestReg)
672 .addReg(CmpValReg)
673 .addMBB(LoopHeadBNETarget);
674 // .looptail:
675 // sc.[w|d] scratch, newval, (addr)
676 // bnez scratch, loophead
677 BuildMI(LoopTailMBB, DL, TII->get(getSCForRMW(Ordering, Width, STI)),
678 ScratchReg)
679 .addReg(AddrReg)
680 .addReg(NewValReg);
681 BuildMI(LoopTailMBB, DL, TII->get(RISCV::BNE))
682 .addReg(ScratchReg)
683 .addReg(RISCV::X0)
684 .addMBB(LoopHeadMBB);
685 } else {
686 // .loophead:
687 // lr.w dest, (addr)
688 // and scratch, dest, mask
689 // bne scratch, cmpval, done
690 Register MaskReg = MI.getOperand(5).getReg();
691 BuildMI(LoopHeadMBB, DL, TII->get(getLRForRMW(Ordering, Width, STI)),
692 DestReg)
693 .addReg(AddrReg);
694 BuildMI(LoopHeadMBB, DL, TII->get(RISCV::AND), ScratchReg)
695 .addReg(DestReg)
696 .addReg(MaskReg);
697 BuildMI(LoopHeadMBB, DL, TII->get(RISCV::BNE))
698 .addReg(ScratchReg)
699 .addReg(CmpValReg)
700 .addMBB(LoopHeadBNETarget);
702 // .looptail:
703 // xor scratch, dest, newval
704 // and scratch, scratch, mask
705 // xor scratch, dest, scratch
706 // sc.w scratch, scratch, (adrr)
707 // bnez scratch, loophead
708 insertMaskedMerge(TII, DL, LoopTailMBB, ScratchReg, DestReg, NewValReg,
709 MaskReg, ScratchReg);
710 BuildMI(LoopTailMBB, DL, TII->get(getSCForRMW(Ordering, Width, STI)),
711 ScratchReg)
712 .addReg(AddrReg)
713 .addReg(ScratchReg);
714 BuildMI(LoopTailMBB, DL, TII->get(RISCV::BNE))
715 .addReg(ScratchReg)
716 .addReg(RISCV::X0)
717 .addMBB(LoopHeadMBB);
720 NextMBBI = MBB.end();
721 MI.eraseFromParent();
723 LivePhysRegs LiveRegs;
724 computeAndAddLiveIns(LiveRegs, *LoopHeadMBB);
725 computeAndAddLiveIns(LiveRegs, *LoopTailMBB);
726 computeAndAddLiveIns(LiveRegs, *DoneMBB);
728 return true;
731 } // end of anonymous namespace
733 INITIALIZE_PASS(RISCVExpandAtomicPseudo, "riscv-expand-atomic-pseudo",
734 RISCV_EXPAND_ATOMIC_PSEUDO_NAME, false, false)
736 namespace llvm {
738 FunctionPass *createRISCVExpandAtomicPseudoPass() {
739 return new RISCVExpandAtomicPseudo();
742 } // end of namespace llvm