[X86] Preserve volatile ATOMIC_LOAD_OR nodes
[llvm-project.git] / llvm / lib / Target / X86 / X86PreAMXConfig.cpp
blobc9c59af8d6d74b2f6bdc7633b78de28fc684043a
1 //===- Target/X86/X86PreAMXConfig.cpp - ------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// Insert tilecfg for each area of key AMX intrinsic.
10 /// All the key AMX intrinsic's tile operand must come from tileload. And the
11 /// def tile of key AMX intrinsic must be tilestored.
12 /// take tdpbssd for example:
13 /// --------------------------------------------------------------------------
14 /// %t1 = call x86_amx @llvm.x86.tileloadd64.internal(...) key
15 /// %t2 = call x86_amx @llvm.x86.tileloadd64.internal(...) |
16 /// %t3 = call x86_amx @llvm.x86.tileloadd64.internal(...) amx
17 /// %td = tail call x86_amx @llvm.x86.tdpbssd.internal(t1, t2, t3) |
18 /// call void @llvm.x86.tilestored64.internal(... td) area
19 /// --------------------------------------------------------------------------
20 /// This pass will insert tilecfg before every key-amx-area, some like:
21 /// --------------------------------------------------------------------------
22 /// %cfgmem = alloca <16 x i32>, align 4 * allocate mem
23 /// store <16 x i32> zeroinitializer, <16 x i32>* %cfgmem * zero init
24 /// ...
25 /// ... pre-config shape of %t1 *
26 /// store volatile i8 %m, i8* %amx.tmm.0.shape.row, align 1 *
27 /// store volatile i16 %k, i16* %amx.tmm.0.shape.col, align 2 * pre-config
28 /// ... *
29 /// ... pre-config shape of %t2 * shapes
30 /// store volatile i8 %k, i8* %amx.tmm.1.shape.row, align 1 *
31 /// store volatile i16 %n, i16* %amx.tmm.1.shape.col, align 2 *
32 /// ...
33 /// call void @llvm.x86.ldtilecfg(i8* %cfgmem) * tile config
35 //===----------------------------------------------------------------------===//
37 #include "X86.h"
38 #include "llvm/ADT/SmallSet.h"
39 #include "llvm/ADT/StringExtras.h"
40 #include "llvm/Analysis/TargetTransformInfo.h"
41 #include "llvm/CodeGen/Passes.h"
42 #include "llvm/CodeGen/TargetPassConfig.h"
43 #include "llvm/CodeGen/ValueTypes.h"
44 #include "llvm/IR/DataLayout.h"
45 #include "llvm/IR/Function.h"
46 #include "llvm/IR/IRBuilder.h"
47 #include "llvm/IR/Instructions.h"
48 #include "llvm/IR/IntrinsicInst.h"
49 #include "llvm/IR/IntrinsicsX86.h"
50 #include "llvm/IR/PatternMatch.h"
51 #include "llvm/InitializePasses.h"
52 #include "llvm/Pass.h"
53 #include "llvm/Support/raw_ostream.h"
54 #include "llvm/Target/TargetMachine.h"
56 using namespace llvm;
57 using namespace PatternMatch;
59 #define DEBUG_TYPE "pre-amx-config"
61 static bool isAMXIntrinsic(IntrinsicInst *II) {
62 for (Value *Operand : II->operands())
63 if (Operand->getType()->isX86_AMXTy())
64 return true;
65 return II->getType()->isX86_AMXTy();
68 static bool isTileLoad(IntrinsicInst *II) {
69 return II->getIntrinsicID() == Intrinsic::x86_tileloadd64_internal ||
70 II->getIntrinsicID() == Intrinsic::x86_tileloaddt164_internal;
73 static bool isTileStore(IntrinsicInst *II) {
74 return II->getIntrinsicID() == Intrinsic::x86_tilestored64_internal;
77 #ifndef NDEBUG
78 static bool onlyTileDef(IntrinsicInst *II) {
79 for (Value *Operand : II->operands())
80 if (Operand->getType()->isX86_AMXTy())
81 return false;
82 return II->getType()->isX86_AMXTy();
85 static bool brokenVolatile(Instruction *I) {
86 // Todo: it is weak to identify a normal call here.
87 if ((isa<CallInst>(I) && !isa<IntrinsicInst>(I)) || I->isTerminator())
88 return true;
89 return false;
91 #endif
93 namespace {
94 class X86PreAMXConfig {
95 using PosAndShapesMap = MapVector<Instruction *, SmallVector<Value *, 8>>;
97 Function &F;
99 public:
100 X86PreAMXConfig(Function &Func) : F(Func) {}
101 bool preTileConfig();
102 void addTileConfig(Instruction *ModelStart, SmallVector<Value *, 8> &Shapes);
103 bool findConfigShapes(PosAndShapesMap &PosAndShapes);
104 bool getKeyAMXShapes(IntrinsicInst *KeyAMX, SmallVector<Value *, 8> &Shapes);
105 void preWriteTileCfg(Value *I8Ptr, IRBuilderBase &Builder,
106 SmallVector<Value *, 8> &Shapes);
107 BasicBlock::iterator
108 getShapesAndConfigPosEnd(BasicBlock::iterator Iter,
109 SmallVector<Value *, 8> &Shapes);
110 bool checkVolatileModel(SmallSet<Value *, 4> &Loads, IntrinsicInst *Store,
111 IntrinsicInst *KeyAMX);
114 // Orderly write the shapes in tilecfg's mem. This maybe not right.
115 // Because the first shape may not corresponding to the first tmm register,
116 // so we need to handle at at X86FastTileConfig::materializeTileCfg()
117 // after register allocation.
118 // For example:
119 // --------------------------------------------------------------------------
120 // zeroinitialize tilecfg's mem (of ldtilecfg)
121 // --------------------------------------------------------------------------
122 // ... pre-config shape of %t1 *
123 // %amx.tmm.0.shape.row = getelementptr i8, i8* %mem, i64 48 *
124 // %amx.tmm.0.shape.col = getelementptr i16, i16* %mem, i64 16 *
125 // store volatile i8 %m, i8* %amx.tmm.0.shape.row, align 1 *
126 // store volatile i16 %k, i16* %amx.tmm.0.shape.col, align 2 * pre-config
127 // ... *
128 // ... pre-config shape of %t2 *
129 // %amx.tmm.1.shape.row = getelementptr i8, i8* %mem, i64 49 *
130 // %amx.tmm.1.shape.col = getelementptr i16, i16* %mem, i64 18 *
131 // store volatile i8 %k, i8* %amx.tmm.1.shape.row, align 1 * shapes
132 // store volatile i16 %n, i16* %amx.tmm.1.shape.col, align 2 *
133 // ... *
134 // ... pre-config shape of %t3 * of
135 // %amx.tmm.2.shape.row = getelementptr i8, i8* %mem, i64 50 *
136 // %amx.tmm.2.shape.col = getelementptr i16, i16* %mem, i64 20 *
137 // store volatile i8 %m, i8* %amx.tmm.2.shape.row, align 1 *
138 // store volatile i16 %n, i16* %amx.tmm.2.shape.col, align 2 *
139 // ... * tiles
140 // ... pre-config shape of %td *
141 // %amx.tmm.3.shape.row = getelementptr i8, i8* %mem, i64 51 *
142 // %amx.tmm.3.shape.col = getelementptr i16, i16* %mem, i64 22 *
143 // store volatile i8 %m, i8* %amx.tmm.3.shape.row, align 1 *
144 // store volatile i16 %n, i16* %amx.tmm.3.shape.col, align 2 *
145 // --------------------------------------------------------------------------
146 // call void @llvm.x86.ldtilecfg(i8* %mem) * tile config
147 // --------------------------------------------------------------------------
148 // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) key
149 // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)
150 // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) amx
151 // %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3)
152 // call void @llvm.x86.tilestored64.internal(... td) area
153 // --------------------------------------------------------------------------
154 void X86PreAMXConfig::preWriteTileCfg(Value *I8Ptr, IRBuilderBase &Builder,
155 SmallVector<Value *, 8> &Shapes) {
156 LLVMContext &Ctx = Builder.getContext();
157 Type *I8Ty = Type::getInt8Ty(Ctx);
158 Type *I16Ty = Type::getInt16Ty(Ctx);
160 // TODO: Currently we defaultly set Palette = 1, it may be assigned to
161 // other value in the future.
162 Value *PaletteOffset = ConstantInt::get(Type::getInt64Ty(Ctx), 0);
163 Value *PaletteValue = ConstantInt::get(Type::getInt8Ty(Ctx), 1);
164 Value *PalettePos = Builder.CreateGEP(I8Ty, I8Ptr, PaletteOffset);
165 Builder.CreateStore(PaletteValue, PalettePos);
167 for (int I = 0, E = Shapes.size() / 2; I < E; I++) {
168 Value *RowOffset = ConstantInt::get(Type::getInt64Ty(Ctx), 48 + I);
169 Value *ColOffset = ConstantInt::get(Type::getInt64Ty(Ctx), 16 + I * 2);
170 const std::string ShapeName = "amx.tmm." + itostr(I);
171 Value *RowPos = Builder.CreateGEP(I8Ty, I8Ptr, RowOffset,
172 ShapeName + ".shape.row");
173 Value *ColPos = Builder.CreateGEP(I8Ty, I8Ptr, ColOffset);
174 ColPos = Builder.CreateBitCast(ColPos, PointerType::get(I16Ty, 0),
175 ShapeName + ".shape.col");
176 Value *Row = Shapes[I * 2];
177 Value *Col = Shapes[I * 2 + 1];
178 Row = Builder.CreateTrunc(Row, I8Ty);
179 Builder.CreateStore(Row, RowPos);
180 Builder.CreateStore(Col, ColPos);
184 void X86PreAMXConfig::addTileConfig(Instruction *ModelStart,
185 SmallVector<Value *, 8> &Shapes) {
186 Module *M = F.getParent();
187 IRBuilder<> Builder(ModelStart);
188 const DataLayout &DL = M->getDataLayout();
189 unsigned AddrSpace = DL.getAllocaAddrSpace();
190 LLVMContext &Ctx = Builder.getContext();
191 Type *V512Ty = VectorType::get(Builder.getInt32Ty(), 16, false);
192 Align Alignment = DL.getPrefTypeAlign(Type::getInt32Ty(Ctx));
194 AllocaInst *Addr =
195 new AllocaInst(V512Ty, AddrSpace, "", &F.getEntryBlock().front());
196 Addr->setAlignment(Alignment);
197 Value *I8Ptr = Builder.CreateBitCast(Addr, Builder.getInt8PtrTy());
199 Builder.CreateAlignedStore(Constant::getNullValue(V512Ty), Addr, Alignment);
201 preWriteTileCfg(I8Ptr, Builder, Shapes);
203 Builder.CreateIntrinsic(Intrinsic::x86_ldtilecfg_internal, std::nullopt,
204 {I8Ptr});
207 // Todo: We may need to handle "more than one store" case in the future.
208 bool X86PreAMXConfig::checkVolatileModel(SmallSet<Value *, 4> &Loads,
209 IntrinsicInst *Store,
210 IntrinsicInst *KeyAMX) {
211 Value *ST = Store->getOperand(4);
213 // Only has tileload and tilestore.
214 if (!KeyAMX)
215 return (Loads.size() == 1) && Loads.contains(ST);
217 // All Loads should be operands of KeyAMX.
218 // All tile operands of KeyAMX should come from Loads.
219 for (Value *Op : KeyAMX->operands()) {
220 if (Op->getType()->isX86_AMXTy())
221 if (!Loads.erase(Op))
222 return false;
225 // The def of KeyAMX should be stored into mem.
226 // Todo: is it key amx can be no def?
227 return Loads.empty() && (ST == cast<Value>(KeyAMX));
230 bool X86PreAMXConfig::getKeyAMXShapes(IntrinsicInst *KeyAMX,
231 SmallVector<Value *, 8> &Shapes) {
232 for (unsigned I = 0; I < KeyAMX->getNumOperands(); I++) {
233 Value *Op = KeyAMX->getOperand(I);
234 if (!Op->getType()->isX86_AMXTy())
235 continue;
236 IntrinsicInst *TileDef = dyn_cast<IntrinsicInst>(Op);
237 assert((TileDef && isTileLoad(TileDef)) &&
238 "All KeyAMX's tile definiation should comes from TileLoad!");
239 Shapes.push_back(TileDef->getOperand(0));
240 Shapes.push_back(TileDef->getOperand(1));
242 if (!isTileStore(KeyAMX)) {
243 Shapes.push_back(KeyAMX->getOperand(0));
244 Shapes.push_back(KeyAMX->getOperand(1));
246 return Shapes.size() != 0;
249 // Collect the shapes and skip the area of current key amx intrinsic.
251 // For example:
252 // ...
253 // --------------------------------------------------------------------------
254 // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) record (m,k)
255 // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...) record (m,k)
256 // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) record (m,k)
257 // %td = call x86_amx @llvm.x86.tdpbssd.internal(...t1, t2, t3)
258 // call void @llvm.x86.tilestored64.internal(m, n,... td) <--PosEnd record (m,k)
259 // --------------------------------------------------------------------------
260 BasicBlock::iterator
261 X86PreAMXConfig::getShapesAndConfigPosEnd(BasicBlock::iterator Iter,
262 SmallVector<Value *, 8> &Shapes) {
263 IntrinsicInst *KeyAMX = nullptr;
264 BasicBlock *BB = Iter->getParent();
265 BasicBlock::iterator PosEnd = BB->end();
266 SmallSet<Value *, 4> Loads;
268 // See TileStore as "Config Position End" and check volatile model.
269 for (auto I = Iter, E = BB->end(); I != E; ++I) {
270 assert(!brokenVolatile(&*I) && "Not reach tile store!");
271 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&*I);
272 if (!II || !isAMXIntrinsic(II))
273 continue;
275 if (isTileLoad(II)) {
276 Loads.insert(II);
277 } else if (isTileStore(II)) {
278 if (!checkVolatileModel(Loads, II, KeyAMX))
279 report_fatal_error("Not Volatile AMX Model!");
280 PosEnd = I;
281 break;
282 } else {
283 assert(!KeyAMX && "Too many key amx intrinsic!");
284 KeyAMX = II;
287 assert(PosEnd != BB->end() && "Not find TileStore!");
289 // See KeyAMX as TileStore if only TileLoad and TileStore.
290 if (!KeyAMX)
291 KeyAMX = dyn_cast<IntrinsicInst>(&*PosEnd);
293 // Get Shapes in order.
294 assert(Shapes.empty() && "Shapes should be clean.");
295 getKeyAMXShapes(KeyAMX, Shapes);
297 return PosEnd;
300 // Record a key amx area's shapes with its position.
301 // Use the first tileload as its position.
302 // For example:
303 // ...
304 // --------------------------------------------------------------------------
305 // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) <-- pos
306 // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...) /
307 // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) shapes:
308 // %td = call x86_amx @llvm.x86.tdpbssd.internal(...t1, t2, t3) (m,k)(k,n)
309 // call void @llvm.x86.tilestored64.internal(m, n,... td) (m,n)(m,n)
310 // --------------------------------------------------------------------------
311 bool X86PreAMXConfig::findConfigShapes(PosAndShapesMap &PosAndShapes) {
312 bool Find = false;
313 for (BasicBlock &BB : F) {
314 for (BasicBlock::iterator I = BB.begin(), E = BB.end(); I != E; ++I) {
315 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&*I);
316 if (!II)
317 continue;
318 if (!isAMXIntrinsic(II))
319 continue;
320 assert(onlyTileDef(II) && "Not volatile model for AMX at O0!");
322 I = getShapesAndConfigPosEnd(I, PosAndShapes[&*I]);
323 Find = true;
326 return Find;
329 // Insert ldtilecfg and preconfig the shapes for each area of key AMX intrinsic.
330 // e.g. (key amx = tdpbssd)
331 // --------------------------------------------------------------------------
332 // %cfgmem = alloca <16 x i32>, align 4 * allocate mem
333 // store <16 x i32> zeroinitializer, <16 x i32>* %cfgmem * zero init
334 // ...
335 // ... pre-config shape of %t1 *
336 // store volatile i8 %m, i8* %amx.tmm.0.shape.row, align 1 *
337 // store volatile i16 %k, i16* %amx.tmm.0.shape.col, align 2 * pre-config
338 // ... *
339 // ... pre-config shape of %t2 *
340 // store volatile i8 %k, i8* %amx.tmm.1.shape.row, align 1 * shapes
341 // store volatile i16 %n, i16* %amx.tmm.1.shape.col, align 2 *
342 // ... *
343 // ... pre-config shape of %t3 * of
344 // store volatile i8 %m, i8* %amx.tmm.2.shape.row, align 1 *
345 // store volatile i16 %n, i16* %amx.tmm.2.shape.col, align 2 *
346 // ... * tiles
347 // ... pre-config shape of %td *
348 // store volatile i8 %m, i8* %amx.tmm.3.shape.row, align 1 *
349 // store volatile i16 %n, i16* %amx.tmm.3.shape.col, align 2 *
351 // call void @llvm.x86.ldtilecfg(i8* %cfgmem) * pre-config
352 // --------------------------------------------------------------------------
353 // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) key
354 // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)
355 // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) amx
356 // %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3)
357 // call void @llvm.x86.tilestored64.internal(... td) area
358 // --------------------------------------------------------------------------
359 bool X86PreAMXConfig::preTileConfig() {
360 PosAndShapesMap PosAndShapes;
361 bool NeedCfg = findConfigShapes(PosAndShapes);
362 if (!NeedCfg)
363 return false;
364 for (auto &IPAndShapes : PosAndShapes)
365 addTileConfig(IPAndShapes.first, IPAndShapes.second);
367 return true;
369 } // anonymous namespace
371 namespace {
373 class X86PreAMXConfigPass : public FunctionPass {
374 public:
375 static char ID;
377 X86PreAMXConfigPass() : FunctionPass(ID) {
378 initializeX86PreAMXConfigPassPass(*PassRegistry::getPassRegistry());
381 bool runOnFunction(Function &F) override {
382 TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
383 bool C = false;
385 // Prepare for fast register allocation at O0.
386 if (TM->getOptLevel() == CodeGenOpt::None) {
388 // We pre-config each key AMX intrinsic at O0.
389 // In theory, one tile config can cover several AMX intrinsics, but
390 // it is very diffcult to classify the tile shapes at O0. So here we
391 // let thing be easy, pre-config every key AMX intrinsic.
392 X86PreAMXConfig PCFG(F);
393 C = PCFG.preTileConfig();
396 return C;
399 void getAnalysisUsage(AnalysisUsage &AU) const override {
400 AU.setPreservesCFG();
401 AU.addRequired<TargetPassConfig>();
405 } // anonymous namespace
407 static const char PassName[] = "Pre AMX Tile Config";
408 char X86PreAMXConfigPass::ID = 0;
409 INITIALIZE_PASS_BEGIN(X86PreAMXConfigPass, DEBUG_TYPE, PassName, false, false)
410 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
411 INITIALIZE_PASS_END(X86PreAMXConfigPass, DEBUG_TYPE, PassName, false, false)
413 FunctionPass *llvm::createX86PreAMXConfigPass() {
414 return new X86PreAMXConfigPass();