1 //===-- X86PreTileConfig.cpp - Tile Register Pre-configure-----------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 /// \file Pass to pre-config the shapes of AMX registers
10 /// AMX register needs to be configured before use. The shapes of AMX register
11 /// are encoded in the 1st and 2nd machine operand of AMX pseudo instructions.
13 /// The instruction ldtilecfg is used to config the shapes. It must be reachable
14 /// for all variable shapes. ldtilecfg will be inserted more than once if we
15 /// cannot find a dominating point for all AMX instructions.
17 /// The configure register is caller saved according to ABI. We need to insert
18 /// ldtilecfg again after the call instruction if callee clobbers any AMX
21 /// This pass calculates all points that ldtilecfg need to be inserted to and
22 /// insert them. It reports error if the reachability conditions aren't met.
24 //===----------------------------------------------------------------------===//
27 #include "X86InstrBuilder.h"
28 #include "X86RegisterInfo.h"
29 #include "X86Subtarget.h"
30 #include "llvm/CodeGen/MachineFunctionPass.h"
31 #include "llvm/CodeGen/MachineInstr.h"
32 #include "llvm/CodeGen/MachineLoopInfo.h"
33 #include "llvm/CodeGen/MachineRegisterInfo.h"
34 #include "llvm/CodeGen/Passes.h"
35 #include "llvm/CodeGen/TargetInstrInfo.h"
36 #include "llvm/CodeGen/TargetRegisterInfo.h"
37 #include "llvm/InitializePasses.h"
41 #define DEBUG_TYPE "tile-pre-config"
42 #define REPORT_CONFIG_FAIL \
45 ": Failed to config tile register, please define the shape earlier");
50 MachineInstr
*MI
= nullptr;
51 MachineBasicBlock
*MBB
= nullptr;
52 // A virtual position for instruction that will be inserted after MI.
55 MIRef(MachineBasicBlock
*MBB
) : MBB(MBB
) {
56 for (auto I
= MBB
->begin(), E
= MBB
->end(); I
!= E
&& I
->isPHI();
60 MIRef(MachineInstr
*MI
)
61 : MI(MI
), MBB(MI
->getParent()),
62 Pos(std::distance(MBB
->instr_begin(), ++MI
->getIterator())) {}
63 MIRef(MachineInstr
*MI
, MachineBasicBlock
*MBB
)
65 Pos(std::distance(MBB
->instr_begin(), ++MI
->getIterator())) {}
66 MIRef(MachineInstr
*MI
, MachineBasicBlock
*MBB
, size_t Pos
)
67 : MI(MI
), MBB(MBB
), Pos(Pos
) {}
68 operator bool() const { return MBB
!= nullptr; }
69 bool operator==(const MIRef
&RHS
) const {
70 return MI
== RHS
.MI
&& MBB
== RHS
.MBB
;
72 bool operator!=(const MIRef
&RHS
) const { return !(*this == RHS
); }
73 bool operator<(const MIRef
&RHS
) const {
74 // Comparison between different BBs happens when inserting a MIRef into set.
75 // So we compare MBB first to make the insertion happy.
76 return MBB
< RHS
.MBB
|| (MBB
== RHS
.MBB
&& Pos
< RHS
.Pos
);
78 bool operator>(const MIRef
&RHS
) const {
79 // Comparison between different BBs happens when inserting a MIRef into set.
80 // So we compare MBB first to make the insertion happy.
81 return MBB
> RHS
.MBB
|| (MBB
== RHS
.MBB
&& Pos
> RHS
.Pos
);
88 bool HasAMXRegLiveIn
= false;
89 bool TileCfgForbidden
= false;
90 bool NeedTileCfgLiveIn
= false;
93 class X86PreTileConfig
: public MachineFunctionPass
{
94 MachineRegisterInfo
*MRI
;
95 const MachineLoopInfo
*MLI
;
96 SmallSet
<MachineInstr
*, 8> DefVisited
;
97 DenseMap
<MachineBasicBlock
*, BBInfo
> BBVisitedInfo
;
98 DenseMap
<MachineBasicBlock
*, SmallVector
<MIRef
, 8>> ShapeBBs
;
100 /// Check if the callee will clobber AMX registers.
101 bool isDestructiveCall(MachineInstr
&MI
, BitVector UsableRegs
) {
102 auto Iter
= llvm::find_if(
103 MI
.operands(), [](MachineOperand
&MO
) { return MO
.isRegMask(); });
104 if (Iter
== MI
.operands_end())
106 UsableRegs
.clearBitsInMask(Iter
->getRegMask());
107 return !UsableRegs
.none();
110 /// Check if MI is AMX pseudo instruction.
111 bool isAMXInstruction(MachineInstr
&MI
) {
112 if (MI
.isPHI() || MI
.isDebugInstr() || MI
.getNumOperands() < 3)
114 MachineOperand
&MO
= MI
.getOperand(0);
115 // We can simply check if it is AMX instruction by its def.
116 // But we should exclude old API which uses physical registers.
117 if (MO
.isReg() && MO
.getReg().isVirtual() &&
118 MRI
->getRegClass(MO
.getReg())->getID() == X86::TILERegClassID
) {
119 collectShapeInfo(MI
);
122 // PTILESTOREDV is the only exception that doesn't def a AMX register.
123 return MI
.getOpcode() == X86::PTILESTOREDV
;
126 /// Check if it is an edge from loop bottom to loop head.
127 bool isLoopBackEdge(MachineBasicBlock
*Header
, MachineBasicBlock
*Bottom
) {
128 if (!MLI
->isLoopHeader(Header
))
130 auto *ML
= MLI
->getLoopFor(Header
);
131 if (ML
->contains(Bottom
) && ML
->isLoopLatch(Bottom
))
137 /// Collect the shape def information for later use.
138 void collectShapeInfo(MachineInstr
&MI
);
140 /// Try to hoist shapes definded below AMX instructions.
141 bool hoistShapesInBB(MachineBasicBlock
*MBB
, SmallVectorImpl
<MIRef
> &Shapes
) {
142 MIRef
&FirstAMX
= BBVisitedInfo
[MBB
].FirstAMX
;
143 auto FirstShapeBelowAMX
= llvm::lower_bound(Shapes
, FirstAMX
);
144 auto InsertPoint
= FirstAMX
.MI
->getIterator();
145 for (auto I
= FirstShapeBelowAMX
, E
= Shapes
.end(); I
!= E
; ++I
) {
146 // Do not hoist instructions that access memory.
147 if (I
->MI
->mayLoadOrStore())
149 for (auto &MO
: I
->MI
->operands()) {
152 // Do not hoist instructions if the sources' def under AMX instruction.
153 // TODO: We can handle isMoveImmediate MI here.
154 if (MO
.isReg() && MIRef(MRI
->getVRegDef(MO
.getReg())) > FirstAMX
)
156 // TODO: Maybe need more checks here.
158 MBB
->insert(InsertPoint
, I
->MI
->removeFromParent());
160 // We only need to mark the last shape in the BB now.
162 Shapes
.push_back(MIRef(&*--InsertPoint
, MBB
));
167 X86PreTileConfig() : MachineFunctionPass(ID
) {}
169 /// Return the pass name.
170 StringRef
getPassName() const override
{
171 return "Tile Register Pre-configure";
174 /// X86PreTileConfig analysis usage.
175 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
176 AU
.setPreservesAll();
177 AU
.addRequired
<MachineLoopInfo
>();
178 MachineFunctionPass::getAnalysisUsage(AU
);
181 /// Clear MF related structures.
182 void releaseMemory() override
{
185 BBVisitedInfo
.clear();
188 /// Perform ldtilecfg instructions inserting.
189 bool runOnMachineFunction(MachineFunction
&MF
) override
;
194 } // end anonymous namespace
196 char X86PreTileConfig::ID
= 0;
198 INITIALIZE_PASS_BEGIN(X86PreTileConfig
, "tilepreconfig",
199 "Tile Register Pre-configure", false, false)
200 INITIALIZE_PASS_DEPENDENCY(MachineLoopInfo
)
201 INITIALIZE_PASS_END(X86PreTileConfig
, "tilepreconfig",
202 "Tile Register Pre-configure", false, false)
204 void X86PreTileConfig::collectShapeInfo(MachineInstr
&MI
) {
205 auto RecordShape
= [&](MachineInstr
*MI
, MachineBasicBlock
*MBB
) {
207 auto I
= llvm::lower_bound(ShapeBBs
[MBB
], MIR
);
208 if (I
== ShapeBBs
[MBB
].end() || *I
!= MIR
)
209 ShapeBBs
[MBB
].insert(I
, MIR
);
212 SmallVector
<Register
, 8> WorkList(
213 {MI
.getOperand(1).getReg(), MI
.getOperand(2).getReg()});
214 while (!WorkList
.empty()) {
215 Register R
= WorkList
.pop_back_val();
216 MachineInstr
*DefMI
= MRI
->getVRegDef(R
);
217 assert(DefMI
&& "R must has one define instruction");
218 MachineBasicBlock
*DefMBB
= DefMI
->getParent();
219 if (DefMI
->isMoveImmediate() || !DefVisited
.insert(DefMI
).second
)
221 if (DefMI
->isPHI()) {
222 for (unsigned I
= 1; I
< DefMI
->getNumOperands(); I
+= 2)
223 if (isLoopBackEdge(DefMBB
, DefMI
->getOperand(I
+ 1).getMBB()))
224 RecordShape(DefMI
, DefMBB
); // In this case, PHI is also a shape def.
226 WorkList
.push_back(DefMI
->getOperand(I
).getReg());
228 RecordShape(DefMI
, DefMBB
);
233 bool X86PreTileConfig::runOnMachineFunction(MachineFunction
&MF
) {
234 const X86Subtarget
&ST
= MF
.getSubtarget
<X86Subtarget
>();
235 const TargetInstrInfo
*TII
= ST
.getInstrInfo();
236 const TargetRegisterInfo
*TRI
= ST
.getRegisterInfo();
237 const TargetRegisterClass
*RC
= TRI
->getRegClass(X86::TILERegClassID
);
239 BitVector
AMXRegs(TRI
->getNumRegs());
240 for (unsigned I
= 0; I
< RC
->getNumRegs(); I
++)
241 AMXRegs
.set(X86::TMM0
+ I
);
243 // Iterate MF to collect information.
244 MRI
= &MF
.getRegInfo();
245 MLI
= &getAnalysis
<MachineLoopInfo
>();
246 SmallSet
<MIRef
, 8> CfgNeedInsert
;
247 SmallVector
<MachineBasicBlock
*, 8> CfgLiveInBBs
;
248 for (auto &MBB
: MF
) {
250 for (auto &MI
: MBB
) {
252 if (isAMXInstruction(MI
)) {
253 // If there's call before the AMX, we need to reload tile config.
254 if (BBVisitedInfo
[&MBB
].LastCall
)
255 CfgNeedInsert
.insert(BBVisitedInfo
[&MBB
].LastCall
);
256 else // Otherwise, we need tile config to live in this BB.
257 BBVisitedInfo
[&MBB
].NeedTileCfgLiveIn
= true;
258 // Always record the first AMX in case there's shape def after it.
259 if (!BBVisitedInfo
[&MBB
].FirstAMX
)
260 BBVisitedInfo
[&MBB
].FirstAMX
= MIRef(&MI
, &MBB
, Pos
);
261 } else if (MI
.isCall() && isDestructiveCall(MI
, AMXRegs
)) {
262 // Record the call only if the callee clobbers all AMX registers.
263 BBVisitedInfo
[&MBB
].LastCall
= MIRef(&MI
, &MBB
, Pos
);
266 if (BBVisitedInfo
[&MBB
].NeedTileCfgLiveIn
) {
267 if (&MBB
== &MF
.front())
268 CfgNeedInsert
.insert(MIRef(&MBB
));
270 CfgLiveInBBs
.push_back(&MBB
);
272 if (BBVisitedInfo
[&MBB
].FirstAMX
|| BBVisitedInfo
[&MBB
].HasAMXRegLiveIn
)
273 for (auto *Succ
: MBB
.successors())
274 if (!isLoopBackEdge(Succ
, &MBB
))
275 BBVisitedInfo
[Succ
].HasAMXRegLiveIn
= true;
278 // Update NeedTileCfgLiveIn for predecessors.
279 while (!CfgLiveInBBs
.empty()) {
280 MachineBasicBlock
*MBB
= CfgLiveInBBs
.pop_back_val();
281 for (auto *Pred
: MBB
->predecessors()) {
282 if (BBVisitedInfo
[Pred
].LastCall
) {
283 CfgNeedInsert
.insert(BBVisitedInfo
[Pred
].LastCall
);
284 } else if (!BBVisitedInfo
[Pred
].NeedTileCfgLiveIn
) {
285 BBVisitedInfo
[Pred
].NeedTileCfgLiveIn
= true;
286 if (Pred
== &MF
.front())
287 CfgNeedInsert
.insert(MIRef(Pred
));
289 CfgLiveInBBs
.push_back(Pred
);
294 // There's no AMX instruction if we didn't find a tile config live in point.
295 if (CfgNeedInsert
.empty())
298 // Avoid to insert ldtilecfg before any shape defs.
299 SmallVector
<MachineBasicBlock
*, 8> WorkList
;
300 for (auto &I
: ShapeBBs
) {
301 // TODO: We can hoist shapes across BBs here.
302 if (BBVisitedInfo
[I
.first
].HasAMXRegLiveIn
)
304 if (BBVisitedInfo
[I
.first
].FirstAMX
&&
305 BBVisitedInfo
[I
.first
].FirstAMX
< I
.second
.back() &&
306 !hoistShapesInBB(I
.first
, I
.second
))
308 WorkList
.push_back(I
.first
);
310 while (!WorkList
.empty()) {
311 MachineBasicBlock
*MBB
= WorkList
.pop_back_val();
312 for (auto *Pred
: MBB
->predecessors()) {
313 if (!BBVisitedInfo
[Pred
].TileCfgForbidden
&& !isLoopBackEdge(MBB
, Pred
)) {
314 BBVisitedInfo
[Pred
].TileCfgForbidden
= true;
315 WorkList
.push_back(Pred
);
321 SmallSet
<MIRef
, 8> VisitedOrInserted
;
322 int SS
= MF
.getFrameInfo().CreateStackObject(
323 ST
.getTileConfigSize(), ST
.getTileConfigAlignment(), false);
325 // Try to insert for the tile config live in points.
326 for (auto I
: CfgNeedInsert
) {
327 SmallSet
<MIRef
, 8> InsertPoints
;
328 SmallVector
<MIRef
, 8> WorkList({I
});
329 while (!WorkList
.empty()) {
330 MIRef I
= WorkList
.pop_back_val();
331 if (!VisitedOrInserted
.count(I
)) {
332 if (!BBVisitedInfo
[I
.MBB
].TileCfgForbidden
) {
333 // If the BB is all shapes reachable, stop sink and try to insert.
334 InsertPoints
.insert(I
);
336 // Avoid the BB to be multi visited.
337 VisitedOrInserted
.insert(I
);
338 // Sink the inserting point along the chain with NeedTileCfgLiveIn =
339 // true when MBB isn't all shapes reachable.
340 for (auto *Succ
: I
.MBB
->successors())
341 if (BBVisitedInfo
[Succ
].NeedTileCfgLiveIn
)
342 WorkList
.push_back(MIRef(Succ
));
347 // A given point might be forked due to shape conditions are not met.
348 for (MIRef I
: InsertPoints
) {
349 // Make sure we insert ldtilecfg after the last shape def in MBB.
350 if (ShapeBBs
.count(I
.MBB
) && I
< ShapeBBs
[I
.MBB
].back())
351 I
= ShapeBBs
[I
.MBB
].back();
352 // There're chances the MBB is sunk more than once. Record it to avoid
354 if (VisitedOrInserted
.insert(I
).second
) {
355 auto II
= I
.MI
? I
.MI
->getIterator() : I
.MBB
->instr_begin();
356 addFrameReference(BuildMI(*I
.MBB
, ++II
, DL
, TII
->get(X86::LDTILECFG
)),
363 MachineBasicBlock
&MBB
= MF
.front();
364 MachineInstr
*MI
= &*MBB
.begin();
365 if (ST
.hasAVX512()) {
366 Register Zmm
= MRI
->createVirtualRegister(&X86::VR512RegClass
);
367 BuildMI(MBB
, MI
, DL
, TII
->get(X86::VPXORDZrr
), Zmm
)
368 .addReg(Zmm
, RegState::Undef
)
369 .addReg(Zmm
, RegState::Undef
);
370 addFrameReference(BuildMI(MBB
, MI
, DL
, TII
->get(X86::VMOVUPSZmr
)), SS
)
372 } else if (ST
.hasAVX2()) {
373 Register Ymm
= MRI
->createVirtualRegister(&X86::VR256RegClass
);
374 BuildMI(MBB
, MI
, DL
, TII
->get(X86::VPXORYrr
), Ymm
)
375 .addReg(Ymm
, RegState::Undef
)
376 .addReg(Ymm
, RegState::Undef
);
377 addFrameReference(BuildMI(MBB
, MI
, DL
, TII
->get(X86::VMOVUPSYmr
)), SS
)
379 addFrameReference(BuildMI(MBB
, MI
, DL
, TII
->get(X86::VMOVUPSYmr
)), SS
, 32)
382 assert(ST
.hasSSE2() && "AMX should assume SSE2 enabled");
383 Register Xmm
= MRI
->createVirtualRegister(&X86::VR128RegClass
);
384 BuildMI(MBB
, MI
, DL
, TII
->get(X86::PXORrr
), Xmm
)
385 .addReg(Xmm
, RegState::Undef
)
386 .addReg(Xmm
, RegState::Undef
);
387 addFrameReference(BuildMI(MBB
, MI
, DL
, TII
->get(X86::MOVUPSmr
)), SS
)
389 addFrameReference(BuildMI(MBB
, MI
, DL
, TII
->get(X86::MOVUPSmr
)), SS
, 16)
391 addFrameReference(BuildMI(MBB
, MI
, DL
, TII
->get(X86::MOVUPSmr
)), SS
, 32)
393 addFrameReference(BuildMI(MBB
, MI
, DL
, TII
->get(X86::MOVUPSmr
)), SS
, 48)
396 // Fill in the palette first.
397 addFrameReference(BuildMI(MBB
, MI
, DL
, TII
->get(X86::MOV8mi
)), SS
).addImm(1);
402 FunctionPass
*llvm::createX86PreTileConfigPass() {
403 return new X86PreTileConfig();