1 //===-- X86LowerTileCopy.cpp - Expand Tile Copy Instructions---------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file defines the pass which lower AMX tile copy instructions. Since
10 // there is no tile copy instruction, we need store tile register to stack
11 // and load from stack to another tile register. We need extra GR to hold
12 // the stride, and we need stack slot to hold the tile data register.
13 // We would run this pass after copy propagation, so that we don't miss copy
14 // optimization. And we would run this pass before prolog/epilog insertion,
15 // so that we can allocate stack slot.
17 //===----------------------------------------------------------------------===//
20 #include "X86InstrBuilder.h"
21 #include "X86InstrInfo.h"
22 #include "X86MachineFunctionInfo.h"
23 #include "X86Subtarget.h"
24 #include "llvm/CodeGen/LiveRegUnits.h"
25 #include "llvm/CodeGen/MachineBasicBlock.h"
26 #include "llvm/CodeGen/MachineFrameInfo.h"
27 #include "llvm/CodeGen/MachineFunction.h"
28 #include "llvm/CodeGen/MachineFunctionPass.h"
29 #include "llvm/CodeGen/MachineInstr.h"
30 #include "llvm/CodeGen/MachineInstrBuilder.h"
31 #include "llvm/CodeGen/MachineOperand.h"
32 #include "llvm/CodeGen/Passes.h"
33 #include "llvm/IR/DebugLoc.h"
37 #define DEBUG_TYPE "x86-lower-tile-copy"
41 class X86LowerTileCopy
: public MachineFunctionPass
{
45 X86LowerTileCopy() : MachineFunctionPass(ID
) {}
47 void getAnalysisUsage(AnalysisUsage
&AU
) const override
;
49 bool runOnMachineFunction(MachineFunction
&MF
) override
;
51 StringRef
getPassName() const override
{ return "X86 Lower Tile Copy"; }
56 char X86LowerTileCopy::ID
= 0;
58 INITIALIZE_PASS_BEGIN(X86LowerTileCopy
, "lowertilecopy", "Tile Copy Lowering",
60 INITIALIZE_PASS_END(X86LowerTileCopy
, "lowertilecopy", "Tile Copy Lowering",
63 void X86LowerTileCopy::getAnalysisUsage(AnalysisUsage
&AU
) const {
65 MachineFunctionPass::getAnalysisUsage(AU
);
68 FunctionPass
*llvm::createX86LowerTileCopyPass() {
69 return new X86LowerTileCopy();
72 bool X86LowerTileCopy::runOnMachineFunction(MachineFunction
&MF
) {
73 X86MachineFunctionInfo
*FuncInfo
= MF
.getInfo
<X86MachineFunctionInfo
>();
74 if (FuncInfo
->getAMXProgModel() != AMXProgModelEnum::ManagedRA
)
77 const X86Subtarget
&ST
= MF
.getSubtarget
<X86Subtarget
>();
78 const X86InstrInfo
*TII
= ST
.getInstrInfo();
79 const TargetRegisterInfo
*TRI
= ST
.getRegisterInfo();
81 TRI
->getAllocatableSet(MF
, TRI
->getRegClass(X86::GR64RegClassID
));
83 TRI
->getAllocatableSet(MF
, TRI
->getRegClass(X86::TILERegClassID
));
86 for (MachineBasicBlock
&MBB
: MF
) {
87 LiveRegUnits
UsedRegs(*TRI
);
88 UsedRegs
.addLiveOuts(MBB
);
89 for (MachineInstr
&MI
: llvm::make_early_inc_range(reverse(MBB
))) {
90 UsedRegs
.stepBackward(MI
);
93 MachineOperand
&DstMO
= MI
.getOperand(0);
94 MachineOperand
&SrcMO
= MI
.getOperand(1);
95 Register SrcReg
= SrcMO
.getReg();
96 Register DstReg
= DstMO
.getReg();
97 if (!X86::TILERegClass
.contains(DstReg
, SrcReg
))
100 // Allocate stack slot for tile register
101 unsigned Size
= TRI
->getSpillSize(X86::TILERegClass
);
102 Align Alignment
= TRI
->getSpillAlign(X86::TILERegClass
);
103 int TileSS
= MF
.getFrameInfo().CreateSpillStackObject(Size
, Alignment
);
107 // Pick a killed register to avoid a save/reload.
108 Register GR64Cand
= X86::NoRegister
;
109 for (auto RegT
: GR64Regs
.set_bits()) {
110 if (UsedRegs
.available(RegT
)) {
116 const DebugLoc
&DL
= MI
.getDebugLoc();
119 BuildMI(MBB
, MI
, DL
, TII
->get(X86::MOV64ri
), GR64Cand
).addImm(64);
121 // No available register? Save RAX and reload it after use.
123 // Allocate stack slot for stride register
124 Size
= TRI
->getSpillSize(X86::GR64RegClass
);
125 Alignment
= TRI
->getSpillAlign(X86::GR64RegClass
);
126 StrideSS
= MF
.getFrameInfo().CreateSpillStackObject(Size
, Alignment
);
129 addFrameReference(BuildMI(MBB
, MI
, DL
, TII
->get(X86::MOV64mr
)),
133 BuildMI(MBB
, MI
, DL
, TII
->get(X86::MOV64ri
), X86::RAX
).addImm(64);
135 // tilestored %tmm, (%sp, %idx)
136 #define GET_EGPR_IF_ENABLED(OPC) (ST.hasEGPR() ? OPC##_EVEX : OPC)
137 unsigned Opc
= GET_EGPR_IF_ENABLED(X86::TILESTORED
);
138 MachineInstr
*NewMI
=
139 addFrameReference(BuildMI(MBB
, MI
, DL
, TII
->get(Opc
)), TileSS
)
140 .addReg(SrcReg
, getKillRegState(SrcMO
.isKill()));
141 MachineOperand
*MO
= &NewMI
->getOperand(X86::AddrIndexReg
);
142 MO
->setReg(GR64Cand
? GR64Cand
: X86::RAX
);
143 // tileloadd (%sp, %idx), %tmm
144 Opc
= GET_EGPR_IF_ENABLED(X86::TILELOADD
);
145 #undef GET_EGPR_IF_ENABLED
146 NewMI
= addFrameReference(BuildMI(MBB
, MI
, DL
, TII
->get(Opc
), DstReg
),
148 MO
= &NewMI
->getOperand(1 + X86::AddrIndexReg
);
149 MO
->setReg(GR64Cand
? GR64Cand
: X86::RAX
);
155 BuildMI(MBB
, MI
, DL
, TII
->get(X86::MOV64rm
), X86::RAX
), StrideSS
);
157 MI
.eraseFromParent();