1 //===- Target/X86/X86LowerAMXType.cpp - -------------------------*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 /// \file Pass to transform <256 x i32> load/store
10 /// <256 x i32> is bitcasted to x86_amx on X86, and AMX instruction set only
11 /// provides simple operation on x86_amx. The basic elementwise operation
12 /// is not supported by AMX. Since x86_amx is bitcasted from vector <256 x i32>
13 /// and only AMX intrinsics can operate on the type, we need transform
14 /// load/store <256 x i32> instruction to AMX load/store. If the bitcast can
15 /// not be combined with load/store, we transform the bitcast to amx load/store
16 /// and <256 x i32> store/load.
18 /// If Front End not use O0 but the Mid/Back end use O0, (e.g. "Clang -O2 -S
19 /// -emit-llvm t.c" + "llc t.ll") we should make sure the amx data is volatile,
20 /// because that is necessary for AMX fast register allocation. (In Fast
21 /// registera allocation, register will be allocated before spill/reload, so
22 /// there is no additional register for amx to identify the step in spill.)
23 /// The volatileTileData() will handle this case.
25 /// ----------------------------------------------------------
29 /// ----------------------------------------------------------
30 /// will transfer to -->
31 /// ----------------------------------------------------------
33 /// | call void @llvm.x86.tilestored64.internal(mem, %td) |
35 /// | %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)|
37 /// ----------------------------------------------------------
39 //===----------------------------------------------------------------------===//
42 #include "llvm/ADT/PostOrderIterator.h"
43 #include "llvm/ADT/SetVector.h"
44 #include "llvm/ADT/SmallSet.h"
45 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
46 #include "llvm/Analysis/TargetLibraryInfo.h"
47 #include "llvm/Analysis/TargetTransformInfo.h"
48 #include "llvm/CodeGen/Passes.h"
49 #include "llvm/CodeGen/TargetPassConfig.h"
50 #include "llvm/CodeGen/ValueTypes.h"
51 #include "llvm/IR/DataLayout.h"
52 #include "llvm/IR/Function.h"
53 #include "llvm/IR/IRBuilder.h"
54 #include "llvm/IR/Instructions.h"
55 #include "llvm/IR/IntrinsicInst.h"
56 #include "llvm/IR/IntrinsicsX86.h"
57 #include "llvm/IR/PatternMatch.h"
58 #include "llvm/InitializePasses.h"
59 #include "llvm/Pass.h"
60 #include "llvm/Target/TargetMachine.h"
61 #include "llvm/Transforms/Utils/AssumeBundleBuilder.h"
62 #include "llvm/Transforms/Utils/Local.h"
65 using namespace PatternMatch
;
67 #define DEBUG_TYPE "lower-amx-type"
69 static bool isAMXCast(Instruction
*II
) {
71 m_Intrinsic
<Intrinsic::x86_cast_vector_to_tile
>(m_Value())) ||
72 match(II
, m_Intrinsic
<Intrinsic::x86_cast_tile_to_vector
>(m_Value()));
75 static AllocaInst
*createAllocaInstAtEntry(IRBuilder
<> &Builder
, BasicBlock
*BB
,
77 Function
&F
= *BB
->getParent();
78 Module
*M
= BB
->getModule();
79 const DataLayout
&DL
= M
->getDataLayout();
81 LLVMContext
&Ctx
= Builder
.getContext();
82 auto AllocaAlignment
= DL
.getPrefTypeAlign(Type::getX86_AMXTy(Ctx
));
83 unsigned AllocaAS
= DL
.getAllocaAddrSpace();
84 AllocaInst
*AllocaRes
=
85 new AllocaInst(Ty
, AllocaAS
, "", &F
.getEntryBlock().front());
86 AllocaRes
->setAlignment(AllocaAlignment
);
90 static Instruction
*getFirstNonAllocaInTheEntryBlock(Function
&F
) {
91 for (Instruction
&I
: F
.getEntryBlock())
92 if (!isa
<AllocaInst
>(&I
))
94 llvm_unreachable("No terminator in the entry block!");
97 static std::pair
<Value
*, Value
*> getShape(IntrinsicInst
*II
, unsigned OpNo
) {
98 IRBuilder
<> Builder(II
);
99 Value
*Row
= nullptr, *Col
= nullptr;
100 switch (II
->getIntrinsicID()) {
102 llvm_unreachable("Expect amx intrinsics");
103 case Intrinsic::x86_tileloadd64_internal
:
104 case Intrinsic::x86_tileloaddt164_internal
:
105 case Intrinsic::x86_tilestored64_internal
: {
106 Row
= II
->getArgOperand(0);
107 Col
= II
->getArgOperand(1);
111 // The shape depends on which operand.
112 case Intrinsic::x86_tdpbssd_internal
:
113 case Intrinsic::x86_tdpbsud_internal
:
114 case Intrinsic::x86_tdpbusd_internal
:
115 case Intrinsic::x86_tdpbuud_internal
:
116 case Intrinsic::x86_tdpbf16ps_internal
: {
119 Row
= II
->getArgOperand(0);
120 Col
= II
->getArgOperand(1);
123 Row
= II
->getArgOperand(0);
124 Col
= II
->getArgOperand(2);
127 if (isa
<ConstantInt
>(II
->getArgOperand(2)))
128 Row
= Builder
.getInt16(
129 (dyn_cast
<ConstantInt
>(II
->getOperand(2))->getSExtValue()) / 4);
130 else if (isa
<Instruction
>(II
->getArgOperand(2))) {
131 // When it is not a const value and it is not a function argument, we
132 // create Row after the definition of II->getOperand(2) instead of
133 // before II. For example, II is %118, we try to getshape for %117:
134 // %117 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x
136 // %118 = call x86_amx @llvm.x86.tdpbf16ps.internal(i16
137 // %104, i16 %105, i16 %106, x86_amx %110, x86_amx %114, x86_amx
139 // If we create %row = udiv i16 %106, 4 before %118(aka. II), then its
140 // definition is after its user(new tileload for %117).
141 // So, the best choice is to create %row right after the definition of
143 Builder
.SetInsertPoint(cast
<Instruction
>(II
->getOperand(2)));
144 Row
= Builder
.CreateUDiv(II
->getOperand(2), Builder
.getInt16(4));
145 cast
<Instruction
>(Row
)->moveAfter(cast
<Instruction
>(II
->getOperand(2)));
147 // When it is not a const value and it is a function argument, we create
148 // Row at the entry bb.
149 IRBuilder
<> NewBuilder(
150 getFirstNonAllocaInTheEntryBlock(*II
->getFunction()));
151 Row
= NewBuilder
.CreateUDiv(II
->getOperand(2), NewBuilder
.getInt16(4));
153 Col
= II
->getArgOperand(1);
160 return std::make_pair(Row
, Col
);
164 class X86LowerAMXType
{
167 // In AMX intrinsics we let Shape = {Row, Col}, but the
168 // RealCol = Col / ElementSize. We may use the RealCol
169 // as a new Row for other new created AMX intrinsics.
170 std::map
<Value
*, Value
*> Col2Row
;
173 X86LowerAMXType(Function
&F
) : Func(F
) {}
175 void combineLoadBitcast(LoadInst
*LD
, BitCastInst
*Bitcast
);
176 void combineBitcastStore(BitCastInst
*Bitcast
, StoreInst
*ST
);
177 bool transformBitcast(BitCastInst
*Bitcast
);
178 Value
*getRowFromCol(Instruction
*II
, Value
*V
, unsigned Granularity
);
181 Value
*X86LowerAMXType::getRowFromCol(Instruction
*II
, Value
*V
,
182 unsigned Granularity
) {
183 if (Col2Row
.count(V
))
185 IRBuilder
<> Builder(&*II
->getParent()->getFirstInsertionPt());
186 if (auto *I
= dyn_cast
<Instruction
>(V
)) {
187 BasicBlock::iterator Iter
= I
->getIterator();
189 Builder
.SetInsertPoint(&*Iter
);
191 ConstantInt
*Gran
= Builder
.getInt16(Granularity
);
192 Value
*RealRow
= Builder
.CreateUDiv(V
, Gran
);
193 Col2Row
[V
] = RealRow
;
197 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
198 // %2 = bitcast <256 x i32> %src to x86_amx
200 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
201 // i8* %addr, i64 %stride64)
202 void X86LowerAMXType::combineLoadBitcast(LoadInst
*LD
, BitCastInst
*Bitcast
) {
203 Value
*Row
= nullptr, *Col
= nullptr;
204 Use
&U
= *(Bitcast
->use_begin());
205 unsigned OpNo
= U
.getOperandNo();
206 auto *II
= cast
<IntrinsicInst
>(U
.getUser());
207 std::tie(Row
, Col
) = getShape(II
, OpNo
);
208 IRBuilder
<> Builder(Bitcast
);
209 // Use the maximun column as stride.
210 Value
*Stride
= Builder
.getInt64(64);
212 Builder
.CreateBitCast(LD
->getOperand(0), Builder
.getInt8PtrTy());
213 std::array
<Value
*, 4> Args
= {Row
, Col
, I8Ptr
, Stride
};
216 Builder
.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal
, None
, Args
);
217 Bitcast
->replaceAllUsesWith(NewInst
);
220 // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
222 // %13 = bitcast x86_amx %src to <256 x i32>
223 // store <256 x i32> %13, <256 x i32>* %addr, align 64
225 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
227 void X86LowerAMXType::combineBitcastStore(BitCastInst
*Bitcast
, StoreInst
*ST
) {
229 Value
*Tile
= Bitcast
->getOperand(0);
230 auto *II
= cast
<IntrinsicInst
>(Tile
);
231 // Tile is output from AMX intrinsic. The first operand of the
232 // intrinsic is row, the second operand of the intrinsic is column.
233 Value
*Row
= II
->getOperand(0);
234 Value
*Col
= II
->getOperand(1);
235 IRBuilder
<> Builder(ST
);
236 // Use the maximum column as stride. It must be the same with load
238 Value
*Stride
= Builder
.getInt64(64);
240 Builder
.CreateBitCast(ST
->getOperand(1), Builder
.getInt8PtrTy());
241 std::array
<Value
*, 5> Args
= {Row
, Col
, I8Ptr
, Stride
, Tile
};
242 Builder
.CreateIntrinsic(Intrinsic::x86_tilestored64_internal
, None
, Args
);
243 if (Bitcast
->hasOneUse())
245 // %13 = bitcast x86_amx %src to <256 x i32>
246 // store <256 x i32> %13, <256 x i32>* %addr, align 64
247 // %add = <256 x i32> %13, <256 x i32> %src2
249 // %13 = bitcast x86_amx %src to <256 x i32>
250 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
252 // %14 = load <256 x i32>, %addr
253 // %add = <256 x i32> %14, <256 x i32> %src2
254 Value
*Vec
= Builder
.CreateLoad(Bitcast
->getType(), ST
->getOperand(1));
255 Bitcast
->replaceAllUsesWith(Vec
);
258 // transform bitcast to <store, load> instructions.
259 bool X86LowerAMXType::transformBitcast(BitCastInst
*Bitcast
) {
260 IRBuilder
<> Builder(Bitcast
);
261 AllocaInst
*AllocaAddr
;
262 Value
*I8Ptr
, *Stride
;
263 auto *Src
= Bitcast
->getOperand(0);
265 auto Prepare
= [&](Type
*MemTy
) {
266 AllocaAddr
= createAllocaInstAtEntry(Builder
, Bitcast
->getParent(), MemTy
);
267 I8Ptr
= Builder
.CreateBitCast(AllocaAddr
, Builder
.getInt8PtrTy());
268 Stride
= Builder
.getInt64(64);
271 if (Bitcast
->getType()->isX86_AMXTy()) {
272 // %2 = bitcast <256 x i32> %src to x86_amx
274 // %addr = alloca <256 x i32>, align 64
275 // store <256 x i32> %src, <256 x i32>* %addr, align 64
276 // %addr2 = bitcast <256 x i32>* to i8*
277 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
280 Use
&U
= *(Bitcast
->use_begin());
281 unsigned OpNo
= U
.getOperandNo();
282 auto *II
= dyn_cast
<IntrinsicInst
>(U
.getUser());
284 return false; // May be bitcast from x86amx to <256 x i32>.
285 Prepare(Bitcast
->getOperand(0)->getType());
286 Builder
.CreateStore(Src
, AllocaAddr
);
287 // TODO we can pick an constant operand for the shape.
288 Value
*Row
= nullptr, *Col
= nullptr;
289 std::tie(Row
, Col
) = getShape(II
, OpNo
);
290 std::array
<Value
*, 4> Args
= {Row
, Col
, I8Ptr
, Stride
};
291 Value
*NewInst
= Builder
.CreateIntrinsic(
292 Intrinsic::x86_tileloadd64_internal
, None
, Args
);
293 Bitcast
->replaceAllUsesWith(NewInst
);
295 // %2 = bitcast x86_amx %src to <256 x i32>
297 // %addr = alloca <256 x i32>, align 64
298 // %addr2 = bitcast <256 x i32>* to i8*
299 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
300 // i8* %addr2, i64 %stride)
301 // %2 = load <256 x i32>, <256 x i32>* %addr, align 64
302 auto *II
= dyn_cast
<IntrinsicInst
>(Src
);
304 return false; // May be bitcast from <256 x i32> to x86amx.
305 Prepare(Bitcast
->getType());
306 Value
*Row
= II
->getOperand(0);
307 Value
*Col
= II
->getOperand(1);
308 std::array
<Value
*, 5> Args
= {Row
, Col
, I8Ptr
, Stride
, Src
};
309 Builder
.CreateIntrinsic(Intrinsic::x86_tilestored64_internal
, None
, Args
);
310 Value
*NewInst
= Builder
.CreateLoad(Bitcast
->getType(), AllocaAddr
);
311 Bitcast
->replaceAllUsesWith(NewInst
);
317 bool X86LowerAMXType::visit() {
318 SmallVector
<Instruction
*, 8> DeadInsts
;
321 for (BasicBlock
*BB
: post_order(&Func
)) {
322 for (BasicBlock::reverse_iterator II
= BB
->rbegin(), IE
= BB
->rend();
324 Instruction
&Inst
= *II
++;
325 auto *Bitcast
= dyn_cast
<BitCastInst
>(&Inst
);
329 Value
*Src
= Bitcast
->getOperand(0);
330 if (Bitcast
->getType()->isX86_AMXTy()) {
331 if (Bitcast
->user_empty()) {
332 DeadInsts
.push_back(Bitcast
);
335 LoadInst
*LD
= dyn_cast
<LoadInst
>(Src
);
337 if (transformBitcast(Bitcast
))
338 DeadInsts
.push_back(Bitcast
);
341 // If load has mutli-user, duplicate a vector load.
342 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
343 // %2 = bitcast <256 x i32> %src to x86_amx
344 // %add = add <256 x i32> %src, <256 x i32> %src2
346 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
347 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
348 // i8* %addr, i64 %stride64)
349 // %add = add <256 x i32> %src, <256 x i32> %src2
351 // If load has one user, the load will be eliminated in DAG ISel.
352 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
353 // %2 = bitcast <256 x i32> %src to x86_amx
355 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
356 // i8* %addr, i64 %stride64)
357 combineLoadBitcast(LD
, Bitcast
);
358 DeadInsts
.push_back(Bitcast
);
360 DeadInsts
.push_back(LD
);
361 } else if (Src
->getType()->isX86_AMXTy()) {
362 if (Bitcast
->user_empty()) {
363 DeadInsts
.push_back(Bitcast
);
366 StoreInst
*ST
= nullptr;
367 for (auto UI
= Bitcast
->use_begin(), UE
= Bitcast
->use_end();
369 Value
*I
= (UI
++)->getUser();
370 ST
= dyn_cast
<StoreInst
>(I
);
375 if (transformBitcast(Bitcast
))
376 DeadInsts
.push_back(Bitcast
);
379 // If bitcast (%13) has one use, combine bitcast and store to amx store.
380 // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
382 // %13 = bitcast x86_amx %src to <256 x i32>
383 // store <256 x i32> %13, <256 x i32>* %addr, align 64
385 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
388 // If bitcast (%13) has multi-use, transform as below.
389 // %13 = bitcast x86_amx %src to <256 x i32>
390 // store <256 x i32> %13, <256 x i32>* %addr, align 64
391 // %add = <256 x i32> %13, <256 x i32> %src2
393 // %13 = bitcast x86_amx %src to <256 x i32>
394 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
396 // %14 = load <256 x i32>, %addr
397 // %add = <256 x i32> %14, <256 x i32> %src2
399 combineBitcastStore(Bitcast
, ST
);
400 // Delete user first.
401 DeadInsts
.push_back(ST
);
402 DeadInsts
.push_back(Bitcast
);
407 bool C
= !DeadInsts
.empty();
409 for (auto *Inst
: DeadInsts
)
410 Inst
->eraseFromParent();
414 } // anonymous namespace
416 static Value
*getAllocaPos(BasicBlock
*BB
) {
417 Module
*M
= BB
->getModule();
418 Function
*F
= BB
->getParent();
419 IRBuilder
<> Builder(&F
->getEntryBlock().front());
420 const DataLayout
&DL
= M
->getDataLayout();
421 unsigned AllocaAS
= DL
.getAllocaAddrSpace();
422 Type
*V256I32Ty
= VectorType::get(Builder
.getInt32Ty(), 256, false);
423 AllocaInst
*AllocaRes
=
424 new AllocaInst(V256I32Ty
, AllocaAS
, "", &F
->getEntryBlock().front());
425 BasicBlock::iterator Iter
= AllocaRes
->getIterator();
427 Builder
.SetInsertPoint(&*Iter
);
428 Value
*I8Ptr
= Builder
.CreateBitCast(AllocaRes
, Builder
.getInt8PtrTy());
432 static Instruction
*createTileStore(Instruction
*TileDef
, Value
*Ptr
) {
433 assert(TileDef
->getType()->isX86_AMXTy() && "Not define tile!");
434 auto *II
= cast
<IntrinsicInst
>(TileDef
);
435 assert(II
&& "Not tile intrinsic!");
436 Value
*Row
= II
->getOperand(0);
437 Value
*Col
= II
->getOperand(1);
439 BasicBlock
*BB
= TileDef
->getParent();
440 BasicBlock::iterator Iter
= TileDef
->getIterator();
441 IRBuilder
<> Builder(BB
, ++Iter
);
442 Value
*Stride
= Builder
.getInt64(64);
443 std::array
<Value
*, 5> Args
= {Row
, Col
, Ptr
, Stride
, TileDef
};
445 Instruction
*TileStore
=
446 Builder
.CreateIntrinsic(Intrinsic::x86_tilestored64_internal
, None
, Args
);
450 static void replaceWithTileLoad(Use
&U
, Value
*Ptr
, bool IsPHI
= false) {
452 assert(V
->getType()->isX86_AMXTy() && "Not define tile!");
455 IntrinsicInst
*II
= nullptr;
457 Value
*PhiOp
= dyn_cast
<PHINode
>(V
)->getIncomingValue(0);
458 II
= cast
<IntrinsicInst
>(PhiOp
);
460 II
= cast
<IntrinsicInst
>(V
);
462 Value
*Row
= II
->getOperand(0);
463 Value
*Col
= II
->getOperand(1);
465 Instruction
*UserI
= dyn_cast
<Instruction
>(U
.getUser());
466 IRBuilder
<> Builder(UserI
);
467 Value
*Stride
= Builder
.getInt64(64);
468 std::array
<Value
*, 4> Args
= {Row
, Col
, Ptr
, Stride
};
471 Builder
.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal
, None
, Args
);
472 UserI
->replaceUsesOfWith(V
, TileLoad
);
475 static bool isIncomingOfPHI(Instruction
*I
) {
476 for (Use
&U
: I
->uses()) {
477 User
*V
= U
.getUser();
484 // Let all AMX tile data become volatile data, shorten the life range
485 // of each tile register before fast register allocation.
487 class X86VolatileTileData
{
491 X86VolatileTileData(Function
&Func
) : F(Func
) {}
492 Value
*updatePhiIncomings(BasicBlock
*BB
,
493 SmallVector
<Instruction
*, 2> &Incomings
);
494 void replacePhiDefWithLoad(Instruction
*PHI
, Value
*StorePtr
);
495 bool volatileTileData();
496 void volatileTilePHI(PHINode
*Inst
);
497 void volatileTileNonPHI(Instruction
*I
);
500 Value
*X86VolatileTileData::updatePhiIncomings(
501 BasicBlock
*BB
, SmallVector
<Instruction
*, 2> &Incomings
) {
502 Value
*I8Ptr
= getAllocaPos(BB
);
504 for (auto *I
: Incomings
) {
505 User
*Store
= createTileStore(I
, I8Ptr
);
507 // All its uses (except phi) should load from stored mem.
508 for (Use
&U
: I
->uses()) {
509 User
*V
= U
.getUser();
510 if (isa
<PHINode
>(V
) || V
== Store
)
512 replaceWithTileLoad(U
, I8Ptr
);
518 void X86VolatileTileData::replacePhiDefWithLoad(Instruction
*PHI
,
520 for (Use
&U
: PHI
->uses())
521 replaceWithTileLoad(U
, StorePtr
, true);
522 PHI
->eraseFromParent();
525 // Smilar with volatileTileNonPHI, this function only handle PHI Nodes
526 // and their related AMX intrinsics.
527 // 1) PHI Def should change to tileload.
528 // 2) PHI Incoming Values should tilestored in just after their def.
529 // 3) The mem of these tileload and tilestores should be same.
531 // ------------------------------------------------------
534 // br i1 %bool.cond, label %if.else, label %if.then
548 // %td = phi x86_amx [ %t1, %if.else ], [ %t0, %if.then ]
551 // ------------------------------------------------------
553 // ------------------------------------------------------
555 // %mem = alloca <256 x i32>, align 1024 *
559 // br i1 %bool.cond, label %if.else, label %if.then
563 // call void @llvm.x86.tilestored64.internal(mem, %t0) *
565 // %t0` = call x86_amx @llvm.x86.tileloadd64.internal(mem)*
572 // call void @llvm.x86.tilestored64.internal(mem, %t1) *
577 // %td = call x86_amx @llvm.x86.tileloadd64.internal(mem) *
579 // ------------------------------------------------------
580 void X86VolatileTileData::volatileTilePHI(PHINode
*PHI
) {
581 BasicBlock
*BB
= PHI
->getParent();
582 SmallVector
<Instruction
*, 2> Incomings
;
584 for (unsigned I
= 0, E
= PHI
->getNumIncomingValues(); I
!= E
; ++I
) {
585 Value
*Op
= PHI
->getIncomingValue(I
);
586 Instruction
*Inst
= dyn_cast
<Instruction
>(Op
);
587 assert(Inst
&& "We shouldn't fold AMX instrution!");
588 Incomings
.push_back(Inst
);
591 Value
*StorePtr
= updatePhiIncomings(BB
, Incomings
);
592 replacePhiDefWithLoad(PHI
, StorePtr
);
595 // Store the defined tile and load it before use.
596 // All its users are not PHI.
598 // ------------------------------------------------------
602 // ------------------------------------------------------
604 // ------------------------------------------------------
606 // call void @llvm.x86.tilestored64.internal(mem, %td)
608 // %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)
610 // ------------------------------------------------------
611 void X86VolatileTileData::volatileTileNonPHI(Instruction
*I
) {
612 BasicBlock
*BB
= I
->getParent();
613 Value
*I8Ptr
= getAllocaPos(BB
);
614 User
*Store
= createTileStore(I
, I8Ptr
);
616 // All its uses should load from stored mem.
617 for (Use
&U
: I
->uses()) {
618 User
*V
= U
.getUser();
619 assert(!isa
<PHINode
>(V
) && "PHI Nodes should be excluded!");
621 replaceWithTileLoad(U
, I8Ptr
);
625 // Volatile Tile Model:
626 // 1) All the uses of tile data comes from tileload in time.
627 // 2) All the defs of tile data tilestore into mem immediately.
629 // --------------------------------------------------------------------------
630 // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) key
631 // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)
632 // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) amx
633 // %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3)
634 // call void @llvm.x86.tilestored64.internal(... td) area
635 // --------------------------------------------------------------------------
636 // 3) No terminator, call or other amx instructions in the key amx area.
637 bool X86VolatileTileData::volatileTileData() {
638 bool Changed
= false;
639 for (BasicBlock
&BB
: F
) {
640 SmallVector
<Instruction
*, 2> PHIInsts
;
641 SmallVector
<Instruction
*, 8> AMXDefInsts
;
643 for (Instruction
&I
: BB
) {
644 if (!I
.getType()->isX86_AMXTy())
646 if (isa
<PHINode
>(&I
))
647 PHIInsts
.push_back(&I
);
649 AMXDefInsts
.push_back(&I
);
652 // First we "volatile" the non-phi related amx intrinsics.
653 for (Instruction
*I
: AMXDefInsts
) {
654 if (isIncomingOfPHI(I
))
656 volatileTileNonPHI(I
);
660 for (Instruction
*I
: PHIInsts
) {
661 volatileTilePHI(dyn_cast
<PHINode
>(I
));
668 } // anonymous namespace
672 class X86LowerAMXCast
{
676 X86LowerAMXCast(Function
&F
) : Func(F
) {}
677 bool combineAMXcast(TargetLibraryInfo
*TLI
);
678 bool transformAMXCast(IntrinsicInst
*AMXCast
);
679 bool transformAllAMXCast();
680 bool optimizeAMXCastFromPhi(IntrinsicInst
*CI
, PHINode
*PN
,
681 SmallSetVector
<Instruction
*, 16> &DeadInst
);
684 static bool DCEInstruction(Instruction
*I
,
685 SmallSetVector
<Instruction
*, 16> &WorkList
,
686 const TargetLibraryInfo
*TLI
) {
687 if (isInstructionTriviallyDead(I
, TLI
)) {
688 salvageDebugInfo(*I
);
691 // Null out all of the instruction's operands to see if any operand becomes
693 for (unsigned i
= 0, e
= I
->getNumOperands(); i
!= e
; ++i
) {
694 Value
*OpV
= I
->getOperand(i
);
695 I
->setOperand(i
, nullptr);
697 if (!OpV
->use_empty() || I
== OpV
)
700 // If the operand is an instruction that became dead as we nulled out the
701 // operand, and if it is 'trivially' dead, delete it in a future loop
703 if (Instruction
*OpI
= dyn_cast
<Instruction
>(OpV
)) {
704 if (isInstructionTriviallyDead(OpI
, TLI
)) {
705 WorkList
.insert(OpI
);
709 I
->eraseFromParent();
715 /// This function handles following case
721 /// All the related PHI nodes can be replaced by new PHI nodes with type A.
722 /// The uses of \p CI can be changed to the new PHI node corresponding to \p PN.
723 bool X86LowerAMXCast::optimizeAMXCastFromPhi(
724 IntrinsicInst
*CI
, PHINode
*PN
,
725 SmallSetVector
<Instruction
*, 16> &DeadInst
) {
726 IRBuilder
<> Builder(CI
);
727 Value
*Src
= CI
->getOperand(0);
728 Type
*SrcTy
= Src
->getType(); // Type B
729 Type
*DestTy
= CI
->getType(); // Type A
731 SmallVector
<PHINode
*, 4> PhiWorklist
;
732 SmallSetVector
<PHINode
*, 4> OldPhiNodes
;
734 // Find all of the A->B casts and PHI nodes.
735 // We need to inspect all related PHI nodes, but PHIs can be cyclic, so
736 // OldPhiNodes is used to track all known PHI nodes, before adding a new
737 // PHI to PhiWorklist, it is checked against and added to OldPhiNodes first.
738 PhiWorklist
.push_back(PN
);
739 OldPhiNodes
.insert(PN
);
740 while (!PhiWorklist
.empty()) {
741 auto *OldPN
= PhiWorklist
.pop_back_val();
742 for (Value
*IncValue
: OldPN
->incoming_values()) {
743 // TODO: currently, We ignore cases where it is a const. In the future, we
744 // might support const.
745 if (isa
<Constant
>(IncValue
))
748 if (auto *PNode
= dyn_cast
<PHINode
>(IncValue
)) {
749 if (OldPhiNodes
.insert(PNode
))
750 PhiWorklist
.push_back(PNode
);
753 Instruction
*ACI
= dyn_cast
<Instruction
>(IncValue
);
754 if (ACI
&& isAMXCast(ACI
)) {
755 // Verify it's a A->B cast.
756 Type
*TyA
= ACI
->getOperand(0)->getType();
757 Type
*TyB
= ACI
->getType();
758 if (TyA
!= DestTy
|| TyB
!= SrcTy
)
766 // Check that each user of each old PHI node is something that we can
767 // rewrite, so that all of the old PHI nodes can be cleaned up afterwards.
768 for (auto *OldPN
: OldPhiNodes
) {
769 for (User
*V
: OldPN
->users()) {
770 Instruction
*ACI
= dyn_cast
<Instruction
>(V
);
771 if (ACI
&& isAMXCast(ACI
)) {
772 // Verify it's a B->A cast.
773 Type
*TyB
= ACI
->getOperand(0)->getType();
774 Type
*TyA
= ACI
->getType();
775 if (TyA
!= DestTy
|| TyB
!= SrcTy
)
777 } else if (auto *PHI
= dyn_cast
<PHINode
>(V
)) {
778 // As long as the user is another old PHI node, then even if we don't
779 // rewrite it, the PHI web we're considering won't have any users
780 // outside itself, so it'll be dead.
787 // %goodphi = phi %0, %1
788 // %3 = amxcast %goodphi
790 // %goodphi2 = phi %0, %goodphi
791 // %4 = amxcast %goodphi2
792 // When optimizeAMXCastFromPhi process %3 and %goodphi, %goodphi2 is
793 // outside the phi-web, so the combination stop When
794 // optimizeAMXCastFromPhi process %4 and %goodphi2, the optimization
796 if (OldPhiNodes
.count(PHI
) == 0)
803 // For each old PHI node, create a corresponding new PHI node with a type A.
804 SmallDenseMap
<PHINode
*, PHINode
*> NewPNodes
;
805 for (auto *OldPN
: OldPhiNodes
) {
806 Builder
.SetInsertPoint(OldPN
);
807 PHINode
*NewPN
= Builder
.CreatePHI(DestTy
, OldPN
->getNumOperands());
808 NewPNodes
[OldPN
] = NewPN
;
811 // Fill in the operands of new PHI nodes.
812 for (auto *OldPN
: OldPhiNodes
) {
813 PHINode
*NewPN
= NewPNodes
[OldPN
];
814 for (unsigned j
= 0, e
= OldPN
->getNumOperands(); j
!= e
; ++j
) {
815 Value
*V
= OldPN
->getOperand(j
);
816 Value
*NewV
= nullptr;
817 Instruction
*ACI
= dyn_cast
<Instruction
>(V
);
818 // There should not be a AMXcast from a const.
819 if (ACI
&& isAMXCast(ACI
))
820 NewV
= ACI
->getOperand(0);
821 else if (auto *PrevPN
= dyn_cast
<PHINode
>(V
))
822 NewV
= NewPNodes
[PrevPN
];
824 NewPN
->addIncoming(NewV
, OldPN
->getIncomingBlock(j
));
828 // Traverse all accumulated PHI nodes and process its users,
829 // which are Stores and BitcCasts. Without this processing
830 // NewPHI nodes could be replicated and could lead to extra
831 // moves generated after DeSSA.
832 // If there is a store with type B, change it to type A.
834 // Replace users of BitCast B->A with NewPHI. These will help
835 // later to get rid of a closure formed by OldPHI nodes.
836 for (auto *OldPN
: OldPhiNodes
) {
837 PHINode
*NewPN
= NewPNodes
[OldPN
];
838 for (User
*V
: make_early_inc_range(OldPN
->users())) {
839 Instruction
*ACI
= dyn_cast
<Instruction
>(V
);
840 if (ACI
&& isAMXCast(ACI
)) {
841 Type
*TyB
= ACI
->getOperand(0)->getType();
842 Type
*TyA
= ACI
->getType();
843 assert(TyA
== DestTy
&& TyB
== SrcTy
);
846 ACI
->replaceAllUsesWith(NewPN
);
847 DeadInst
.insert(ACI
);
848 } else if (auto *PHI
= dyn_cast
<PHINode
>(V
)) {
849 // We don't need to push PHINode into DeadInst since they are operands
850 // of rootPN DCE can safely delete rootPN's operands if rootPN is dead.
851 assert(OldPhiNodes
.contains(PHI
));
854 llvm_unreachable("all uses should be handled");
860 bool X86LowerAMXCast::combineAMXcast(TargetLibraryInfo
*TLI
) {
862 // Collect tile cast instruction.
863 SmallVector
<Instruction
*, 8> Vec2TileInsts
;
864 SmallVector
<Instruction
*, 8> Tile2VecInsts
;
865 SmallVector
<Instruction
*, 8> PhiCastWorkList
;
866 SmallSetVector
<Instruction
*, 16> DeadInst
;
867 for (BasicBlock
&BB
: Func
) {
868 for (Instruction
&I
: BB
) {
871 m_Intrinsic
<Intrinsic::x86_cast_vector_to_tile
>(m_Value(Vec
))))
872 Vec2TileInsts
.push_back(&I
);
873 else if (match(&I
, m_Intrinsic
<Intrinsic::x86_cast_tile_to_vector
>(
875 Tile2VecInsts
.push_back(&I
);
879 auto Convert
= [&](SmallVectorImpl
<Instruction
*> &Insts
, Intrinsic::ID IID
) {
880 for (auto *Inst
: Insts
) {
881 for (User
*U
: Inst
->users()) {
882 IntrinsicInst
*II
= dyn_cast
<IntrinsicInst
>(U
);
883 if (!II
|| II
->getIntrinsicID() != IID
)
892 II
->replaceAllUsesWith(Inst
->getOperand(0));
898 Convert(Vec2TileInsts
, Intrinsic::x86_cast_tile_to_vector
);
899 Convert(Tile2VecInsts
, Intrinsic::x86_cast_vector_to_tile
);
901 auto EraseInst
= [&](SmallVectorImpl
<Instruction
*> &Insts
) {
902 for (auto *Inst
: Insts
) {
903 if (Inst
->use_empty()) {
904 Inst
->eraseFromParent();
910 EraseInst(Vec2TileInsts
);
911 EraseInst(Tile2VecInsts
);
913 // Handle the A->B->A cast, and there is an intervening PHI node.
914 for (BasicBlock
&BB
: Func
) {
915 for (Instruction
&I
: BB
) {
917 if (isa
<PHINode
>(I
.getOperand(0)))
918 PhiCastWorkList
.push_back(&I
);
922 for (auto *I
: PhiCastWorkList
) {
923 // We skip the dead Amxcast.
924 if (DeadInst
.contains(I
))
926 PHINode
*PN
= cast
<PHINode
>(I
->getOperand(0));
927 if (optimizeAMXCastFromPhi(cast
<IntrinsicInst
>(I
), PN
, DeadInst
)) {
933 // Since we create new phi and merge AMXCast, some old phis and AMXCast might
934 // have no uses. We do some DeadCodeElimination for them.
935 while (!DeadInst
.empty()) {
936 Instruction
*I
= DeadInst
.pop_back_val();
937 Change
|= DCEInstruction(I
, DeadInst
, TLI
);
942 // There might be remaining AMXcast after combineAMXcast and they should be
943 // handled elegantly.
944 bool X86LowerAMXCast::transformAMXCast(IntrinsicInst
*AMXCast
) {
945 IRBuilder
<> Builder(AMXCast
);
946 AllocaInst
*AllocaAddr
;
947 Value
*I8Ptr
, *Stride
;
948 auto *Src
= AMXCast
->getOperand(0);
950 auto Prepare
= [&](Type
*MemTy
) {
951 AllocaAddr
= createAllocaInstAtEntry(Builder
, AMXCast
->getParent(), MemTy
);
952 I8Ptr
= Builder
.CreateBitCast(AllocaAddr
, Builder
.getInt8PtrTy());
953 Stride
= Builder
.getInt64(64);
956 if (AMXCast
->getType()->isX86_AMXTy()) {
957 // %2 = amxcast <225 x i32> %src to x86_amx
958 // call void @llvm.x86.tilestored64.internal(i16 15, i16 60,
959 // i8* %addr3, i64 60, x86_amx %2)
961 // %addr = alloca <225 x i32>, align 64
962 // store <225 x i32> %src, <225 x i32>* %addr, align 64
963 // %addr2 = bitcast <225 x i32>* %addr to i8*
964 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 15, i16 60,
967 // call void @llvm.x86.tilestored64.internal(i16 15, i16 60,
968 // i8* %addr3, i64 60, x86_amx %2)
969 Use
&U
= *(AMXCast
->use_begin());
970 unsigned OpNo
= U
.getOperandNo();
971 auto *II
= dyn_cast
<IntrinsicInst
>(U
.getUser());
973 return false; // May be bitcast from x86amx to <256 x i32>.
974 Prepare(AMXCast
->getOperand(0)->getType());
975 Builder
.CreateStore(Src
, AllocaAddr
);
976 // TODO we can pick an constant operand for the shape.
977 Value
*Row
= nullptr, *Col
= nullptr;
978 std::tie(Row
, Col
) = getShape(II
, OpNo
);
979 std::array
<Value
*, 4> Args
= {
980 Row
, Col
, I8Ptr
, Builder
.CreateSExt(Col
, Builder
.getInt64Ty())};
981 Value
*NewInst
= Builder
.CreateIntrinsic(
982 Intrinsic::x86_tileloadd64_internal
, None
, Args
);
983 AMXCast
->replaceAllUsesWith(NewInst
);
984 AMXCast
->eraseFromParent();
986 // %2 = amxcast x86_amx %src to <225 x i32>
988 // %addr = alloca <225 x i32>, align 64
989 // %addr2 = bitcast <225 x i32>* to i8*
990 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
991 // i8* %addr2, i64 %stride)
992 // %2 = load <225 x i32>, <225 x i32>* %addr, align 64
993 auto *II
= dyn_cast
<IntrinsicInst
>(Src
);
995 return false; // May be bitcast from <256 x i32> to x86amx.
996 Prepare(AMXCast
->getType());
997 Value
*Row
= II
->getOperand(0);
998 Value
*Col
= II
->getOperand(1);
999 std::array
<Value
*, 5> Args
= {
1000 Row
, Col
, I8Ptr
, Builder
.CreateSExt(Col
, Builder
.getInt64Ty()), Src
};
1001 Builder
.CreateIntrinsic(Intrinsic::x86_tilestored64_internal
, None
, Args
);
1002 Value
*NewInst
= Builder
.CreateLoad(AMXCast
->getType(), AllocaAddr
);
1003 AMXCast
->replaceAllUsesWith(NewInst
);
1004 AMXCast
->eraseFromParent();
1010 bool X86LowerAMXCast::transformAllAMXCast() {
1011 bool Change
= false;
1012 // Collect tile cast instruction.
1013 SmallVector
<Instruction
*, 8> WorkLists
;
1014 for (BasicBlock
&BB
: Func
) {
1015 for (Instruction
&I
: BB
) {
1017 WorkLists
.push_back(&I
);
1021 for (auto *Inst
: WorkLists
) {
1022 Change
|= transformAMXCast(cast
<IntrinsicInst
>(Inst
));
1028 } // anonymous namespace
1032 class X86LowerAMXTypeLegacyPass
: public FunctionPass
{
1036 X86LowerAMXTypeLegacyPass() : FunctionPass(ID
) {
1037 initializeX86LowerAMXTypeLegacyPassPass(*PassRegistry::getPassRegistry());
1040 bool runOnFunction(Function
&F
) override
{
1042 TargetMachine
*TM
= &getAnalysis
<TargetPassConfig
>().getTM
<TargetMachine
>();
1043 TargetLibraryInfo
*TLI
=
1044 &getAnalysis
<TargetLibraryInfoWrapperPass
>().getTLI(F
);
1045 X86LowerAMXCast
LAC(F
);
1046 C
|= LAC
.combineAMXcast(TLI
);
1047 // There might be remaining AMXcast after combineAMXcast and they should be
1048 // handled elegantly.
1049 C
|= LAC
.transformAllAMXCast();
1051 X86LowerAMXType
LAT(F
);
1054 // Prepare for fast register allocation at O0.
1055 // Todo: May better check the volatile model of AMX code, not just
1056 // by checking Attribute::OptimizeNone and CodeGenOpt::None.
1057 if (TM
->getOptLevel() == CodeGenOpt::None
) {
1058 // If Front End not use O0 but the Mid/Back end use O0, (e.g.
1059 // "Clang -O2 -S -emit-llvm t.c" + "llc t.ll") we should make
1060 // sure the amx data is volatile, that is nessary for AMX fast
1061 // register allocation.
1062 if (!F
.hasFnAttribute(Attribute::OptimizeNone
)) {
1063 X86VolatileTileData
VTD(F
);
1064 C
= VTD
.volatileTileData() || C
;
1071 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
1072 AU
.setPreservesCFG();
1073 AU
.addRequired
<TargetPassConfig
>();
1074 AU
.addRequired
<TargetLibraryInfoWrapperPass
>();
1078 } // anonymous namespace
1080 static const char PassName
[] = "Lower AMX type for load/store";
1081 char X86LowerAMXTypeLegacyPass::ID
= 0;
1082 INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass
, DEBUG_TYPE
, PassName
, false,
1084 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig
)
1085 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass
)
1086 INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass
, DEBUG_TYPE
, PassName
, false,
1089 FunctionPass
*llvm::createX86LowerAMXTypePass() {
1090 return new X86LowerAMXTypeLegacyPass();