AMDGPU: Mark test as XFAIL in expensive_checks builds
[llvm-project.git] / llvm / lib / Target / X86 / X86LowerAMXIntrinsics.cpp
blob0a187ee42e3f8b7ede08cdde90bb7f1c3addca14
1 //===-- X86LowerAMXIntrinsics.cpp -X86 Scalarize AMX Intrinsics------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file Pass to transform amx intrinsics to scalar operations.
10 /// This pass is always enabled and it skips when it is not -O0 and has no
11 /// optnone attributes. With -O0 or optnone attribute, the def of shape to amx
12 /// intrinsics is near the amx intrinsics code. We are not able to find a
13 /// point which post-dominate all the shape and dominate all amx intrinsics.
14 /// To decouple the dependency of the shape, we transform amx intrinsics
15 /// to scalar operation, so that compiling doesn't fail. In long term, we
16 /// should improve fast register allocation to allocate amx register.
17 //===----------------------------------------------------------------------===//
19 #include "X86.h"
20 #include "llvm/Analysis/DomTreeUpdater.h"
21 #include "llvm/Analysis/LoopInfo.h"
22 #include "llvm/Analysis/TargetTransformInfo.h"
23 #include "llvm/CodeGen/Passes.h"
24 #include "llvm/CodeGen/TargetPassConfig.h"
25 #include "llvm/CodeGen/ValueTypes.h"
26 #include "llvm/IR/DataLayout.h"
27 #include "llvm/IR/Function.h"
28 #include "llvm/IR/IRBuilder.h"
29 #include "llvm/IR/Instructions.h"
30 #include "llvm/IR/IntrinsicInst.h"
31 #include "llvm/IR/IntrinsicsX86.h"
32 #include "llvm/IR/PatternMatch.h"
33 #include "llvm/InitializePasses.h"
34 #include "llvm/Pass.h"
35 #include "llvm/Support/CommandLine.h"
36 #include "llvm/Target/TargetMachine.h"
37 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
38 #include "llvm/Transforms/Utils/LoopUtils.h"
40 using namespace llvm;
41 using namespace PatternMatch;
43 #define DEBUG_TYPE "lower-amx-intrinsics"
45 #ifndef NDEBUG
46 static bool isV256I32Ty(Type *Ty) {
47 if (auto *FVT = dyn_cast<FixedVectorType>(Ty))
48 return FVT->getNumElements() == 256 &&
49 FVT->getElementType()->isIntegerTy(32);
50 return false;
52 #endif
54 static cl::opt<bool>
55 X86ScalarizeAMX("enable-x86-scalar-amx", cl::init(false), cl::Hidden,
56 cl::desc("X86: enable AMX scalarizition."));
58 namespace {
59 class X86LowerAMXIntrinsics {
60 Function &Func;
62 public:
63 X86LowerAMXIntrinsics(Function &F, DomTreeUpdater &DomTU, LoopInfo *LoopI)
64 : Func(F), DTU(DomTU), LI(LoopI) {}
65 bool visit();
67 private:
68 DomTreeUpdater &DTU;
69 LoopInfo *LI;
70 BasicBlock *createLoop(BasicBlock *Preheader, BasicBlock *Exit, Value *Bound,
71 Value *Step, StringRef Name, IRBuilderBase &B,
72 Loop *L);
73 template <bool IsTileLoad>
74 Value *createTileLoadStoreLoops(BasicBlock *Start, BasicBlock *End,
75 IRBuilderBase &B, Value *Row, Value *Col,
76 Value *Ptr, Value *Stride, Value *Tile);
77 template <Intrinsic::ID IntrID>
78 std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal ||
79 IntrID == Intrinsic::x86_tdpbsud_internal ||
80 IntrID == Intrinsic::x86_tdpbusd_internal ||
81 IntrID == Intrinsic::x86_tdpbuud_internal ||
82 IntrID == Intrinsic::x86_tdpbf16ps_internal,
83 Value *>
84 createTileDPLoops(BasicBlock *Start, BasicBlock *End, IRBuilderBase &B,
85 Value *Row, Value *Col, Value *K, Value *Acc, Value *LHS,
86 Value *RHS);
87 template <bool IsTileLoad>
88 bool lowerTileLoadStore(Instruction *TileLoadStore);
89 template <Intrinsic::ID IntrID>
90 std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal ||
91 IntrID == Intrinsic::x86_tdpbsud_internal ||
92 IntrID == Intrinsic::x86_tdpbusd_internal ||
93 IntrID == Intrinsic::x86_tdpbuud_internal ||
94 IntrID == Intrinsic::x86_tdpbf16ps_internal,
95 bool>
96 lowerTileDP(Instruction *TileDP);
97 bool lowerTileZero(Instruction *TileZero);
99 } // anonymous namespace
101 BasicBlock *X86LowerAMXIntrinsics::createLoop(BasicBlock *Preheader,
102 BasicBlock *Exit, Value *Bound,
103 Value *Step, StringRef Name,
104 IRBuilderBase &B, Loop *L) {
105 LLVMContext &Ctx = Preheader->getContext();
106 BasicBlock *Header =
107 BasicBlock::Create(Ctx, Name + ".header", Preheader->getParent(), Exit);
108 BasicBlock *Body =
109 BasicBlock::Create(Ctx, Name + ".body", Header->getParent(), Exit);
110 BasicBlock *Latch =
111 BasicBlock::Create(Ctx, Name + ".latch", Header->getParent(), Exit);
113 Type *I16Ty = Type::getInt16Ty(Ctx);
114 BranchInst::Create(Body, Header);
115 BranchInst::Create(Latch, Body);
116 PHINode *IV =
117 PHINode::Create(I16Ty, 2, Name + ".iv", Header->getTerminator()->getIterator());
118 IV->addIncoming(ConstantInt::get(I16Ty, 0), Preheader);
120 B.SetInsertPoint(Latch);
121 Value *Inc = B.CreateAdd(IV, Step, Name + ".step");
122 Value *Cond = B.CreateICmpNE(Inc, Bound, Name + ".cond");
123 BranchInst::Create(Header, Exit, Cond, Latch);
124 IV->addIncoming(Inc, Latch);
126 BranchInst *PreheaderBr = cast<BranchInst>(Preheader->getTerminator());
127 BasicBlock *Tmp = PreheaderBr->getSuccessor(0);
128 PreheaderBr->setSuccessor(0, Header);
129 DTU.applyUpdatesPermissive({
130 {DominatorTree::Delete, Preheader, Tmp},
131 {DominatorTree::Insert, Header, Body},
132 {DominatorTree::Insert, Body, Latch},
133 {DominatorTree::Insert, Latch, Header},
134 {DominatorTree::Insert, Latch, Exit},
135 {DominatorTree::Insert, Preheader, Header},
137 if (LI) {
138 L->addBasicBlockToLoop(Header, *LI);
139 L->addBasicBlockToLoop(Body, *LI);
140 L->addBasicBlockToLoop(Latch, *LI);
142 return Body;
145 template <bool IsTileLoad>
146 Value *X86LowerAMXIntrinsics::createTileLoadStoreLoops(
147 BasicBlock *Start, BasicBlock *End, IRBuilderBase &B, Value *Row,
148 Value *Col, Value *Ptr, Value *Stride, Value *Tile) {
149 std::string IntrinName = IsTileLoad ? "tileload" : "tilestore";
150 Loop *RowLoop = nullptr;
151 Loop *ColLoop = nullptr;
152 if (LI) {
153 RowLoop = LI->AllocateLoop();
154 ColLoop = LI->AllocateLoop();
155 RowLoop->addChildLoop(ColLoop);
156 if (Loop *ParentL = LI->getLoopFor(Start))
157 ParentL->addChildLoop(RowLoop);
158 else
159 LI->addTopLevelLoop(RowLoop);
162 BasicBlock *RowBody = createLoop(Start, End, Row, B.getInt16(1),
163 IntrinName + ".scalarize.rows", B, RowLoop);
164 BasicBlock *RowLatch = RowBody->getSingleSuccessor();
166 BasicBlock *ColBody = createLoop(RowBody, RowLatch, Col, B.getInt16(1),
167 IntrinName + ".scalarize.cols", B, ColLoop);
169 BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor();
170 BasicBlock *ColLoopHeader = ColBody->getSinglePredecessor();
171 BasicBlock *RowLoopHeader = RowBody->getSinglePredecessor();
172 Value *CurrentRow = &*RowLoopHeader->begin();
173 Value *CurrentCol = &*ColLoopHeader->begin();
174 Type *EltTy = B.getInt32Ty();
175 FixedVectorType *V256I32Ty = FixedVectorType::get(EltTy, 256);
177 // Common part for tileload and tilestore
178 // *.scalarize.cols.body:
179 // Calculate %idxmem and %idxvec
180 B.SetInsertPoint(ColBody->getTerminator());
181 Value *CurrentRowZExt = B.CreateZExt(CurrentRow, Stride->getType());
182 Value *CurrentColZExt = B.CreateZExt(CurrentCol, Stride->getType());
183 Value *Offset =
184 B.CreateAdd(B.CreateMul(CurrentRowZExt, Stride), CurrentColZExt);
185 Value *EltPtr = B.CreateGEP(EltTy, Ptr, Offset);
186 Value *Idx = B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentCol);
187 if (IsTileLoad) {
188 // tileload.scalarize.rows.header:
189 // %vec.phi.row = phi <256 x i32> [ zeroinitializer, %entry ], [ %ResVec,
190 // %tileload.scalarize.rows.latch ]
191 B.SetInsertPoint(RowLoopHeader->getTerminator());
192 Value *VecZero = Constant::getNullValue(V256I32Ty);
193 PHINode *VecCPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.phi.row");
194 VecCPhiRowLoop->addIncoming(VecZero, Start);
196 // tileload.scalarize.cols.header:
197 // %vec.phi = phi <256 x i32> [ %vec.phi.row, %tileload.scalarize.rows.body
198 // ], [ %ResVec, %tileload.scalarize.cols.latch ]
199 B.SetInsertPoint(ColLoopHeader->getTerminator());
200 PHINode *VecPhi = B.CreatePHI(V256I32Ty, 2, "vec.phi");
201 VecPhi->addIncoming(VecCPhiRowLoop, RowBody);
203 // tileload.scalarize.cols.body:
204 // Calculate %idxmem and %idxvec
205 // %eltptr = getelementptr i32, i32* %base, i64 %idxmem
206 // %elt = load i32, i32* %ptr
207 // %ResVec = insertelement <256 x i32> %vec.phi, i32 %elt, i16 %idxvec
208 B.SetInsertPoint(ColBody->getTerminator());
209 Value *Elt = B.CreateLoad(EltTy, EltPtr);
210 Value *ResVec = B.CreateInsertElement(VecPhi, Elt, Idx);
211 VecPhi->addIncoming(ResVec, ColLoopLatch);
212 VecCPhiRowLoop->addIncoming(ResVec, RowLatch);
214 return ResVec;
215 } else {
216 auto *BitCast = cast<BitCastInst>(Tile);
217 Value *Vec = BitCast->getOperand(0);
218 assert(isV256I32Ty(Vec->getType()) && "bitcast from non-v256i32 to x86amx");
219 // tilestore.scalarize.cols.body:
220 // %mul = mul i16 %row.iv, i16 16
221 // %idx = add i16 %mul, i16 %col.iv
222 // %vec = extractelement <16 x i32> %vec, i16 %idx
223 // store i32 %vec, i32* %ptr
224 B.SetInsertPoint(ColBody->getTerminator());
225 Value *Elt = B.CreateExtractElement(Vec, Idx);
227 B.CreateStore(Elt, EltPtr);
228 return nullptr;
232 template <Intrinsic::ID IntrID>
233 std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal ||
234 IntrID == Intrinsic::x86_tdpbsud_internal ||
235 IntrID == Intrinsic::x86_tdpbusd_internal ||
236 IntrID == Intrinsic::x86_tdpbuud_internal ||
237 IntrID == Intrinsic::x86_tdpbf16ps_internal,
238 Value *>
239 X86LowerAMXIntrinsics::createTileDPLoops(BasicBlock *Start, BasicBlock *End,
240 IRBuilderBase &B, Value *Row,
241 Value *Col, Value *K, Value *Acc,
242 Value *LHS, Value *RHS) {
243 std::string IntrinName;
244 switch (IntrID) {
245 case Intrinsic::x86_tdpbssd_internal:
246 IntrinName = "tiledpbssd";
247 break;
248 case Intrinsic::x86_tdpbsud_internal:
249 IntrinName = "tiledpbsud";
250 break;
251 case Intrinsic::x86_tdpbusd_internal:
252 IntrinName = "tiledpbusd";
253 break;
254 case Intrinsic::x86_tdpbuud_internal:
255 IntrinName = "tiledpbuud";
256 break;
257 case Intrinsic::x86_tdpbf16ps_internal:
258 IntrinName = "tiledpbf16ps";
259 break;
261 Loop *RowLoop = nullptr;
262 Loop *ColLoop = nullptr;
263 Loop *InnerLoop = nullptr;
264 if (LI) {
265 RowLoop = LI->AllocateLoop();
266 ColLoop = LI->AllocateLoop();
267 InnerLoop = LI->AllocateLoop();
268 ColLoop->addChildLoop(InnerLoop);
269 RowLoop->addChildLoop(ColLoop);
270 if (Loop *ParentL = LI->getLoopFor(Start))
271 ParentL->addChildLoop(RowLoop);
272 else
273 LI->addTopLevelLoop(RowLoop);
276 BasicBlock *RowBody = createLoop(Start, End, Row, B.getInt16(1),
277 IntrinName + ".scalarize.rows", B, RowLoop);
278 BasicBlock *RowLatch = RowBody->getSingleSuccessor();
280 BasicBlock *ColBody = createLoop(RowBody, RowLatch, Col, B.getInt16(1),
281 IntrinName + ".scalarize.cols", B, ColLoop);
283 BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor();
285 B.SetInsertPoint(ColBody->getTerminator());
286 BasicBlock *InnerBody =
287 createLoop(ColBody, ColLoopLatch, K, B.getInt16(1),
288 IntrinName + ".scalarize.inner", B, InnerLoop);
290 BasicBlock *ColLoopHeader = ColBody->getSinglePredecessor();
291 BasicBlock *RowLoopHeader = RowBody->getSinglePredecessor();
292 BasicBlock *InnerLoopHeader = InnerBody->getSinglePredecessor();
293 BasicBlock *InnerLoopLatch = InnerBody->getSingleSuccessor();
294 Value *CurrentRow = &*RowLoopHeader->begin();
295 Value *CurrentCol = &*ColLoopHeader->begin();
296 Value *CurrentInner = &*InnerLoopHeader->begin();
298 FixedVectorType *V256I32Ty = FixedVectorType::get(B.getInt32Ty(), 256);
299 auto *BitCastAcc = cast<BitCastInst>(Acc);
300 Value *VecC = BitCastAcc->getOperand(0);
301 assert(isV256I32Ty(VecC->getType()) && "bitcast from non-v256i32 to x86amx");
302 // TODO else create BitCast from x86amx to v256i32.
303 // Store x86amx to memory, and reload from memory
304 // to vector. However with -O0, it doesn't happen.
305 auto *BitCastLHS = cast<BitCastInst>(LHS);
306 Value *VecA = BitCastLHS->getOperand(0);
307 assert(isV256I32Ty(VecA->getType()) && "bitcast from non-v256i32 to x86amx");
308 auto *BitCastRHS = cast<BitCastInst>(RHS);
309 Value *VecB = BitCastRHS->getOperand(0);
310 assert(isV256I32Ty(VecB->getType()) && "bitcast from non-v256i32 to x86amx");
312 // tiledpbssd.scalarize.rows.header:
313 // %vec.c.phi.row = phi <256 x i32> [ %VecC, %continue ], [ %NewVecC,
314 // %tiledpbssd.scalarize.rows.latch ]
316 // %vec.d.phi.row = phi <256 x i32> [ zeroinitializer, %continue ], [
317 // %NewVecD, %tiledpbssd.scalarize.rows.latch ]
318 B.SetInsertPoint(RowLoopHeader->getTerminator());
319 PHINode *VecCPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.c.phi.row");
320 VecCPhiRowLoop->addIncoming(VecC, Start);
321 Value *VecZero = Constant::getNullValue(V256I32Ty);
322 PHINode *VecDPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.d.phi.row");
323 VecDPhiRowLoop->addIncoming(VecZero, Start);
325 // tiledpbssd.scalarize.cols.header:
326 // %vec.c.phi.col = phi <256 x i32> [ %vec.c.phi.row,
327 // %tiledpbssd.scalarize.rows.body ], [ %NewVecC,
328 // %tiledpbssd.scalarize.cols.latch ]
330 // %vec.d.phi.col = phi <256 x i32> [
331 // %vec.d.phi.row, %tiledpbssd.scalarize.rows.body ], [ %NewVecD,
332 // %tiledpbssd.scalarize.cols.latch ]
334 // calculate idxc.
335 B.SetInsertPoint(ColLoopHeader->getTerminator());
336 PHINode *VecCPhiColLoop = B.CreatePHI(V256I32Ty, 2, "vec.c.phi.col");
337 VecCPhiColLoop->addIncoming(VecCPhiRowLoop, RowBody);
338 PHINode *VecDPhiColLoop = B.CreatePHI(V256I32Ty, 2, "vec.d.phi.col");
339 VecDPhiColLoop->addIncoming(VecDPhiRowLoop, RowBody);
340 Value *IdxC =
341 B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentCol);
343 // tiledpbssd.scalarize.inner.header:
344 // %vec.c.inner.phi = phi <256 x i32> [ %vec.c.phi.col,
345 // %tiledpbssd.scalarize.cols.body ], [ %NewVecC,
346 // %tiledpbssd.scalarize.inner.latch ]
348 B.SetInsertPoint(InnerLoopHeader->getTerminator());
349 PHINode *VecCPhi = B.CreatePHI(V256I32Ty, 2, "vec.c.inner.phi");
350 VecCPhi->addIncoming(VecCPhiColLoop, ColBody);
352 B.SetInsertPoint(InnerBody->getTerminator());
353 Value *IdxA =
354 B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentInner);
355 Value *IdxB =
356 B.CreateAdd(B.CreateMul(CurrentInner, B.getInt16(16)), CurrentCol);
357 Value *NewVecC = nullptr;
359 if (IntrID != Intrinsic::x86_tdpbf16ps_internal) {
360 // tiledpbssd.scalarize.inner.body:
361 // calculate idxa, idxb
362 // %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc
363 // %elta = extractelement <256 x i32> %veca, i16 %idxa
364 // %eltav4i8 = bitcast i32 %elta to <4 x i8>
365 // %eltb = extractelement <256 x i32> %vecb, i16 %idxb
366 // %eltbv4i8 = bitcast i32 %eltb to <4 x i8>
367 // %eltav4i32 = sext <4 x i8> %eltav4i8 to <4 x i32>
368 // %eltbv4i32 = sext <4 x i8> %eltbv4i8 to <4 x i32>
369 // %mulab = mul <4 x i32> %eltbv4i32, %eltav4i32
370 // %acc = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %131)
371 // %neweltc = add i32 %elt, %acc
372 // %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc,
373 // i16 %idxc
374 FixedVectorType *V4I8Ty = FixedVectorType::get(B.getInt8Ty(), 4);
375 FixedVectorType *V4I32Ty = FixedVectorType::get(B.getInt32Ty(), 4);
376 Value *EltC = B.CreateExtractElement(VecCPhi, IdxC);
377 Value *EltA = B.CreateExtractElement(VecA, IdxA);
378 Value *SubVecA = B.CreateBitCast(EltA, V4I8Ty);
379 Value *EltB = B.CreateExtractElement(VecB, IdxB);
380 Value *SubVecB = B.CreateBitCast(EltB, V4I8Ty);
381 Value *SEXTSubVecB = nullptr;
382 Value *SEXTSubVecA = nullptr;
383 switch (IntrID) {
384 case Intrinsic::x86_tdpbssd_internal:
385 SEXTSubVecB = B.CreateSExt(SubVecB, V4I32Ty);
386 SEXTSubVecA = B.CreateSExt(SubVecA, V4I32Ty);
387 break;
388 case Intrinsic::x86_tdpbsud_internal:
389 SEXTSubVecB = B.CreateZExt(SubVecB, V4I32Ty);
390 SEXTSubVecA = B.CreateSExt(SubVecA, V4I32Ty);
391 break;
392 case Intrinsic::x86_tdpbusd_internal:
393 SEXTSubVecB = B.CreateSExt(SubVecB, V4I32Ty);
394 SEXTSubVecA = B.CreateZExt(SubVecA, V4I32Ty);
395 break;
396 case Intrinsic::x86_tdpbuud_internal:
397 SEXTSubVecB = B.CreateZExt(SubVecB, V4I32Ty);
398 SEXTSubVecA = B.CreateZExt(SubVecA, V4I32Ty);
399 break;
400 default:
401 llvm_unreachable("Invalid intrinsic ID!");
403 Value *SubVecR = B.CreateAddReduce(B.CreateMul(SEXTSubVecA, SEXTSubVecB));
404 Value *ResElt = B.CreateAdd(EltC, SubVecR);
405 NewVecC = B.CreateInsertElement(VecCPhi, ResElt, IdxC);
406 } else {
407 // tiledpbf16ps.scalarize.inner.body:
408 // calculate idxa, idxb, idxc
409 // %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc
410 // %eltcf32 = bitcast i32 %eltc to float
411 // %elta = extractelement <256 x i32> %veca, i16 %idxa
412 // %eltav2i16 = bitcast i32 %elta to <2 x i16>
413 // %eltb = extractelement <256 x i32> %vecb, i16 %idxb
414 // %eltbv2i16 = bitcast i32 %eltb to <2 x i16>
415 // %shufflea = shufflevector <2 x i16> %elta, <2 x i16> zeroinitializer, <4
416 // x i32> <i32 2, i32 0, i32 3, i32 1>
417 // %eltav2f32 = bitcast <4 x i16> %shufflea to <2 x float>
418 // %shuffleb = shufflevector <2 x i16> %eltb, <2 xi16> zeroinitializer, <4 x
419 // i32> <i32 2, i32 0, i32 3, i32 1>
420 // %eltbv2f32 = bitcast <4 x i16> %shuffleb to <2 x float>
421 // %mulab = fmul <2 x float> %eltav2f32, %eltbv2f32
422 // %acc = call float
423 // @llvm.vector.reduce.fadd.v2f32(float %eltcf32, <2 x float> %mulab)
424 // %neweltc = bitcast float %acc to i32
425 // %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc,
426 // i16 %idxc
427 // %NewVecD = insertelement <256 x i32> %vec.d.inner.phi, i32 %neweltc,
428 // i16 %idxc
429 FixedVectorType *V2I16Ty = FixedVectorType::get(B.getInt16Ty(), 2);
430 FixedVectorType *V2F32Ty = FixedVectorType::get(B.getFloatTy(), 2);
431 Value *EltC = B.CreateExtractElement(VecCPhi, IdxC);
432 Value *EltCF32 = B.CreateBitCast(EltC, B.getFloatTy());
433 Value *EltA = B.CreateExtractElement(VecA, IdxA);
434 Value *SubVecA = B.CreateBitCast(EltA, V2I16Ty);
435 Value *EltB = B.CreateExtractElement(VecB, IdxB);
436 Value *SubVecB = B.CreateBitCast(EltB, V2I16Ty);
437 Value *ZeroV2I16 = Constant::getNullValue(V2I16Ty);
438 int ShuffleMask[4] = {2, 0, 3, 1};
439 auto ShuffleArray = ArrayRef(ShuffleMask);
440 Value *AV2F32 = B.CreateBitCast(
441 B.CreateShuffleVector(SubVecA, ZeroV2I16, ShuffleArray), V2F32Ty);
442 Value *BV2F32 = B.CreateBitCast(
443 B.CreateShuffleVector(SubVecB, ZeroV2I16, ShuffleArray), V2F32Ty);
444 Value *SubVecR = B.CreateFAddReduce(EltCF32, B.CreateFMul(AV2F32, BV2F32));
445 Value *ResElt = B.CreateBitCast(SubVecR, B.getInt32Ty());
446 NewVecC = B.CreateInsertElement(VecCPhi, ResElt, IdxC);
449 // tiledpbssd.scalarize.cols.latch:
450 // %NewEltC = extractelement <256 x i32> %vec.c.phi.col, i16 %idxc
451 // %NewVecD = insertelement <256 x i32> %vec.d.phi.col, i32 %NewEltC,
452 // i16 %idxc
453 B.SetInsertPoint(ColLoopLatch->getTerminator());
454 Value *NewEltC = B.CreateExtractElement(NewVecC, IdxC);
455 Value *NewVecD = B.CreateInsertElement(VecDPhiColLoop, NewEltC, IdxC);
457 VecCPhi->addIncoming(NewVecC, InnerLoopLatch);
458 VecCPhiRowLoop->addIncoming(NewVecC, RowLatch);
459 VecCPhiColLoop->addIncoming(NewVecC, ColLoopLatch);
460 VecDPhiRowLoop->addIncoming(NewVecD, RowLatch);
461 VecDPhiColLoop->addIncoming(NewVecD, ColLoopLatch);
463 return NewVecD;
466 template <Intrinsic::ID IntrID>
467 std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal ||
468 IntrID == Intrinsic::x86_tdpbsud_internal ||
469 IntrID == Intrinsic::x86_tdpbusd_internal ||
470 IntrID == Intrinsic::x86_tdpbuud_internal ||
471 IntrID == Intrinsic::x86_tdpbf16ps_internal,
472 bool>
473 X86LowerAMXIntrinsics::lowerTileDP(Instruction *TileDP) {
474 Value *M, *N, *K, *C, *A, *B;
475 match(TileDP, m_Intrinsic<IntrID>(m_Value(M), m_Value(N), m_Value(K),
476 m_Value(C), m_Value(A), m_Value(B)));
477 Instruction *InsertI = TileDP;
478 IRBuilder<> PreBuilder(TileDP);
479 PreBuilder.SetInsertPoint(TileDP);
480 // We visit the loop with (m, n/4, k/4):
481 // %n_dword = lshr i16 %n, 2
482 // %k_dword = lshr i16 %k, 2
483 Value *NDWord = PreBuilder.CreateLShr(N, PreBuilder.getInt16(2));
484 Value *KDWord = PreBuilder.CreateLShr(K, PreBuilder.getInt16(2));
485 BasicBlock *Start = InsertI->getParent();
486 BasicBlock *End =
487 SplitBlock(InsertI->getParent(), InsertI, &DTU, LI, nullptr, "continue");
488 IRBuilder<> Builder(TileDP);
489 Value *ResVec = createTileDPLoops<IntrID>(Start, End, Builder, M, NDWord,
490 KDWord, C, A, B);
491 // we cannot assume there always be bitcast after tiledpbssd. So we need to
492 // insert one bitcast as required
493 Builder.SetInsertPoint(End, End->getFirstNonPHIIt());
494 Value *ResAMX =
495 Builder.CreateBitCast(ResVec, Type::getX86_AMXTy(Builder.getContext()));
496 // Delete TileDP intrinsic and do some clean-up.
497 for (Use &U : llvm::make_early_inc_range(TileDP->uses())) {
498 Instruction *I = cast<Instruction>(U.getUser());
499 Value *Vec;
500 if (match(I, m_BitCast(m_Value(Vec)))) {
501 I->replaceAllUsesWith(ResVec);
502 I->eraseFromParent();
505 TileDP->replaceAllUsesWith(ResAMX);
506 TileDP->eraseFromParent();
507 return true;
510 template <bool IsTileLoad>
511 bool X86LowerAMXIntrinsics::lowerTileLoadStore(Instruction *TileLoadStore) {
512 Value *M, *N, *Ptr, *Stride, *Tile;
513 if (IsTileLoad)
514 match(TileLoadStore,
515 m_Intrinsic<Intrinsic::x86_tileloadd64_internal>(
516 m_Value(M), m_Value(N), m_Value(Ptr), m_Value(Stride)));
517 else
518 match(TileLoadStore, m_Intrinsic<Intrinsic::x86_tilestored64_internal>(
519 m_Value(M), m_Value(N), m_Value(Ptr),
520 m_Value(Stride), m_Value(Tile)));
522 Instruction *InsertI = TileLoadStore;
523 IRBuilder<> PreBuilder(TileLoadStore);
524 PreBuilder.SetInsertPoint(TileLoadStore);
525 Value *NDWord = PreBuilder.CreateLShr(N, PreBuilder.getInt16(2));
526 Value *StrideDWord = PreBuilder.CreateLShr(Stride, PreBuilder.getInt64(2));
527 BasicBlock *Start = InsertI->getParent();
528 BasicBlock *End =
529 SplitBlock(InsertI->getParent(), InsertI, &DTU, LI, nullptr, "continue");
530 IRBuilder<> Builder(TileLoadStore);
531 Value *ResVec = createTileLoadStoreLoops<IsTileLoad>(
532 Start, End, Builder, M, NDWord, Ptr, StrideDWord,
533 IsTileLoad ? nullptr : Tile);
534 if (IsTileLoad) {
535 // we cannot assume there always be bitcast after tileload. So we need to
536 // insert one bitcast as required
537 Builder.SetInsertPoint(End, End->getFirstNonPHIIt());
538 Value *ResAMX =
539 Builder.CreateBitCast(ResVec, Type::getX86_AMXTy(Builder.getContext()));
540 // Delete tileloadd6 intrinsic and do some clean-up
541 for (Use &U : llvm::make_early_inc_range(TileLoadStore->uses())) {
542 Instruction *I = cast<Instruction>(U.getUser());
543 Value *Vec;
544 if (match(I, m_BitCast(m_Value(Vec)))) {
545 I->replaceAllUsesWith(ResVec);
546 I->eraseFromParent();
549 TileLoadStore->replaceAllUsesWith(ResAMX);
551 TileLoadStore->eraseFromParent();
552 return true;
555 bool X86LowerAMXIntrinsics::lowerTileZero(Instruction *TileZero) {
556 IRBuilder<> Builder(TileZero);
557 FixedVectorType *V256I32Ty = FixedVectorType::get(Builder.getInt32Ty(), 256);
558 Value *VecZero = Constant::getNullValue(V256I32Ty);
559 for (Use &U : llvm::make_early_inc_range(TileZero->uses())) {
560 Instruction *I = cast<Instruction>(U.getUser());
561 Value *Vec;
562 if (match(I, m_BitCast(m_Value(Vec)))) {
563 I->replaceAllUsesWith(VecZero);
564 I->eraseFromParent();
567 TileZero->eraseFromParent();
568 return true;
571 bool X86LowerAMXIntrinsics::visit() {
572 bool C = false;
573 SmallVector<IntrinsicInst *, 8> WorkList;
574 for (BasicBlock *BB : depth_first(&Func)) {
575 for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); II != IE;) {
576 if (auto *Inst = dyn_cast<IntrinsicInst>(&*II++)) {
577 switch (Inst->getIntrinsicID()) {
578 case Intrinsic::x86_tdpbssd_internal:
579 case Intrinsic::x86_tdpbsud_internal:
580 case Intrinsic::x86_tdpbusd_internal:
581 case Intrinsic::x86_tdpbuud_internal:
582 case Intrinsic::x86_tileloadd64_internal:
583 case Intrinsic::x86_tilestored64_internal:
584 case Intrinsic::x86_tilezero_internal:
585 case Intrinsic::x86_tdpbf16ps_internal:
586 WorkList.push_back(Inst);
587 break;
588 default:
589 break;
595 for (auto *Inst : WorkList) {
596 switch (Inst->getIntrinsicID()) {
597 case Intrinsic::x86_tdpbssd_internal:
598 C = lowerTileDP<Intrinsic::x86_tdpbssd_internal>(Inst) || C;
599 break;
600 case Intrinsic::x86_tdpbsud_internal:
601 C = lowerTileDP<Intrinsic::x86_tdpbsud_internal>(Inst) || C;
602 break;
603 case Intrinsic::x86_tdpbusd_internal:
604 C = lowerTileDP<Intrinsic::x86_tdpbusd_internal>(Inst) || C;
605 break;
606 case Intrinsic::x86_tdpbuud_internal:
607 C = lowerTileDP<Intrinsic::x86_tdpbuud_internal>(Inst) || C;
608 break;
609 case Intrinsic::x86_tdpbf16ps_internal:
610 C = lowerTileDP<Intrinsic::x86_tdpbf16ps_internal>(Inst) || C;
611 break;
612 case Intrinsic::x86_tileloadd64_internal:
613 C = lowerTileLoadStore<true>(Inst) || C;
614 break;
615 case Intrinsic::x86_tilestored64_internal:
616 C = lowerTileLoadStore<false>(Inst) || C;
617 break;
618 case Intrinsic::x86_tilezero_internal:
619 C = lowerTileZero(Inst) || C;
620 break;
621 default:
622 llvm_unreachable("invalid amx intrinsics!");
626 return C;
629 namespace {
630 class X86LowerAMXIntrinsicsLegacyPass : public FunctionPass {
631 public:
632 static char ID;
634 X86LowerAMXIntrinsicsLegacyPass() : FunctionPass(ID) {
635 initializeX86LowerAMXIntrinsicsLegacyPassPass(
636 *PassRegistry::getPassRegistry());
639 bool runOnFunction(Function &F) override {
640 if (!X86ScalarizeAMX)
641 return false;
642 TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
643 if (!F.hasFnAttribute(Attribute::OptimizeNone) &&
644 TM->getOptLevel() != CodeGenOptLevel::None)
645 return false;
647 auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>();
648 auto *DT = DTWP ? &DTWP->getDomTree() : nullptr;
649 auto *LIWP = getAnalysisIfAvailable<LoopInfoWrapperPass>();
650 auto *LI = LIWP ? &LIWP->getLoopInfo() : nullptr;
651 DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
653 X86LowerAMXIntrinsics LAT(F, DTU, LI);
654 return LAT.visit();
656 StringRef getPassName() const override { return "Lower AMX intrinsics"; }
658 void getAnalysisUsage(AnalysisUsage &AU) const override {
659 AU.addPreserved<DominatorTreeWrapperPass>();
660 AU.addPreserved<LoopInfoWrapperPass>();
661 AU.addRequired<TargetPassConfig>();
664 } // namespace
666 static const char PassName[] = "Lower AMX intrinsics";
667 char X86LowerAMXIntrinsicsLegacyPass::ID = 0;
668 INITIALIZE_PASS_BEGIN(X86LowerAMXIntrinsicsLegacyPass, DEBUG_TYPE, PassName,
669 false, false)
670 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
671 INITIALIZE_PASS_END(X86LowerAMXIntrinsicsLegacyPass, DEBUG_TYPE, PassName,
672 false, false)
674 FunctionPass *llvm::createX86LowerAMXIntrinsicsPass() {
675 return new X86LowerAMXIntrinsicsLegacyPass();