[ORC] Add std::tuple support to SimplePackedSerialization.
[llvm-project.git] / llvm / test / CodeGen / X86 / AMX / amx-gemm.ll
blob8f5d0c7383187609c35212767b7578100aac254f
1 ; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+amx-int8 -mattr=+avx512f -verify-machineinstrs | FileCheck %s
3 ; #include <immintrin.h>
5 ; #define TILE_SZ 16
6 ; void inner_product(int *A_mem, int *B_mem, int *C_mem, int M, int N, int K) {
7 ;   const int m = M / TILE_SZ;
8 ;   const int n = N / TILE_SZ;
9 ;   const int k = K / TILE_SZ;
11 ;   for (int i = 0; i < m; i++)
12 ;     for (int j = 0; j < n; j++) {
13 ;       __tile1024i c = {TILE_SZ, TILE_SZ*sizeof(int)};
14 ;       __tile_zero(&c);
15 ;       for (int l = 0; l < k; l++) {
16 ;         __tile1024i a = {TILE_SZ, TILE_SZ*sizeof(int)};
17 ;         __tile1024i b = {TILE_SZ, TILE_SZ*sizeof(int)};
18 ;         __tile_loadd(&a, A_mem+(i*TILE_SZ)*K+l*TILE_SZ, K*sizeof(int));
19 ;         __tile_loadd(&b, B_mem+(l*TILE_SZ)*N+j*TILE_SZ, N*sizeof(int));
20 ;         __tile_dpbssd(&c, a, b);
21 ;       }
22 ;       __tile_stored(C_mem+(i*TILE_SZ)*M+j*TILE_SZ, N*sizeof(int), c);
23 ;     }
24 ; }
26 ; CHECK:  ldtilecfg
28 ; Function Attrs: noinline nounwind uwtable
29 define dso_local void @inner_product(i32* %A_mem, i32* %B_mem, i32* %C_mem, i32 %M, i32 %N, i32 %K) local_unnamed_addr {
30 entry:
31   %mul = shl i32 %K, 4
32   %conv = sext i32 %K to i64
33   %mul15 = shl nsw i64 %conv, 2
34   %conv23 = sext i32 %N to i64
35   %mul24 = shl nsw i64 %conv23, 2
36   %cmp8163 = icmp sgt i32 %K, 15
37   %mul25 = shl i32 %M, 4
38   %cmp4173 = icmp sgt i32 %N, 15
39   %cmp187 = icmp sgt i32 %M, 15
40   br i1 %cmp187, label %for.cond3.preheader.preheader, label %for.cond.cleanup
42 for.cond3.preheader.preheader:                    ; preds = %entry
43   %div2 = sdiv i32 %K, 16
44   %div1 = sdiv i32 %N, 16
45   %div209 = lshr i32 %M, 4
46   %wide.trip.count207 = zext i32 %div209 to i64
47   %wide.trip.count203 = zext i32 %div1 to i64
48   %wide.trip.count = zext i32 %div2 to i64
49   %0 = add nsw i64 %wide.trip.count, -1
50   %xtraiter = and i64 %wide.trip.count, 7
51   %1 = icmp ult i64 %0, 7
52   %unroll_iter = and i64 %wide.trip.count, 4294967288
53   %lcmp.mod.not = icmp eq i64 %xtraiter, 0
54   br label %for.cond3.preheader
56 for.cond3.preheader:                              ; preds = %for.cond3.preheader.preheader, %for.cond.cleanup5
57   %indvars.iv205 = phi i64 [ 0, %for.cond3.preheader.preheader ], [ %indvars.iv.next206, %for.cond.cleanup5 ]
58   %2 = trunc i64 %indvars.iv205 to i32
59   %mul11 = mul i32 %mul, %2
60   %idx.ext = sext i32 %mul11 to i64
61   %add.ptr = getelementptr inbounds i32, i32* %A_mem, i64 %idx.ext
62   %mul26 = mul i32 %mul25, %2
63   %idx.ext27 = sext i32 %mul26 to i64
64   %add.ptr28 = getelementptr inbounds i32, i32* %C_mem, i64 %idx.ext27
65   br i1 %cmp4173, label %for.body6, label %for.cond.cleanup5
67 for.cond.cleanup:                                 ; preds = %for.cond.cleanup5, %entry
68   ret void
70 for.cond.cleanup5:                                ; preds = %for.cond.cleanup9, %for.cond3.preheader
71   %indvars.iv.next206 = add nuw nsw i64 %indvars.iv205, 1
72   %exitcond208.not = icmp eq i64 %indvars.iv.next206, %wide.trip.count207
73   br i1 %exitcond208.not, label %for.cond.cleanup, label %for.cond3.preheader
75 for.body6:                                        ; preds = %for.cond3.preheader, %for.cond.cleanup9
76   %indvars.iv199 = phi i64 [ %indvars.iv.next200, %for.cond.cleanup9 ], [ 0, %for.cond3.preheader ]
77   %3 = tail call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 64)
78   %4 = shl nsw i64 %indvars.iv199, 4
79   br i1 %cmp8163, label %for.body10.preheader, label %for.cond.cleanup9
81 for.body10.preheader:                             ; preds = %for.body6
82   %add.ptr19 = getelementptr inbounds i32, i32* %B_mem, i64 %4
83   br i1 %1, label %for.cond.cleanup9.loopexit.unr-lcssa, label %for.body10
85 for.cond.cleanup9.loopexit.unr-lcssa:             ; preds = %for.body10, %for.body10.preheader
86   %.lcssa.ph = phi x86_amx [ undef, %for.body10.preheader ], [ %68, %for.body10 ]
87   %indvars.iv.unr = phi i64 [ 0, %for.body10.preheader ], [ %indvars.iv.next.7, %for.body10 ]
88   %c.sroa.8127.2.in164.unr = phi x86_amx [ %3, %for.body10.preheader ], [ %68, %for.body10 ]
89   br i1 %lcmp.mod.not, label %for.cond.cleanup9, label %for.body10.epil
91 for.body10.epil:                                  ; preds = %for.cond.cleanup9.loopexit.unr-lcssa, %for.body10.epil
92   %indvars.iv.epil = phi i64 [ %indvars.iv.next.epil, %for.body10.epil ], [ %indvars.iv.unr, %for.cond.cleanup9.loopexit.unr-lcssa ]
93   %c.sroa.8127.2.in164.epil = phi x86_amx [ %11, %for.body10.epil ], [ %c.sroa.8127.2.in164.unr, %for.cond.cleanup9.loopexit.unr-lcssa ]
94   %epil.iter = phi i64 [ %epil.iter.sub, %for.body10.epil ], [ %xtraiter, %for.cond.cleanup9.loopexit.unr-lcssa ]
95   %5 = shl nsw i64 %indvars.iv.epil, 4
96   %add.ptr14.epil = getelementptr inbounds i32, i32* %add.ptr, i64 %5
97   %6 = bitcast i32* %add.ptr14.epil to i8*
98   %7 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, i8* %6, i64 %mul15)
99   %8 = mul nsw i64 %5, %conv23
100   %add.ptr22.epil = getelementptr inbounds i32, i32* %add.ptr19, i64 %8
101   %9 = bitcast i32* %add.ptr22.epil to i8*
102   %10 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, i8* %9, i64 %mul24)
103   %11 = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 16, i16 64, i16 64, x86_amx %c.sroa.8127.2.in164.epil, x86_amx %7, x86_amx %10)
104   %indvars.iv.next.epil = add nuw nsw i64 %indvars.iv.epil, 1
105   %epil.iter.sub = add i64 %epil.iter, -1
106   %epil.iter.cmp.not = icmp eq i64 %epil.iter.sub, 0
107   br i1 %epil.iter.cmp.not, label %for.cond.cleanup9, label %for.body10.epil
109 for.cond.cleanup9:                                ; preds = %for.cond.cleanup9.loopexit.unr-lcssa, %for.body10.epil, %for.body6
110   %c.sroa.8127.2.in.lcssa = phi x86_amx [ %3, %for.body6 ], [ %.lcssa.ph, %for.cond.cleanup9.loopexit.unr-lcssa ], [ %11, %for.body10.epil ]
111   %add.ptr31 = getelementptr inbounds i32, i32* %add.ptr28, i64 %4
112   %12 = bitcast i32* %add.ptr31 to i8*
113   tail call void @llvm.x86.tilestored64.internal(i16 16, i16 64, i8* %12, i64 %mul24, x86_amx %c.sroa.8127.2.in.lcssa)
114   %indvars.iv.next200 = add nuw nsw i64 %indvars.iv199, 1
115   %exitcond204.not = icmp eq i64 %indvars.iv.next200, %wide.trip.count203
116   br i1 %exitcond204.not, label %for.cond.cleanup5, label %for.body6
118 for.body10:                                       ; preds = %for.body10.preheader, %for.body10
119   %indvars.iv = phi i64 [ %indvars.iv.next.7, %for.body10 ], [ 0, %for.body10.preheader ]
120   %c.sroa.8127.2.in164 = phi x86_amx [ %68, %for.body10 ], [ %3, %for.body10.preheader ]
121   %niter = phi i64 [ %niter.nsub.7, %for.body10 ], [ %unroll_iter, %for.body10.preheader ]
122   %13 = shl nsw i64 %indvars.iv, 4
123   %add.ptr14 = getelementptr inbounds i32, i32* %add.ptr, i64 %13
124   %14 = bitcast i32* %add.ptr14 to i8*
125   %15 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, i8* %14, i64 %mul15)
126   %16 = mul nsw i64 %13, %conv23
127   %add.ptr22 = getelementptr inbounds i32, i32* %add.ptr19, i64 %16
128   %17 = bitcast i32* %add.ptr22 to i8*
129   %18 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, i8* %17, i64 %mul24)
130   %19 = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 16, i16 64, i16 64, x86_amx %c.sroa.8127.2.in164, x86_amx %15, x86_amx %18)
131   %indvars.iv.next = shl i64 %indvars.iv, 4
132   %20 = or i64 %indvars.iv.next, 16
133   %add.ptr14.1 = getelementptr inbounds i32, i32* %add.ptr, i64 %20
134   %21 = bitcast i32* %add.ptr14.1 to i8*
135   %22 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, i8* nonnull %21, i64 %mul15)
136   %23 = mul nsw i64 %20, %conv23
137   %add.ptr22.1 = getelementptr inbounds i32, i32* %add.ptr19, i64 %23
138   %24 = bitcast i32* %add.ptr22.1 to i8*
139   %25 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, i8* nonnull %24, i64 %mul24)
140   %26 = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 16, i16 64, i16 64, x86_amx %19, x86_amx %22, x86_amx %25)
141   %indvars.iv.next.1 = shl i64 %indvars.iv, 4
142   %27 = or i64 %indvars.iv.next.1, 32
143   %add.ptr14.2 = getelementptr inbounds i32, i32* %add.ptr, i64 %27
144   %28 = bitcast i32* %add.ptr14.2 to i8*
145   %29 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, i8* nonnull %28, i64 %mul15)
146   %30 = mul nsw i64 %27, %conv23
147   %add.ptr22.2 = getelementptr inbounds i32, i32* %add.ptr19, i64 %30
148   %31 = bitcast i32* %add.ptr22.2 to i8*
149   %32 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, i8* nonnull %31, i64 %mul24)
150   %33 = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 16, i16 64, i16 64, x86_amx %26, x86_amx %29, x86_amx %32)
151   %indvars.iv.next.2 = shl i64 %indvars.iv, 4
152   %34 = or i64 %indvars.iv.next.2, 48
153   %add.ptr14.3 = getelementptr inbounds i32, i32* %add.ptr, i64 %34
154   %35 = bitcast i32* %add.ptr14.3 to i8*
155   %36 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, i8* nonnull %35, i64 %mul15)
156   %37 = mul nsw i64 %34, %conv23
157   %add.ptr22.3 = getelementptr inbounds i32, i32* %add.ptr19, i64 %37
158   %38 = bitcast i32* %add.ptr22.3 to i8*
159   %39 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, i8* nonnull %38, i64 %mul24)
160   %40 = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 16, i16 64, i16 64, x86_amx %33, x86_amx %36, x86_amx %39)
161   %indvars.iv.next.3 = shl i64 %indvars.iv, 4
162   %41 = or i64 %indvars.iv.next.3, 64
163   %add.ptr14.4 = getelementptr inbounds i32, i32* %add.ptr, i64 %41
164   %42 = bitcast i32* %add.ptr14.4 to i8*
165   %43 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, i8* nonnull %42, i64 %mul15)
166   %44 = mul nsw i64 %41, %conv23
167   %add.ptr22.4 = getelementptr inbounds i32, i32* %add.ptr19, i64 %44
168   %45 = bitcast i32* %add.ptr22.4 to i8*
169   %46 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, i8* nonnull %45, i64 %mul24)
170   %47 = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 16, i16 64, i16 64, x86_amx %40, x86_amx %43, x86_amx %46)
171   %indvars.iv.next.4 = shl i64 %indvars.iv, 4
172   %48 = or i64 %indvars.iv.next.4, 80
173   %add.ptr14.5 = getelementptr inbounds i32, i32* %add.ptr, i64 %48
174   %49 = bitcast i32* %add.ptr14.5 to i8*
175   %50 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, i8* nonnull %49, i64 %mul15)
176   %51 = mul nsw i64 %48, %conv23
177   %add.ptr22.5 = getelementptr inbounds i32, i32* %add.ptr19, i64 %51
178   %52 = bitcast i32* %add.ptr22.5 to i8*
179   %53 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, i8* nonnull %52, i64 %mul24)
180   %54 = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 16, i16 64, i16 64, x86_amx %47, x86_amx %50, x86_amx %53)
181   %indvars.iv.next.5 = shl i64 %indvars.iv, 4
182   %55 = or i64 %indvars.iv.next.5, 96
183   %add.ptr14.6 = getelementptr inbounds i32, i32* %add.ptr, i64 %55
184   %56 = bitcast i32* %add.ptr14.6 to i8*
185   %57 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, i8* nonnull %56, i64 %mul15)
186   %58 = mul nsw i64 %55, %conv23
187   %add.ptr22.6 = getelementptr inbounds i32, i32* %add.ptr19, i64 %58
188   %59 = bitcast i32* %add.ptr22.6 to i8*
189   %60 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, i8* nonnull %59, i64 %mul24)
190   %61 = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 16, i16 64, i16 64, x86_amx %54, x86_amx %57, x86_amx %60)
191   %indvars.iv.next.6 = shl i64 %indvars.iv, 4
192   %62 = or i64 %indvars.iv.next.6, 112
193   %add.ptr14.7 = getelementptr inbounds i32, i32* %add.ptr, i64 %62
194   %63 = bitcast i32* %add.ptr14.7 to i8*
195   %64 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, i8* nonnull %63, i64 %mul15)
196   %65 = mul nsw i64 %62, %conv23
197   %add.ptr22.7 = getelementptr inbounds i32, i32* %add.ptr19, i64 %65
198   %66 = bitcast i32* %add.ptr22.7 to i8*
199   %67 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, i8* nonnull %66, i64 %mul24)
200   %68 = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 16, i16 64, i16 64, x86_amx %61, x86_amx %64, x86_amx %67)
201   %indvars.iv.next.7 = add nuw nsw i64 %indvars.iv, 8
202   %niter.nsub.7 = add i64 %niter, -8
203   %niter.ncmp.7 = icmp eq i64 %niter.nsub.7, 0
204   br i1 %niter.ncmp.7, label %for.cond.cleanup9.loopexit.unr-lcssa, label %for.body10
207 declare x86_amx @llvm.x86.tilezero.internal(i16, i16)
208 declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64)
209 declare x86_amx @llvm.x86.tdpbssd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx)
210 declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, x86_amx)