1 //===-- X86LowerTileCopy.cpp - Expand Tile Copy Instructions---------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file defines the pass which lower AMX tile copy instructions. Since
10 // there is no tile copy instruction, we need store tile register to stack
11 // and load from stack to another tile register. We need extra GR to hold
12 // the stride, and we need stack slot to hold the tile data register.
13 // We would run this pass after copy propagation, so that we don't miss copy
14 // optimization. And we would run this pass before prolog/epilog insertion,
15 // so that we can allocate stack slot.
17 //===----------------------------------------------------------------------===//
20 #include "X86InstrBuilder.h"
21 #include "X86InstrInfo.h"
22 #include "X86Subtarget.h"
23 #include "llvm/CodeGen/MachineBasicBlock.h"
24 #include "llvm/CodeGen/MachineFrameInfo.h"
25 #include "llvm/CodeGen/MachineFunction.h"
26 #include "llvm/CodeGen/MachineFunctionPass.h"
27 #include "llvm/CodeGen/MachineInstr.h"
28 #include "llvm/CodeGen/MachineInstrBuilder.h"
29 #include "llvm/CodeGen/MachineOperand.h"
30 #include "llvm/CodeGen/Passes.h"
31 #include "llvm/IR/DebugLoc.h"
32 #include "llvm/InitializePasses.h"
33 #include "llvm/Support/Debug.h"
37 #define DEBUG_TYPE "x86-lower-tile-copy"
41 class X86LowerTileCopy
: public MachineFunctionPass
{
45 X86LowerTileCopy() : MachineFunctionPass(ID
) {}
47 void getAnalysisUsage(AnalysisUsage
&AU
) const override
;
49 bool runOnMachineFunction(MachineFunction
&MF
) override
;
51 StringRef
getPassName() const override
{ return "X86 Lower Tile Copy"; }
56 char X86LowerTileCopy::ID
= 0;
58 INITIALIZE_PASS_BEGIN(X86LowerTileCopy
, "lowertilecopy", "Tile Copy Lowering",
60 INITIALIZE_PASS_END(X86LowerTileCopy
, "lowertilecopy", "Tile Copy Lowering",
63 void X86LowerTileCopy::getAnalysisUsage(AnalysisUsage
&AU
) const {
65 MachineFunctionPass::getAnalysisUsage(AU
);
68 FunctionPass
*llvm::createX86LowerTileCopyPass() {
69 return new X86LowerTileCopy();
72 bool X86LowerTileCopy::runOnMachineFunction(MachineFunction
&MF
) {
73 const X86Subtarget
&ST
= MF
.getSubtarget
<X86Subtarget
>();
74 const X86InstrInfo
*TII
= ST
.getInstrInfo();
77 for (MachineBasicBlock
&MBB
: MF
) {
78 for (MachineInstr
&MI
: llvm::make_early_inc_range(MBB
)) {
81 MachineOperand
&DstMO
= MI
.getOperand(0);
82 MachineOperand
&SrcMO
= MI
.getOperand(1);
83 Register SrcReg
= SrcMO
.getReg();
84 Register DstReg
= DstMO
.getReg();
85 if (!X86::TILERegClass
.contains(DstReg
, SrcReg
))
88 const TargetRegisterInfo
*TRI
= ST
.getRegisterInfo();
89 // Allocate stack slot for tile register
90 unsigned Size
= TRI
->getSpillSize(X86::TILERegClass
);
91 Align Alignment
= TRI
->getSpillAlign(X86::TILERegClass
);
92 int TileSS
= MF
.getFrameInfo().CreateSpillStackObject(Size
, Alignment
);
93 // Allocate stack slot for stride register
94 Size
= TRI
->getSpillSize(X86::GR64RegClass
);
95 Alignment
= TRI
->getSpillAlign(X86::GR64RegClass
);
96 int StrideSS
= MF
.getFrameInfo().CreateSpillStackObject(Size
, Alignment
);
98 // TODO: Pick a killed regiter to avoid save/reload. There is problem
99 // to get live interval in this stage.
100 Register GR64Cand
= X86::RAX
;
102 const DebugLoc
&DL
= MI
.getDebugLoc();
104 BuildMI(MBB
, MI
, DL
, TII
->get(X86::IMPLICIT_DEF
), GR64Cand
);
105 addFrameReference(BuildMI(MBB
, MI
, DL
, TII
->get(X86::MOV64mr
)), StrideSS
)
108 BuildMI(MBB
, MI
, DL
, TII
->get(X86::MOV64ri
), GR64Cand
).addImm(64);
109 // tilestored %tmm, (%sp, %idx)
110 #define GET_EGPR_IF_ENABLED(OPC) (ST.hasEGPR() ? OPC##_EVEX : OPC)
111 unsigned Opc
= GET_EGPR_IF_ENABLED(X86::TILESTORED
);
112 MachineInstr
*NewMI
=
113 addFrameReference(BuildMI(MBB
, MI
, DL
, TII
->get(Opc
)), TileSS
)
114 .addReg(SrcReg
, getKillRegState(SrcMO
.isKill()));
115 MachineOperand
&MO
= NewMI
->getOperand(2);
118 // tileloadd (%sp, %idx), %tmm
119 Opc
= GET_EGPR_IF_ENABLED(X86::TILELOADD
);
120 #undef GET_EGPR_IF_ENABLED
121 NewMI
= addFrameReference(BuildMI(MBB
, MI
, DL
, TII
->get(Opc
), DstReg
),
125 addFrameReference(BuildMI(MBB
, MI
, DL
, TII
->get(X86::MOV64rm
), GR64Cand
),
127 MI
.eraseFromParent();