1 //===-- X86PreTileConfig.cpp - Tile Register Pre-configure-----------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 /// \file Pass to pre-config the shapes of AMX registers
10 /// AMX register needs to be configured before use. The shapes of AMX register
11 /// are encoded in the 1st and 2nd machine operand of AMX pseudo instructions.
13 /// The instruction ldtilecfg is used to config the shapes. It must be reachable
14 /// for all variable shapes. ldtilecfg will be inserted more than once if we
15 /// cannot find a dominating point for all AMX instructions.
17 /// The configure register is caller saved according to ABI. We need to insert
18 /// ldtilecfg again after the call instruction if callee clobbers any AMX
21 /// This pass calculates all points that ldtilecfg need to be inserted to and
22 /// insert them. It reports error if the reachability conditions aren't met.
24 //===----------------------------------------------------------------------===//
27 #include "X86InstrBuilder.h"
28 #include "X86MachineFunctionInfo.h"
29 #include "X86RegisterInfo.h"
30 #include "X86Subtarget.h"
31 #include "llvm/CodeGen/MachineFunctionPass.h"
32 #include "llvm/CodeGen/MachineInstr.h"
33 #include "llvm/CodeGen/MachineLoopInfo.h"
34 #include "llvm/CodeGen/MachineModuleInfo.h"
35 #include "llvm/CodeGen/MachineRegisterInfo.h"
36 #include "llvm/CodeGen/Passes.h"
37 #include "llvm/CodeGen/TargetInstrInfo.h"
38 #include "llvm/CodeGen/TargetRegisterInfo.h"
39 #include "llvm/InitializePasses.h"
43 #define DEBUG_TYPE "tile-pre-config"
45 static void emitErrorMsg(MachineFunction
&MF
) {
46 LLVMContext
&Context
= MF
.getMMI().getModule()->getContext();
49 ": Failed to config tile register, please define the shape earlier");
55 MachineInstr
*MI
= nullptr;
56 MachineBasicBlock
*MBB
= nullptr;
57 // A virtual position for instruction that will be inserted after MI.
60 MIRef(MachineBasicBlock
*MBB
) : MBB(MBB
) {
61 for (auto I
= MBB
->begin(), E
= MBB
->end(); I
!= E
&& I
->isPHI();
65 MIRef(MachineInstr
*MI
)
66 : MI(MI
), MBB(MI
->getParent()),
67 Pos(std::distance(MBB
->instr_begin(), ++MI
->getIterator())) {}
68 MIRef(MachineInstr
*MI
, MachineBasicBlock
*MBB
)
70 Pos(std::distance(MBB
->instr_begin(), ++MI
->getIterator())) {}
71 MIRef(MachineInstr
*MI
, MachineBasicBlock
*MBB
, size_t Pos
)
72 : MI(MI
), MBB(MBB
), Pos(Pos
) {}
73 operator bool() const { return MBB
!= nullptr; }
74 bool operator==(const MIRef
&RHS
) const {
75 return MI
== RHS
.MI
&& MBB
== RHS
.MBB
;
77 bool operator!=(const MIRef
&RHS
) const { return !(*this == RHS
); }
78 bool operator<(const MIRef
&RHS
) const {
79 // Comparison between different BBs happens when inserting a MIRef into set.
80 // So we compare MBB first to make the insertion happy.
81 return MBB
< RHS
.MBB
|| (MBB
== RHS
.MBB
&& Pos
< RHS
.Pos
);
83 bool operator>(const MIRef
&RHS
) const {
84 // Comparison between different BBs happens when inserting a MIRef into set.
85 // So we compare MBB first to make the insertion happy.
86 return MBB
> RHS
.MBB
|| (MBB
== RHS
.MBB
&& Pos
> RHS
.Pos
);
93 bool HasAMXRegLiveIn
= false;
94 bool TileCfgForbidden
= false;
95 bool NeedTileCfgLiveIn
= false;
98 class X86PreTileConfig
: public MachineFunctionPass
{
99 MachineRegisterInfo
*MRI
= nullptr;
100 const MachineLoopInfo
*MLI
= nullptr;
101 SmallSet
<MachineInstr
*, 8> DefVisited
;
102 DenseMap
<MachineBasicBlock
*, BBInfo
> BBVisitedInfo
;
103 DenseMap
<MachineBasicBlock
*, SmallVector
<MIRef
, 8>> ShapeBBs
;
105 /// Check if the callee will clobber AMX registers.
106 bool isDestructiveCall(MachineInstr
&MI
, BitVector UsableRegs
) {
107 auto Iter
= llvm::find_if(
108 MI
.operands(), [](MachineOperand
&MO
) { return MO
.isRegMask(); });
109 if (Iter
== MI
.operands_end())
111 UsableRegs
.clearBitsInMask(Iter
->getRegMask());
112 return !UsableRegs
.none();
115 /// Check if MI is AMX pseudo instruction.
116 bool isAMXInstruction(MachineInstr
&MI
) {
117 if (MI
.isPHI() || MI
.isDebugInstr() || MI
.getNumOperands() < 3)
119 MachineOperand
&MO
= MI
.getOperand(0);
120 // We can simply check if it is AMX instruction by its def.
121 // But we should exclude old API which uses physical registers.
122 if (MO
.isReg() && MO
.getReg().isVirtual() &&
123 MRI
->getRegClass(MO
.getReg())->getID() == X86::TILERegClassID
) {
124 collectShapeInfo(MI
);
127 // PTILESTOREDV is the only exception that doesn't def a AMX register.
128 return MI
.getOpcode() == X86::PTILESTOREDV
;
131 /// Check if it is an edge from loop bottom to loop head.
132 bool isLoopBackEdge(MachineBasicBlock
*Header
, MachineBasicBlock
*Bottom
) {
133 if (!MLI
->isLoopHeader(Header
))
135 auto *ML
= MLI
->getLoopFor(Header
);
136 if (ML
->contains(Bottom
) && ML
->isLoopLatch(Bottom
))
142 /// Collect the shape def information for later use.
143 void collectShapeInfo(MachineInstr
&MI
);
145 /// Try to hoist shapes definded below AMX instructions.
146 bool hoistShapesInBB(MachineBasicBlock
*MBB
, SmallVectorImpl
<MIRef
> &Shapes
) {
147 MIRef
&FirstAMX
= BBVisitedInfo
[MBB
].FirstAMX
;
148 auto FirstShapeBelowAMX
= llvm::lower_bound(Shapes
, FirstAMX
);
149 auto InsertPoint
= FirstAMX
.MI
->getIterator();
150 for (auto I
= FirstShapeBelowAMX
, E
= Shapes
.end(); I
!= E
; ++I
) {
151 // Do not hoist instructions that access memory.
152 if (I
->MI
->mayLoadOrStore())
154 for (auto &MO
: I
->MI
->operands()) {
157 // Do not hoist instructions if the sources' def under AMX instruction.
158 // TODO: We can handle isMoveImmediate MI here.
159 if (MO
.isReg() && MIRef(MRI
->getVRegDef(MO
.getReg())) > FirstAMX
)
161 // TODO: Maybe need more checks here.
163 MBB
->insert(InsertPoint
, I
->MI
->removeFromParent());
165 // We only need to mark the last shape in the BB now.
167 Shapes
.push_back(MIRef(&*--InsertPoint
, MBB
));
172 X86PreTileConfig() : MachineFunctionPass(ID
) {}
174 /// Return the pass name.
175 StringRef
getPassName() const override
{
176 return "Tile Register Pre-configure";
179 /// X86PreTileConfig analysis usage.
180 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
181 AU
.setPreservesAll();
182 AU
.addRequired
<MachineLoopInfo
>();
183 MachineFunctionPass::getAnalysisUsage(AU
);
186 /// Clear MF related structures.
187 void releaseMemory() override
{
190 BBVisitedInfo
.clear();
193 /// Perform ldtilecfg instructions inserting.
194 bool runOnMachineFunction(MachineFunction
&MF
) override
;
199 } // end anonymous namespace
201 char X86PreTileConfig::ID
= 0;
203 INITIALIZE_PASS_BEGIN(X86PreTileConfig
, "tilepreconfig",
204 "Tile Register Pre-configure", false, false)
205 INITIALIZE_PASS_DEPENDENCY(MachineLoopInfo
)
206 INITIALIZE_PASS_END(X86PreTileConfig
, "tilepreconfig",
207 "Tile Register Pre-configure", false, false)
209 void X86PreTileConfig::collectShapeInfo(MachineInstr
&MI
) {
210 auto RecordShape
= [&](MachineInstr
*MI
, MachineBasicBlock
*MBB
) {
212 auto I
= llvm::lower_bound(ShapeBBs
[MBB
], MIR
);
213 if (I
== ShapeBBs
[MBB
].end() || *I
!= MIR
)
214 ShapeBBs
[MBB
].insert(I
, MIR
);
217 SmallVector
<Register
, 8> WorkList(
218 {MI
.getOperand(1).getReg(), MI
.getOperand(2).getReg()});
219 while (!WorkList
.empty()) {
220 Register R
= WorkList
.pop_back_val();
221 MachineInstr
*DefMI
= MRI
->getVRegDef(R
);
222 assert(DefMI
&& "R must has one define instruction");
223 MachineBasicBlock
*DefMBB
= DefMI
->getParent();
224 if (DefMI
->isMoveImmediate() || !DefVisited
.insert(DefMI
).second
)
226 if (DefMI
->isPHI()) {
227 for (unsigned I
= 1; I
< DefMI
->getNumOperands(); I
+= 2)
228 if (isLoopBackEdge(DefMBB
, DefMI
->getOperand(I
+ 1).getMBB()))
229 RecordShape(DefMI
, DefMBB
); // In this case, PHI is also a shape def.
231 WorkList
.push_back(DefMI
->getOperand(I
).getReg());
233 RecordShape(DefMI
, DefMBB
);
238 bool X86PreTileConfig::runOnMachineFunction(MachineFunction
&MF
) {
239 const X86Subtarget
&ST
= MF
.getSubtarget
<X86Subtarget
>();
240 const TargetInstrInfo
*TII
= ST
.getInstrInfo();
241 const TargetRegisterInfo
*TRI
= ST
.getRegisterInfo();
242 const TargetRegisterClass
*RC
= TRI
->getRegClass(X86::TILERegClassID
);
243 X86MachineFunctionInfo
*X86FI
= MF
.getInfo
<X86MachineFunctionInfo
>();
245 BitVector
AMXRegs(TRI
->getNumRegs());
246 for (unsigned I
= 0; I
< RC
->getNumRegs(); I
++)
247 AMXRegs
.set(X86::TMM0
+ I
);
249 // Iterate MF to collect information.
250 MRI
= &MF
.getRegInfo();
251 MLI
= &getAnalysis
<MachineLoopInfo
>();
252 SmallSet
<MIRef
, 8> CfgNeedInsert
;
253 SmallVector
<MachineBasicBlock
*, 8> CfgLiveInBBs
;
254 for (auto &MBB
: MF
) {
256 for (auto &MI
: MBB
) {
258 if (isAMXInstruction(MI
)) {
259 // If there's call before the AMX, we need to reload tile config.
260 if (BBVisitedInfo
[&MBB
].LastCall
)
261 CfgNeedInsert
.insert(BBVisitedInfo
[&MBB
].LastCall
);
262 else // Otherwise, we need tile config to live in this BB.
263 BBVisitedInfo
[&MBB
].NeedTileCfgLiveIn
= true;
264 // Always record the first AMX in case there's shape def after it.
265 if (!BBVisitedInfo
[&MBB
].FirstAMX
)
266 BBVisitedInfo
[&MBB
].FirstAMX
= MIRef(&MI
, &MBB
, Pos
);
267 } else if (MI
.isCall() && isDestructiveCall(MI
, AMXRegs
)) {
268 // Record the call only if the callee clobbers all AMX registers.
269 BBVisitedInfo
[&MBB
].LastCall
= MIRef(&MI
, &MBB
, Pos
);
272 if (BBVisitedInfo
[&MBB
].NeedTileCfgLiveIn
) {
273 if (&MBB
== &MF
.front())
274 CfgNeedInsert
.insert(MIRef(&MBB
));
276 CfgLiveInBBs
.push_back(&MBB
);
278 if (BBVisitedInfo
[&MBB
].FirstAMX
|| BBVisitedInfo
[&MBB
].HasAMXRegLiveIn
)
279 for (auto *Succ
: MBB
.successors())
280 if (!isLoopBackEdge(Succ
, &MBB
))
281 BBVisitedInfo
[Succ
].HasAMXRegLiveIn
= true;
284 // Update NeedTileCfgLiveIn for predecessors.
285 while (!CfgLiveInBBs
.empty()) {
286 MachineBasicBlock
*MBB
= CfgLiveInBBs
.pop_back_val();
287 for (auto *Pred
: MBB
->predecessors()) {
288 if (BBVisitedInfo
[Pred
].LastCall
) {
289 CfgNeedInsert
.insert(BBVisitedInfo
[Pred
].LastCall
);
290 } else if (!BBVisitedInfo
[Pred
].NeedTileCfgLiveIn
) {
291 BBVisitedInfo
[Pred
].NeedTileCfgLiveIn
= true;
292 if (Pred
== &MF
.front())
293 CfgNeedInsert
.insert(MIRef(Pred
));
295 CfgLiveInBBs
.push_back(Pred
);
300 // There's no AMX instruction if we didn't find a tile config live in point.
301 if (CfgNeedInsert
.empty())
303 X86FI
->setHasVirtualTileReg(true);
305 // Avoid to insert ldtilecfg before any shape defs.
306 SmallVector
<MachineBasicBlock
*, 8> WorkList
;
307 for (auto &I
: ShapeBBs
) {
308 // TODO: We can hoist shapes across BBs here.
309 if (BBVisitedInfo
[I
.first
].HasAMXRegLiveIn
) {
310 // We are not able to config tile registers since the shape to config
311 // is not defined yet. Emit error message and continue. The function
312 // would not config tile registers.
316 if (BBVisitedInfo
[I
.first
].FirstAMX
&&
317 BBVisitedInfo
[I
.first
].FirstAMX
< I
.second
.back() &&
318 !hoistShapesInBB(I
.first
, I
.second
)) {
322 WorkList
.push_back(I
.first
);
324 while (!WorkList
.empty()) {
325 MachineBasicBlock
*MBB
= WorkList
.pop_back_val();
326 for (auto *Pred
: MBB
->predecessors()) {
327 if (!BBVisitedInfo
[Pred
].TileCfgForbidden
&& !isLoopBackEdge(MBB
, Pred
)) {
328 BBVisitedInfo
[Pred
].TileCfgForbidden
= true;
329 WorkList
.push_back(Pred
);
335 SmallSet
<MIRef
, 8> VisitedOrInserted
;
336 int SS
= MF
.getFrameInfo().CreateStackObject(
337 ST
.getTileConfigSize(), ST
.getTileConfigAlignment(), false);
339 // Try to insert for the tile config live in points.
340 for (const auto &I
: CfgNeedInsert
) {
341 SmallSet
<MIRef
, 8> InsertPoints
;
342 SmallVector
<MIRef
, 8> WorkList({I
});
343 while (!WorkList
.empty()) {
344 MIRef I
= WorkList
.pop_back_val();
345 if (!VisitedOrInserted
.count(I
)) {
346 if (!BBVisitedInfo
[I
.MBB
].TileCfgForbidden
) {
347 // If the BB is all shapes reachable, stop sink and try to insert.
348 InsertPoints
.insert(I
);
350 // Avoid the BB to be multi visited.
351 VisitedOrInserted
.insert(I
);
352 // Sink the inserting point along the chain with NeedTileCfgLiveIn =
353 // true when MBB isn't all shapes reachable.
354 for (auto *Succ
: I
.MBB
->successors())
355 if (BBVisitedInfo
[Succ
].NeedTileCfgLiveIn
)
356 WorkList
.push_back(MIRef(Succ
));
361 // A given point might be forked due to shape conditions are not met.
362 for (MIRef I
: InsertPoints
) {
363 // Make sure we insert ldtilecfg after the last shape def in MBB.
364 if (ShapeBBs
.count(I
.MBB
) && I
< ShapeBBs
[I
.MBB
].back())
365 I
= ShapeBBs
[I
.MBB
].back();
366 // There're chances the MBB is sunk more than once. Record it to avoid
368 if (VisitedOrInserted
.insert(I
).second
) {
369 auto II
= I
.MI
? I
.MI
->getIterator() : I
.MBB
->instr_begin();
370 addFrameReference(BuildMI(*I
.MBB
, ++II
, DL
, TII
->get(X86::PLDTILECFGV
)),
377 MachineBasicBlock
&MBB
= MF
.front();
378 MachineInstr
*MI
= &*MBB
.begin();
379 if (ST
.hasAVX512()) {
380 Register Zmm
= MRI
->createVirtualRegister(&X86::VR512RegClass
);
381 BuildMI(MBB
, MI
, DL
, TII
->get(X86::AVX512_512_SET0
), Zmm
);
382 addFrameReference(BuildMI(MBB
, MI
, DL
, TII
->get(X86::VMOVUPSZmr
)), SS
)
384 } else if (ST
.hasAVX2()) {
385 Register Ymm
= MRI
->createVirtualRegister(&X86::VR256RegClass
);
386 BuildMI(MBB
, MI
, DL
, TII
->get(X86::AVX_SET0
), Ymm
);
387 addFrameReference(BuildMI(MBB
, MI
, DL
, TII
->get(X86::VMOVUPSYmr
)), SS
)
389 addFrameReference(BuildMI(MBB
, MI
, DL
, TII
->get(X86::VMOVUPSYmr
)), SS
, 32)
392 assert(ST
.hasSSE2() && "AMX should assume SSE2 enabled");
393 unsigned StoreOpc
= ST
.hasAVX() ? X86::VMOVUPSmr
: X86::MOVUPSmr
;
394 Register Xmm
= MRI
->createVirtualRegister(&X86::VR128RegClass
);
395 BuildMI(MBB
, MI
, DL
, TII
->get(X86::V_SET0
), Xmm
);
396 addFrameReference(BuildMI(MBB
, MI
, DL
, TII
->get(StoreOpc
)), SS
).addReg(Xmm
);
397 addFrameReference(BuildMI(MBB
, MI
, DL
, TII
->get(StoreOpc
)), SS
, 16)
399 addFrameReference(BuildMI(MBB
, MI
, DL
, TII
->get(StoreOpc
)), SS
, 32)
401 addFrameReference(BuildMI(MBB
, MI
, DL
, TII
->get(StoreOpc
)), SS
, 48)
404 // Fill in the palette first.
405 addFrameReference(BuildMI(MBB
, MI
, DL
, TII
->get(X86::MOV8mi
)), SS
).addImm(1);
410 FunctionPass
*llvm::createX86PreTileConfigPass() {
411 return new X86PreTileConfig();