1 //===- Target/X86/X86PreAMXConfig.cpp - ------------------------*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 /// Insert tilecfg for each area of key AMX intrinsic.
10 /// All the key AMX intrinsic's tile operand must come from tileload. And the
11 /// def tile of key AMX intrinsic must be tilestored.
12 /// take tdpbssd for example:
13 /// --------------------------------------------------------------------------
14 /// %t1 = call x86_amx @llvm.x86.tileloadd64.internal(...) key
15 /// %t2 = call x86_amx @llvm.x86.tileloadd64.internal(...) |
16 /// %t3 = call x86_amx @llvm.x86.tileloadd64.internal(...) amx
17 /// %td = tail call x86_amx @llvm.x86.tdpbssd.internal(t1, t2, t3) |
18 /// call void @llvm.x86.tilestored64.internal(... td) area
19 /// --------------------------------------------------------------------------
20 /// This pass will insert tilecfg before every key-amx-area, some like:
21 /// --------------------------------------------------------------------------
22 /// %cfgmem = alloca <16 x i32>, align 4 * allocate mem
23 /// store <16 x i32> zeroinitializer, <16 x i32>* %cfgmem * zero init
25 /// ... pre-config shape of %t1 *
26 /// store volatile i8 %m, i8* %amx.tmm.0.shape.row, align 1 *
27 /// store volatile i16 %k, i16* %amx.tmm.0.shape.col, align 2 * pre-config
29 /// ... pre-config shape of %t2 * shapes
30 /// store volatile i8 %k, i8* %amx.tmm.1.shape.row, align 1 *
31 /// store volatile i16 %n, i16* %amx.tmm.1.shape.col, align 2 *
33 /// call void @llvm.x86.ldtilecfg(i8* %cfgmem) * tile config
35 //===----------------------------------------------------------------------===//
38 #include "llvm/ADT/SmallSet.h"
39 #include "llvm/Analysis/TargetTransformInfo.h"
40 #include "llvm/CodeGen/Passes.h"
41 #include "llvm/CodeGen/TargetPassConfig.h"
42 #include "llvm/CodeGen/ValueTypes.h"
43 #include "llvm/IR/DataLayout.h"
44 #include "llvm/IR/Function.h"
45 #include "llvm/IR/IRBuilder.h"
46 #include "llvm/IR/Instructions.h"
47 #include "llvm/IR/IntrinsicInst.h"
48 #include "llvm/IR/IntrinsicsX86.h"
49 #include "llvm/IR/PatternMatch.h"
50 #include "llvm/InitializePasses.h"
51 #include "llvm/Pass.h"
52 #include "llvm/Support/raw_ostream.h"
53 #include "llvm/Target/TargetMachine.h"
56 using namespace PatternMatch
;
58 #define DEBUG_TYPE "pre-amx-config"
60 static bool isAMXIntrinsic(IntrinsicInst
*II
) {
61 for (Value
*Operand
: II
->operands())
62 if (Operand
->getType()->isX86_AMXTy())
64 return II
->getType()->isX86_AMXTy();
67 static bool isTileLoad(IntrinsicInst
*II
) {
68 return II
->getIntrinsicID() == Intrinsic::x86_tileloadd64_internal
||
69 II
->getIntrinsicID() == Intrinsic::x86_tileloaddt164_internal
;
72 static bool isTileStore(IntrinsicInst
*II
) {
73 return II
->getIntrinsicID() == Intrinsic::x86_tilestored64_internal
;
77 static bool onlyTileDef(IntrinsicInst
*II
) {
78 for (Value
*Operand
: II
->operands())
79 if (Operand
->getType()->isX86_AMXTy())
81 return II
->getType()->isX86_AMXTy();
84 static bool brokenVolatile(Instruction
*I
) {
85 // Todo: it is weak to identify a normal call here.
86 if ((isa
<CallInst
>(I
) && !isa
<IntrinsicInst
>(I
)) || I
->isTerminator())
93 class X86PreAMXConfig
{
97 X86PreAMXConfig(Function
&Func
) : F(Func
) {}
99 bool addTileConfig(Instruction
*ModelStart
, SmallVector
<Value
*, 8> &Shapes
);
100 bool findConfigShapes(
101 DenseMap
<Instruction
*, SmallVector
<Value
*, 8>> &PosAndShapes
);
102 bool getKeyAMXShapes(IntrinsicInst
*KeyAMX
, SmallVector
<Value
*, 8> &Shapes
);
103 bool preWriteTileCfg(Value
*I8Ptr
, Instruction
*Pos
,
104 SmallVector
<Value
*, 8> &Shapes
);
106 getShapesAndConfigPosEnd(BasicBlock::iterator Iter
,
107 SmallVector
<Value
*, 8> &Shapes
);
108 bool checkVolatileModel(SmallSet
<Value
*, 4> &Loads
, IntrinsicInst
*Store
,
109 IntrinsicInst
*KeyAMX
);
112 // Orderly write the shapes in tilecfg's mem. This maybe not right.
113 // Because the first shape may not corresponding to the first tmm register,
114 // so we need to handle at at X86FastTileConfig::materializeTileCfg()
115 // after register allocation.
117 // --------------------------------------------------------------------------
118 // zeroinitialize tilecfg's mem (of ldtilecfg)
119 // --------------------------------------------------------------------------
120 // ... pre-config shape of %t1 *
121 // %amx.tmm.0.shape.row = getelementptr i8, i8* %mem, i64 48 *
122 // %amx.tmm.0.shape.col = getelementptr i16, i16* %mem, i64 16 *
123 // store volatile i8 %m, i8* %amx.tmm.0.shape.row, align 1 *
124 // store volatile i16 %k, i16* %amx.tmm.0.shape.col, align 2 * pre-config
126 // ... pre-config shape of %t2 *
127 // %amx.tmm.1.shape.row = getelementptr i8, i8* %mem, i64 49 *
128 // %amx.tmm.1.shape.col = getelementptr i16, i16* %mem, i64 18 *
129 // store volatile i8 %k, i8* %amx.tmm.1.shape.row, align 1 * shapes
130 // store volatile i16 %n, i16* %amx.tmm.1.shape.col, align 2 *
132 // ... pre-config shape of %t3 * of
133 // %amx.tmm.2.shape.row = getelementptr i8, i8* %mem, i64 50 *
134 // %amx.tmm.2.shape.col = getelementptr i16, i16* %mem, i64 20 *
135 // store volatile i8 %m, i8* %amx.tmm.2.shape.row, align 1 *
136 // store volatile i16 %n, i16* %amx.tmm.2.shape.col, align 2 *
138 // ... pre-config shape of %td *
139 // %amx.tmm.3.shape.row = getelementptr i8, i8* %mem, i64 51 *
140 // %amx.tmm.3.shape.col = getelementptr i16, i16* %mem, i64 22 *
141 // store volatile i8 %m, i8* %amx.tmm.3.shape.row, align 1 *
142 // store volatile i16 %n, i16* %amx.tmm.3.shape.col, align 2 *
143 // --------------------------------------------------------------------------
144 // call void @llvm.x86.ldtilecfg(i8* %mem) * tile config
145 // --------------------------------------------------------------------------
146 // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) key
147 // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)
148 // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) amx
149 // %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3)
150 // call void @llvm.x86.tilestored64.internal(... td) area
151 // --------------------------------------------------------------------------
152 bool X86PreAMXConfig::preWriteTileCfg(Value
*I8Ptr
, Instruction
*Pos
,
153 SmallVector
<Value
*, 8> &Shapes
) {
155 LLVMContext
&Ctx
= Pos
->getParent()->getContext();
156 Type
*I8Ty
= Type::getInt8Ty(Ctx
);
157 Type
*I16Ty
= Type::getInt16Ty(Ctx
);
159 // TODO: Currently we defaultly set Palette = 1, it may be assigned to
160 // other value in the future.
161 Value
*PaletteOffset
= ConstantInt::get(Type::getInt64Ty(Ctx
), 0);
162 Value
*PaletteValue
= ConstantInt::get(Type::getInt8Ty(Ctx
), 1);
164 GetElementPtrInst::Create(I8Ty
, I8Ptr
, PaletteOffset
, "", Pos
);
165 new StoreInst(PaletteValue
, PalettePos
, Pos
);
167 for (int I
= 0, E
= Shapes
.size() / 2; I
< E
; I
++) {
168 Value
*RowOffset
= ConstantInt::get(Type::getInt64Ty(Ctx
), 48 + I
);
169 Value
*ColOffset
= ConstantInt::get(Type::getInt64Ty(Ctx
), 16 + I
* 2);
170 const std::string ShapeName
= "amx.tmm." + itostr(I
);
171 Value
*RowPos
= GetElementPtrInst::Create(I8Ty
, I8Ptr
, RowOffset
,
172 ShapeName
+ ".shape.row", Pos
);
173 Value
*ColPos
= GetElementPtrInst::Create(I8Ty
, I8Ptr
, ColOffset
, "", Pos
);
174 ColPos
= new BitCastInst(ColPos
, PointerType::get(I16Ty
, 0),
175 ShapeName
+ ".shape.col", Pos
);
176 Value
*Row
= Shapes
[I
* 2];
177 Value
*Col
= Shapes
[I
* 2 + 1];
178 Row
= new TruncInst(Row
, I8Ty
, "", Pos
);
179 new StoreInst(Row
, RowPos
, Pos
);
180 new StoreInst(Col
, ColPos
, Pos
);
186 bool X86PreAMXConfig::addTileConfig(Instruction
*ModelStart
,
187 SmallVector
<Value
*, 8> &Shapes
) {
188 Module
*M
= F
.getParent();
189 IRBuilder
<> Builder(ModelStart
);
190 const DataLayout
&DL
= M
->getDataLayout();
191 unsigned AddrSpace
= DL
.getAllocaAddrSpace();
192 LLVMContext
&Ctx
= Builder
.getContext();
193 Type
*V512Ty
= VectorType::get(Builder
.getInt32Ty(), 16, false);
194 Align Alignment
= DL
.getPrefTypeAlign(Type::getInt32Ty(Ctx
));
197 new AllocaInst(V512Ty
, AddrSpace
, "", &F
.getEntryBlock().front());
198 Addr
->setAlignment(Alignment
);
199 Value
*I8Ptr
= Builder
.CreateBitCast(Addr
, Builder
.getInt8PtrTy());
201 std::array
<Value
*, 1> Args
= {I8Ptr
};
203 Builder
.CreateIntrinsic(Intrinsic::x86_ldtilecfg_internal
, None
, Args
);
205 Value
*Val0
= Constant::getNullValue(V512Ty
);
206 Instruction
*Init0
= new StoreInst(Val0
, Addr
, false, Alignment
, Cfg
);
207 assert(Init0
&& "Not Zero initilizate the cfg mem!");
209 preWriteTileCfg(I8Ptr
, Cfg
, Shapes
);
214 // Todo: We may need to handle "more than one store" case in the future.
215 bool X86PreAMXConfig::checkVolatileModel(SmallSet
<Value
*, 4> &Loads
,
216 IntrinsicInst
*Store
,
217 IntrinsicInst
*KeyAMX
) {
218 Value
*ST
= Store
->getOperand(4);
220 // Only has tileload and tilestore.
222 return (Loads
.size() == 1) && Loads
.contains(ST
);
224 // All Loads should be operands of KeyAMX.
225 // All tile operands of KeyAMX should come from Loads.
226 for (Value
*Op
: KeyAMX
->operands()) {
227 if (Op
->getType()->isX86_AMXTy())
228 if (!Loads
.erase(Op
))
232 // The def of KeyAMX should be stored into mem.
233 // Todo: is it key amx can be no def?
234 return Loads
.empty() && (ST
== cast
<Value
>(KeyAMX
));
237 bool X86PreAMXConfig::getKeyAMXShapes(IntrinsicInst
*KeyAMX
,
238 SmallVector
<Value
*, 8> &Shapes
) {
239 for (unsigned I
= 0; I
< KeyAMX
->getNumOperands(); I
++) {
240 Value
*Op
= KeyAMX
->getOperand(I
);
241 if (!Op
->getType()->isX86_AMXTy())
243 IntrinsicInst
*TileDef
= dyn_cast
<IntrinsicInst
>(Op
);
244 assert((TileDef
&& isTileLoad(TileDef
)) &&
245 "All KeyAMX's tile definiation should comes from TileLoad!");
246 Shapes
.push_back(TileDef
->getOperand(0));
247 Shapes
.push_back(TileDef
->getOperand(1));
249 if (!isTileStore(KeyAMX
)) {
250 Shapes
.push_back(KeyAMX
->getOperand(0));
251 Shapes
.push_back(KeyAMX
->getOperand(1));
253 return Shapes
.size() != 0;
256 // Collect the shapes and skip the area of current key amx intrinsic.
260 // --------------------------------------------------------------------------
261 // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) record (m,k)
262 // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...) record (m,k)
263 // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) record (m,k)
264 // %td = call x86_amx @llvm.x86.tdpbssd.internal(...t1, t2, t3)
265 // call void @llvm.x86.tilestored64.internal(m, n,... td) <--PosEnd record (m,k)
266 // --------------------------------------------------------------------------
268 X86PreAMXConfig::getShapesAndConfigPosEnd(BasicBlock::iterator Iter
,
269 SmallVector
<Value
*, 8> &Shapes
) {
270 IntrinsicInst
*KeyAMX
= nullptr;
271 BasicBlock
*BB
= Iter
->getParent();
272 BasicBlock::iterator PosEnd
= BB
->end();
273 SmallSet
<Value
*, 4> Loads
;
275 // See TileStore as "Config Position End" and check volatile model.
276 for (auto I
= Iter
, E
= BB
->end(); I
!= E
; ++I
) {
277 assert(!brokenVolatile(&*I
) && "Not reach tile store!");
278 IntrinsicInst
*II
= dyn_cast
<IntrinsicInst
>(&*I
);
279 if (!II
|| !isAMXIntrinsic(II
))
282 if (isTileLoad(II
)) {
284 } else if (isTileStore(II
)) {
285 if (!checkVolatileModel(Loads
, II
, KeyAMX
))
286 report_fatal_error("Not Volatile AMX Model!");
290 assert(!KeyAMX
&& "Too many key amx intrinsic!");
294 assert(PosEnd
!= BB
->end() && "Not find TileStore!");
296 // See KeyAMX as TileStore if only TileLoad and TileStore.
298 KeyAMX
= dyn_cast
<IntrinsicInst
>(&*PosEnd
);
300 // Get Shapes in order.
301 assert(Shapes
.empty() && "Shapes should be clean.");
302 getKeyAMXShapes(KeyAMX
, Shapes
);
307 // Record a key amx area's shapes with its position.
308 // Use the first tileload as its position.
311 // --------------------------------------------------------------------------
312 // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) <-- pos
313 // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...) /
314 // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) shapes:
315 // %td = call x86_amx @llvm.x86.tdpbssd.internal(...t1, t2, t3) (m,k)(k,n)
316 // call void @llvm.x86.tilestored64.internal(m, n,... td) (m,n)(m,n)
317 // --------------------------------------------------------------------------
318 bool X86PreAMXConfig::findConfigShapes(
319 DenseMap
<Instruction
*, SmallVector
<Value
*, 8>> &PosAndShapes
) {
321 for (BasicBlock
&BB
: F
) {
322 for (BasicBlock::iterator I
= BB
.begin(), E
= BB
.end(); I
!= E
; ++I
) {
323 IntrinsicInst
*II
= dyn_cast
<IntrinsicInst
>(&*I
);
326 if (!isAMXIntrinsic(II
))
328 assert(onlyTileDef(II
) && "Not volatile model for AMX at O0!");
330 I
= getShapesAndConfigPosEnd(I
, PosAndShapes
[&*I
]);
337 // Insert ldtilecfg and preconfig the shapes for each area of key AMX intrinsic.
338 // e.g. (key amx = tdpbssd)
339 // --------------------------------------------------------------------------
340 // %cfgmem = alloca <16 x i32>, align 4 * allocate mem
341 // store <16 x i32> zeroinitializer, <16 x i32>* %cfgmem * zero init
343 // ... pre-config shape of %t1 *
344 // store volatile i8 %m, i8* %amx.tmm.0.shape.row, align 1 *
345 // store volatile i16 %k, i16* %amx.tmm.0.shape.col, align 2 * pre-config
347 // ... pre-config shape of %t2 *
348 // store volatile i8 %k, i8* %amx.tmm.1.shape.row, align 1 * shapes
349 // store volatile i16 %n, i16* %amx.tmm.1.shape.col, align 2 *
351 // ... pre-config shape of %t3 * of
352 // store volatile i8 %m, i8* %amx.tmm.2.shape.row, align 1 *
353 // store volatile i16 %n, i16* %amx.tmm.2.shape.col, align 2 *
355 // ... pre-config shape of %td *
356 // store volatile i8 %m, i8* %amx.tmm.3.shape.row, align 1 *
357 // store volatile i16 %n, i16* %amx.tmm.3.shape.col, align 2 *
359 // call void @llvm.x86.ldtilecfg(i8* %cfgmem) * pre-config
360 // --------------------------------------------------------------------------
361 // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) key
362 // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)
363 // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) amx
364 // %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3)
365 // call void @llvm.x86.tilestored64.internal(... td) area
366 // --------------------------------------------------------------------------
367 bool X86PreAMXConfig::preTileConfig() {
368 DenseMap
<Instruction
*, SmallVector
<Value
*, 8>> PosAndShapes
;
369 bool NeedCfg
= findConfigShapes(PosAndShapes
);
372 for (auto &IPAndShapes
: PosAndShapes
)
373 addTileConfig(IPAndShapes
.first
, IPAndShapes
.second
);
377 } // anonymous namespace
381 class X86PreAMXConfigPass
: public FunctionPass
{
385 X86PreAMXConfigPass() : FunctionPass(ID
) {
386 initializeX86PreAMXConfigPassPass(*PassRegistry::getPassRegistry());
389 bool runOnFunction(Function
&F
) override
{
390 TargetMachine
*TM
= &getAnalysis
<TargetPassConfig
>().getTM
<TargetMachine
>();
393 // Prepare for fast register allocation at O0.
394 if (TM
->getOptLevel() == CodeGenOpt::None
) {
396 // We pre-config each key AMX intrinsic at O0.
397 // In theory, one tile config can cover several AMX intrinsics, but
398 // it is very diffcult to classify the tile shapes at O0. So here we
399 // let thing be easy, pre-config every key AMX intrinsic.
400 X86PreAMXConfig
PCFG(F
);
401 C
= PCFG
.preTileConfig();
407 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
408 AU
.setPreservesCFG();
409 AU
.addRequired
<TargetPassConfig
>();
413 } // anonymous namespace
415 static const char PassName
[] = "Pre AMX Tile Config";
416 char X86PreAMXConfigPass::ID
= 0;
417 INITIALIZE_PASS_BEGIN(X86PreAMXConfigPass
, DEBUG_TYPE
, PassName
, false, false)
418 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig
)
419 INITIALIZE_PASS_END(X86PreAMXConfigPass
, DEBUG_TYPE
, PassName
, false, false)
421 FunctionPass
*llvm::createX86PreAMXConfigPass() {
422 return new X86PreAMXConfigPass();