Re-land [openmp] Fix warnings when building on Windows with latest MSVC or Clang...
[llvm-project.git] / llvm / test / CodeGen / X86 / avx512vnni-combine.ll
blobc491a952682d53f81ba407f7e5bd042c1c65fb5d
1 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 2
2 ; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mcpu=sapphirerapids -verify-machineinstrs | FileCheck %s
4 define <8 x i64> @foo_reg_512(<8 x i64> %0, <8 x i64> %1, <8 x i64> %2, <8 x i64> %3, <8 x i64> %4, <8 x i64> %5) {
5 ; CHECK-LABEL: foo_reg_512:
6 ; CHECK:       # %bb.0:
7 ; CHECK-NEXT:    vpdpwssd %zmm2, %zmm1, %zmm0
8 ; CHECK-NEXT:    vpmaddwd %zmm3, %zmm1, %zmm2
9 ; CHECK-NEXT:    vpaddd %zmm2, %zmm0, %zmm0
10 ; CHECK-NEXT:    vpmaddwd %zmm4, %zmm1, %zmm2
11 ; CHECK-NEXT:    vpaddd %zmm2, %zmm0, %zmm0
12 ; CHECK-NEXT:    vpmaddwd %zmm5, %zmm1, %zmm1
13 ; CHECK-NEXT:    vpaddd %zmm1, %zmm0, %zmm0
14 ; CHECK-NEXT:    retq
15   %7 = bitcast <8 x i64> %0 to <16 x i32>
16   %8 = bitcast <8 x i64> %1 to <16 x i32>
17   %9 = bitcast <8 x i64> %2 to <16 x i32>
18   %10 = tail call <16 x i32> @llvm.x86.avx512.vpdpwssd.512(<16 x i32> %7, <16 x i32> %8, <16 x i32> %9)
19   %11 = bitcast <8 x i64> %3 to <16 x i32>
20   %12 = tail call <16 x i32> @llvm.x86.avx512.vpdpwssd.512(<16 x i32> %10, <16 x i32> %8, <16 x i32> %11)
21   %13 = bitcast <8 x i64> %4 to <16 x i32>
22   %14 = tail call <16 x i32> @llvm.x86.avx512.vpdpwssd.512(<16 x i32> %12, <16 x i32> %8, <16 x i32> %13)
23   %15 = bitcast <8 x i64> %5 to <16 x i32>
24   %16 = tail call <16 x i32> @llvm.x86.avx512.vpdpwssd.512(<16 x i32> %14, <16 x i32> %8, <16 x i32> %15)
25   %17 = bitcast <16 x i32> %16 to <8 x i64>
26   ret <8 x i64> %17
29 ; __m512i foo(int cnt, __m512i c, __m512i b, __m512i *p) {
31 ;     for (int i = 0; i < cnt; ++i) {
32 ;         __m512i a = p[i];
33 ;         __m512i m = _mm512_madd_epi16(b, a);
34 ;         c = _mm512_add_epi32(m, c);
35 ;     }
37 ;     return c;
38 ; }
39 define <8 x i64> @foo_512(i32 %0, <8 x i64> %1, <8 x i64> %2, ptr %3) {
40 ; CHECK-LABEL: foo_512:
41 ; CHECK:       # %bb.0:
42 ; CHECK-NEXT:    testl %edi, %edi
43 ; CHECK-NEXT:    jle .LBB1_6
44 ; CHECK-NEXT:  # %bb.1:
45 ; CHECK-NEXT:    movl %edi, %edx
46 ; CHECK-NEXT:    movl %edx, %eax
47 ; CHECK-NEXT:    andl $3, %eax
48 ; CHECK-NEXT:    cmpl $4, %edi
49 ; CHECK-NEXT:    jae .LBB1_7
50 ; CHECK-NEXT:  # %bb.2:
51 ; CHECK-NEXT:    xorl %ecx, %ecx
52 ; CHECK-NEXT:    jmp .LBB1_3
53 ; CHECK-NEXT:  .LBB1_7:
54 ; CHECK-NEXT:    andl $-4, %edx
55 ; CHECK-NEXT:    leaq 192(%rsi), %rdi
56 ; CHECK-NEXT:    xorl %ecx, %ecx
57 ; CHECK-NEXT:    .p2align 4, 0x90
58 ; CHECK-NEXT:  .LBB1_8: # =>This Inner Loop Header: Depth=1
59 ; CHECK-NEXT:    vpdpwssd -192(%rdi), %zmm1, %zmm0
60 ; CHECK-NEXT:    vpmaddwd -128(%rdi), %zmm1, %zmm2
61 ; CHECK-NEXT:    vpaddd %zmm2, %zmm0, %zmm0
62 ; CHECK-NEXT:    vpmaddwd -64(%rdi), %zmm1, %zmm2
63 ; CHECK-NEXT:    vpaddd %zmm2, %zmm0, %zmm0
64 ; CHECK-NEXT:    vpmaddwd (%rdi), %zmm1, %zmm2
65 ; CHECK-NEXT:    vpaddd %zmm2, %zmm0, %zmm0
66 ; CHECK-NEXT:    addq $4, %rcx
67 ; CHECK-NEXT:    addq $256, %rdi # imm = 0x100
68 ; CHECK-NEXT:    cmpq %rcx, %rdx
69 ; CHECK-NEXT:    jne .LBB1_8
70 ; CHECK-NEXT:  .LBB1_3:
71 ; CHECK-NEXT:    testq %rax, %rax
72 ; CHECK-NEXT:    je .LBB1_6
73 ; CHECK-NEXT:  # %bb.4: # %.preheader
74 ; CHECK-NEXT:    shlq $6, %rcx
75 ; CHECK-NEXT:    addq %rcx, %rsi
76 ; CHECK-NEXT:    shll $6, %eax
77 ; CHECK-NEXT:    xorl %ecx, %ecx
78 ; CHECK-NEXT:    .p2align 4, 0x90
79 ; CHECK-NEXT:  .LBB1_5: # =>This Inner Loop Header: Depth=1
80 ; CHECK-NEXT:    vpdpwssd (%rsi,%rcx), %zmm1, %zmm0
81 ; CHECK-NEXT:    addq $64, %rcx
82 ; CHECK-NEXT:    cmpq %rcx, %rax
83 ; CHECK-NEXT:    jne .LBB1_5
84 ; CHECK-NEXT:  .LBB1_6:
85 ; CHECK-NEXT:    retq
86   %5 = icmp sgt i32 %0, 0
87   br i1 %5, label %6, label %33
89 6:                                                ; preds = %4
90   %7 = bitcast <8 x i64> %2 to <32 x i16>
91   %8 = bitcast <8 x i64> %1 to <16 x i32>
92   %9 = zext i32 %0 to i64
93   %10 = and i64 %9, 3
94   %11 = icmp ult i32 %0, 4
95   br i1 %11, label %14, label %12
97 12:                                               ; preds = %6
98   %13 = and i64 %9, 4294967292
99   br label %35
101 14:                                               ; preds = %35, %6
102   %15 = phi <16 x i32> [ undef, %6 ], [ %57, %35 ]
103   %16 = phi i64 [ 0, %6 ], [ %58, %35 ]
104   %17 = phi <16 x i32> [ %8, %6 ], [ %57, %35 ]
105   %18 = icmp eq i64 %10, 0
106   br i1 %18, label %30, label %19
108 19:                                               ; preds = %14, %19
109   %20 = phi i64 [ %27, %19 ], [ %16, %14 ]
110   %21 = phi <16 x i32> [ %26, %19 ], [ %17, %14 ]
111   %22 = phi i64 [ %28, %19 ], [ 0, %14 ]
112   %23 = getelementptr inbounds <8 x i64>, ptr %3, i64 %20
113   %24 = load <32 x i16>, ptr %23, align 64
114   %25 = tail call <16 x i32> @llvm.x86.avx512.pmaddw.d.512(<32 x i16> %7, <32 x i16> %24)
115   %26 = add <16 x i32> %25, %21
116   %27 = add nuw nsw i64 %20, 1
117   %28 = add i64 %22, 1
118   %29 = icmp eq i64 %28, %10
119   br i1 %29, label %30, label %19
121 30:                                               ; preds = %19, %14
122   %31 = phi <16 x i32> [ %15, %14 ], [ %26, %19 ]
123   %32 = bitcast <16 x i32> %31 to <8 x i64>
124   br label %33
126 33:                                               ; preds = %30, %4
127   %34 = phi <8 x i64> [ %32, %30 ], [ %1, %4 ]
128   ret <8 x i64> %34
130 35:                                               ; preds = %35, %12
131   %36 = phi i64 [ 0, %12 ], [ %58, %35 ]
132   %37 = phi <16 x i32> [ %8, %12 ], [ %57, %35 ]
133   %38 = phi i64 [ 0, %12 ], [ %59, %35 ]
134   %39 = getelementptr inbounds <8 x i64>, ptr %3, i64 %36
135   %40 = load <32 x i16>, ptr %39, align 64
136   %41 = tail call <16 x i32> @llvm.x86.avx512.pmaddw.d.512(<32 x i16> %7, <32 x i16> %40)
137   %42 = add <16 x i32> %41, %37
138   %43 = or disjoint i64 %36, 1
139   %44 = getelementptr inbounds <8 x i64>, ptr %3, i64 %43
140   %45 = load <32 x i16>, ptr %44, align 64
141   %46 = tail call <16 x i32> @llvm.x86.avx512.pmaddw.d.512(<32 x i16> %7, <32 x i16> %45)
142   %47 = add <16 x i32> %46, %42
143   %48 = or disjoint i64 %36, 2
144   %49 = getelementptr inbounds <8 x i64>, ptr %3, i64 %48
145   %50 = load <32 x i16>, ptr %49, align 64
146   %51 = tail call <16 x i32> @llvm.x86.avx512.pmaddw.d.512(<32 x i16> %7, <32 x i16> %50)
147   %52 = add <16 x i32> %51, %47
148   %53 = or disjoint i64 %36, 3
149   %54 = getelementptr inbounds <8 x i64>, ptr %3, i64 %53
150   %55 = load <32 x i16>, ptr %54, align 64
151   %56 = tail call <16 x i32> @llvm.x86.avx512.pmaddw.d.512(<32 x i16> %7, <32 x i16> %55)
152   %57 = add <16 x i32> %56, %52
153   %58 = add nuw nsw i64 %36, 4
154   %59 = add i64 %38, 4
155   %60 = icmp eq i64 %59, %13
156   br i1 %60, label %14, label %35
159 ; void bar(int cnt, __m512i *c, __m512i b, __m512i *p) {
160 ;     for (int i = 0; i < cnt; ++i) {
161 ;         __m512i a = p[i];
162 ;         c[i] = _mm512_dpwssd_epi32(c[i], b, a);
163 ;     }
164 ; }
165 define void @bar_512(i32 %0, ptr %1, <8 x i64> %2, ptr %3) {
166 ; CHECK-LABEL: bar_512:
167 ; CHECK:       # %bb.0:
168 ; CHECK-NEXT:    testl %edi, %edi
169 ; CHECK-NEXT:    jle .LBB2_5
170 ; CHECK-NEXT:  # %bb.1:
171 ; CHECK-NEXT:    movl %edi, %eax
172 ; CHECK-NEXT:    cmpl $1, %edi
173 ; CHECK-NEXT:    jne .LBB2_6
174 ; CHECK-NEXT:  # %bb.2:
175 ; CHECK-NEXT:    xorl %ecx, %ecx
176 ; CHECK-NEXT:    jmp .LBB2_3
177 ; CHECK-NEXT:  .LBB2_6:
178 ; CHECK-NEXT:    movl %eax, %edi
179 ; CHECK-NEXT:    andl $-2, %edi
180 ; CHECK-NEXT:    movl $64, %r8d
181 ; CHECK-NEXT:    xorl %ecx, %ecx
182 ; CHECK-NEXT:    .p2align 4, 0x90
183 ; CHECK-NEXT:  .LBB2_7: # =>This Inner Loop Header: Depth=1
184 ; CHECK-NEXT:    vmovdqa64 (%rsi,%r8), %zmm1
185 ; CHECK-NEXT:    vpmaddwd -64(%rdx,%r8), %zmm0, %zmm2
186 ; CHECK-NEXT:    vpaddd -64(%rsi,%r8), %zmm2, %zmm2
187 ; CHECK-NEXT:    vmovdqa64 %zmm2, -64(%rsi,%r8)
188 ; CHECK-NEXT:    vpmaddwd (%rdx,%r8), %zmm0, %zmm2
189 ; CHECK-NEXT:    vpaddd %zmm2, %zmm1, %zmm1
190 ; CHECK-NEXT:    vmovdqa64 %zmm1, (%rsi,%r8)
191 ; CHECK-NEXT:    addq $2, %rcx
192 ; CHECK-NEXT:    subq $-128, %r8
193 ; CHECK-NEXT:    cmpq %rcx, %rdi
194 ; CHECK-NEXT:    jne .LBB2_7
195 ; CHECK-NEXT:  .LBB2_3:
196 ; CHECK-NEXT:    testb $1, %al
197 ; CHECK-NEXT:    je .LBB2_5
198 ; CHECK-NEXT:  # %bb.4:
199 ; CHECK-NEXT:    shlq $6, %rcx
200 ; CHECK-NEXT:    vpmaddwd (%rdx,%rcx), %zmm0, %zmm0
201 ; CHECK-NEXT:    vpaddd (%rsi,%rcx), %zmm0, %zmm0
202 ; CHECK-NEXT:    vmovdqa64 %zmm0, (%rsi,%rcx)
203 ; CHECK-NEXT:  .LBB2_5:
204 ; CHECK-NEXT:    vzeroupper
205 ; CHECK-NEXT:    retq
206   %5 = icmp sgt i32 %0, 0
207   br i1 %5, label %6, label %22
209 6:                                                ; preds = %4
210   %7 = bitcast <8 x i64> %2 to <16 x i32>
211   %8 = zext i32 %0 to i64
212   %9 = and i64 %8, 1
213   %10 = icmp eq i32 %0, 1
214   br i1 %10, label %13, label %11
216 11:                                               ; preds = %6
217   %12 = and i64 %8, 4294967294
218   br label %23
220 13:                                               ; preds = %23, %6
221   %14 = phi i64 [ 0, %6 ], [ %37, %23 ]
222   %15 = icmp eq i64 %9, 0
223   br i1 %15, label %22, label %16
225 16:                                               ; preds = %13
226   %17 = getelementptr inbounds <8 x i64>, ptr %3, i64 %14
227   %18 = load <16 x i32>, ptr %17, align 64
228   %19 = getelementptr inbounds <8 x i64>, ptr %1, i64 %14
229   %20 = load <16 x i32>, ptr %19, align 64
230   %21 = tail call <16 x i32> @llvm.x86.avx512.vpdpwssd.512(<16 x i32> %20, <16 x i32> %7, <16 x i32> %18)
231   store <16 x i32> %21, ptr %19, align 64
232   br label %22
234 22:                                               ; preds = %16, %13, %4
235   ret void
237 23:                                               ; preds = %23, %11
238   %24 = phi i64 [ 0, %11 ], [ %37, %23 ]
239   %25 = phi i64 [ 0, %11 ], [ %38, %23 ]
240   %26 = getelementptr inbounds <8 x i64>, ptr %3, i64 %24
241   %27 = load <16 x i32>, ptr %26, align 64
242   %28 = getelementptr inbounds <8 x i64>, ptr %1, i64 %24
243   %29 = load <16 x i32>, ptr %28, align 64
244   %30 = tail call <16 x i32> @llvm.x86.avx512.vpdpwssd.512(<16 x i32> %29, <16 x i32> %7, <16 x i32> %27)
245   store <16 x i32> %30, ptr %28, align 64
246   %31 = or disjoint i64 %24, 1
247   %32 = getelementptr inbounds <8 x i64>, ptr %3, i64 %31
248   %33 = load <16 x i32>, ptr %32, align 64
249   %34 = getelementptr inbounds <8 x i64>, ptr %1, i64 %31
250   %35 = load <16 x i32>, ptr %34, align 64
251   %36 = tail call <16 x i32> @llvm.x86.avx512.vpdpwssd.512(<16 x i32> %35, <16 x i32> %7, <16 x i32> %33)
252   store <16 x i32> %36, ptr %34, align 64
253   %37 = add nuw nsw i64 %24, 2
254   %38 = add i64 %25, 2
255   %39 = icmp eq i64 %38, %12
256   br i1 %39, label %13, label %23
259 declare <16 x i32> @llvm.x86.avx512.vpdpwssd.512(<16 x i32>, <16 x i32>, <16 x i32>) #3
260 declare <16 x i32> @llvm.x86.avx512.pmaddw.d.512(<32 x i16>, <32 x i16>) #3