1 //===-- X86LowerAMXIntrinsics.cpp -X86 Scalarize AMX Intrinsics------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 /// \file Pass to transform amx intrinsics to scalar operations.
10 /// This pass is always enabled and it skips when it is not -O0 and has no
11 /// optnone attributes. With -O0 or optnone attribute, the def of shape to amx
12 /// intrinsics is near the amx intrinsics code. We are not able to find a
13 /// point which post-dominate all the shape and dominate all amx intrinsics.
14 /// To decouple the dependency of the shape, we transform amx intrinsics
15 /// to scalar operation, so that compiling doesn't fail. In long term, we
16 /// should improve fast register allocation to allocate amx register.
17 //===----------------------------------------------------------------------===//
20 #include "llvm/Analysis/DomTreeUpdater.h"
21 #include "llvm/Analysis/LoopInfo.h"
22 #include "llvm/Analysis/TargetTransformInfo.h"
23 #include "llvm/CodeGen/Passes.h"
24 #include "llvm/CodeGen/TargetPassConfig.h"
25 #include "llvm/CodeGen/ValueTypes.h"
26 #include "llvm/IR/DataLayout.h"
27 #include "llvm/IR/Function.h"
28 #include "llvm/IR/IRBuilder.h"
29 #include "llvm/IR/Instructions.h"
30 #include "llvm/IR/IntrinsicInst.h"
31 #include "llvm/IR/IntrinsicsX86.h"
32 #include "llvm/IR/PatternMatch.h"
33 #include "llvm/InitializePasses.h"
34 #include "llvm/Pass.h"
35 #include "llvm/Support/CommandLine.h"
36 #include "llvm/Target/TargetMachine.h"
37 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
38 #include "llvm/Transforms/Utils/LoopUtils.h"
41 using namespace PatternMatch
;
43 #define DEBUG_TYPE "lower-amx-intrinsics"
46 static bool isV256I32Ty(Type
*Ty
) {
47 if (auto *FVT
= dyn_cast
<FixedVectorType
>(Ty
))
48 return FVT
->getNumElements() == 256 &&
49 FVT
->getElementType()->isIntegerTy(32);
55 X86ScalarizeAMX("enable-x86-scalar-amx", cl::init(false), cl::Hidden
,
56 cl::desc("X86: enable AMX scalarizition."));
59 class X86LowerAMXIntrinsics
{
63 X86LowerAMXIntrinsics(Function
&F
, DomTreeUpdater
&DomTU
, LoopInfo
*LoopI
)
64 : Func(F
), DTU(DomTU
), LI(LoopI
) {}
70 BasicBlock
*createLoop(BasicBlock
*Preheader
, BasicBlock
*Exit
, Value
*Bound
,
71 Value
*Step
, StringRef Name
, IRBuilderBase
&B
,
73 template <bool IsTileLoad
>
74 Value
*createTileLoadStoreLoops(BasicBlock
*Start
, BasicBlock
*End
,
75 IRBuilderBase
&B
, Value
*Row
, Value
*Col
,
76 Value
*Ptr
, Value
*Stride
, Value
*Tile
);
77 template <Intrinsic::ID IntrID
>
78 std::enable_if_t
<IntrID
== Intrinsic::x86_tdpbssd_internal
||
79 IntrID
== Intrinsic::x86_tdpbsud_internal
||
80 IntrID
== Intrinsic::x86_tdpbusd_internal
||
81 IntrID
== Intrinsic::x86_tdpbuud_internal
||
82 IntrID
== Intrinsic::x86_tdpbf16ps_internal
,
84 createTileDPLoops(BasicBlock
*Start
, BasicBlock
*End
, IRBuilderBase
&B
,
85 Value
*Row
, Value
*Col
, Value
*K
, Value
*Acc
, Value
*LHS
,
87 template <bool IsTileLoad
>
88 bool lowerTileLoadStore(Instruction
*TileLoadStore
);
89 template <Intrinsic::ID IntrID
>
90 std::enable_if_t
<IntrID
== Intrinsic::x86_tdpbssd_internal
||
91 IntrID
== Intrinsic::x86_tdpbsud_internal
||
92 IntrID
== Intrinsic::x86_tdpbusd_internal
||
93 IntrID
== Intrinsic::x86_tdpbuud_internal
||
94 IntrID
== Intrinsic::x86_tdpbf16ps_internal
,
96 lowerTileDP(Instruction
*TileDP
);
97 bool lowerTileZero(Instruction
*TileZero
);
99 } // anonymous namespace
101 BasicBlock
*X86LowerAMXIntrinsics::createLoop(BasicBlock
*Preheader
,
102 BasicBlock
*Exit
, Value
*Bound
,
103 Value
*Step
, StringRef Name
,
104 IRBuilderBase
&B
, Loop
*L
) {
105 LLVMContext
&Ctx
= Preheader
->getContext();
107 BasicBlock::Create(Ctx
, Name
+ ".header", Preheader
->getParent(), Exit
);
109 BasicBlock::Create(Ctx
, Name
+ ".body", Header
->getParent(), Exit
);
111 BasicBlock::Create(Ctx
, Name
+ ".latch", Header
->getParent(), Exit
);
113 Type
*I16Ty
= Type::getInt16Ty(Ctx
);
114 BranchInst::Create(Body
, Header
);
115 BranchInst::Create(Latch
, Body
);
117 PHINode::Create(I16Ty
, 2, Name
+ ".iv", Header
->getTerminator()->getIterator());
118 IV
->addIncoming(ConstantInt::get(I16Ty
, 0), Preheader
);
120 B
.SetInsertPoint(Latch
);
121 Value
*Inc
= B
.CreateAdd(IV
, Step
, Name
+ ".step");
122 Value
*Cond
= B
.CreateICmpNE(Inc
, Bound
, Name
+ ".cond");
123 BranchInst::Create(Header
, Exit
, Cond
, Latch
);
124 IV
->addIncoming(Inc
, Latch
);
126 BranchInst
*PreheaderBr
= cast
<BranchInst
>(Preheader
->getTerminator());
127 BasicBlock
*Tmp
= PreheaderBr
->getSuccessor(0);
128 PreheaderBr
->setSuccessor(0, Header
);
129 DTU
.applyUpdatesPermissive({
130 {DominatorTree::Delete
, Preheader
, Tmp
},
131 {DominatorTree::Insert
, Header
, Body
},
132 {DominatorTree::Insert
, Body
, Latch
},
133 {DominatorTree::Insert
, Latch
, Header
},
134 {DominatorTree::Insert
, Latch
, Exit
},
135 {DominatorTree::Insert
, Preheader
, Header
},
138 L
->addBasicBlockToLoop(Header
, *LI
);
139 L
->addBasicBlockToLoop(Body
, *LI
);
140 L
->addBasicBlockToLoop(Latch
, *LI
);
145 template <bool IsTileLoad
>
146 Value
*X86LowerAMXIntrinsics::createTileLoadStoreLoops(
147 BasicBlock
*Start
, BasicBlock
*End
, IRBuilderBase
&B
, Value
*Row
,
148 Value
*Col
, Value
*Ptr
, Value
*Stride
, Value
*Tile
) {
149 std::string IntrinName
= IsTileLoad
? "tileload" : "tilestore";
150 Loop
*RowLoop
= nullptr;
151 Loop
*ColLoop
= nullptr;
153 RowLoop
= LI
->AllocateLoop();
154 ColLoop
= LI
->AllocateLoop();
155 RowLoop
->addChildLoop(ColLoop
);
156 if (Loop
*ParentL
= LI
->getLoopFor(Start
))
157 ParentL
->addChildLoop(RowLoop
);
159 LI
->addTopLevelLoop(RowLoop
);
162 BasicBlock
*RowBody
= createLoop(Start
, End
, Row
, B
.getInt16(1),
163 IntrinName
+ ".scalarize.rows", B
, RowLoop
);
164 BasicBlock
*RowLatch
= RowBody
->getSingleSuccessor();
166 BasicBlock
*ColBody
= createLoop(RowBody
, RowLatch
, Col
, B
.getInt16(1),
167 IntrinName
+ ".scalarize.cols", B
, ColLoop
);
169 BasicBlock
*ColLoopLatch
= ColBody
->getSingleSuccessor();
170 BasicBlock
*ColLoopHeader
= ColBody
->getSinglePredecessor();
171 BasicBlock
*RowLoopHeader
= RowBody
->getSinglePredecessor();
172 Value
*CurrentRow
= &*RowLoopHeader
->begin();
173 Value
*CurrentCol
= &*ColLoopHeader
->begin();
174 Type
*EltTy
= B
.getInt32Ty();
175 FixedVectorType
*V256I32Ty
= FixedVectorType::get(EltTy
, 256);
177 // Common part for tileload and tilestore
178 // *.scalarize.cols.body:
179 // Calculate %idxmem and %idxvec
180 B
.SetInsertPoint(ColBody
->getTerminator());
181 Value
*CurrentRowZExt
= B
.CreateZExt(CurrentRow
, Stride
->getType());
182 Value
*CurrentColZExt
= B
.CreateZExt(CurrentCol
, Stride
->getType());
184 B
.CreateAdd(B
.CreateMul(CurrentRowZExt
, Stride
), CurrentColZExt
);
185 Value
*EltPtr
= B
.CreateGEP(EltTy
, Ptr
, Offset
);
186 Value
*Idx
= B
.CreateAdd(B
.CreateMul(CurrentRow
, B
.getInt16(16)), CurrentCol
);
188 // tileload.scalarize.rows.header:
189 // %vec.phi.row = phi <256 x i32> [ zeroinitializer, %entry ], [ %ResVec,
190 // %tileload.scalarize.rows.latch ]
191 B
.SetInsertPoint(RowLoopHeader
->getTerminator());
192 Value
*VecZero
= Constant::getNullValue(V256I32Ty
);
193 PHINode
*VecCPhiRowLoop
= B
.CreatePHI(V256I32Ty
, 2, "vec.phi.row");
194 VecCPhiRowLoop
->addIncoming(VecZero
, Start
);
196 // tileload.scalarize.cols.header:
197 // %vec.phi = phi <256 x i32> [ %vec.phi.row, %tileload.scalarize.rows.body
198 // ], [ %ResVec, %tileload.scalarize.cols.latch ]
199 B
.SetInsertPoint(ColLoopHeader
->getTerminator());
200 PHINode
*VecPhi
= B
.CreatePHI(V256I32Ty
, 2, "vec.phi");
201 VecPhi
->addIncoming(VecCPhiRowLoop
, RowBody
);
203 // tileload.scalarize.cols.body:
204 // Calculate %idxmem and %idxvec
205 // %eltptr = getelementptr i32, i32* %base, i64 %idxmem
206 // %elt = load i32, i32* %ptr
207 // %ResVec = insertelement <256 x i32> %vec.phi, i32 %elt, i16 %idxvec
208 B
.SetInsertPoint(ColBody
->getTerminator());
209 Value
*Elt
= B
.CreateLoad(EltTy
, EltPtr
);
210 Value
*ResVec
= B
.CreateInsertElement(VecPhi
, Elt
, Idx
);
211 VecPhi
->addIncoming(ResVec
, ColLoopLatch
);
212 VecCPhiRowLoop
->addIncoming(ResVec
, RowLatch
);
216 auto *BitCast
= cast
<BitCastInst
>(Tile
);
217 Value
*Vec
= BitCast
->getOperand(0);
218 assert(isV256I32Ty(Vec
->getType()) && "bitcast from non-v256i32 to x86amx");
219 // tilestore.scalarize.cols.body:
220 // %mul = mul i16 %row.iv, i16 16
221 // %idx = add i16 %mul, i16 %col.iv
222 // %vec = extractelement <16 x i32> %vec, i16 %idx
223 // store i32 %vec, i32* %ptr
224 B
.SetInsertPoint(ColBody
->getTerminator());
225 Value
*Elt
= B
.CreateExtractElement(Vec
, Idx
);
227 B
.CreateStore(Elt
, EltPtr
);
232 template <Intrinsic::ID IntrID
>
233 std::enable_if_t
<IntrID
== Intrinsic::x86_tdpbssd_internal
||
234 IntrID
== Intrinsic::x86_tdpbsud_internal
||
235 IntrID
== Intrinsic::x86_tdpbusd_internal
||
236 IntrID
== Intrinsic::x86_tdpbuud_internal
||
237 IntrID
== Intrinsic::x86_tdpbf16ps_internal
,
239 X86LowerAMXIntrinsics::createTileDPLoops(BasicBlock
*Start
, BasicBlock
*End
,
240 IRBuilderBase
&B
, Value
*Row
,
241 Value
*Col
, Value
*K
, Value
*Acc
,
242 Value
*LHS
, Value
*RHS
) {
243 std::string IntrinName
;
245 case Intrinsic::x86_tdpbssd_internal
:
246 IntrinName
= "tiledpbssd";
248 case Intrinsic::x86_tdpbsud_internal
:
249 IntrinName
= "tiledpbsud";
251 case Intrinsic::x86_tdpbusd_internal
:
252 IntrinName
= "tiledpbusd";
254 case Intrinsic::x86_tdpbuud_internal
:
255 IntrinName
= "tiledpbuud";
257 case Intrinsic::x86_tdpbf16ps_internal
:
258 IntrinName
= "tiledpbf16ps";
261 Loop
*RowLoop
= nullptr;
262 Loop
*ColLoop
= nullptr;
263 Loop
*InnerLoop
= nullptr;
265 RowLoop
= LI
->AllocateLoop();
266 ColLoop
= LI
->AllocateLoop();
267 InnerLoop
= LI
->AllocateLoop();
268 ColLoop
->addChildLoop(InnerLoop
);
269 RowLoop
->addChildLoop(ColLoop
);
270 if (Loop
*ParentL
= LI
->getLoopFor(Start
))
271 ParentL
->addChildLoop(RowLoop
);
273 LI
->addTopLevelLoop(RowLoop
);
276 BasicBlock
*RowBody
= createLoop(Start
, End
, Row
, B
.getInt16(1),
277 IntrinName
+ ".scalarize.rows", B
, RowLoop
);
278 BasicBlock
*RowLatch
= RowBody
->getSingleSuccessor();
280 BasicBlock
*ColBody
= createLoop(RowBody
, RowLatch
, Col
, B
.getInt16(1),
281 IntrinName
+ ".scalarize.cols", B
, ColLoop
);
283 BasicBlock
*ColLoopLatch
= ColBody
->getSingleSuccessor();
285 B
.SetInsertPoint(ColBody
->getTerminator());
286 BasicBlock
*InnerBody
=
287 createLoop(ColBody
, ColLoopLatch
, K
, B
.getInt16(1),
288 IntrinName
+ ".scalarize.inner", B
, InnerLoop
);
290 BasicBlock
*ColLoopHeader
= ColBody
->getSinglePredecessor();
291 BasicBlock
*RowLoopHeader
= RowBody
->getSinglePredecessor();
292 BasicBlock
*InnerLoopHeader
= InnerBody
->getSinglePredecessor();
293 BasicBlock
*InnerLoopLatch
= InnerBody
->getSingleSuccessor();
294 Value
*CurrentRow
= &*RowLoopHeader
->begin();
295 Value
*CurrentCol
= &*ColLoopHeader
->begin();
296 Value
*CurrentInner
= &*InnerLoopHeader
->begin();
298 FixedVectorType
*V256I32Ty
= FixedVectorType::get(B
.getInt32Ty(), 256);
299 auto *BitCastAcc
= cast
<BitCastInst
>(Acc
);
300 Value
*VecC
= BitCastAcc
->getOperand(0);
301 assert(isV256I32Ty(VecC
->getType()) && "bitcast from non-v256i32 to x86amx");
302 // TODO else create BitCast from x86amx to v256i32.
303 // Store x86amx to memory, and reload from memory
304 // to vector. However with -O0, it doesn't happen.
305 auto *BitCastLHS
= cast
<BitCastInst
>(LHS
);
306 Value
*VecA
= BitCastLHS
->getOperand(0);
307 assert(isV256I32Ty(VecA
->getType()) && "bitcast from non-v256i32 to x86amx");
308 auto *BitCastRHS
= cast
<BitCastInst
>(RHS
);
309 Value
*VecB
= BitCastRHS
->getOperand(0);
310 assert(isV256I32Ty(VecB
->getType()) && "bitcast from non-v256i32 to x86amx");
312 // tiledpbssd.scalarize.rows.header:
313 // %vec.c.phi.row = phi <256 x i32> [ %VecC, %continue ], [ %NewVecC,
314 // %tiledpbssd.scalarize.rows.latch ]
316 // %vec.d.phi.row = phi <256 x i32> [ zeroinitializer, %continue ], [
317 // %NewVecD, %tiledpbssd.scalarize.rows.latch ]
318 B
.SetInsertPoint(RowLoopHeader
->getTerminator());
319 PHINode
*VecCPhiRowLoop
= B
.CreatePHI(V256I32Ty
, 2, "vec.c.phi.row");
320 VecCPhiRowLoop
->addIncoming(VecC
, Start
);
321 Value
*VecZero
= Constant::getNullValue(V256I32Ty
);
322 PHINode
*VecDPhiRowLoop
= B
.CreatePHI(V256I32Ty
, 2, "vec.d.phi.row");
323 VecDPhiRowLoop
->addIncoming(VecZero
, Start
);
325 // tiledpbssd.scalarize.cols.header:
326 // %vec.c.phi.col = phi <256 x i32> [ %vec.c.phi.row,
327 // %tiledpbssd.scalarize.rows.body ], [ %NewVecC,
328 // %tiledpbssd.scalarize.cols.latch ]
330 // %vec.d.phi.col = phi <256 x i32> [
331 // %vec.d.phi.row, %tiledpbssd.scalarize.rows.body ], [ %NewVecD,
332 // %tiledpbssd.scalarize.cols.latch ]
335 B
.SetInsertPoint(ColLoopHeader
->getTerminator());
336 PHINode
*VecCPhiColLoop
= B
.CreatePHI(V256I32Ty
, 2, "vec.c.phi.col");
337 VecCPhiColLoop
->addIncoming(VecCPhiRowLoop
, RowBody
);
338 PHINode
*VecDPhiColLoop
= B
.CreatePHI(V256I32Ty
, 2, "vec.d.phi.col");
339 VecDPhiColLoop
->addIncoming(VecDPhiRowLoop
, RowBody
);
341 B
.CreateAdd(B
.CreateMul(CurrentRow
, B
.getInt16(16)), CurrentCol
);
343 // tiledpbssd.scalarize.inner.header:
344 // %vec.c.inner.phi = phi <256 x i32> [ %vec.c.phi.col,
345 // %tiledpbssd.scalarize.cols.body ], [ %NewVecC,
346 // %tiledpbssd.scalarize.inner.latch ]
348 B
.SetInsertPoint(InnerLoopHeader
->getTerminator());
349 PHINode
*VecCPhi
= B
.CreatePHI(V256I32Ty
, 2, "vec.c.inner.phi");
350 VecCPhi
->addIncoming(VecCPhiColLoop
, ColBody
);
352 B
.SetInsertPoint(InnerBody
->getTerminator());
354 B
.CreateAdd(B
.CreateMul(CurrentRow
, B
.getInt16(16)), CurrentInner
);
356 B
.CreateAdd(B
.CreateMul(CurrentInner
, B
.getInt16(16)), CurrentCol
);
357 Value
*NewVecC
= nullptr;
359 if (IntrID
!= Intrinsic::x86_tdpbf16ps_internal
) {
360 // tiledpbssd.scalarize.inner.body:
361 // calculate idxa, idxb
362 // %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc
363 // %elta = extractelement <256 x i32> %veca, i16 %idxa
364 // %eltav4i8 = bitcast i32 %elta to <4 x i8>
365 // %eltb = extractelement <256 x i32> %vecb, i16 %idxb
366 // %eltbv4i8 = bitcast i32 %eltb to <4 x i8>
367 // %eltav4i32 = sext <4 x i8> %eltav4i8 to <4 x i32>
368 // %eltbv4i32 = sext <4 x i8> %eltbv4i8 to <4 x i32>
369 // %mulab = mul <4 x i32> %eltbv4i32, %eltav4i32
370 // %acc = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %131)
371 // %neweltc = add i32 %elt, %acc
372 // %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc,
374 FixedVectorType
*V4I8Ty
= FixedVectorType::get(B
.getInt8Ty(), 4);
375 FixedVectorType
*V4I32Ty
= FixedVectorType::get(B
.getInt32Ty(), 4);
376 Value
*EltC
= B
.CreateExtractElement(VecCPhi
, IdxC
);
377 Value
*EltA
= B
.CreateExtractElement(VecA
, IdxA
);
378 Value
*SubVecA
= B
.CreateBitCast(EltA
, V4I8Ty
);
379 Value
*EltB
= B
.CreateExtractElement(VecB
, IdxB
);
380 Value
*SubVecB
= B
.CreateBitCast(EltB
, V4I8Ty
);
381 Value
*SEXTSubVecB
= nullptr;
382 Value
*SEXTSubVecA
= nullptr;
384 case Intrinsic::x86_tdpbssd_internal
:
385 SEXTSubVecB
= B
.CreateSExt(SubVecB
, V4I32Ty
);
386 SEXTSubVecA
= B
.CreateSExt(SubVecA
, V4I32Ty
);
388 case Intrinsic::x86_tdpbsud_internal
:
389 SEXTSubVecB
= B
.CreateZExt(SubVecB
, V4I32Ty
);
390 SEXTSubVecA
= B
.CreateSExt(SubVecA
, V4I32Ty
);
392 case Intrinsic::x86_tdpbusd_internal
:
393 SEXTSubVecB
= B
.CreateSExt(SubVecB
, V4I32Ty
);
394 SEXTSubVecA
= B
.CreateZExt(SubVecA
, V4I32Ty
);
396 case Intrinsic::x86_tdpbuud_internal
:
397 SEXTSubVecB
= B
.CreateZExt(SubVecB
, V4I32Ty
);
398 SEXTSubVecA
= B
.CreateZExt(SubVecA
, V4I32Ty
);
401 llvm_unreachable("Invalid intrinsic ID!");
403 Value
*SubVecR
= B
.CreateAddReduce(B
.CreateMul(SEXTSubVecA
, SEXTSubVecB
));
404 Value
*ResElt
= B
.CreateAdd(EltC
, SubVecR
);
405 NewVecC
= B
.CreateInsertElement(VecCPhi
, ResElt
, IdxC
);
407 // tiledpbf16ps.scalarize.inner.body:
408 // calculate idxa, idxb, idxc
409 // %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc
410 // %eltcf32 = bitcast i32 %eltc to float
411 // %elta = extractelement <256 x i32> %veca, i16 %idxa
412 // %eltav2i16 = bitcast i32 %elta to <2 x i16>
413 // %eltb = extractelement <256 x i32> %vecb, i16 %idxb
414 // %eltbv2i16 = bitcast i32 %eltb to <2 x i16>
415 // %shufflea = shufflevector <2 x i16> %elta, <2 x i16> zeroinitializer, <4
416 // x i32> <i32 2, i32 0, i32 3, i32 1>
417 // %eltav2f32 = bitcast <4 x i16> %shufflea to <2 x float>
418 // %shuffleb = shufflevector <2 x i16> %eltb, <2 xi16> zeroinitializer, <4 x
419 // i32> <i32 2, i32 0, i32 3, i32 1>
420 // %eltbv2f32 = bitcast <4 x i16> %shuffleb to <2 x float>
421 // %mulab = fmul <2 x float> %eltav2f32, %eltbv2f32
423 // @llvm.vector.reduce.fadd.v2f32(float %eltcf32, <2 x float> %mulab)
424 // %neweltc = bitcast float %acc to i32
425 // %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc,
427 // %NewVecD = insertelement <256 x i32> %vec.d.inner.phi, i32 %neweltc,
429 FixedVectorType
*V2I16Ty
= FixedVectorType::get(B
.getInt16Ty(), 2);
430 FixedVectorType
*V2F32Ty
= FixedVectorType::get(B
.getFloatTy(), 2);
431 Value
*EltC
= B
.CreateExtractElement(VecCPhi
, IdxC
);
432 Value
*EltCF32
= B
.CreateBitCast(EltC
, B
.getFloatTy());
433 Value
*EltA
= B
.CreateExtractElement(VecA
, IdxA
);
434 Value
*SubVecA
= B
.CreateBitCast(EltA
, V2I16Ty
);
435 Value
*EltB
= B
.CreateExtractElement(VecB
, IdxB
);
436 Value
*SubVecB
= B
.CreateBitCast(EltB
, V2I16Ty
);
437 Value
*ZeroV2I16
= Constant::getNullValue(V2I16Ty
);
438 int ShuffleMask
[4] = {2, 0, 3, 1};
439 auto ShuffleArray
= ArrayRef(ShuffleMask
);
440 Value
*AV2F32
= B
.CreateBitCast(
441 B
.CreateShuffleVector(SubVecA
, ZeroV2I16
, ShuffleArray
), V2F32Ty
);
442 Value
*BV2F32
= B
.CreateBitCast(
443 B
.CreateShuffleVector(SubVecB
, ZeroV2I16
, ShuffleArray
), V2F32Ty
);
444 Value
*SubVecR
= B
.CreateFAddReduce(EltCF32
, B
.CreateFMul(AV2F32
, BV2F32
));
445 Value
*ResElt
= B
.CreateBitCast(SubVecR
, B
.getInt32Ty());
446 NewVecC
= B
.CreateInsertElement(VecCPhi
, ResElt
, IdxC
);
449 // tiledpbssd.scalarize.cols.latch:
450 // %NewEltC = extractelement <256 x i32> %vec.c.phi.col, i16 %idxc
451 // %NewVecD = insertelement <256 x i32> %vec.d.phi.col, i32 %NewEltC,
453 B
.SetInsertPoint(ColLoopLatch
->getTerminator());
454 Value
*NewEltC
= B
.CreateExtractElement(NewVecC
, IdxC
);
455 Value
*NewVecD
= B
.CreateInsertElement(VecDPhiColLoop
, NewEltC
, IdxC
);
457 VecCPhi
->addIncoming(NewVecC
, InnerLoopLatch
);
458 VecCPhiRowLoop
->addIncoming(NewVecC
, RowLatch
);
459 VecCPhiColLoop
->addIncoming(NewVecC
, ColLoopLatch
);
460 VecDPhiRowLoop
->addIncoming(NewVecD
, RowLatch
);
461 VecDPhiColLoop
->addIncoming(NewVecD
, ColLoopLatch
);
466 template <Intrinsic::ID IntrID
>
467 std::enable_if_t
<IntrID
== Intrinsic::x86_tdpbssd_internal
||
468 IntrID
== Intrinsic::x86_tdpbsud_internal
||
469 IntrID
== Intrinsic::x86_tdpbusd_internal
||
470 IntrID
== Intrinsic::x86_tdpbuud_internal
||
471 IntrID
== Intrinsic::x86_tdpbf16ps_internal
,
473 X86LowerAMXIntrinsics::lowerTileDP(Instruction
*TileDP
) {
474 Value
*M
, *N
, *K
, *C
, *A
, *B
;
475 match(TileDP
, m_Intrinsic
<IntrID
>(m_Value(M
), m_Value(N
), m_Value(K
),
476 m_Value(C
), m_Value(A
), m_Value(B
)));
477 Instruction
*InsertI
= TileDP
;
478 IRBuilder
<> PreBuilder(TileDP
);
479 PreBuilder
.SetInsertPoint(TileDP
);
480 // We visit the loop with (m, n/4, k/4):
481 // %n_dword = lshr i16 %n, 2
482 // %k_dword = lshr i16 %k, 2
483 Value
*NDWord
= PreBuilder
.CreateLShr(N
, PreBuilder
.getInt16(2));
484 Value
*KDWord
= PreBuilder
.CreateLShr(K
, PreBuilder
.getInt16(2));
485 BasicBlock
*Start
= InsertI
->getParent();
487 SplitBlock(InsertI
->getParent(), InsertI
, &DTU
, LI
, nullptr, "continue");
488 IRBuilder
<> Builder(TileDP
);
489 Value
*ResVec
= createTileDPLoops
<IntrID
>(Start
, End
, Builder
, M
, NDWord
,
491 // we cannot assume there always be bitcast after tiledpbssd. So we need to
492 // insert one bitcast as required
493 Builder
.SetInsertPoint(End
, End
->getFirstNonPHIIt());
495 Builder
.CreateBitCast(ResVec
, Type::getX86_AMXTy(Builder
.getContext()));
496 // Delete TileDP intrinsic and do some clean-up.
497 for (Use
&U
: llvm::make_early_inc_range(TileDP
->uses())) {
498 Instruction
*I
= cast
<Instruction
>(U
.getUser());
500 if (match(I
, m_BitCast(m_Value(Vec
)))) {
501 I
->replaceAllUsesWith(ResVec
);
502 I
->eraseFromParent();
505 TileDP
->replaceAllUsesWith(ResAMX
);
506 TileDP
->eraseFromParent();
510 template <bool IsTileLoad
>
511 bool X86LowerAMXIntrinsics::lowerTileLoadStore(Instruction
*TileLoadStore
) {
512 Value
*M
, *N
, *Ptr
, *Stride
, *Tile
;
515 m_Intrinsic
<Intrinsic::x86_tileloadd64_internal
>(
516 m_Value(M
), m_Value(N
), m_Value(Ptr
), m_Value(Stride
)));
518 match(TileLoadStore
, m_Intrinsic
<Intrinsic::x86_tilestored64_internal
>(
519 m_Value(M
), m_Value(N
), m_Value(Ptr
),
520 m_Value(Stride
), m_Value(Tile
)));
522 Instruction
*InsertI
= TileLoadStore
;
523 IRBuilder
<> PreBuilder(TileLoadStore
);
524 PreBuilder
.SetInsertPoint(TileLoadStore
);
525 Value
*NDWord
= PreBuilder
.CreateLShr(N
, PreBuilder
.getInt16(2));
526 Value
*StrideDWord
= PreBuilder
.CreateLShr(Stride
, PreBuilder
.getInt64(2));
527 BasicBlock
*Start
= InsertI
->getParent();
529 SplitBlock(InsertI
->getParent(), InsertI
, &DTU
, LI
, nullptr, "continue");
530 IRBuilder
<> Builder(TileLoadStore
);
531 Value
*ResVec
= createTileLoadStoreLoops
<IsTileLoad
>(
532 Start
, End
, Builder
, M
, NDWord
, Ptr
, StrideDWord
,
533 IsTileLoad
? nullptr : Tile
);
535 // we cannot assume there always be bitcast after tileload. So we need to
536 // insert one bitcast as required
537 Builder
.SetInsertPoint(End
, End
->getFirstNonPHIIt());
539 Builder
.CreateBitCast(ResVec
, Type::getX86_AMXTy(Builder
.getContext()));
540 // Delete tileloadd6 intrinsic and do some clean-up
541 for (Use
&U
: llvm::make_early_inc_range(TileLoadStore
->uses())) {
542 Instruction
*I
= cast
<Instruction
>(U
.getUser());
544 if (match(I
, m_BitCast(m_Value(Vec
)))) {
545 I
->replaceAllUsesWith(ResVec
);
546 I
->eraseFromParent();
549 TileLoadStore
->replaceAllUsesWith(ResAMX
);
551 TileLoadStore
->eraseFromParent();
555 bool X86LowerAMXIntrinsics::lowerTileZero(Instruction
*TileZero
) {
556 IRBuilder
<> Builder(TileZero
);
557 FixedVectorType
*V256I32Ty
= FixedVectorType::get(Builder
.getInt32Ty(), 256);
558 Value
*VecZero
= Constant::getNullValue(V256I32Ty
);
559 for (Use
&U
: llvm::make_early_inc_range(TileZero
->uses())) {
560 Instruction
*I
= cast
<Instruction
>(U
.getUser());
562 if (match(I
, m_BitCast(m_Value(Vec
)))) {
563 I
->replaceAllUsesWith(VecZero
);
564 I
->eraseFromParent();
567 TileZero
->eraseFromParent();
571 bool X86LowerAMXIntrinsics::visit() {
573 SmallVector
<IntrinsicInst
*, 8> WorkList
;
574 for (BasicBlock
*BB
: depth_first(&Func
)) {
575 for (BasicBlock::iterator II
= BB
->begin(), IE
= BB
->end(); II
!= IE
;) {
576 if (auto *Inst
= dyn_cast
<IntrinsicInst
>(&*II
++)) {
577 switch (Inst
->getIntrinsicID()) {
578 case Intrinsic::x86_tdpbssd_internal
:
579 case Intrinsic::x86_tdpbsud_internal
:
580 case Intrinsic::x86_tdpbusd_internal
:
581 case Intrinsic::x86_tdpbuud_internal
:
582 case Intrinsic::x86_tileloadd64_internal
:
583 case Intrinsic::x86_tilestored64_internal
:
584 case Intrinsic::x86_tilezero_internal
:
585 case Intrinsic::x86_tdpbf16ps_internal
:
586 WorkList
.push_back(Inst
);
595 for (auto *Inst
: WorkList
) {
596 switch (Inst
->getIntrinsicID()) {
597 case Intrinsic::x86_tdpbssd_internal
:
598 C
= lowerTileDP
<Intrinsic::x86_tdpbssd_internal
>(Inst
) || C
;
600 case Intrinsic::x86_tdpbsud_internal
:
601 C
= lowerTileDP
<Intrinsic::x86_tdpbsud_internal
>(Inst
) || C
;
603 case Intrinsic::x86_tdpbusd_internal
:
604 C
= lowerTileDP
<Intrinsic::x86_tdpbusd_internal
>(Inst
) || C
;
606 case Intrinsic::x86_tdpbuud_internal
:
607 C
= lowerTileDP
<Intrinsic::x86_tdpbuud_internal
>(Inst
) || C
;
609 case Intrinsic::x86_tdpbf16ps_internal
:
610 C
= lowerTileDP
<Intrinsic::x86_tdpbf16ps_internal
>(Inst
) || C
;
612 case Intrinsic::x86_tileloadd64_internal
:
613 C
= lowerTileLoadStore
<true>(Inst
) || C
;
615 case Intrinsic::x86_tilestored64_internal
:
616 C
= lowerTileLoadStore
<false>(Inst
) || C
;
618 case Intrinsic::x86_tilezero_internal
:
619 C
= lowerTileZero(Inst
) || C
;
622 llvm_unreachable("invalid amx intrinsics!");
630 class X86LowerAMXIntrinsicsLegacyPass
: public FunctionPass
{
634 X86LowerAMXIntrinsicsLegacyPass() : FunctionPass(ID
) {
635 initializeX86LowerAMXIntrinsicsLegacyPassPass(
636 *PassRegistry::getPassRegistry());
639 bool runOnFunction(Function
&F
) override
{
640 if (!X86ScalarizeAMX
)
642 TargetMachine
*TM
= &getAnalysis
<TargetPassConfig
>().getTM
<TargetMachine
>();
643 if (!F
.hasFnAttribute(Attribute::OptimizeNone
) &&
644 TM
->getOptLevel() != CodeGenOptLevel::None
)
647 auto *DTWP
= getAnalysisIfAvailable
<DominatorTreeWrapperPass
>();
648 auto *DT
= DTWP
? &DTWP
->getDomTree() : nullptr;
649 auto *LIWP
= getAnalysisIfAvailable
<LoopInfoWrapperPass
>();
650 auto *LI
= LIWP
? &LIWP
->getLoopInfo() : nullptr;
651 DomTreeUpdater
DTU(DT
, DomTreeUpdater::UpdateStrategy::Lazy
);
653 X86LowerAMXIntrinsics
LAT(F
, DTU
, LI
);
656 StringRef
getPassName() const override
{ return "Lower AMX intrinsics"; }
658 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
659 AU
.addPreserved
<DominatorTreeWrapperPass
>();
660 AU
.addPreserved
<LoopInfoWrapperPass
>();
661 AU
.addRequired
<TargetPassConfig
>();
666 static const char PassName
[] = "Lower AMX intrinsics";
667 char X86LowerAMXIntrinsicsLegacyPass::ID
= 0;
668 INITIALIZE_PASS_BEGIN(X86LowerAMXIntrinsicsLegacyPass
, DEBUG_TYPE
, PassName
,
670 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig
)
671 INITIALIZE_PASS_END(X86LowerAMXIntrinsicsLegacyPass
, DEBUG_TYPE
, PassName
,
674 FunctionPass
*llvm::createX86LowerAMXIntrinsicsPass() {
675 return new X86LowerAMXIntrinsicsLegacyPass();