[ORC] Add std::tuple support to SimplePackedSerialization.
[llvm-project.git] / llvm / lib / Target / X86 / X86LowerAMXType.cpp
blobe150f2dbc35408fb7bcc4f09ae45549edec5f53a
1 //===- Target/X86/X86LowerAMXType.cpp - -------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file Pass to transform <256 x i32> load/store
10 /// <256 x i32> is bitcasted to x86_amx on X86, and AMX instruction set only
11 /// provides simple operation on x86_amx. The basic elementwise operation
12 /// is not supported by AMX. Since x86_amx is bitcasted from vector <256 x i32>
13 /// and only AMX intrinsics can operate on the type, we need transform
14 /// load/store <256 x i32> instruction to AMX load/store. If the bitcast can
15 /// not be combined with load/store, we transform the bitcast to amx load/store
16 /// and <256 x i32> store/load.
17 ///
18 /// If Front End not use O0 but the Mid/Back end use O0, (e.g. "Clang -O2 -S
19 /// -emit-llvm t.c" + "llc t.ll") we should make sure the amx data is volatile,
20 /// because that is necessary for AMX fast register allocation. (In Fast
21 /// registera allocation, register will be allocated before spill/reload, so
22 /// there is no additional register for amx to identify the step in spill.)
23 /// The volatileTileData() will handle this case.
24 /// e.g.
25 /// ----------------------------------------------------------
26 /// | def %td = ... |
27 /// | ... |
28 /// | "use %td" |
29 /// ----------------------------------------------------------
30 /// will transfer to -->
31 /// ----------------------------------------------------------
32 /// | def %td = ... |
33 /// | call void @llvm.x86.tilestored64.internal(mem, %td) |
34 /// | ... |
35 /// | %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)|
36 /// | "use %td2" |
37 /// ----------------------------------------------------------
39 //===----------------------------------------------------------------------===//
41 #include "X86.h"
42 #include "llvm/ADT/PostOrderIterator.h"
43 #include "llvm/ADT/SetVector.h"
44 #include "llvm/ADT/SmallSet.h"
45 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
46 #include "llvm/Analysis/TargetLibraryInfo.h"
47 #include "llvm/Analysis/TargetTransformInfo.h"
48 #include "llvm/CodeGen/Passes.h"
49 #include "llvm/CodeGen/TargetPassConfig.h"
50 #include "llvm/CodeGen/ValueTypes.h"
51 #include "llvm/IR/DataLayout.h"
52 #include "llvm/IR/Function.h"
53 #include "llvm/IR/IRBuilder.h"
54 #include "llvm/IR/Instructions.h"
55 #include "llvm/IR/IntrinsicInst.h"
56 #include "llvm/IR/IntrinsicsX86.h"
57 #include "llvm/IR/PatternMatch.h"
58 #include "llvm/InitializePasses.h"
59 #include "llvm/Pass.h"
60 #include "llvm/Target/TargetMachine.h"
61 #include "llvm/Transforms/Utils/AssumeBundleBuilder.h"
62 #include "llvm/Transforms/Utils/Local.h"
64 using namespace llvm;
65 using namespace PatternMatch;
67 #define DEBUG_TYPE "lower-amx-type"
69 static bool isAMXCast(Instruction *II) {
70 return match(II,
71 m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(m_Value())) ||
72 match(II, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(m_Value()));
75 static AllocaInst *createAllocaInstAtEntry(IRBuilder<> &Builder, BasicBlock *BB,
76 Type *Ty) {
77 Function &F = *BB->getParent();
78 Module *M = BB->getModule();
79 const DataLayout &DL = M->getDataLayout();
81 LLVMContext &Ctx = Builder.getContext();
82 auto AllocaAlignment = DL.getPrefTypeAlign(Type::getX86_AMXTy(Ctx));
83 unsigned AllocaAS = DL.getAllocaAddrSpace();
84 AllocaInst *AllocaRes =
85 new AllocaInst(Ty, AllocaAS, "", &F.getEntryBlock().front());
86 AllocaRes->setAlignment(AllocaAlignment);
87 return AllocaRes;
90 static Instruction *getFirstNonAllocaInTheEntryBlock(Function &F) {
91 for (Instruction &I : F.getEntryBlock())
92 if (!isa<AllocaInst>(&I))
93 return &I;
94 llvm_unreachable("No terminator in the entry block!");
97 static std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo) {
98 IRBuilder<> Builder(II);
99 Value *Row = nullptr, *Col = nullptr;
100 switch (II->getIntrinsicID()) {
101 default:
102 llvm_unreachable("Expect amx intrinsics");
103 case Intrinsic::x86_tileloadd64_internal:
104 case Intrinsic::x86_tileloaddt164_internal:
105 case Intrinsic::x86_tilestored64_internal: {
106 Row = II->getArgOperand(0);
107 Col = II->getArgOperand(1);
108 break;
110 // a * b + c
111 // The shape depends on which operand.
112 case Intrinsic::x86_tdpbssd_internal:
113 case Intrinsic::x86_tdpbsud_internal:
114 case Intrinsic::x86_tdpbusd_internal:
115 case Intrinsic::x86_tdpbuud_internal:
116 case Intrinsic::x86_tdpbf16ps_internal: {
117 switch (OpNo) {
118 case 3:
119 Row = II->getArgOperand(0);
120 Col = II->getArgOperand(1);
121 break;
122 case 4:
123 Row = II->getArgOperand(0);
124 Col = II->getArgOperand(2);
125 break;
126 case 5:
127 if (isa<ConstantInt>(II->getArgOperand(2)))
128 Row = Builder.getInt16(
129 (dyn_cast<ConstantInt>(II->getOperand(2))->getSExtValue()) / 4);
130 else if (isa<Instruction>(II->getArgOperand(2))) {
131 // When it is not a const value and it is not a function argument, we
132 // create Row after the definition of II->getOperand(2) instead of
133 // before II. For example, II is %118, we try to getshape for %117:
134 // %117 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x
135 // i32> %115).
136 // %118 = call x86_amx @llvm.x86.tdpbf16ps.internal(i16
137 // %104, i16 %105, i16 %106, x86_amx %110, x86_amx %114, x86_amx
138 // %117).
139 // If we create %row = udiv i16 %106, 4 before %118(aka. II), then its
140 // definition is after its user(new tileload for %117).
141 // So, the best choice is to create %row right after the definition of
142 // %106.
143 Builder.SetInsertPoint(cast<Instruction>(II->getOperand(2)));
144 Row = Builder.CreateUDiv(II->getOperand(2), Builder.getInt16(4));
145 cast<Instruction>(Row)->moveAfter(cast<Instruction>(II->getOperand(2)));
146 } else {
147 // When it is not a const value and it is a function argument, we create
148 // Row at the entry bb.
149 IRBuilder<> NewBuilder(
150 getFirstNonAllocaInTheEntryBlock(*II->getFunction()));
151 Row = NewBuilder.CreateUDiv(II->getOperand(2), NewBuilder.getInt16(4));
153 Col = II->getArgOperand(1);
154 break;
156 break;
160 return std::make_pair(Row, Col);
163 namespace {
164 class X86LowerAMXType {
165 Function &Func;
167 // In AMX intrinsics we let Shape = {Row, Col}, but the
168 // RealCol = Col / ElementSize. We may use the RealCol
169 // as a new Row for other new created AMX intrinsics.
170 std::map<Value *, Value *> Col2Row;
172 public:
173 X86LowerAMXType(Function &F) : Func(F) {}
174 bool visit();
175 void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast);
176 void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST);
177 bool transformBitcast(BitCastInst *Bitcast);
178 Value *getRowFromCol(Instruction *II, Value *V, unsigned Granularity);
181 Value *X86LowerAMXType::getRowFromCol(Instruction *II, Value *V,
182 unsigned Granularity) {
183 if (Col2Row.count(V))
184 return Col2Row[V];
185 IRBuilder<> Builder(&*II->getParent()->getFirstInsertionPt());
186 if (auto *I = dyn_cast<Instruction>(V)) {
187 BasicBlock::iterator Iter = I->getIterator();
188 ++Iter;
189 Builder.SetInsertPoint(&*Iter);
191 ConstantInt *Gran = Builder.getInt16(Granularity);
192 Value *RealRow = Builder.CreateUDiv(V, Gran);
193 Col2Row[V] = RealRow;
194 return RealRow;
197 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
198 // %2 = bitcast <256 x i32> %src to x86_amx
199 // -->
200 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
201 // i8* %addr, i64 %stride64)
202 void X86LowerAMXType::combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast) {
203 Value *Row = nullptr, *Col = nullptr;
204 Use &U = *(Bitcast->use_begin());
205 unsigned OpNo = U.getOperandNo();
206 auto *II = cast<IntrinsicInst>(U.getUser());
207 std::tie(Row, Col) = getShape(II, OpNo);
208 IRBuilder<> Builder(Bitcast);
209 // Use the maximun column as stride.
210 Value *Stride = Builder.getInt64(64);
211 Value *I8Ptr =
212 Builder.CreateBitCast(LD->getOperand(0), Builder.getInt8PtrTy());
213 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
215 Value *NewInst =
216 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args);
217 Bitcast->replaceAllUsesWith(NewInst);
220 // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
221 // %stride);
222 // %13 = bitcast x86_amx %src to <256 x i32>
223 // store <256 x i32> %13, <256 x i32>* %addr, align 64
224 // -->
225 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
226 // %stride64, %13)
227 void X86LowerAMXType::combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST) {
229 Value *Tile = Bitcast->getOperand(0);
230 auto *II = cast<IntrinsicInst>(Tile);
231 // Tile is output from AMX intrinsic. The first operand of the
232 // intrinsic is row, the second operand of the intrinsic is column.
233 Value *Row = II->getOperand(0);
234 Value *Col = II->getOperand(1);
235 IRBuilder<> Builder(ST);
236 // Use the maximum column as stride. It must be the same with load
237 // stride.
238 Value *Stride = Builder.getInt64(64);
239 Value *I8Ptr =
240 Builder.CreateBitCast(ST->getOperand(1), Builder.getInt8PtrTy());
241 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};
242 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
243 if (Bitcast->hasOneUse())
244 return;
245 // %13 = bitcast x86_amx %src to <256 x i32>
246 // store <256 x i32> %13, <256 x i32>* %addr, align 64
247 // %add = <256 x i32> %13, <256 x i32> %src2
248 // -->
249 // %13 = bitcast x86_amx %src to <256 x i32>
250 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
251 // %stride64, %13)
252 // %14 = load <256 x i32>, %addr
253 // %add = <256 x i32> %14, <256 x i32> %src2
254 Value *Vec = Builder.CreateLoad(Bitcast->getType(), ST->getOperand(1));
255 Bitcast->replaceAllUsesWith(Vec);
258 // transform bitcast to <store, load> instructions.
259 bool X86LowerAMXType::transformBitcast(BitCastInst *Bitcast) {
260 IRBuilder<> Builder(Bitcast);
261 AllocaInst *AllocaAddr;
262 Value *I8Ptr, *Stride;
263 auto *Src = Bitcast->getOperand(0);
265 auto Prepare = [&](Type *MemTy) {
266 AllocaAddr = createAllocaInstAtEntry(Builder, Bitcast->getParent(), MemTy);
267 I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy());
268 Stride = Builder.getInt64(64);
271 if (Bitcast->getType()->isX86_AMXTy()) {
272 // %2 = bitcast <256 x i32> %src to x86_amx
273 // -->
274 // %addr = alloca <256 x i32>, align 64
275 // store <256 x i32> %src, <256 x i32>* %addr, align 64
276 // %addr2 = bitcast <256 x i32>* to i8*
277 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
278 // i8* %addr2,
279 // i64 64)
280 Use &U = *(Bitcast->use_begin());
281 unsigned OpNo = U.getOperandNo();
282 auto *II = dyn_cast<IntrinsicInst>(U.getUser());
283 if (!II)
284 return false; // May be bitcast from x86amx to <256 x i32>.
285 Prepare(Bitcast->getOperand(0)->getType());
286 Builder.CreateStore(Src, AllocaAddr);
287 // TODO we can pick an constant operand for the shape.
288 Value *Row = nullptr, *Col = nullptr;
289 std::tie(Row, Col) = getShape(II, OpNo);
290 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
291 Value *NewInst = Builder.CreateIntrinsic(
292 Intrinsic::x86_tileloadd64_internal, None, Args);
293 Bitcast->replaceAllUsesWith(NewInst);
294 } else {
295 // %2 = bitcast x86_amx %src to <256 x i32>
296 // -->
297 // %addr = alloca <256 x i32>, align 64
298 // %addr2 = bitcast <256 x i32>* to i8*
299 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
300 // i8* %addr2, i64 %stride)
301 // %2 = load <256 x i32>, <256 x i32>* %addr, align 64
302 auto *II = dyn_cast<IntrinsicInst>(Src);
303 if (!II)
304 return false; // May be bitcast from <256 x i32> to x86amx.
305 Prepare(Bitcast->getType());
306 Value *Row = II->getOperand(0);
307 Value *Col = II->getOperand(1);
308 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Src};
309 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
310 Value *NewInst = Builder.CreateLoad(Bitcast->getType(), AllocaAddr);
311 Bitcast->replaceAllUsesWith(NewInst);
314 return true;
317 bool X86LowerAMXType::visit() {
318 SmallVector<Instruction *, 8> DeadInsts;
319 Col2Row.clear();
321 for (BasicBlock *BB : post_order(&Func)) {
322 for (BasicBlock::reverse_iterator II = BB->rbegin(), IE = BB->rend();
323 II != IE;) {
324 Instruction &Inst = *II++;
325 auto *Bitcast = dyn_cast<BitCastInst>(&Inst);
326 if (!Bitcast)
327 continue;
329 Value *Src = Bitcast->getOperand(0);
330 if (Bitcast->getType()->isX86_AMXTy()) {
331 if (Bitcast->user_empty()) {
332 DeadInsts.push_back(Bitcast);
333 continue;
335 LoadInst *LD = dyn_cast<LoadInst>(Src);
336 if (!LD) {
337 if (transformBitcast(Bitcast))
338 DeadInsts.push_back(Bitcast);
339 continue;
341 // If load has mutli-user, duplicate a vector load.
342 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
343 // %2 = bitcast <256 x i32> %src to x86_amx
344 // %add = add <256 x i32> %src, <256 x i32> %src2
345 // -->
346 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
347 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
348 // i8* %addr, i64 %stride64)
349 // %add = add <256 x i32> %src, <256 x i32> %src2
351 // If load has one user, the load will be eliminated in DAG ISel.
352 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
353 // %2 = bitcast <256 x i32> %src to x86_amx
354 // -->
355 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
356 // i8* %addr, i64 %stride64)
357 combineLoadBitcast(LD, Bitcast);
358 DeadInsts.push_back(Bitcast);
359 if (LD->hasOneUse())
360 DeadInsts.push_back(LD);
361 } else if (Src->getType()->isX86_AMXTy()) {
362 if (Bitcast->user_empty()) {
363 DeadInsts.push_back(Bitcast);
364 continue;
366 StoreInst *ST = nullptr;
367 for (auto UI = Bitcast->use_begin(), UE = Bitcast->use_end();
368 UI != UE;) {
369 Value *I = (UI++)->getUser();
370 ST = dyn_cast<StoreInst>(I);
371 if (ST)
372 break;
374 if (!ST) {
375 if (transformBitcast(Bitcast))
376 DeadInsts.push_back(Bitcast);
377 continue;
379 // If bitcast (%13) has one use, combine bitcast and store to amx store.
380 // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
381 // %stride);
382 // %13 = bitcast x86_amx %src to <256 x i32>
383 // store <256 x i32> %13, <256 x i32>* %addr, align 64
384 // -->
385 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
386 // %stride64, %13)
388 // If bitcast (%13) has multi-use, transform as below.
389 // %13 = bitcast x86_amx %src to <256 x i32>
390 // store <256 x i32> %13, <256 x i32>* %addr, align 64
391 // %add = <256 x i32> %13, <256 x i32> %src2
392 // -->
393 // %13 = bitcast x86_amx %src to <256 x i32>
394 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
395 // %stride64, %13)
396 // %14 = load <256 x i32>, %addr
397 // %add = <256 x i32> %14, <256 x i32> %src2
399 combineBitcastStore(Bitcast, ST);
400 // Delete user first.
401 DeadInsts.push_back(ST);
402 DeadInsts.push_back(Bitcast);
407 bool C = !DeadInsts.empty();
409 for (auto *Inst : DeadInsts)
410 Inst->eraseFromParent();
412 return C;
414 } // anonymous namespace
416 static Value *getAllocaPos(BasicBlock *BB) {
417 Module *M = BB->getModule();
418 Function *F = BB->getParent();
419 IRBuilder<> Builder(&F->getEntryBlock().front());
420 const DataLayout &DL = M->getDataLayout();
421 unsigned AllocaAS = DL.getAllocaAddrSpace();
422 Type *V256I32Ty = VectorType::get(Builder.getInt32Ty(), 256, false);
423 AllocaInst *AllocaRes =
424 new AllocaInst(V256I32Ty, AllocaAS, "", &F->getEntryBlock().front());
425 BasicBlock::iterator Iter = AllocaRes->getIterator();
426 ++Iter;
427 Builder.SetInsertPoint(&*Iter);
428 Value *I8Ptr = Builder.CreateBitCast(AllocaRes, Builder.getInt8PtrTy());
429 return I8Ptr;
432 static Instruction *createTileStore(Instruction *TileDef, Value *Ptr) {
433 assert(TileDef->getType()->isX86_AMXTy() && "Not define tile!");
434 auto *II = cast<IntrinsicInst>(TileDef);
435 assert(II && "Not tile intrinsic!");
436 Value *Row = II->getOperand(0);
437 Value *Col = II->getOperand(1);
439 BasicBlock *BB = TileDef->getParent();
440 BasicBlock::iterator Iter = TileDef->getIterator();
441 IRBuilder<> Builder(BB, ++Iter);
442 Value *Stride = Builder.getInt64(64);
443 std::array<Value *, 5> Args = {Row, Col, Ptr, Stride, TileDef};
445 Instruction *TileStore =
446 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
447 return TileStore;
450 static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI = false) {
451 Value *V = U.get();
452 assert(V->getType()->isX86_AMXTy() && "Not define tile!");
454 // Get tile shape.
455 IntrinsicInst *II = nullptr;
456 if (IsPHI) {
457 Value *PhiOp = dyn_cast<PHINode>(V)->getIncomingValue(0);
458 II = cast<IntrinsicInst>(PhiOp);
459 } else {
460 II = cast<IntrinsicInst>(V);
462 Value *Row = II->getOperand(0);
463 Value *Col = II->getOperand(1);
465 Instruction *UserI = dyn_cast<Instruction>(U.getUser());
466 IRBuilder<> Builder(UserI);
467 Value *Stride = Builder.getInt64(64);
468 std::array<Value *, 4> Args = {Row, Col, Ptr, Stride};
470 Value *TileLoad =
471 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args);
472 UserI->replaceUsesOfWith(V, TileLoad);
475 static bool isIncomingOfPHI(Instruction *I) {
476 for (Use &U : I->uses()) {
477 User *V = U.getUser();
478 if (isa<PHINode>(V))
479 return true;
481 return false;
484 // Let all AMX tile data become volatile data, shorten the life range
485 // of each tile register before fast register allocation.
486 namespace {
487 class X86VolatileTileData {
488 Function &F;
490 public:
491 X86VolatileTileData(Function &Func) : F(Func) {}
492 Value *updatePhiIncomings(BasicBlock *BB,
493 SmallVector<Instruction *, 2> &Incomings);
494 void replacePhiDefWithLoad(Instruction *PHI, Value *StorePtr);
495 bool volatileTileData();
496 void volatileTilePHI(PHINode *Inst);
497 void volatileTileNonPHI(Instruction *I);
500 Value *X86VolatileTileData::updatePhiIncomings(
501 BasicBlock *BB, SmallVector<Instruction *, 2> &Incomings) {
502 Value *I8Ptr = getAllocaPos(BB);
504 for (auto *I : Incomings) {
505 User *Store = createTileStore(I, I8Ptr);
507 // All its uses (except phi) should load from stored mem.
508 for (Use &U : I->uses()) {
509 User *V = U.getUser();
510 if (isa<PHINode>(V) || V == Store)
511 continue;
512 replaceWithTileLoad(U, I8Ptr);
515 return I8Ptr;
518 void X86VolatileTileData::replacePhiDefWithLoad(Instruction *PHI,
519 Value *StorePtr) {
520 for (Use &U : PHI->uses())
521 replaceWithTileLoad(U, StorePtr, true);
522 PHI->eraseFromParent();
525 // Smilar with volatileTileNonPHI, this function only handle PHI Nodes
526 // and their related AMX intrinsics.
527 // 1) PHI Def should change to tileload.
528 // 2) PHI Incoming Values should tilestored in just after their def.
529 // 3) The mem of these tileload and tilestores should be same.
530 // e.g.
531 // ------------------------------------------------------
532 // bb_dom:
533 // ...
534 // br i1 %bool.cond, label %if.else, label %if.then
536 // if.then:
537 // def %t0 = ...
538 // ...
539 // use %t0
540 // ...
541 // br label %if.end
543 // if.else:
544 // def %t1 = ...
545 // br label %if.end
547 // if.end:
548 // %td = phi x86_amx [ %t1, %if.else ], [ %t0, %if.then ]
549 // ...
550 // use %td
551 // ------------------------------------------------------
552 // -->
553 // ------------------------------------------------------
554 // bb_entry:
555 // %mem = alloca <256 x i32>, align 1024 *
556 // ...
557 // bb_dom:
558 // ...
559 // br i1 %bool.cond, label %if.else, label %if.then
561 // if.then:
562 // def %t0 = ...
563 // call void @llvm.x86.tilestored64.internal(mem, %t0) *
564 // ...
565 // %t0` = call x86_amx @llvm.x86.tileloadd64.internal(mem)*
566 // use %t0` *
567 // ...
568 // br label %if.end
570 // if.else:
571 // def %t1 = ...
572 // call void @llvm.x86.tilestored64.internal(mem, %t1) *
573 // br label %if.end
575 // if.end:
576 // ...
577 // %td = call x86_amx @llvm.x86.tileloadd64.internal(mem) *
578 // use %td
579 // ------------------------------------------------------
580 void X86VolatileTileData::volatileTilePHI(PHINode *PHI) {
581 BasicBlock *BB = PHI->getParent();
582 SmallVector<Instruction *, 2> Incomings;
584 for (unsigned I = 0, E = PHI->getNumIncomingValues(); I != E; ++I) {
585 Value *Op = PHI->getIncomingValue(I);
586 Instruction *Inst = dyn_cast<Instruction>(Op);
587 assert(Inst && "We shouldn't fold AMX instrution!");
588 Incomings.push_back(Inst);
591 Value *StorePtr = updatePhiIncomings(BB, Incomings);
592 replacePhiDefWithLoad(PHI, StorePtr);
595 // Store the defined tile and load it before use.
596 // All its users are not PHI.
597 // e.g.
598 // ------------------------------------------------------
599 // def %td = ...
600 // ...
601 // "use %td"
602 // ------------------------------------------------------
603 // -->
604 // ------------------------------------------------------
605 // def %td = ...
606 // call void @llvm.x86.tilestored64.internal(mem, %td)
607 // ...
608 // %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)
609 // "use %td2"
610 // ------------------------------------------------------
611 void X86VolatileTileData::volatileTileNonPHI(Instruction *I) {
612 BasicBlock *BB = I->getParent();
613 Value *I8Ptr = getAllocaPos(BB);
614 User *Store = createTileStore(I, I8Ptr);
616 // All its uses should load from stored mem.
617 for (Use &U : I->uses()) {
618 User *V = U.getUser();
619 assert(!isa<PHINode>(V) && "PHI Nodes should be excluded!");
620 if (V != Store)
621 replaceWithTileLoad(U, I8Ptr);
625 // Volatile Tile Model:
626 // 1) All the uses of tile data comes from tileload in time.
627 // 2) All the defs of tile data tilestore into mem immediately.
628 // For example:
629 // --------------------------------------------------------------------------
630 // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) key
631 // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)
632 // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) amx
633 // %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3)
634 // call void @llvm.x86.tilestored64.internal(... td) area
635 // --------------------------------------------------------------------------
636 // 3) No terminator, call or other amx instructions in the key amx area.
637 bool X86VolatileTileData::volatileTileData() {
638 bool Changed = false;
639 for (BasicBlock &BB : F) {
640 SmallVector<Instruction *, 2> PHIInsts;
641 SmallVector<Instruction *, 8> AMXDefInsts;
643 for (Instruction &I : BB) {
644 if (!I.getType()->isX86_AMXTy())
645 continue;
646 if (isa<PHINode>(&I))
647 PHIInsts.push_back(&I);
648 else
649 AMXDefInsts.push_back(&I);
652 // First we "volatile" the non-phi related amx intrinsics.
653 for (Instruction *I : AMXDefInsts) {
654 if (isIncomingOfPHI(I))
655 continue;
656 volatileTileNonPHI(I);
657 Changed = true;
660 for (Instruction *I : PHIInsts) {
661 volatileTilePHI(dyn_cast<PHINode>(I));
662 Changed = true;
665 return Changed;
668 } // anonymous namespace
670 namespace {
672 class X86LowerAMXCast {
673 Function &Func;
675 public:
676 X86LowerAMXCast(Function &F) : Func(F) {}
677 bool combineAMXcast(TargetLibraryInfo *TLI);
678 bool transformAMXCast(IntrinsicInst *AMXCast);
679 bool transformAllAMXCast();
680 bool optimizeAMXCastFromPhi(IntrinsicInst *CI, PHINode *PN,
681 SmallSetVector<Instruction *, 16> &DeadInst);
684 static bool DCEInstruction(Instruction *I,
685 SmallSetVector<Instruction *, 16> &WorkList,
686 const TargetLibraryInfo *TLI) {
687 if (isInstructionTriviallyDead(I, TLI)) {
688 salvageDebugInfo(*I);
689 salvageKnowledge(I);
691 // Null out all of the instruction's operands to see if any operand becomes
692 // dead as we go.
693 for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
694 Value *OpV = I->getOperand(i);
695 I->setOperand(i, nullptr);
697 if (!OpV->use_empty() || I == OpV)
698 continue;
700 // If the operand is an instruction that became dead as we nulled out the
701 // operand, and if it is 'trivially' dead, delete it in a future loop
702 // iteration.
703 if (Instruction *OpI = dyn_cast<Instruction>(OpV)) {
704 if (isInstructionTriviallyDead(OpI, TLI)) {
705 WorkList.insert(OpI);
709 I->eraseFromParent();
710 return true;
712 return false;
715 /// This function handles following case
717 /// A -> B amxcast
718 /// PHI
719 /// B -> A amxcast
721 /// All the related PHI nodes can be replaced by new PHI nodes with type A.
722 /// The uses of \p CI can be changed to the new PHI node corresponding to \p PN.
723 bool X86LowerAMXCast::optimizeAMXCastFromPhi(
724 IntrinsicInst *CI, PHINode *PN,
725 SmallSetVector<Instruction *, 16> &DeadInst) {
726 IRBuilder<> Builder(CI);
727 Value *Src = CI->getOperand(0);
728 Type *SrcTy = Src->getType(); // Type B
729 Type *DestTy = CI->getType(); // Type A
731 SmallVector<PHINode *, 4> PhiWorklist;
732 SmallSetVector<PHINode *, 4> OldPhiNodes;
734 // Find all of the A->B casts and PHI nodes.
735 // We need to inspect all related PHI nodes, but PHIs can be cyclic, so
736 // OldPhiNodes is used to track all known PHI nodes, before adding a new
737 // PHI to PhiWorklist, it is checked against and added to OldPhiNodes first.
738 PhiWorklist.push_back(PN);
739 OldPhiNodes.insert(PN);
740 while (!PhiWorklist.empty()) {
741 auto *OldPN = PhiWorklist.pop_back_val();
742 for (Value *IncValue : OldPN->incoming_values()) {
743 // TODO: currently, We ignore cases where it is a const. In the future, we
744 // might support const.
745 if (isa<Constant>(IncValue))
746 return false;
748 if (auto *PNode = dyn_cast<PHINode>(IncValue)) {
749 if (OldPhiNodes.insert(PNode))
750 PhiWorklist.push_back(PNode);
751 continue;
753 Instruction *ACI = dyn_cast<Instruction>(IncValue);
754 if (ACI && isAMXCast(ACI)) {
755 // Verify it's a A->B cast.
756 Type *TyA = ACI->getOperand(0)->getType();
757 Type *TyB = ACI->getType();
758 if (TyA != DestTy || TyB != SrcTy)
759 return false;
760 continue;
762 return false;
766 // Check that each user of each old PHI node is something that we can
767 // rewrite, so that all of the old PHI nodes can be cleaned up afterwards.
768 for (auto *OldPN : OldPhiNodes) {
769 for (User *V : OldPN->users()) {
770 Instruction *ACI = dyn_cast<Instruction>(V);
771 if (ACI && isAMXCast(ACI)) {
772 // Verify it's a B->A cast.
773 Type *TyB = ACI->getOperand(0)->getType();
774 Type *TyA = ACI->getType();
775 if (TyA != DestTy || TyB != SrcTy)
776 return false;
777 } else if (auto *PHI = dyn_cast<PHINode>(V)) {
778 // As long as the user is another old PHI node, then even if we don't
779 // rewrite it, the PHI web we're considering won't have any users
780 // outside itself, so it'll be dead.
781 // example:
782 // bb.0:
783 // %0 = amxcast ...
784 // bb.1:
785 // %1 = amxcast ...
786 // bb.2:
787 // %goodphi = phi %0, %1
788 // %3 = amxcast %goodphi
789 // bb.3:
790 // %goodphi2 = phi %0, %goodphi
791 // %4 = amxcast %goodphi2
792 // When optimizeAMXCastFromPhi process %3 and %goodphi, %goodphi2 is
793 // outside the phi-web, so the combination stop When
794 // optimizeAMXCastFromPhi process %4 and %goodphi2, the optimization
795 // will be done.
796 if (OldPhiNodes.count(PHI) == 0)
797 return false;
798 } else
799 return false;
803 // For each old PHI node, create a corresponding new PHI node with a type A.
804 SmallDenseMap<PHINode *, PHINode *> NewPNodes;
805 for (auto *OldPN : OldPhiNodes) {
806 Builder.SetInsertPoint(OldPN);
807 PHINode *NewPN = Builder.CreatePHI(DestTy, OldPN->getNumOperands());
808 NewPNodes[OldPN] = NewPN;
811 // Fill in the operands of new PHI nodes.
812 for (auto *OldPN : OldPhiNodes) {
813 PHINode *NewPN = NewPNodes[OldPN];
814 for (unsigned j = 0, e = OldPN->getNumOperands(); j != e; ++j) {
815 Value *V = OldPN->getOperand(j);
816 Value *NewV = nullptr;
817 Instruction *ACI = dyn_cast<Instruction>(V);
818 // There should not be a AMXcast from a const.
819 if (ACI && isAMXCast(ACI))
820 NewV = ACI->getOperand(0);
821 else if (auto *PrevPN = dyn_cast<PHINode>(V))
822 NewV = NewPNodes[PrevPN];
823 assert(NewV);
824 NewPN->addIncoming(NewV, OldPN->getIncomingBlock(j));
828 // Traverse all accumulated PHI nodes and process its users,
829 // which are Stores and BitcCasts. Without this processing
830 // NewPHI nodes could be replicated and could lead to extra
831 // moves generated after DeSSA.
832 // If there is a store with type B, change it to type A.
834 // Replace users of BitCast B->A with NewPHI. These will help
835 // later to get rid of a closure formed by OldPHI nodes.
836 for (auto *OldPN : OldPhiNodes) {
837 PHINode *NewPN = NewPNodes[OldPN];
838 for (User *V : make_early_inc_range(OldPN->users())) {
839 Instruction *ACI = dyn_cast<Instruction>(V);
840 if (ACI && isAMXCast(ACI)) {
841 Type *TyB = ACI->getOperand(0)->getType();
842 Type *TyA = ACI->getType();
843 assert(TyA == DestTy && TyB == SrcTy);
844 (void)TyA;
845 (void)TyB;
846 ACI->replaceAllUsesWith(NewPN);
847 DeadInst.insert(ACI);
848 } else if (auto *PHI = dyn_cast<PHINode>(V)) {
849 // We don't need to push PHINode into DeadInst since they are operands
850 // of rootPN DCE can safely delete rootPN's operands if rootPN is dead.
851 assert(OldPhiNodes.contains(PHI));
852 (void)PHI;
853 } else
854 llvm_unreachable("all uses should be handled");
857 return true;
860 bool X86LowerAMXCast::combineAMXcast(TargetLibraryInfo *TLI) {
861 bool Change = false;
862 // Collect tile cast instruction.
863 SmallVector<Instruction *, 8> Vec2TileInsts;
864 SmallVector<Instruction *, 8> Tile2VecInsts;
865 SmallVector<Instruction *, 8> PhiCastWorkList;
866 SmallSetVector<Instruction *, 16> DeadInst;
867 for (BasicBlock &BB : Func) {
868 for (Instruction &I : BB) {
869 Value *Vec;
870 if (match(&I,
871 m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(m_Value(Vec))))
872 Vec2TileInsts.push_back(&I);
873 else if (match(&I, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(
874 m_Value(Vec))))
875 Tile2VecInsts.push_back(&I);
879 auto Convert = [&](SmallVectorImpl<Instruction *> &Insts, Intrinsic::ID IID) {
880 for (auto *Inst : Insts) {
881 for (User *U : Inst->users()) {
882 IntrinsicInst *II = dyn_cast<IntrinsicInst>(U);
883 if (!II || II->getIntrinsicID() != IID)
884 continue;
885 // T1 = vec2tile V0
886 // V2 = tile2vec T1
887 // V3 = OP V2
888 // -->
889 // T1 = vec2tile V0
890 // V2 = tile2vec T1
891 // V3 = OP V0
892 II->replaceAllUsesWith(Inst->getOperand(0));
893 Change = true;
898 Convert(Vec2TileInsts, Intrinsic::x86_cast_tile_to_vector);
899 Convert(Tile2VecInsts, Intrinsic::x86_cast_vector_to_tile);
901 auto EraseInst = [&](SmallVectorImpl<Instruction *> &Insts) {
902 for (auto *Inst : Insts) {
903 if (Inst->use_empty()) {
904 Inst->eraseFromParent();
905 Change = true;
910 EraseInst(Vec2TileInsts);
911 EraseInst(Tile2VecInsts);
913 // Handle the A->B->A cast, and there is an intervening PHI node.
914 for (BasicBlock &BB : Func) {
915 for (Instruction &I : BB) {
916 if (isAMXCast(&I)) {
917 if (isa<PHINode>(I.getOperand(0)))
918 PhiCastWorkList.push_back(&I);
922 for (auto *I : PhiCastWorkList) {
923 // We skip the dead Amxcast.
924 if (DeadInst.contains(I))
925 continue;
926 PHINode *PN = cast<PHINode>(I->getOperand(0));
927 if (optimizeAMXCastFromPhi(cast<IntrinsicInst>(I), PN, DeadInst)) {
928 DeadInst.insert(PN);
929 Change = true;
933 // Since we create new phi and merge AMXCast, some old phis and AMXCast might
934 // have no uses. We do some DeadCodeElimination for them.
935 while (!DeadInst.empty()) {
936 Instruction *I = DeadInst.pop_back_val();
937 Change |= DCEInstruction(I, DeadInst, TLI);
939 return Change;
942 // There might be remaining AMXcast after combineAMXcast and they should be
943 // handled elegantly.
944 bool X86LowerAMXCast::transformAMXCast(IntrinsicInst *AMXCast) {
945 IRBuilder<> Builder(AMXCast);
946 AllocaInst *AllocaAddr;
947 Value *I8Ptr, *Stride;
948 auto *Src = AMXCast->getOperand(0);
950 auto Prepare = [&](Type *MemTy) {
951 AllocaAddr = createAllocaInstAtEntry(Builder, AMXCast->getParent(), MemTy);
952 I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy());
953 Stride = Builder.getInt64(64);
956 if (AMXCast->getType()->isX86_AMXTy()) {
957 // %2 = amxcast <225 x i32> %src to x86_amx
958 // call void @llvm.x86.tilestored64.internal(i16 15, i16 60,
959 // i8* %addr3, i64 60, x86_amx %2)
960 // -->
961 // %addr = alloca <225 x i32>, align 64
962 // store <225 x i32> %src, <225 x i32>* %addr, align 64
963 // %addr2 = bitcast <225 x i32>* %addr to i8*
964 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 15, i16 60,
965 // i8* %addr2,
966 // i64 60)
967 // call void @llvm.x86.tilestored64.internal(i16 15, i16 60,
968 // i8* %addr3, i64 60, x86_amx %2)
969 Use &U = *(AMXCast->use_begin());
970 unsigned OpNo = U.getOperandNo();
971 auto *II = dyn_cast<IntrinsicInst>(U.getUser());
972 if (!II)
973 return false; // May be bitcast from x86amx to <256 x i32>.
974 Prepare(AMXCast->getOperand(0)->getType());
975 Builder.CreateStore(Src, AllocaAddr);
976 // TODO we can pick an constant operand for the shape.
977 Value *Row = nullptr, *Col = nullptr;
978 std::tie(Row, Col) = getShape(II, OpNo);
979 std::array<Value *, 4> Args = {
980 Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty())};
981 Value *NewInst = Builder.CreateIntrinsic(
982 Intrinsic::x86_tileloadd64_internal, None, Args);
983 AMXCast->replaceAllUsesWith(NewInst);
984 AMXCast->eraseFromParent();
985 } else {
986 // %2 = amxcast x86_amx %src to <225 x i32>
987 // -->
988 // %addr = alloca <225 x i32>, align 64
989 // %addr2 = bitcast <225 x i32>* to i8*
990 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
991 // i8* %addr2, i64 %stride)
992 // %2 = load <225 x i32>, <225 x i32>* %addr, align 64
993 auto *II = dyn_cast<IntrinsicInst>(Src);
994 if (!II)
995 return false; // May be bitcast from <256 x i32> to x86amx.
996 Prepare(AMXCast->getType());
997 Value *Row = II->getOperand(0);
998 Value *Col = II->getOperand(1);
999 std::array<Value *, 5> Args = {
1000 Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty()), Src};
1001 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
1002 Value *NewInst = Builder.CreateLoad(AMXCast->getType(), AllocaAddr);
1003 AMXCast->replaceAllUsesWith(NewInst);
1004 AMXCast->eraseFromParent();
1007 return true;
1010 bool X86LowerAMXCast::transformAllAMXCast() {
1011 bool Change = false;
1012 // Collect tile cast instruction.
1013 SmallVector<Instruction *, 8> WorkLists;
1014 for (BasicBlock &BB : Func) {
1015 for (Instruction &I : BB) {
1016 if (isAMXCast(&I))
1017 WorkLists.push_back(&I);
1021 for (auto *Inst : WorkLists) {
1022 Change |= transformAMXCast(cast<IntrinsicInst>(Inst));
1025 return Change;
1028 } // anonymous namespace
1030 namespace {
1032 class X86LowerAMXTypeLegacyPass : public FunctionPass {
1033 public:
1034 static char ID;
1036 X86LowerAMXTypeLegacyPass() : FunctionPass(ID) {
1037 initializeX86LowerAMXTypeLegacyPassPass(*PassRegistry::getPassRegistry());
1040 bool runOnFunction(Function &F) override {
1041 bool C = false;
1042 TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
1043 TargetLibraryInfo *TLI =
1044 &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
1045 X86LowerAMXCast LAC(F);
1046 C |= LAC.combineAMXcast(TLI);
1047 // There might be remaining AMXcast after combineAMXcast and they should be
1048 // handled elegantly.
1049 C |= LAC.transformAllAMXCast();
1051 X86LowerAMXType LAT(F);
1052 C |= LAT.visit();
1054 // Prepare for fast register allocation at O0.
1055 // Todo: May better check the volatile model of AMX code, not just
1056 // by checking Attribute::OptimizeNone and CodeGenOpt::None.
1057 if (TM->getOptLevel() == CodeGenOpt::None) {
1058 // If Front End not use O0 but the Mid/Back end use O0, (e.g.
1059 // "Clang -O2 -S -emit-llvm t.c" + "llc t.ll") we should make
1060 // sure the amx data is volatile, that is nessary for AMX fast
1061 // register allocation.
1062 if (!F.hasFnAttribute(Attribute::OptimizeNone)) {
1063 X86VolatileTileData VTD(F);
1064 C = VTD.volatileTileData() || C;
1068 return C;
1071 void getAnalysisUsage(AnalysisUsage &AU) const override {
1072 AU.setPreservesCFG();
1073 AU.addRequired<TargetPassConfig>();
1074 AU.addRequired<TargetLibraryInfoWrapperPass>();
1078 } // anonymous namespace
1080 static const char PassName[] = "Lower AMX type for load/store";
1081 char X86LowerAMXTypeLegacyPass::ID = 0;
1082 INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
1083 false)
1084 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
1085 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
1086 INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
1087 false)
1089 FunctionPass *llvm::createX86LowerAMXTypePass() {
1090 return new X86LowerAMXTypeLegacyPass();