Run DCE after a LoopFlatten test to reduce spurious output [nfc]
[llvm-project.git] / llvm / test / CodeGen / X86 / AMX / amx-gemm.ll
blobb8771d525a54b38702f10845492739c7dd5c6131
1 ; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+amx-int8 -mattr=+avx512f -verify-machineinstrs | FileCheck %s
3 ; #include <immintrin.h>
5 ; #define TILE_SZ 16
6 ; void inner_product(int *A_mem, int *B_mem, int *C_mem, int M, int N, int K) {
7 ;   const int m = M / TILE_SZ;
8 ;   const int n = N / TILE_SZ;
9 ;   const int k = K / TILE_SZ;
11 ;   for (int i = 0; i < m; i++)
12 ;     for (int j = 0; j < n; j++) {
13 ;       __tile1024i c = {TILE_SZ, TILE_SZ*sizeof(int)};
14 ;       __tile_zero(&c);
15 ;       for (int l = 0; l < k; l++) {
16 ;         __tile1024i a = {TILE_SZ, TILE_SZ*sizeof(int)};
17 ;         __tile1024i b = {TILE_SZ, TILE_SZ*sizeof(int)};
18 ;         __tile_loadd(&a, A_mem+(i*TILE_SZ)*K+l*TILE_SZ, K*sizeof(int));
19 ;         __tile_loadd(&b, B_mem+(l*TILE_SZ)*N+j*TILE_SZ, N*sizeof(int));
20 ;         __tile_dpbssd(&c, a, b);
21 ;       }
22 ;       __tile_stored(C_mem+(i*TILE_SZ)*M+j*TILE_SZ, N*sizeof(int), c);
23 ;     }
24 ; }
26 ; CHECK:  ldtilecfg
28 define dso_local void @inner_product(ptr %A_mem, ptr %B_mem, ptr %C_mem, i32 %M, i32 %N, i32 %K) local_unnamed_addr {
29 entry:
30   %mul = shl i32 %K, 4
31   %conv = sext i32 %K to i64
32   %mul15 = shl nsw i64 %conv, 2
33   %conv23 = sext i32 %N to i64
34   %mul24 = shl nsw i64 %conv23, 2
35   %cmp8163 = icmp sgt i32 %K, 15
36   %mul25 = shl i32 %M, 4
37   %cmp4173 = icmp sgt i32 %N, 15
38   %cmp187 = icmp sgt i32 %M, 15
39   br i1 %cmp187, label %for.cond3.preheader.preheader, label %for.cond.cleanup
41 for.cond3.preheader.preheader:                    ; preds = %entry
42   %div2 = sdiv i32 %K, 16
43   %div1 = sdiv i32 %N, 16
44   %div209 = lshr i32 %M, 4
45   %wide.trip.count207 = zext i32 %div209 to i64
46   %wide.trip.count203 = zext i32 %div1 to i64
47   %wide.trip.count = zext i32 %div2 to i64
48   %i = add nsw i64 %wide.trip.count, -1
49   %xtraiter = and i64 %wide.trip.count, 7
50   %i1 = icmp ult i64 %i, 7
51   %unroll_iter = and i64 %wide.trip.count, 4294967288
52   %lcmp.mod.not = icmp eq i64 %xtraiter, 0
53   br label %for.cond3.preheader
55 for.cond3.preheader:                              ; preds = %for.cond.cleanup5, %for.cond3.preheader.preheader
56   %indvars.iv205 = phi i64 [ 0, %for.cond3.preheader.preheader ], [ %indvars.iv.next206, %for.cond.cleanup5 ]
57   %i2 = trunc i64 %indvars.iv205 to i32
58   %mul11 = mul i32 %mul, %i2
59   %idx.ext = sext i32 %mul11 to i64
60   %add.ptr = getelementptr inbounds i32, ptr %A_mem, i64 %idx.ext
61   %mul26 = mul i32 %mul25, %i2
62   %idx.ext27 = sext i32 %mul26 to i64
63   %add.ptr28 = getelementptr inbounds i32, ptr %C_mem, i64 %idx.ext27
64   br i1 %cmp4173, label %for.body6, label %for.cond.cleanup5
66 for.cond.cleanup:                                 ; preds = %for.cond.cleanup5, %entry
67   ret void
69 for.cond.cleanup5:                                ; preds = %for.cond.cleanup9, %for.cond3.preheader
70   %indvars.iv.next206 = add nuw nsw i64 %indvars.iv205, 1
71   %exitcond208.not = icmp eq i64 %indvars.iv.next206, %wide.trip.count207
72   br i1 %exitcond208.not, label %for.cond.cleanup, label %for.cond3.preheader
74 for.body6:                                        ; preds = %for.cond.cleanup9, %for.cond3.preheader
75   %indvars.iv199 = phi i64 [ %indvars.iv.next200, %for.cond.cleanup9 ], [ 0, %for.cond3.preheader ]
76   %i3 = tail call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 64)
77   %i4 = shl nsw i64 %indvars.iv199, 4
78   br i1 %cmp8163, label %for.body10.preheader, label %for.cond.cleanup9
80 for.body10.preheader:                             ; preds = %for.body6
81   %add.ptr19 = getelementptr inbounds i32, ptr %B_mem, i64 %i4
82   br i1 %i1, label %for.cond.cleanup9.loopexit.unr-lcssa, label %for.body10
84 for.cond.cleanup9.loopexit.unr-lcssa:             ; preds = %for.body10, %for.body10.preheader
85   %.lcssa.ph = phi x86_amx [ undef, %for.body10.preheader ], [ %i68, %for.body10 ]
86   %indvars.iv.unr = phi i64 [ 0, %for.body10.preheader ], [ %indvars.iv.next.7, %for.body10 ]
87   %c.sroa.8127.2.in164.unr = phi x86_amx [ %i3, %for.body10.preheader ], [ %i68, %for.body10 ]
88   br i1 %lcmp.mod.not, label %for.cond.cleanup9, label %for.body10.epil
90 for.body10.epil:                                  ; preds = %for.body10.epil, %for.cond.cleanup9.loopexit.unr-lcssa
91   %indvars.iv.epil = phi i64 [ %indvars.iv.next.epil, %for.body10.epil ], [ %indvars.iv.unr, %for.cond.cleanup9.loopexit.unr-lcssa ]
92   %c.sroa.8127.2.in164.epil = phi x86_amx [ %i11, %for.body10.epil ], [ %c.sroa.8127.2.in164.unr, %for.cond.cleanup9.loopexit.unr-lcssa ]
93   %epil.iter = phi i64 [ %epil.iter.sub, %for.body10.epil ], [ %xtraiter, %for.cond.cleanup9.loopexit.unr-lcssa ]
94   %i5 = shl nsw i64 %indvars.iv.epil, 4
95   %add.ptr14.epil = getelementptr inbounds i32, ptr %add.ptr, i64 %i5
96   %i7 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, ptr %add.ptr14.epil, i64 %mul15)
97   %i8 = mul nsw i64 %i5, %conv23
98   %add.ptr22.epil = getelementptr inbounds i32, ptr %add.ptr19, i64 %i8
99   %i10 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, ptr %add.ptr22.epil, i64 %mul24)
100   %i11 = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 16, i16 64, i16 64, x86_amx %c.sroa.8127.2.in164.epil, x86_amx %i7, x86_amx %i10)
101   %indvars.iv.next.epil = add nuw nsw i64 %indvars.iv.epil, 1
102   %epil.iter.sub = add i64 %epil.iter, -1
103   %epil.iter.cmp.not = icmp eq i64 %epil.iter.sub, 0
104   br i1 %epil.iter.cmp.not, label %for.cond.cleanup9, label %for.body10.epil
106 for.cond.cleanup9:                                ; preds = %for.body10.epil, %for.cond.cleanup9.loopexit.unr-lcssa, %for.body6
107   %c.sroa.8127.2.in.lcssa = phi x86_amx [ %i3, %for.body6 ], [ %.lcssa.ph, %for.cond.cleanup9.loopexit.unr-lcssa ], [ %i11, %for.body10.epil ]
108   %add.ptr31 = getelementptr inbounds i32, ptr %add.ptr28, i64 %i4
109   tail call void @llvm.x86.tilestored64.internal(i16 16, i16 64, ptr %add.ptr31, i64 %mul24, x86_amx %c.sroa.8127.2.in.lcssa)
110   %indvars.iv.next200 = add nuw nsw i64 %indvars.iv199, 1
111   %exitcond204.not = icmp eq i64 %indvars.iv.next200, %wide.trip.count203
112   br i1 %exitcond204.not, label %for.cond.cleanup5, label %for.body6
114 for.body10:                                       ; preds = %for.body10, %for.body10.preheader
115   %indvars.iv = phi i64 [ %indvars.iv.next.7, %for.body10 ], [ 0, %for.body10.preheader ]
116   %c.sroa.8127.2.in164 = phi x86_amx [ %i68, %for.body10 ], [ %i3, %for.body10.preheader ]
117   %niter = phi i64 [ %niter.nsub.7, %for.body10 ], [ %unroll_iter, %for.body10.preheader ]
118   %i13 = shl nsw i64 %indvars.iv, 4
119   %add.ptr14 = getelementptr inbounds i32, ptr %add.ptr, i64 %i13
120   %i15 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, ptr %add.ptr14, i64 %mul15)
121   %i16 = mul nsw i64 %i13, %conv23
122   %add.ptr22 = getelementptr inbounds i32, ptr %add.ptr19, i64 %i16
123   %i18 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, ptr %add.ptr22, i64 %mul24)
124   %i19 = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 16, i16 64, i16 64, x86_amx %c.sroa.8127.2.in164, x86_amx %i15, x86_amx %i18)
125   %indvars.iv.next = shl i64 %indvars.iv, 4
126   %i20 = or i64 %indvars.iv.next, 16
127   %add.ptr14.1 = getelementptr inbounds i32, ptr %add.ptr, i64 %i20
128   %i22 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, ptr nonnull %add.ptr14.1, i64 %mul15)
129   %i23 = mul nsw i64 %i20, %conv23
130   %add.ptr22.1 = getelementptr inbounds i32, ptr %add.ptr19, i64 %i23
131   %i25 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, ptr nonnull %add.ptr22.1, i64 %mul24)
132   %i26 = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 16, i16 64, i16 64, x86_amx %i19, x86_amx %i22, x86_amx %i25)
133   %indvars.iv.next.1 = shl i64 %indvars.iv, 4
134   %i27 = or i64 %indvars.iv.next.1, 32
135   %add.ptr14.2 = getelementptr inbounds i32, ptr %add.ptr, i64 %i27
136   %i29 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, ptr nonnull %add.ptr14.2, i64 %mul15)
137   %i30 = mul nsw i64 %i27, %conv23
138   %add.ptr22.2 = getelementptr inbounds i32, ptr %add.ptr19, i64 %i30
139   %i32 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, ptr nonnull %add.ptr22.2, i64 %mul24)
140   %i33 = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 16, i16 64, i16 64, x86_amx %i26, x86_amx %i29, x86_amx %i32)
141   %indvars.iv.next.2 = shl i64 %indvars.iv, 4
142   %i34 = or i64 %indvars.iv.next.2, 48
143   %add.ptr14.3 = getelementptr inbounds i32, ptr %add.ptr, i64 %i34
144   %i36 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, ptr nonnull %add.ptr14.3, i64 %mul15)
145   %i37 = mul nsw i64 %i34, %conv23
146   %add.ptr22.3 = getelementptr inbounds i32, ptr %add.ptr19, i64 %i37
147   %i39 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, ptr nonnull %add.ptr22.3, i64 %mul24)
148   %i40 = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 16, i16 64, i16 64, x86_amx %i33, x86_amx %i36, x86_amx %i39)
149   %indvars.iv.next.3 = shl i64 %indvars.iv, 4
150   %i41 = or i64 %indvars.iv.next.3, 64
151   %add.ptr14.4 = getelementptr inbounds i32, ptr %add.ptr, i64 %i41
152   %i43 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, ptr nonnull %add.ptr14.4, i64 %mul15)
153   %i44 = mul nsw i64 %i41, %conv23
154   %add.ptr22.4 = getelementptr inbounds i32, ptr %add.ptr19, i64 %i44
155   %i46 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, ptr nonnull %add.ptr22.4, i64 %mul24)
156   %i47 = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 16, i16 64, i16 64, x86_amx %i40, x86_amx %i43, x86_amx %i46)
157   %indvars.iv.next.4 = shl i64 %indvars.iv, 4
158   %i48 = or i64 %indvars.iv.next.4, 80
159   %add.ptr14.5 = getelementptr inbounds i32, ptr %add.ptr, i64 %i48
160   %i50 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, ptr nonnull %add.ptr14.5, i64 %mul15)
161   %i51 = mul nsw i64 %i48, %conv23
162   %add.ptr22.5 = getelementptr inbounds i32, ptr %add.ptr19, i64 %i51
163   %i53 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, ptr nonnull %add.ptr22.5, i64 %mul24)
164   %i54 = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 16, i16 64, i16 64, x86_amx %i47, x86_amx %i50, x86_amx %i53)
165   %indvars.iv.next.5 = shl i64 %indvars.iv, 4
166   %i55 = or i64 %indvars.iv.next.5, 96
167   %add.ptr14.6 = getelementptr inbounds i32, ptr %add.ptr, i64 %i55
168   %i57 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, ptr nonnull %add.ptr14.6, i64 %mul15)
169   %i58 = mul nsw i64 %i55, %conv23
170   %add.ptr22.6 = getelementptr inbounds i32, ptr %add.ptr19, i64 %i58
171   %i60 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, ptr nonnull %add.ptr22.6, i64 %mul24)
172   %i61 = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 16, i16 64, i16 64, x86_amx %i54, x86_amx %i57, x86_amx %i60)
173   %indvars.iv.next.6 = shl i64 %indvars.iv, 4
174   %i62 = or i64 %indvars.iv.next.6, 112
175   %add.ptr14.7 = getelementptr inbounds i32, ptr %add.ptr, i64 %i62
176   %i64 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, ptr nonnull %add.ptr14.7, i64 %mul15)
177   %i65 = mul nsw i64 %i62, %conv23
178   %add.ptr22.7 = getelementptr inbounds i32, ptr %add.ptr19, i64 %i65
179   %i67 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, ptr nonnull %add.ptr22.7, i64 %mul24)
180   %i68 = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 16, i16 64, i16 64, x86_amx %i61, x86_amx %i64, x86_amx %i67)
181   %indvars.iv.next.7 = add nuw nsw i64 %indvars.iv, 8
182   %niter.nsub.7 = add i64 %niter, -8
183   %niter.ncmp.7 = icmp eq i64 %niter.nsub.7, 0
184   br i1 %niter.ncmp.7, label %for.cond.cleanup9.loopexit.unr-lcssa, label %for.body10
187 declare x86_amx @llvm.x86.tilezero.internal(i16, i16)
188 declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, ptr, i64)
189 declare x86_amx @llvm.x86.tdpbssd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx)
190 declare void @llvm.x86.tilestored64.internal(i16, i16, ptr, i64, x86_amx)