1 //===-- X86TileConfig.cpp - Tile Register Configure----------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 /// \file Pass to config the shape of AMX physical registers
10 /// AMX register need to be configured before use. In X86PreTileConfig pass
11 /// the pldtilecfg instruction is inserted, however at that time we don't
12 /// know the shape of each physical tile registers, because the register
13 /// allocation is not done yet. This pass runs after egister allocation
14 /// pass. It collects the shape information of each physical tile register
15 /// and store the shape in the stack slot that is allocated for load config
16 /// to tile config register.
18 //===----------------------------------------------------------------------===//
21 #include "X86InstrBuilder.h"
22 #include "X86MachineFunctionInfo.h"
23 #include "X86RegisterInfo.h"
24 #include "X86Subtarget.h"
25 #include "llvm/CodeGen/LiveIntervals.h"
26 #include "llvm/CodeGen/MachineFrameInfo.h"
27 #include "llvm/CodeGen/MachineFunctionPass.h"
28 #include "llvm/CodeGen/MachineInstr.h"
29 #include "llvm/CodeGen/MachineRegisterInfo.h"
30 #include "llvm/CodeGen/Passes.h"
31 #include "llvm/CodeGen/TargetInstrInfo.h"
32 #include "llvm/CodeGen/TargetRegisterInfo.h"
33 #include "llvm/CodeGen/TileShapeInfo.h"
34 #include "llvm/CodeGen/VirtRegMap.h"
35 #include "llvm/InitializePasses.h"
39 #define DEBUG_TYPE "tile-config"
43 struct X86TileConfig
: public MachineFunctionPass
{
45 X86TileConfig() : MachineFunctionPass(ID
) {}
47 /// Return the pass name.
48 StringRef
getPassName() const override
{ return "Tile Register Configure"; }
50 /// X86TileConfig analysis usage.
51 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
53 AU
.addRequired
<VirtRegMap
>();
54 AU
.addRequired
<LiveIntervals
>();
55 MachineFunctionPass::getAnalysisUsage(AU
);
58 /// Perform register allocation.
59 bool runOnMachineFunction(MachineFunction
&mf
) override
;
61 MachineFunctionProperties
getRequiredProperties() const override
{
62 return MachineFunctionProperties().set(
63 MachineFunctionProperties::Property::NoPHIs
);
69 } // end anonymous namespace
71 char X86TileConfig::ID
= 0;
73 INITIALIZE_PASS_BEGIN(X86TileConfig
, "tileconfig", "Tile Register Configure",
75 INITIALIZE_PASS_DEPENDENCY(VirtRegMap
)
76 INITIALIZE_PASS_END(X86TileConfig
, "tileconfig", "Tile Register Configure",
79 bool X86TileConfig::runOnMachineFunction(MachineFunction
&MF
) {
80 const X86Subtarget
&ST
= MF
.getSubtarget
<X86Subtarget
>();
81 const TargetRegisterInfo
*TRI
= ST
.getRegisterInfo();
82 const TargetInstrInfo
*TII
= ST
.getInstrInfo();
83 MachineRegisterInfo
&MRI
= MF
.getRegInfo();
84 LiveIntervals
&LIS
= getAnalysis
<LiveIntervals
>();
85 VirtRegMap
&VRM
= getAnalysis
<VirtRegMap
>();
87 if (VRM
.isShapeMapEmpty())
91 for (MachineBasicBlock
&MBB
: MF
) {
92 for (MachineInstr
&MI
: MBB
) {
93 if (MI
.getOpcode() == X86::LDTILECFG
) {
94 SS
= MI
.getOperand(0).getIndex();
102 // Try to find a point to insert MIs for constant shapes.
103 // Here we are leveraging the palette id inserted in PreRA pass.
104 unsigned ConstPos
= 0;
105 MachineInstr
*ConstMI
= nullptr;
106 for (MachineInstr
&MI
: MF
.front()) {
107 if (MI
.getOpcode() == X86::MOV8mi
&& SS
== MI
.getOperand(0).getIndex()) {
113 assert(ConstMI
&& "Cannot find an insertion point");
115 unsigned AMXRegNum
= TRI
->getRegClass(X86::TILERegClassID
)->getNumRegs();
116 SmallVector
<Register
, 8> Phys2Virt(AMXRegNum
, 0);
117 for (unsigned I
= 0, E
= MRI
.getNumVirtRegs(); I
!= E
; ++I
) {
118 Register VirtReg
= Register::index2VirtReg(I
);
119 if (MRI
.reg_nodbg_empty(VirtReg
))
121 if (MRI
.getRegClass(VirtReg
)->getID() != X86::TILERegClassID
)
123 unsigned Index
= VRM
.getPhys(VirtReg
) - X86::TMM0
;
124 if (!Phys2Virt
[Index
])
125 Phys2Virt
[Index
] = VirtReg
;
128 // Fill in the shape of each tile physical register.
129 for (unsigned I
= 0; I
< AMXRegNum
; ++I
) {
134 MachineInstr
*NewMI
= nullptr;
135 ShapeT Shape
= VRM
.getShape(Phys2Virt
[I
]);
136 for (auto &R
: {Shape
.getRow()->getReg(), Shape
.getCol()->getReg()}) {
137 // Here is the data format for the tile config.
140 // 2-15 reserved, must be zero
141 // 16-17 tile0.colsb Tile 0 bytes per row.
142 // 18-19 tile1.colsb Tile 1 bytes per row.
143 // 20-21 tile2.colsb Tile 2 bytes per row.
144 // ... (sequence continues)
145 // 30-31 tile7.colsb Tile 7 bytes per row.
146 // 32-47 reserved, must be zero
147 // 48 tile0.rows Tile 0 rows.
148 // 49 tile1.rows Tile 1 rows.
149 // 50 tile2.rows Tile 2 rows.
150 // ... (sequence continues)
151 // 55 tile7.rows Tile 7 rows.
152 // 56-63 reserved, must be zero
153 int64_t Imm
= INT64_MAX
;
154 int Offset
= IsRow
? 48 + I
: 16 + I
* 2;
155 for (auto &DefMI
: MRI
.def_instructions(R
)) {
156 MachineBasicBlock
&MBB
= *DefMI
.getParent();
157 if (DefMI
.isMoveImmediate()) {
158 if (Imm
!= INT64_MAX
) {
159 // FIXME: We should handle this case in future.
160 assert(Imm
== DefMI
.getOperand(1).getImm() &&
161 "Cannot initialize with different shapes");
164 Imm
= DefMI
.getOperand(1).getImm();
165 NewMI
= addFrameReference(
166 BuildMI(MF
.front(), ++ConstMI
->getIterator(), DL
,
167 TII
->get(IsRow
? X86::MOV8mi
: X86::MOV16mi
)),
171 LIS
.InsertMachineInstrInMaps(*NewMI
);
173 unsigned SubIdx
= IsRow
? X86::sub_8bit
: X86::sub_16bit
;
174 unsigned RegSize
= TRI
->getRegSizeInBits(*MRI
.getRegClass(R
));
175 if ((IsRow
&& RegSize
== 8) || (!IsRow
&& RegSize
== 16))
177 auto Iter
= DefMI
.getIterator();
178 if (&MBB
== &MF
.front() &&
179 (unsigned)std::distance(MBB
.instr_begin(), Iter
) < ConstPos
)
180 Iter
= ConstMI
->getIterator();
181 NewMI
= addFrameReference(
182 BuildMI(MBB
, ++Iter
, DL
,
183 TII
->get(IsRow
? X86::MOV8mr
: X86::MOV16mr
)),
185 .addReg(R
, 0, SubIdx
);
186 SlotIndex SIdx
= LIS
.InsertMachineInstrInMaps(*NewMI
);
187 LIS
.extendToIndices(LIS
.getInterval(R
), {SIdx
.getRegSlot()});
196 FunctionPass
*llvm::createX86TileConfigPass() { return new X86TileConfig(); }