1 //===- Target/X86/X86LowerAMXType.cpp - -------------------------*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 /// \file Pass to transform <256 x i32> load/store
10 /// <256 x i32> is bitcasted to x86_amx on X86, and AMX instruction set only
11 /// provides simple operation on x86_amx. The basic elementwise operation
12 /// is not supported by AMX. Since x86_amx is bitcasted from vector <256 x i32>
13 /// and only AMX intrinsics can operate on the type, we need transform
14 /// load/store <256 x i32> instruction to AMX load/store. If the bitcast can
15 /// not be combined with load/store, we transform the bitcast to amx load/store
16 /// and <256 x i32> store/load.
18 /// If Front End not use O0 but the Mid/Back end use O0, (e.g. "Clang -O2 -S
19 /// -emit-llvm t.c" + "llc t.ll") we should make sure the amx data is volatile,
20 /// because that is necessary for AMX fast register allocation. (In Fast
21 /// registera allocation, register will be allocated before spill/reload, so
22 /// there is no additional register for amx to identify the step in spill.)
23 /// The volatileTileData() will handle this case.
25 /// ----------------------------------------------------------
29 /// ----------------------------------------------------------
30 /// will transfer to -->
31 /// ----------------------------------------------------------
33 /// | call void @llvm.x86.tilestored64.internal(mem, %td) |
35 /// | %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)|
37 /// ----------------------------------------------------------
39 //===----------------------------------------------------------------------===//
42 #include "llvm/ADT/PostOrderIterator.h"
43 #include "llvm/ADT/SetVector.h"
44 #include "llvm/ADT/SmallSet.h"
45 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
46 #include "llvm/Analysis/TargetLibraryInfo.h"
47 #include "llvm/Analysis/TargetTransformInfo.h"
48 #include "llvm/CodeGen/Passes.h"
49 #include "llvm/CodeGen/TargetPassConfig.h"
50 #include "llvm/CodeGen/ValueTypes.h"
51 #include "llvm/IR/DataLayout.h"
52 #include "llvm/IR/Function.h"
53 #include "llvm/IR/IRBuilder.h"
54 #include "llvm/IR/Instructions.h"
55 #include "llvm/IR/IntrinsicInst.h"
56 #include "llvm/IR/IntrinsicsX86.h"
57 #include "llvm/IR/PatternMatch.h"
58 #include "llvm/InitializePasses.h"
59 #include "llvm/Pass.h"
60 #include "llvm/Target/TargetMachine.h"
61 #include "llvm/Transforms/Utils/AssumeBundleBuilder.h"
62 #include "llvm/Transforms/Utils/Local.h"
67 using namespace PatternMatch
;
69 #define DEBUG_TYPE "lower-amx-type"
71 static bool isAMXCast(Instruction
*II
) {
73 m_Intrinsic
<Intrinsic::x86_cast_vector_to_tile
>(m_Value())) ||
74 match(II
, m_Intrinsic
<Intrinsic::x86_cast_tile_to_vector
>(m_Value()));
77 static bool isAMXIntrinsic(Value
*I
) {
78 auto *II
= dyn_cast
<IntrinsicInst
>(I
);
83 // Check if return type or parameter is x86_amx. If it is x86_amx
84 // the intrinsic must be x86 amx intrinsics.
85 if (II
->getType()->isX86_AMXTy())
87 for (Value
*V
: II
->args()) {
88 if (V
->getType()->isX86_AMXTy())
95 static AllocaInst
*createAllocaInstAtEntry(IRBuilder
<> &Builder
, BasicBlock
*BB
,
97 Function
&F
= *BB
->getParent();
98 Module
*M
= BB
->getModule();
99 const DataLayout
&DL
= M
->getDataLayout();
101 LLVMContext
&Ctx
= Builder
.getContext();
102 auto AllocaAlignment
= DL
.getPrefTypeAlign(Type::getX86_AMXTy(Ctx
));
103 unsigned AllocaAS
= DL
.getAllocaAddrSpace();
104 AllocaInst
*AllocaRes
=
105 new AllocaInst(Ty
, AllocaAS
, "", &F
.getEntryBlock().front());
106 AllocaRes
->setAlignment(AllocaAlignment
);
110 static Instruction
*getFirstNonAllocaInTheEntryBlock(Function
&F
) {
111 for (Instruction
&I
: F
.getEntryBlock())
112 if (!isa
<AllocaInst
>(&I
))
114 llvm_unreachable("No terminator in the entry block!");
117 static std::pair
<Value
*, Value
*> getShape(IntrinsicInst
*II
, unsigned OpNo
) {
118 IRBuilder
<> Builder(II
);
119 Value
*Row
= nullptr, *Col
= nullptr;
120 switch (II
->getIntrinsicID()) {
122 llvm_unreachable("Expect amx intrinsics");
123 case Intrinsic::x86_tileloadd64_internal
:
124 case Intrinsic::x86_tileloaddt164_internal
:
125 case Intrinsic::x86_tilestored64_internal
: {
126 Row
= II
->getArgOperand(0);
127 Col
= II
->getArgOperand(1);
131 // The shape depends on which operand.
132 case Intrinsic::x86_tcmmimfp16ps_internal
:
133 case Intrinsic::x86_tcmmrlfp16ps_internal
:
134 case Intrinsic::x86_tdpbssd_internal
:
135 case Intrinsic::x86_tdpbsud_internal
:
136 case Intrinsic::x86_tdpbusd_internal
:
137 case Intrinsic::x86_tdpbuud_internal
:
138 case Intrinsic::x86_tdpbf16ps_internal
:
139 case Intrinsic::x86_tdpfp16ps_internal
: {
142 Row
= II
->getArgOperand(0);
143 Col
= II
->getArgOperand(1);
146 Row
= II
->getArgOperand(0);
147 Col
= II
->getArgOperand(2);
150 if (isa
<ConstantInt
>(II
->getArgOperand(2)))
151 Row
= Builder
.getInt16(
152 (cast
<ConstantInt
>(II
->getOperand(2))->getSExtValue()) / 4);
153 else if (isa
<Instruction
>(II
->getArgOperand(2))) {
154 // When it is not a const value and it is not a function argument, we
155 // create Row after the definition of II->getOperand(2) instead of
156 // before II. For example, II is %118, we try to getshape for %117:
157 // %117 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x
159 // %118 = call x86_amx @llvm.x86.tdpbf16ps.internal(i16
160 // %104, i16 %105, i16 %106, x86_amx %110, x86_amx %114, x86_amx
162 // If we create %row = udiv i16 %106, 4 before %118(aka. II), then its
163 // definition is after its user(new tileload for %117).
164 // So, the best choice is to create %row right after the definition of
166 Builder
.SetInsertPoint(cast
<Instruction
>(II
->getOperand(2)));
167 Row
= Builder
.CreateUDiv(II
->getOperand(2), Builder
.getInt16(4));
168 cast
<Instruction
>(Row
)->moveAfter(cast
<Instruction
>(II
->getOperand(2)));
170 // When it is not a const value and it is a function argument, we create
171 // Row at the entry bb.
172 IRBuilder
<> NewBuilder(
173 getFirstNonAllocaInTheEntryBlock(*II
->getFunction()));
174 Row
= NewBuilder
.CreateUDiv(II
->getOperand(2), NewBuilder
.getInt16(4));
176 Col
= II
->getArgOperand(1);
183 return std::make_pair(Row
, Col
);
186 static std::pair
<Value
*, Value
*> getShape(PHINode
*Phi
) {
187 Use
&U
= *(Phi
->use_begin());
188 unsigned OpNo
= U
.getOperandNo();
189 User
*V
= U
.getUser();
190 // TODO We don't traverse all users. To make the algorithm simple, here we
191 // just traverse the first user. If we can find shape, then return the shape,
192 // otherwise just return nullptr and the optimization for undef/zero will be
195 if (isAMXCast(dyn_cast
<Instruction
>(V
))) {
198 Use
&U
= *(V
->use_begin());
199 OpNo
= U
.getOperandNo();
201 } else if (isAMXIntrinsic(V
)) {
202 return getShape(cast
<IntrinsicInst
>(V
), OpNo
);
203 } else if (isa
<PHINode
>(V
)) {
206 Use
&U
= *(V
->use_begin());
213 return std::make_pair(nullptr, nullptr);
217 class X86LowerAMXType
{
220 // In AMX intrinsics we let Shape = {Row, Col}, but the
221 // RealCol = Col / ElementSize. We may use the RealCol
222 // as a new Row for other new created AMX intrinsics.
223 std::map
<Value
*, Value
*> Col2Row
;
226 X86LowerAMXType(Function
&F
) : Func(F
) {}
228 void combineLoadBitcast(LoadInst
*LD
, BitCastInst
*Bitcast
);
229 void combineBitcastStore(BitCastInst
*Bitcast
, StoreInst
*ST
);
230 bool transformBitcast(BitCastInst
*Bitcast
);
233 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
234 // %2 = bitcast <256 x i32> %src to x86_amx
236 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
237 // i8* %addr, i64 %stride64)
238 void X86LowerAMXType::combineLoadBitcast(LoadInst
*LD
, BitCastInst
*Bitcast
) {
239 Value
*Row
= nullptr, *Col
= nullptr;
240 Use
&U
= *(Bitcast
->use_begin());
241 unsigned OpNo
= U
.getOperandNo();
242 auto *II
= cast
<IntrinsicInst
>(U
.getUser());
243 std::tie(Row
, Col
) = getShape(II
, OpNo
);
244 IRBuilder
<> Builder(Bitcast
);
245 // Use the maximun column as stride.
246 Value
*Stride
= Builder
.getInt64(64);
247 Value
*I8Ptr
= LD
->getOperand(0);
248 std::array
<Value
*, 4> Args
= {Row
, Col
, I8Ptr
, Stride
};
250 Value
*NewInst
= Builder
.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal
,
252 Bitcast
->replaceAllUsesWith(NewInst
);
255 // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
257 // %13 = bitcast x86_amx %src to <256 x i32>
258 // store <256 x i32> %13, <256 x i32>* %addr, align 64
260 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
262 void X86LowerAMXType::combineBitcastStore(BitCastInst
*Bitcast
, StoreInst
*ST
) {
264 Value
*Tile
= Bitcast
->getOperand(0);
265 auto *II
= cast
<IntrinsicInst
>(Tile
);
266 // Tile is output from AMX intrinsic. The first operand of the
267 // intrinsic is row, the second operand of the intrinsic is column.
268 Value
*Row
= II
->getOperand(0);
269 Value
*Col
= II
->getOperand(1);
270 IRBuilder
<> Builder(ST
);
271 // Use the maximum column as stride. It must be the same with load
273 Value
*Stride
= Builder
.getInt64(64);
274 Value
*I8Ptr
= ST
->getOperand(1);
275 std::array
<Value
*, 5> Args
= {Row
, Col
, I8Ptr
, Stride
, Tile
};
276 Builder
.CreateIntrinsic(Intrinsic::x86_tilestored64_internal
, std::nullopt
,
278 if (Bitcast
->hasOneUse())
280 // %13 = bitcast x86_amx %src to <256 x i32>
281 // store <256 x i32> %13, <256 x i32>* %addr, align 64
282 // %add = <256 x i32> %13, <256 x i32> %src2
284 // %13 = bitcast x86_amx %src to <256 x i32>
285 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
287 // %14 = load <256 x i32>, %addr
288 // %add = <256 x i32> %14, <256 x i32> %src2
289 Value
*Vec
= Builder
.CreateLoad(Bitcast
->getType(), ST
->getOperand(1));
290 Bitcast
->replaceAllUsesWith(Vec
);
293 // transform bitcast to <store, load> instructions.
294 bool X86LowerAMXType::transformBitcast(BitCastInst
*Bitcast
) {
295 IRBuilder
<> Builder(Bitcast
);
296 AllocaInst
*AllocaAddr
;
297 Value
*I8Ptr
, *Stride
;
298 auto *Src
= Bitcast
->getOperand(0);
300 auto Prepare
= [&](Type
*MemTy
) {
301 AllocaAddr
= createAllocaInstAtEntry(Builder
, Bitcast
->getParent(), MemTy
);
303 Stride
= Builder
.getInt64(64);
306 if (Bitcast
->getType()->isX86_AMXTy()) {
307 // %2 = bitcast <256 x i32> %src to x86_amx
309 // %addr = alloca <256 x i32>, align 64
310 // store <256 x i32> %src, <256 x i32>* %addr, align 64
311 // %addr2 = bitcast <256 x i32>* to i8*
312 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
315 Use
&U
= *(Bitcast
->use_begin());
316 unsigned OpNo
= U
.getOperandNo();
317 auto *II
= dyn_cast
<IntrinsicInst
>(U
.getUser());
319 return false; // May be bitcast from x86amx to <256 x i32>.
320 Prepare(Bitcast
->getOperand(0)->getType());
321 Builder
.CreateStore(Src
, AllocaAddr
);
322 // TODO we can pick an constant operand for the shape.
323 Value
*Row
= nullptr, *Col
= nullptr;
324 std::tie(Row
, Col
) = getShape(II
, OpNo
);
325 std::array
<Value
*, 4> Args
= {Row
, Col
, I8Ptr
, Stride
};
326 Value
*NewInst
= Builder
.CreateIntrinsic(
327 Intrinsic::x86_tileloadd64_internal
, std::nullopt
, Args
);
328 Bitcast
->replaceAllUsesWith(NewInst
);
330 // %2 = bitcast x86_amx %src to <256 x i32>
332 // %addr = alloca <256 x i32>, align 64
333 // %addr2 = bitcast <256 x i32>* to i8*
334 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
335 // i8* %addr2, i64 %stride)
336 // %2 = load <256 x i32>, <256 x i32>* %addr, align 64
337 auto *II
= dyn_cast
<IntrinsicInst
>(Src
);
339 return false; // May be bitcast from <256 x i32> to x86amx.
340 Prepare(Bitcast
->getType());
341 Value
*Row
= II
->getOperand(0);
342 Value
*Col
= II
->getOperand(1);
343 std::array
<Value
*, 5> Args
= {Row
, Col
, I8Ptr
, Stride
, Src
};
344 Builder
.CreateIntrinsic(Intrinsic::x86_tilestored64_internal
, std::nullopt
,
346 Value
*NewInst
= Builder
.CreateLoad(Bitcast
->getType(), AllocaAddr
);
347 Bitcast
->replaceAllUsesWith(NewInst
);
353 bool X86LowerAMXType::visit() {
354 SmallVector
<Instruction
*, 8> DeadInsts
;
357 for (BasicBlock
*BB
: post_order(&Func
)) {
358 for (Instruction
&Inst
: llvm::make_early_inc_range(llvm::reverse(*BB
))) {
359 auto *Bitcast
= dyn_cast
<BitCastInst
>(&Inst
);
363 Value
*Src
= Bitcast
->getOperand(0);
364 if (Bitcast
->getType()->isX86_AMXTy()) {
365 if (Bitcast
->user_empty()) {
366 DeadInsts
.push_back(Bitcast
);
369 LoadInst
*LD
= dyn_cast
<LoadInst
>(Src
);
371 if (transformBitcast(Bitcast
))
372 DeadInsts
.push_back(Bitcast
);
375 // If load has mutli-user, duplicate a vector load.
376 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
377 // %2 = bitcast <256 x i32> %src to x86_amx
378 // %add = add <256 x i32> %src, <256 x i32> %src2
380 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
381 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
382 // i8* %addr, i64 %stride64)
383 // %add = add <256 x i32> %src, <256 x i32> %src2
385 // If load has one user, the load will be eliminated in DAG ISel.
386 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
387 // %2 = bitcast <256 x i32> %src to x86_amx
389 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
390 // i8* %addr, i64 %stride64)
391 combineLoadBitcast(LD
, Bitcast
);
392 DeadInsts
.push_back(Bitcast
);
394 DeadInsts
.push_back(LD
);
395 } else if (Src
->getType()->isX86_AMXTy()) {
396 if (Bitcast
->user_empty()) {
397 DeadInsts
.push_back(Bitcast
);
400 StoreInst
*ST
= nullptr;
401 for (Use
&U
: Bitcast
->uses()) {
402 ST
= dyn_cast
<StoreInst
>(U
.getUser());
407 if (transformBitcast(Bitcast
))
408 DeadInsts
.push_back(Bitcast
);
411 // If bitcast (%13) has one use, combine bitcast and store to amx store.
412 // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
414 // %13 = bitcast x86_amx %src to <256 x i32>
415 // store <256 x i32> %13, <256 x i32>* %addr, align 64
417 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
420 // If bitcast (%13) has multi-use, transform as below.
421 // %13 = bitcast x86_amx %src to <256 x i32>
422 // store <256 x i32> %13, <256 x i32>* %addr, align 64
423 // %add = <256 x i32> %13, <256 x i32> %src2
425 // %13 = bitcast x86_amx %src to <256 x i32>
426 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
428 // %14 = load <256 x i32>, %addr
429 // %add = <256 x i32> %14, <256 x i32> %src2
431 combineBitcastStore(Bitcast
, ST
);
432 // Delete user first.
433 DeadInsts
.push_back(ST
);
434 DeadInsts
.push_back(Bitcast
);
439 bool C
= !DeadInsts
.empty();
441 for (auto *Inst
: DeadInsts
)
442 Inst
->eraseFromParent();
446 } // anonymous namespace
448 static Value
*getAllocaPos(BasicBlock
*BB
) {
449 Module
*M
= BB
->getModule();
450 Function
*F
= BB
->getParent();
451 IRBuilder
<> Builder(&F
->getEntryBlock().front());
452 const DataLayout
&DL
= M
->getDataLayout();
453 unsigned AllocaAS
= DL
.getAllocaAddrSpace();
454 Type
*V256I32Ty
= VectorType::get(Builder
.getInt32Ty(), 256, false);
455 AllocaInst
*AllocaRes
=
456 new AllocaInst(V256I32Ty
, AllocaAS
, "", &F
->getEntryBlock().front());
457 BasicBlock::iterator Iter
= AllocaRes
->getIterator();
459 Builder
.SetInsertPoint(&*Iter
);
460 Value
*I8Ptr
= Builder
.CreateBitCast(AllocaRes
, Builder
.getInt8PtrTy());
464 static Instruction
*createTileStore(Instruction
*TileDef
, Value
*Ptr
) {
465 assert(TileDef
->getType()->isX86_AMXTy() && "Not define tile!");
466 auto *II
= cast
<IntrinsicInst
>(TileDef
);
467 assert(II
&& "Not tile intrinsic!");
468 Value
*Row
= II
->getOperand(0);
469 Value
*Col
= II
->getOperand(1);
471 BasicBlock
*BB
= TileDef
->getParent();
472 BasicBlock::iterator Iter
= TileDef
->getIterator();
473 IRBuilder
<> Builder(BB
, ++Iter
);
474 Value
*Stride
= Builder
.getInt64(64);
475 std::array
<Value
*, 5> Args
= {Row
, Col
, Ptr
, Stride
, TileDef
};
477 Instruction
*TileStore
= Builder
.CreateIntrinsic(
478 Intrinsic::x86_tilestored64_internal
, std::nullopt
, Args
);
482 static void replaceWithTileLoad(Use
&U
, Value
*Ptr
, bool IsPHI
= false) {
484 assert(V
->getType()->isX86_AMXTy() && "Not define tile!");
487 IntrinsicInst
*II
= nullptr;
489 Value
*PhiOp
= cast
<PHINode
>(V
)->getIncomingValue(0);
490 II
= cast
<IntrinsicInst
>(PhiOp
);
492 II
= cast
<IntrinsicInst
>(V
);
494 Value
*Row
= II
->getOperand(0);
495 Value
*Col
= II
->getOperand(1);
497 Instruction
*UserI
= cast
<Instruction
>(U
.getUser());
498 IRBuilder
<> Builder(UserI
);
499 Value
*Stride
= Builder
.getInt64(64);
500 std::array
<Value
*, 4> Args
= {Row
, Col
, Ptr
, Stride
};
502 Value
*TileLoad
= Builder
.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal
,
504 UserI
->replaceUsesOfWith(V
, TileLoad
);
507 static bool isIncomingOfPHI(Instruction
*I
) {
508 for (Use
&U
: I
->uses()) {
509 User
*V
= U
.getUser();
516 // Let all AMX tile data become volatile data, shorten the life range
517 // of each tile register before fast register allocation.
519 class X86VolatileTileData
{
523 X86VolatileTileData(Function
&Func
) : F(Func
) {}
524 Value
*updatePhiIncomings(BasicBlock
*BB
,
525 SmallVector
<Instruction
*, 2> &Incomings
);
526 void replacePhiDefWithLoad(Instruction
*PHI
, Value
*StorePtr
);
527 bool volatileTileData();
528 void volatileTilePHI(PHINode
*PHI
);
529 void volatileTileNonPHI(Instruction
*I
);
532 Value
*X86VolatileTileData::updatePhiIncomings(
533 BasicBlock
*BB
, SmallVector
<Instruction
*, 2> &Incomings
) {
534 Value
*I8Ptr
= getAllocaPos(BB
);
536 for (auto *I
: Incomings
) {
537 User
*Store
= createTileStore(I
, I8Ptr
);
539 // All its uses (except phi) should load from stored mem.
540 for (Use
&U
: I
->uses()) {
541 User
*V
= U
.getUser();
542 if (isa
<PHINode
>(V
) || V
== Store
)
544 replaceWithTileLoad(U
, I8Ptr
);
550 void X86VolatileTileData::replacePhiDefWithLoad(Instruction
*PHI
,
552 for (Use
&U
: PHI
->uses())
553 replaceWithTileLoad(U
, StorePtr
, true);
554 PHI
->eraseFromParent();
557 // Smilar with volatileTileNonPHI, this function only handle PHI Nodes
558 // and their related AMX intrinsics.
559 // 1) PHI Def should change to tileload.
560 // 2) PHI Incoming Values should tilestored in just after their def.
561 // 3) The mem of these tileload and tilestores should be same.
563 // ------------------------------------------------------
566 // br i1 %bool.cond, label %if.else, label %if.then
580 // %td = phi x86_amx [ %t1, %if.else ], [ %t0, %if.then ]
583 // ------------------------------------------------------
585 // ------------------------------------------------------
587 // %mem = alloca <256 x i32>, align 1024 *
591 // br i1 %bool.cond, label %if.else, label %if.then
595 // call void @llvm.x86.tilestored64.internal(mem, %t0) *
597 // %t0` = call x86_amx @llvm.x86.tileloadd64.internal(mem)*
604 // call void @llvm.x86.tilestored64.internal(mem, %t1) *
609 // %td = call x86_amx @llvm.x86.tileloadd64.internal(mem) *
611 // ------------------------------------------------------
612 void X86VolatileTileData::volatileTilePHI(PHINode
*PHI
) {
613 BasicBlock
*BB
= PHI
->getParent();
614 SmallVector
<Instruction
*, 2> Incomings
;
616 for (unsigned I
= 0, E
= PHI
->getNumIncomingValues(); I
!= E
; ++I
) {
617 Value
*Op
= PHI
->getIncomingValue(I
);
618 Instruction
*Inst
= dyn_cast
<Instruction
>(Op
);
619 assert(Inst
&& "We shouldn't fold AMX instrution!");
620 Incomings
.push_back(Inst
);
623 Value
*StorePtr
= updatePhiIncomings(BB
, Incomings
);
624 replacePhiDefWithLoad(PHI
, StorePtr
);
627 // Store the defined tile and load it before use.
628 // All its users are not PHI.
630 // ------------------------------------------------------
634 // ------------------------------------------------------
636 // ------------------------------------------------------
638 // call void @llvm.x86.tilestored64.internal(mem, %td)
640 // %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)
642 // ------------------------------------------------------
643 void X86VolatileTileData::volatileTileNonPHI(Instruction
*I
) {
644 BasicBlock
*BB
= I
->getParent();
645 Value
*I8Ptr
= getAllocaPos(BB
);
646 User
*Store
= createTileStore(I
, I8Ptr
);
648 // All its uses should load from stored mem.
649 for (Use
&U
: I
->uses()) {
650 User
*V
= U
.getUser();
651 assert(!isa
<PHINode
>(V
) && "PHI Nodes should be excluded!");
653 replaceWithTileLoad(U
, I8Ptr
);
657 // Volatile Tile Model:
658 // 1) All the uses of tile data comes from tileload in time.
659 // 2) All the defs of tile data tilestore into mem immediately.
661 // --------------------------------------------------------------------------
662 // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) key
663 // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)
664 // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) amx
665 // %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3)
666 // call void @llvm.x86.tilestored64.internal(... td) area
667 // --------------------------------------------------------------------------
668 // 3) No terminator, call or other amx instructions in the key amx area.
669 bool X86VolatileTileData::volatileTileData() {
670 bool Changed
= false;
671 for (BasicBlock
&BB
: F
) {
672 SmallVector
<Instruction
*, 2> PHIInsts
;
673 SmallVector
<Instruction
*, 8> AMXDefInsts
;
675 for (Instruction
&I
: BB
) {
676 if (!I
.getType()->isX86_AMXTy())
678 if (isa
<PHINode
>(&I
))
679 PHIInsts
.push_back(&I
);
681 AMXDefInsts
.push_back(&I
);
684 // First we "volatile" the non-phi related amx intrinsics.
685 for (Instruction
*I
: AMXDefInsts
) {
686 if (isIncomingOfPHI(I
))
688 volatileTileNonPHI(I
);
692 for (Instruction
*I
: PHIInsts
) {
693 volatileTilePHI(dyn_cast
<PHINode
>(I
));
700 } // anonymous namespace
704 class X86LowerAMXCast
{
706 std::unique_ptr
<DominatorTree
> DT
;
709 X86LowerAMXCast(Function
&F
) : Func(F
), DT(nullptr) {}
710 bool combineCastStore(IntrinsicInst
*Cast
, StoreInst
*ST
);
711 bool combineLoadCast(IntrinsicInst
*Cast
, LoadInst
*LD
);
712 bool combineLdSt(SmallVectorImpl
<Instruction
*> &Casts
);
713 bool combineAMXcast(TargetLibraryInfo
*TLI
);
714 bool transformAMXCast(IntrinsicInst
*AMXCast
);
715 bool transformAllAMXCast();
716 bool optimizeAMXCastFromPhi(IntrinsicInst
*CI
, PHINode
*PN
,
717 SmallSetVector
<Instruction
*, 16> &DeadInst
);
720 static bool DCEInstruction(Instruction
*I
,
721 SmallSetVector
<Instruction
*, 16> &WorkList
,
722 const TargetLibraryInfo
*TLI
) {
723 if (isInstructionTriviallyDead(I
, TLI
)) {
724 salvageDebugInfo(*I
);
727 // Null out all of the instruction's operands to see if any operand becomes
729 for (unsigned i
= 0, e
= I
->getNumOperands(); i
!= e
; ++i
) {
730 Value
*OpV
= I
->getOperand(i
);
731 I
->setOperand(i
, nullptr);
733 if (!OpV
->use_empty() || I
== OpV
)
736 // If the operand is an instruction that became dead as we nulled out the
737 // operand, and if it is 'trivially' dead, delete it in a future loop
739 if (Instruction
*OpI
= dyn_cast
<Instruction
>(OpV
)) {
740 if (isInstructionTriviallyDead(OpI
, TLI
)) {
741 WorkList
.insert(OpI
);
745 I
->eraseFromParent();
751 /// This function handles following case
757 /// All the related PHI nodes can be replaced by new PHI nodes with type A.
758 /// The uses of \p CI can be changed to the new PHI node corresponding to \p PN.
759 bool X86LowerAMXCast::optimizeAMXCastFromPhi(
760 IntrinsicInst
*CI
, PHINode
*PN
,
761 SmallSetVector
<Instruction
*, 16> &DeadInst
) {
762 IRBuilder
<> Builder(CI
);
763 Value
*Src
= CI
->getOperand(0);
764 Type
*SrcTy
= Src
->getType(); // Type B
765 Type
*DestTy
= CI
->getType(); // Type A
767 SmallVector
<PHINode
*, 4> PhiWorklist
;
768 SmallSetVector
<PHINode
*, 4> OldPhiNodes
;
770 // Find all of the A->B casts and PHI nodes.
771 // We need to inspect all related PHI nodes, but PHIs can be cyclic, so
772 // OldPhiNodes is used to track all known PHI nodes, before adding a new
773 // PHI to PhiWorklist, it is checked against and added to OldPhiNodes first.
774 PhiWorklist
.push_back(PN
);
775 OldPhiNodes
.insert(PN
);
776 while (!PhiWorklist
.empty()) {
777 auto *OldPN
= PhiWorklist
.pop_back_val();
778 for (unsigned I
= 0; I
< OldPN
->getNumOperands(); ++I
) {
779 Value
*IncValue
= OldPN
->getIncomingValue(I
);
780 // TODO: currently, We ignore cases where it is a const. In the future, we
781 // might support const.
782 if (isa
<Constant
>(IncValue
)) {
783 auto *IncConst
= dyn_cast
<Constant
>(IncValue
);
784 if (!isa
<UndefValue
>(IncValue
) && !IncConst
->isZeroValue())
786 Value
*Row
= nullptr, *Col
= nullptr;
787 std::tie(Row
, Col
) = getShape(OldPN
);
788 // TODO: If it is not constant the Row and Col must domoniate tilezero
789 // that we are going to create.
790 if (!Row
|| !Col
|| !isa
<Constant
>(Row
) || !isa
<Constant
>(Col
))
792 // Create tilezero at the end of incoming block.
793 auto *Block
= OldPN
->getIncomingBlock(I
);
794 BasicBlock::iterator Iter
= Block
->getTerminator()->getIterator();
795 Instruction
*NewInst
= Builder
.CreateIntrinsic(
796 Intrinsic::x86_tilezero_internal
, std::nullopt
, {Row
, Col
});
797 NewInst
->moveBefore(&*Iter
);
798 NewInst
= Builder
.CreateIntrinsic(Intrinsic::x86_cast_tile_to_vector
,
799 {IncValue
->getType()}, {NewInst
});
800 NewInst
->moveBefore(&*Iter
);
801 // Replace InValue with new Value.
802 OldPN
->setIncomingValue(I
, NewInst
);
806 if (auto *PNode
= dyn_cast
<PHINode
>(IncValue
)) {
807 if (OldPhiNodes
.insert(PNode
))
808 PhiWorklist
.push_back(PNode
);
811 Instruction
*ACI
= dyn_cast
<Instruction
>(IncValue
);
812 if (ACI
&& isAMXCast(ACI
)) {
813 // Verify it's a A->B cast.
814 Type
*TyA
= ACI
->getOperand(0)->getType();
815 Type
*TyB
= ACI
->getType();
816 if (TyA
!= DestTy
|| TyB
!= SrcTy
)
824 // Check that each user of each old PHI node is something that we can
825 // rewrite, so that all of the old PHI nodes can be cleaned up afterwards.
826 for (auto *OldPN
: OldPhiNodes
) {
827 for (User
*V
: OldPN
->users()) {
828 Instruction
*ACI
= dyn_cast
<Instruction
>(V
);
829 if (ACI
&& isAMXCast(ACI
)) {
830 // Verify it's a B->A cast.
831 Type
*TyB
= ACI
->getOperand(0)->getType();
832 Type
*TyA
= ACI
->getType();
833 if (TyA
!= DestTy
|| TyB
!= SrcTy
)
835 } else if (auto *PHI
= dyn_cast
<PHINode
>(V
)) {
836 // As long as the user is another old PHI node, then even if we don't
837 // rewrite it, the PHI web we're considering won't have any users
838 // outside itself, so it'll be dead.
845 // %goodphi = phi %0, %1
846 // %3 = amxcast %goodphi
848 // %goodphi2 = phi %0, %goodphi
849 // %4 = amxcast %goodphi2
850 // When optimizeAMXCastFromPhi process %3 and %goodphi, %goodphi2 is
851 // outside the phi-web, so the combination stop When
852 // optimizeAMXCastFromPhi process %4 and %goodphi2, the optimization
854 if (OldPhiNodes
.count(PHI
) == 0)
861 // For each old PHI node, create a corresponding new PHI node with a type A.
862 SmallDenseMap
<PHINode
*, PHINode
*> NewPNodes
;
863 for (auto *OldPN
: OldPhiNodes
) {
864 Builder
.SetInsertPoint(OldPN
);
865 PHINode
*NewPN
= Builder
.CreatePHI(DestTy
, OldPN
->getNumOperands());
866 NewPNodes
[OldPN
] = NewPN
;
869 // Fill in the operands of new PHI nodes.
870 for (auto *OldPN
: OldPhiNodes
) {
871 PHINode
*NewPN
= NewPNodes
[OldPN
];
872 for (unsigned j
= 0, e
= OldPN
->getNumOperands(); j
!= e
; ++j
) {
873 Value
*V
= OldPN
->getOperand(j
);
874 Value
*NewV
= nullptr;
875 Instruction
*ACI
= dyn_cast
<Instruction
>(V
);
876 // There should not be a AMXcast from a const.
877 if (ACI
&& isAMXCast(ACI
))
878 NewV
= ACI
->getOperand(0);
879 else if (auto *PrevPN
= dyn_cast
<PHINode
>(V
))
880 NewV
= NewPNodes
[PrevPN
];
882 NewPN
->addIncoming(NewV
, OldPN
->getIncomingBlock(j
));
886 // Traverse all accumulated PHI nodes and process its users,
887 // which are Stores and BitcCasts. Without this processing
888 // NewPHI nodes could be replicated and could lead to extra
889 // moves generated after DeSSA.
890 // If there is a store with type B, change it to type A.
892 // Replace users of BitCast B->A with NewPHI. These will help
893 // later to get rid of a closure formed by OldPHI nodes.
894 for (auto *OldPN
: OldPhiNodes
) {
895 PHINode
*NewPN
= NewPNodes
[OldPN
];
896 for (User
*V
: make_early_inc_range(OldPN
->users())) {
897 Instruction
*ACI
= dyn_cast
<Instruction
>(V
);
898 if (ACI
&& isAMXCast(ACI
)) {
899 Type
*TyB
= ACI
->getOperand(0)->getType();
900 Type
*TyA
= ACI
->getType();
901 assert(TyA
== DestTy
&& TyB
== SrcTy
);
904 ACI
->replaceAllUsesWith(NewPN
);
905 DeadInst
.insert(ACI
);
906 } else if (auto *PHI
= dyn_cast
<PHINode
>(V
)) {
907 // We don't need to push PHINode into DeadInst since they are operands
908 // of rootPN DCE can safely delete rootPN's operands if rootPN is dead.
909 assert(OldPhiNodes
.contains(PHI
));
912 llvm_unreachable("all uses should be handled");
918 // %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %42)
919 // store <256 x i32> %43, <256 x i32>* %p, align 64
921 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p,
922 // i64 64, x86_amx %42)
923 bool X86LowerAMXCast::combineCastStore(IntrinsicInst
*Cast
, StoreInst
*ST
) {
924 Value
*Tile
= Cast
->getOperand(0);
925 // TODO: If it is cast intrinsic or phi node, we can propagate the
926 // shape information through def-use chain.
927 if (!isAMXIntrinsic(Tile
))
929 auto *II
= cast
<IntrinsicInst
>(Tile
);
930 // Tile is output from AMX intrinsic. The first operand of the
931 // intrinsic is row, the second operand of the intrinsic is column.
932 Value
*Row
= II
->getOperand(0);
933 Value
*Col
= II
->getOperand(1);
934 IRBuilder
<> Builder(ST
);
935 // Stride should be equal to col(measured by bytes)
936 Value
*Stride
= Builder
.CreateSExt(Col
, Builder
.getInt64Ty());
938 Builder
.CreateBitCast(ST
->getOperand(1), Builder
.getInt8PtrTy());
939 std::array
<Value
*, 5> Args
= {Row
, Col
, I8Ptr
, Stride
, Tile
};
940 Builder
.CreateIntrinsic(Intrinsic::x86_tilestored64_internal
, std::nullopt
,
945 // %65 = load <256 x i32>, <256 x i32>* %p, align 64
946 // %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65)
948 // %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
950 bool X86LowerAMXCast::combineLoadCast(IntrinsicInst
*Cast
, LoadInst
*LD
) {
951 bool EraseLoad
= true;
952 Value
*Row
= nullptr, *Col
= nullptr;
953 Use
&U
= *(Cast
->use_begin());
954 unsigned OpNo
= U
.getOperandNo();
955 auto *II
= cast
<IntrinsicInst
>(U
.getUser());
956 // TODO: If it is cast intrinsic or phi node, we can propagate the
957 // shape information through def-use chain.
958 if (!isAMXIntrinsic(II
))
960 std::tie(Row
, Col
) = getShape(II
, OpNo
);
961 IRBuilder
<> Builder(LD
);
962 // Stride should be equal to col(measured by bytes)
963 Value
*Stride
= Builder
.CreateSExt(Col
, Builder
.getInt64Ty());
966 // To save compiling time, we create doninator tree when it is really
969 DT
.reset(new DominatorTree(Func
));
970 if (!DT
->dominates(Row
, LD
) || !DT
->dominates(Col
, LD
)) {
971 // store the value to stack and reload it from stack before cast.
973 createAllocaInstAtEntry(Builder
, Cast
->getParent(), LD
->getType());
974 Builder
.SetInsertPoint(&*std::next(LD
->getIterator()));
975 Builder
.CreateStore(LD
, AllocaAddr
);
977 Builder
.SetInsertPoint(Cast
);
978 I8Ptr
= Builder
.CreateBitCast(AllocaAddr
, Builder
.getInt8PtrTy());
981 I8Ptr
= Builder
.CreateBitCast(LD
->getOperand(0), Builder
.getInt8PtrTy());
983 std::array
<Value
*, 4> Args
= {Row
, Col
, I8Ptr
, Stride
};
985 Value
*NewInst
= Builder
.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal
,
987 Cast
->replaceAllUsesWith(NewInst
);
992 bool X86LowerAMXCast::combineLdSt(SmallVectorImpl
<Instruction
*> &Casts
) {
994 for (auto *Cast
: Casts
) {
995 auto *II
= cast
<IntrinsicInst
>(Cast
);
996 // %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector(x86_amx %42)
997 // store <256 x i32> %43, <256 x i32>* %p, align 64
999 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p,
1000 // i64 64, x86_amx %42)
1001 if (II
->getIntrinsicID() == Intrinsic::x86_cast_tile_to_vector
) {
1002 SmallVector
<Instruction
*, 2> DeadStores
;
1003 for (User
*U
: Cast
->users()) {
1004 StoreInst
*Store
= dyn_cast
<StoreInst
>(U
);
1007 if (combineCastStore(cast
<IntrinsicInst
>(Cast
), Store
)) {
1008 DeadStores
.push_back(Store
);
1012 for (auto *Store
: DeadStores
)
1013 Store
->eraseFromParent();
1014 } else { // x86_cast_vector_to_tile
1015 SmallVector
<Instruction
*, 2> DeadLoads
;
1016 auto *Load
= dyn_cast
<LoadInst
>(Cast
->getOperand(0));
1017 if (!Load
|| !Load
->hasOneUse())
1019 // %65 = load <256 x i32>, <256 x i32>* %p, align 64
1020 // %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65)
1022 // %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
1024 if (combineLoadCast(cast
<IntrinsicInst
>(Cast
), Load
)) {
1025 // Set the operand is null so that load instruction can be erased.
1026 Cast
->setOperand(0, nullptr);
1027 Load
->eraseFromParent();
1034 bool X86LowerAMXCast::combineAMXcast(TargetLibraryInfo
*TLI
) {
1035 bool Change
= false;
1036 // Collect tile cast instruction.
1037 SmallVector
<Instruction
*, 8> Vec2TileInsts
;
1038 SmallVector
<Instruction
*, 8> Tile2VecInsts
;
1039 SmallVector
<Instruction
*, 8> PhiCastWorkList
;
1040 SmallSetVector
<Instruction
*, 16> DeadInst
;
1041 for (BasicBlock
&BB
: Func
) {
1042 for (Instruction
&I
: BB
) {
1045 m_Intrinsic
<Intrinsic::x86_cast_vector_to_tile
>(m_Value(Vec
))))
1046 Vec2TileInsts
.push_back(&I
);
1047 else if (match(&I
, m_Intrinsic
<Intrinsic::x86_cast_tile_to_vector
>(
1049 Tile2VecInsts
.push_back(&I
);
1053 auto Convert
= [&](SmallVectorImpl
<Instruction
*> &Insts
, Intrinsic::ID IID
) {
1054 for (auto *Inst
: Insts
) {
1055 for (User
*U
: Inst
->users()) {
1056 IntrinsicInst
*II
= dyn_cast
<IntrinsicInst
>(U
);
1057 if (!II
|| II
->getIntrinsicID() != IID
)
1066 II
->replaceAllUsesWith(Inst
->getOperand(0));
1072 Convert(Vec2TileInsts
, Intrinsic::x86_cast_tile_to_vector
);
1073 Convert(Tile2VecInsts
, Intrinsic::x86_cast_vector_to_tile
);
1075 SmallVector
<Instruction
*, 8> LiveCasts
;
1076 auto EraseInst
= [&](SmallVectorImpl
<Instruction
*> &Insts
) {
1077 for (auto *Inst
: Insts
) {
1078 if (Inst
->use_empty()) {
1079 Inst
->eraseFromParent();
1082 LiveCasts
.push_back(Inst
);
1087 EraseInst(Vec2TileInsts
);
1088 EraseInst(Tile2VecInsts
);
1089 LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after combine "
1090 "Vec2Tile and Tile2Vec:\n";
1092 Change
|= combineLdSt(LiveCasts
);
1093 EraseInst(LiveCasts
);
1094 LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after combine "
1095 "AMXCast and load/store:\n";
1098 // Handle the A->B->A cast, and there is an intervening PHI node.
1099 for (BasicBlock
&BB
: Func
) {
1100 for (Instruction
&I
: BB
) {
1101 if (isAMXCast(&I
)) {
1102 if (isa
<PHINode
>(I
.getOperand(0)))
1103 PhiCastWorkList
.push_back(&I
);
1107 for (auto *I
: PhiCastWorkList
) {
1108 // We skip the dead Amxcast.
1109 if (DeadInst
.contains(I
))
1111 PHINode
*PN
= cast
<PHINode
>(I
->getOperand(0));
1112 if (optimizeAMXCastFromPhi(cast
<IntrinsicInst
>(I
), PN
, DeadInst
)) {
1113 DeadInst
.insert(PN
);
1118 // Since we create new phi and merge AMXCast, some old phis and AMXCast might
1119 // have no uses. We do some DeadCodeElimination for them.
1120 while (!DeadInst
.empty()) {
1121 Instruction
*I
= DeadInst
.pop_back_val();
1122 Change
|= DCEInstruction(I
, DeadInst
, TLI
);
1124 LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after "
1125 "optimizeAMXCastFromPhi:\n";
1130 // There might be remaining AMXcast after combineAMXcast and they should be
1131 // handled elegantly.
1132 bool X86LowerAMXCast::transformAMXCast(IntrinsicInst
*AMXCast
) {
1133 IRBuilder
<> Builder(AMXCast
);
1134 AllocaInst
*AllocaAddr
;
1135 Value
*I8Ptr
, *Stride
;
1136 auto *Src
= AMXCast
->getOperand(0);
1138 auto Prepare
= [&](Type
*MemTy
) {
1139 AllocaAddr
= createAllocaInstAtEntry(Builder
, AMXCast
->getParent(), MemTy
);
1140 I8Ptr
= Builder
.CreateBitCast(AllocaAddr
, Builder
.getInt8PtrTy());
1141 Stride
= Builder
.getInt64(64);
1144 if (AMXCast
->getType()->isX86_AMXTy()) {
1145 // %2 = amxcast <225 x i32> %src to x86_amx
1146 // call void @llvm.x86.tilestored64.internal(i16 15, i16 60,
1147 // i8* %addr3, i64 60, x86_amx %2)
1149 // %addr = alloca <225 x i32>, align 64
1150 // store <225 x i32> %src, <225 x i32>* %addr, align 64
1151 // %addr2 = bitcast <225 x i32>* %addr to i8*
1152 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 15, i16 60,
1155 // call void @llvm.x86.tilestored64.internal(i16 15, i16 60,
1156 // i8* %addr3, i64 60, x86_amx %2)
1157 if (AMXCast
->use_empty()) {
1158 AMXCast
->eraseFromParent();
1161 Use
&U
= *(AMXCast
->use_begin());
1162 unsigned OpNo
= U
.getOperandNo();
1163 auto *II
= dyn_cast
<IntrinsicInst
>(U
.getUser());
1165 return false; // May be bitcast from x86amx to <256 x i32>.
1166 Prepare(AMXCast
->getOperand(0)->getType());
1167 Builder
.CreateStore(Src
, AllocaAddr
);
1168 // TODO we can pick an constant operand for the shape.
1169 Value
*Row
= nullptr, *Col
= nullptr;
1170 std::tie(Row
, Col
) = getShape(II
, OpNo
);
1171 std::array
<Value
*, 4> Args
= {
1172 Row
, Col
, I8Ptr
, Builder
.CreateSExt(Col
, Builder
.getInt64Ty())};
1173 Value
*NewInst
= Builder
.CreateIntrinsic(
1174 Intrinsic::x86_tileloadd64_internal
, std::nullopt
, Args
);
1175 AMXCast
->replaceAllUsesWith(NewInst
);
1176 AMXCast
->eraseFromParent();
1178 // %2 = amxcast x86_amx %src to <225 x i32>
1180 // %addr = alloca <225 x i32>, align 64
1181 // %addr2 = bitcast <225 x i32>* to i8*
1182 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
1183 // i8* %addr2, i64 %stride)
1184 // %2 = load <225 x i32>, <225 x i32>* %addr, align 64
1185 auto *II
= dyn_cast
<IntrinsicInst
>(Src
);
1187 return false; // May be bitcast from <256 x i32> to x86amx.
1188 Prepare(AMXCast
->getType());
1189 Value
*Row
= II
->getOperand(0);
1190 Value
*Col
= II
->getOperand(1);
1191 std::array
<Value
*, 5> Args
= {
1192 Row
, Col
, I8Ptr
, Builder
.CreateSExt(Col
, Builder
.getInt64Ty()), Src
};
1193 Builder
.CreateIntrinsic(Intrinsic::x86_tilestored64_internal
, std::nullopt
,
1195 Value
*NewInst
= Builder
.CreateLoad(AMXCast
->getType(), AllocaAddr
);
1196 AMXCast
->replaceAllUsesWith(NewInst
);
1197 AMXCast
->eraseFromParent();
1203 bool X86LowerAMXCast::transformAllAMXCast() {
1204 bool Change
= false;
1205 // Collect tile cast instruction.
1206 SmallVector
<Instruction
*, 8> WorkLists
;
1207 for (BasicBlock
&BB
: Func
) {
1208 for (Instruction
&I
: BB
) {
1210 WorkLists
.push_back(&I
);
1214 for (auto *Inst
: WorkLists
) {
1215 Change
|= transformAMXCast(cast
<IntrinsicInst
>(Inst
));
1221 } // anonymous namespace
1225 class X86LowerAMXTypeLegacyPass
: public FunctionPass
{
1229 X86LowerAMXTypeLegacyPass() : FunctionPass(ID
) {
1230 initializeX86LowerAMXTypeLegacyPassPass(*PassRegistry::getPassRegistry());
1233 bool runOnFunction(Function
&F
) override
{
1235 TargetMachine
*TM
= &getAnalysis
<TargetPassConfig
>().getTM
<TargetMachine
>();
1236 TargetLibraryInfo
*TLI
=
1237 &getAnalysis
<TargetLibraryInfoWrapperPass
>().getTLI(F
);
1239 X86LowerAMXCast
LAC(F
);
1240 C
|= LAC
.combineAMXcast(TLI
);
1241 // There might be remaining AMXcast after combineAMXcast and they should be
1242 // handled elegantly.
1243 C
|= LAC
.transformAllAMXCast();
1245 X86LowerAMXType
LAT(F
);
1248 // Prepare for fast register allocation at O0.
1249 // Todo: May better check the volatile model of AMX code, not just
1250 // by checking Attribute::OptimizeNone and CodeGenOptLevel::None.
1251 if (TM
->getOptLevel() == CodeGenOptLevel::None
) {
1252 // If Front End not use O0 but the Mid/Back end use O0, (e.g.
1253 // "Clang -O2 -S -emit-llvm t.c" + "llc t.ll") we should make
1254 // sure the amx data is volatile, that is nessary for AMX fast
1255 // register allocation.
1256 if (!F
.hasFnAttribute(Attribute::OptimizeNone
)) {
1257 X86VolatileTileData
VTD(F
);
1258 C
= VTD
.volatileTileData() || C
;
1265 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
1266 AU
.setPreservesCFG();
1267 AU
.addRequired
<TargetPassConfig
>();
1268 AU
.addRequired
<TargetLibraryInfoWrapperPass
>();
1272 } // anonymous namespace
1274 static const char PassName
[] = "Lower AMX type for load/store";
1275 char X86LowerAMXTypeLegacyPass::ID
= 0;
1276 INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass
, DEBUG_TYPE
, PassName
, false,
1278 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig
)
1279 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass
)
1280 INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass
, DEBUG_TYPE
, PassName
, false,
1283 FunctionPass
*llvm::createX86LowerAMXTypePass() {
1284 return new X86LowerAMXTypeLegacyPass();