1 //===-- X86PreTileConfig.cpp - Tile Register Pre-configure-----------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 /// \file Pass to pre-config the shapes of AMX registers
10 /// AMX register needs to be configured before use. The shapes of AMX register
11 /// are encoded in the 1st and 2nd machine operand of AMX pseudo instructions.
13 /// The instruction ldtilecfg is used to config the shapes. It must be reachable
14 /// for all variable shapes. ldtilecfg will be inserted more than once if we
15 /// cannot find a dominating point for all AMX instructions.
17 /// The configure register is caller saved according to ABI. We need to insert
18 /// ldtilecfg again after the call instruction if callee clobbers any AMX
21 /// This pass calculates all points that ldtilecfg need to be inserted to and
22 /// insert them. It reports error if the reachability conditions aren't met.
24 //===----------------------------------------------------------------------===//
27 #include "X86InstrBuilder.h"
28 #include "X86MachineFunctionInfo.h"
29 #include "X86RegisterInfo.h"
30 #include "X86Subtarget.h"
31 #include "llvm/ADT/SmallSet.h"
32 #include "llvm/CodeGen/MachineFunctionPass.h"
33 #include "llvm/CodeGen/MachineInstr.h"
34 #include "llvm/CodeGen/MachineLoopInfo.h"
35 #include "llvm/CodeGen/MachineModuleInfo.h"
36 #include "llvm/CodeGen/MachineRegisterInfo.h"
37 #include "llvm/CodeGen/Passes.h"
38 #include "llvm/CodeGen/TargetInstrInfo.h"
39 #include "llvm/CodeGen/TargetRegisterInfo.h"
40 #include "llvm/InitializePasses.h"
44 #define DEBUG_TYPE "tile-pre-config"
46 static void emitErrorMsg(MachineFunction
&MF
) {
47 LLVMContext
&Context
= MF
.getMMI().getModule()->getContext();
50 ": Failed to config tile register, please define the shape earlier");
56 MachineInstr
*MI
= nullptr;
57 MachineBasicBlock
*MBB
= nullptr;
58 // A virtual position for instruction that will be inserted after MI.
61 MIRef(MachineBasicBlock
*MBB
) : MBB(MBB
) {
62 for (auto I
= MBB
->begin(), E
= MBB
->end(); I
!= E
&& I
->isPHI();
66 MIRef(MachineInstr
*MI
)
67 : MI(MI
), MBB(MI
->getParent()),
68 Pos(std::distance(MBB
->instr_begin(), ++MI
->getIterator())) {}
69 MIRef(MachineInstr
*MI
, MachineBasicBlock
*MBB
)
71 Pos(std::distance(MBB
->instr_begin(), ++MI
->getIterator())) {}
72 MIRef(MachineInstr
*MI
, MachineBasicBlock
*MBB
, size_t Pos
)
73 : MI(MI
), MBB(MBB
), Pos(Pos
) {}
74 operator bool() const { return MBB
!= nullptr; }
75 bool operator==(const MIRef
&RHS
) const {
76 return MI
== RHS
.MI
&& MBB
== RHS
.MBB
;
78 bool operator!=(const MIRef
&RHS
) const { return !(*this == RHS
); }
79 bool operator<(const MIRef
&RHS
) const {
80 // Comparison between different BBs happens when inserting a MIRef into set.
81 // So we compare MBB first to make the insertion happy.
82 return MBB
< RHS
.MBB
|| (MBB
== RHS
.MBB
&& Pos
< RHS
.Pos
);
84 bool operator>(const MIRef
&RHS
) const {
85 // Comparison between different BBs happens when inserting a MIRef into set.
86 // So we compare MBB first to make the insertion happy.
87 return MBB
> RHS
.MBB
|| (MBB
== RHS
.MBB
&& Pos
> RHS
.Pos
);
94 bool HasAMXRegLiveIn
= false;
95 bool TileCfgForbidden
= false;
96 bool NeedTileCfgLiveIn
= false;
99 class X86PreTileConfig
: public MachineFunctionPass
{
100 MachineRegisterInfo
*MRI
= nullptr;
101 const MachineLoopInfo
*MLI
= nullptr;
102 SmallSet
<MachineInstr
*, 8> DefVisited
;
103 DenseMap
<MachineBasicBlock
*, BBInfo
> BBVisitedInfo
;
104 DenseMap
<MachineBasicBlock
*, SmallVector
<MIRef
, 8>> ShapeBBs
;
106 /// Check if the callee will clobber AMX registers.
107 bool isDestructiveCall(MachineInstr
&MI
, BitVector UsableRegs
) {
108 auto Iter
= llvm::find_if(
109 MI
.operands(), [](MachineOperand
&MO
) { return MO
.isRegMask(); });
110 if (Iter
== MI
.operands_end())
112 UsableRegs
.clearBitsInMask(Iter
->getRegMask());
113 return !UsableRegs
.none();
116 /// Check if MI is AMX pseudo instruction.
117 bool isAMXInstruction(MachineInstr
&MI
) {
118 if (MI
.isPHI() || MI
.isDebugInstr() || MI
.getNumOperands() < 3)
120 MachineOperand
&MO
= MI
.getOperand(0);
121 // We can simply check if it is AMX instruction by its def.
122 // But we should exclude old API which uses physical registers.
123 if (MO
.isReg() && MO
.getReg().isVirtual() &&
124 MRI
->getRegClass(MO
.getReg())->getID() == X86::TILERegClassID
) {
125 collectShapeInfo(MI
);
128 // PTILESTOREDV is the only exception that doesn't def a AMX register.
129 return MI
.getOpcode() == X86::PTILESTOREDV
;
132 /// Check if it is an edge from loop bottom to loop head.
133 bool isLoopBackEdge(MachineBasicBlock
*Header
, MachineBasicBlock
*Bottom
) {
134 if (!MLI
->isLoopHeader(Header
))
136 auto *ML
= MLI
->getLoopFor(Header
);
137 if (ML
->contains(Bottom
) && ML
->isLoopLatch(Bottom
))
143 /// Collect the shape def information for later use.
144 void collectShapeInfo(MachineInstr
&MI
);
146 /// Try to hoist shapes definded below AMX instructions.
147 bool hoistShapesInBB(MachineBasicBlock
*MBB
, SmallVectorImpl
<MIRef
> &Shapes
) {
148 MIRef
&FirstAMX
= BBVisitedInfo
[MBB
].FirstAMX
;
149 auto FirstShapeBelowAMX
= llvm::lower_bound(Shapes
, FirstAMX
);
150 auto InsertPoint
= FirstAMX
.MI
->getIterator();
151 for (auto I
= FirstShapeBelowAMX
, E
= Shapes
.end(); I
!= E
; ++I
) {
152 // Do not hoist instructions that access memory.
153 if (I
->MI
->mayLoadOrStore())
155 for (auto &MO
: I
->MI
->operands()) {
158 // Do not hoist instructions if the sources' def under AMX instruction.
159 // TODO: We can handle isMoveImmediate MI here.
160 if (MO
.isReg() && MIRef(MRI
->getVRegDef(MO
.getReg())) > FirstAMX
)
162 // TODO: Maybe need more checks here.
164 MBB
->insert(InsertPoint
, I
->MI
->removeFromParent());
166 // We only need to mark the last shape in the BB now.
168 Shapes
.push_back(MIRef(&*--InsertPoint
, MBB
));
173 X86PreTileConfig() : MachineFunctionPass(ID
) {}
175 /// Return the pass name.
176 StringRef
getPassName() const override
{
177 return "Tile Register Pre-configure";
180 /// X86PreTileConfig analysis usage.
181 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
182 AU
.setPreservesAll();
183 AU
.addRequired
<MachineLoopInfo
>();
184 MachineFunctionPass::getAnalysisUsage(AU
);
187 /// Clear MF related structures.
188 void releaseMemory() override
{
191 BBVisitedInfo
.clear();
194 /// Perform ldtilecfg instructions inserting.
195 bool runOnMachineFunction(MachineFunction
&MF
) override
;
200 } // end anonymous namespace
202 char X86PreTileConfig::ID
= 0;
204 INITIALIZE_PASS_BEGIN(X86PreTileConfig
, "tilepreconfig",
205 "Tile Register Pre-configure", false, false)
206 INITIALIZE_PASS_DEPENDENCY(MachineLoopInfo
)
207 INITIALIZE_PASS_END(X86PreTileConfig
, "tilepreconfig",
208 "Tile Register Pre-configure", false, false)
210 void X86PreTileConfig::collectShapeInfo(MachineInstr
&MI
) {
211 auto RecordShape
= [&](MachineInstr
*MI
, MachineBasicBlock
*MBB
) {
213 auto I
= llvm::lower_bound(ShapeBBs
[MBB
], MIR
);
214 if (I
== ShapeBBs
[MBB
].end() || *I
!= MIR
)
215 ShapeBBs
[MBB
].insert(I
, MIR
);
218 SmallVector
<Register
, 8> WorkList(
219 {MI
.getOperand(1).getReg(), MI
.getOperand(2).getReg()});
220 while (!WorkList
.empty()) {
221 Register R
= WorkList
.pop_back_val();
222 MachineInstr
*DefMI
= MRI
->getVRegDef(R
);
223 assert(DefMI
&& "R must has one define instruction");
224 MachineBasicBlock
*DefMBB
= DefMI
->getParent();
225 if (DefMI
->isMoveImmediate() || !DefVisited
.insert(DefMI
).second
)
227 if (DefMI
->isPHI()) {
228 for (unsigned I
= 1; I
< DefMI
->getNumOperands(); I
+= 2)
229 if (isLoopBackEdge(DefMBB
, DefMI
->getOperand(I
+ 1).getMBB()))
230 RecordShape(DefMI
, DefMBB
); // In this case, PHI is also a shape def.
232 WorkList
.push_back(DefMI
->getOperand(I
).getReg());
234 RecordShape(DefMI
, DefMBB
);
239 bool X86PreTileConfig::runOnMachineFunction(MachineFunction
&MF
) {
240 const X86Subtarget
&ST
= MF
.getSubtarget
<X86Subtarget
>();
241 const TargetInstrInfo
*TII
= ST
.getInstrInfo();
242 const TargetRegisterInfo
*TRI
= ST
.getRegisterInfo();
243 const TargetRegisterClass
*RC
= TRI
->getRegClass(X86::TILERegClassID
);
244 X86MachineFunctionInfo
*X86FI
= MF
.getInfo
<X86MachineFunctionInfo
>();
246 BitVector
AMXRegs(TRI
->getNumRegs());
247 for (unsigned I
= 0; I
< RC
->getNumRegs(); I
++)
248 AMXRegs
.set(X86::TMM0
+ I
);
250 // Iterate MF to collect information.
251 MRI
= &MF
.getRegInfo();
252 MLI
= &getAnalysis
<MachineLoopInfo
>();
253 SmallSet
<MIRef
, 8> CfgNeedInsert
;
254 SmallVector
<MachineBasicBlock
*, 8> CfgLiveInBBs
;
255 for (auto &MBB
: MF
) {
257 for (auto &MI
: MBB
) {
259 if (isAMXInstruction(MI
)) {
260 // If there's call before the AMX, we need to reload tile config.
261 if (BBVisitedInfo
[&MBB
].LastCall
)
262 CfgNeedInsert
.insert(BBVisitedInfo
[&MBB
].LastCall
);
263 else // Otherwise, we need tile config to live in this BB.
264 BBVisitedInfo
[&MBB
].NeedTileCfgLiveIn
= true;
265 // Always record the first AMX in case there's shape def after it.
266 if (!BBVisitedInfo
[&MBB
].FirstAMX
)
267 BBVisitedInfo
[&MBB
].FirstAMX
= MIRef(&MI
, &MBB
, Pos
);
268 } else if (MI
.isCall() && isDestructiveCall(MI
, AMXRegs
)) {
269 // Record the call only if the callee clobbers all AMX registers.
270 BBVisitedInfo
[&MBB
].LastCall
= MIRef(&MI
, &MBB
, Pos
);
273 if (BBVisitedInfo
[&MBB
].NeedTileCfgLiveIn
) {
274 if (&MBB
== &MF
.front())
275 CfgNeedInsert
.insert(MIRef(&MBB
));
277 CfgLiveInBBs
.push_back(&MBB
);
279 if (BBVisitedInfo
[&MBB
].FirstAMX
|| BBVisitedInfo
[&MBB
].HasAMXRegLiveIn
)
280 for (auto *Succ
: MBB
.successors())
281 if (!isLoopBackEdge(Succ
, &MBB
))
282 BBVisitedInfo
[Succ
].HasAMXRegLiveIn
= true;
285 // Update NeedTileCfgLiveIn for predecessors.
286 while (!CfgLiveInBBs
.empty()) {
287 MachineBasicBlock
*MBB
= CfgLiveInBBs
.pop_back_val();
288 for (auto *Pred
: MBB
->predecessors()) {
289 if (BBVisitedInfo
[Pred
].LastCall
) {
290 CfgNeedInsert
.insert(BBVisitedInfo
[Pred
].LastCall
);
291 } else if (!BBVisitedInfo
[Pred
].NeedTileCfgLiveIn
) {
292 BBVisitedInfo
[Pred
].NeedTileCfgLiveIn
= true;
293 if (Pred
== &MF
.front())
294 CfgNeedInsert
.insert(MIRef(Pred
));
296 CfgLiveInBBs
.push_back(Pred
);
301 // There's no AMX instruction if we didn't find a tile config live in point.
302 if (CfgNeedInsert
.empty())
304 X86FI
->setHasVirtualTileReg(true);
306 // Avoid to insert ldtilecfg before any shape defs.
307 SmallVector
<MachineBasicBlock
*, 8> WorkList
;
308 for (auto &I
: ShapeBBs
) {
309 // TODO: We can hoist shapes across BBs here.
310 if (BBVisitedInfo
[I
.first
].HasAMXRegLiveIn
) {
311 // We are not able to config tile registers since the shape to config
312 // is not defined yet. Emit error message and continue. The function
313 // would not config tile registers.
317 if (BBVisitedInfo
[I
.first
].FirstAMX
&&
318 BBVisitedInfo
[I
.first
].FirstAMX
< I
.second
.back() &&
319 !hoistShapesInBB(I
.first
, I
.second
)) {
323 WorkList
.push_back(I
.first
);
325 while (!WorkList
.empty()) {
326 MachineBasicBlock
*MBB
= WorkList
.pop_back_val();
327 for (auto *Pred
: MBB
->predecessors()) {
328 if (!BBVisitedInfo
[Pred
].TileCfgForbidden
&& !isLoopBackEdge(MBB
, Pred
)) {
329 BBVisitedInfo
[Pred
].TileCfgForbidden
= true;
330 WorkList
.push_back(Pred
);
336 SmallSet
<MIRef
, 8> VisitedOrInserted
;
337 int SS
= MF
.getFrameInfo().CreateStackObject(
338 ST
.getTileConfigSize(), ST
.getTileConfigAlignment(), false);
340 // Try to insert for the tile config live in points.
341 for (const auto &I
: CfgNeedInsert
) {
342 SmallSet
<MIRef
, 8> InsertPoints
;
343 SmallVector
<MIRef
, 8> WorkList({I
});
344 while (!WorkList
.empty()) {
345 MIRef I
= WorkList
.pop_back_val();
346 if (!VisitedOrInserted
.count(I
)) {
347 if (!BBVisitedInfo
[I
.MBB
].TileCfgForbidden
) {
348 // If the BB is all shapes reachable, stop sink and try to insert.
349 InsertPoints
.insert(I
);
351 // Avoid the BB to be multi visited.
352 VisitedOrInserted
.insert(I
);
353 // Sink the inserting point along the chain with NeedTileCfgLiveIn =
354 // true when MBB isn't all shapes reachable.
355 for (auto *Succ
: I
.MBB
->successors())
356 if (BBVisitedInfo
[Succ
].NeedTileCfgLiveIn
)
357 WorkList
.push_back(MIRef(Succ
));
362 // A given point might be forked due to shape conditions are not met.
363 for (MIRef I
: InsertPoints
) {
364 // Make sure we insert ldtilecfg after the last shape def in MBB.
365 if (ShapeBBs
.count(I
.MBB
) && I
< ShapeBBs
[I
.MBB
].back())
366 I
= ShapeBBs
[I
.MBB
].back();
367 // There're chances the MBB is sunk more than once. Record it to avoid
369 if (VisitedOrInserted
.insert(I
).second
) {
370 auto II
= I
.MI
? I
.MI
->getIterator() : I
.MBB
->instr_begin();
371 addFrameReference(BuildMI(*I
.MBB
, ++II
, DL
, TII
->get(X86::PLDTILECFGV
)),
378 MachineBasicBlock
&MBB
= MF
.front();
379 MachineInstr
*MI
= &*MBB
.begin();
380 if (ST
.hasAVX512()) {
381 Register Zmm
= MRI
->createVirtualRegister(&X86::VR512RegClass
);
382 BuildMI(MBB
, MI
, DL
, TII
->get(X86::AVX512_512_SET0
), Zmm
);
383 addFrameReference(BuildMI(MBB
, MI
, DL
, TII
->get(X86::VMOVUPSZmr
)), SS
)
385 } else if (ST
.hasAVX2()) {
386 Register Ymm
= MRI
->createVirtualRegister(&X86::VR256RegClass
);
387 BuildMI(MBB
, MI
, DL
, TII
->get(X86::AVX_SET0
), Ymm
);
388 addFrameReference(BuildMI(MBB
, MI
, DL
, TII
->get(X86::VMOVUPSYmr
)), SS
)
390 addFrameReference(BuildMI(MBB
, MI
, DL
, TII
->get(X86::VMOVUPSYmr
)), SS
, 32)
393 assert(ST
.hasSSE2() && "AMX should assume SSE2 enabled");
394 unsigned StoreOpc
= ST
.hasAVX() ? X86::VMOVUPSmr
: X86::MOVUPSmr
;
395 Register Xmm
= MRI
->createVirtualRegister(&X86::VR128RegClass
);
396 BuildMI(MBB
, MI
, DL
, TII
->get(X86::V_SET0
), Xmm
);
397 addFrameReference(BuildMI(MBB
, MI
, DL
, TII
->get(StoreOpc
)), SS
).addReg(Xmm
);
398 addFrameReference(BuildMI(MBB
, MI
, DL
, TII
->get(StoreOpc
)), SS
, 16)
400 addFrameReference(BuildMI(MBB
, MI
, DL
, TII
->get(StoreOpc
)), SS
, 32)
402 addFrameReference(BuildMI(MBB
, MI
, DL
, TII
->get(StoreOpc
)), SS
, 48)
405 // Fill in the palette first.
406 addFrameReference(BuildMI(MBB
, MI
, DL
, TII
->get(X86::MOV8mi
)), SS
).addImm(1);
411 FunctionPass
*llvm::createX86PreTileConfigPass() {
412 return new X86PreTileConfig();