[clang][modules] Don't prevent translation of FW_Private includes when explicitly...
[llvm-project.git] / llvm / test / Transforms / LowerMatrixIntrinsics / after-transpose-opts.ll
blob4a3b121afb6f534072d0e84fa78e80cabc72a2b3
1 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
2 ; REQUIRES: aarch64-registered-target
4 ; RUN: opt -passes='lower-matrix-intrinsics' -matrix-print-after-transpose-opt -disable-output %s 2>&1 | FileCheck %s
6 target datalayout = "e-m:o-i64:64-f80:128-n8:16:32:64-S128"
7 target triple = "aarch64-apple-ios"
9 ; k * A^T
10 define void @kat(ptr %Aptr, double %k, ptr %C) {
11 ; CHECK-LABEL: @kat(
12 ; CHECK-NEXT:  entry:
13 ; CHECK-NEXT:    [[A:%.*]] = load <9 x double>, ptr [[APTR:%.*]], align 128
14 ; CHECK-NEXT:    [[VECK:%.*]] = insertelement <9 x double> poison, double [[K:%.*]], i64 0
15 ; CHECK-NEXT:    [[SPLAT:%.*]] = shufflevector <9 x double> [[VECK]], <9 x double> poison, <9 x i32> zeroinitializer
16 ; CHECK-NEXT:    [[AT:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[A]], i32 3, i32 3)
17 ; CHECK-NEXT:    [[MUL:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[SPLAT]], <9 x double> [[AT]], i32 3, i32 3, i32 3)
18 ; CHECK-NEXT:    store <9 x double> [[MUL]], ptr [[C:%.*]], align 128
19 ; CHECK-NEXT:    ret void
21 entry:
22   %a = load <9 x double>, ptr %Aptr
23   %veck = insertelement <9 x double> poison, double %k, i64 0
24   %splat = shufflevector <9 x double> %veck, <9 x double> poison, <9 x i32> zeroinitializer
25   %at = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %a, i32 3, i32 3)
26   %mul = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %splat, <9 x double> %at, i32 3, i32 3, i32 3)
27   store <9 x double> %mul, ptr %C
28   ret void
31 ; (k * A)^T -> A^T * k
32 define void @ka_t(ptr %Aptr, double %k, ptr %C) {
33 ; CHECK-LABEL: @ka_t(
34 ; CHECK-NEXT:  entry:
35 ; CHECK-NEXT:    [[A:%.*]] = load <9 x double>, ptr [[APTR:%.*]], align 128
36 ; CHECK-NEXT:    [[VECK:%.*]] = insertelement <9 x double> poison, double [[K:%.*]], i64 0
37 ; CHECK-NEXT:    [[SPLAT:%.*]] = shufflevector <9 x double> [[VECK]], <9 x double> poison, <9 x i32> zeroinitializer
38 ; CHECK-NEXT:    [[A_T:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[A]], i32 3, i32 3)
39 ; CHECK-NEXT:    [[MMUL:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[A_T]], <9 x double> [[SPLAT]], i32 3, i32 3, i32 3)
40 ; CHECK-NEXT:    store <9 x double> [[MMUL]], ptr [[C:%.*]], align 128
41 ; CHECK-NEXT:    ret void
43 entry:
44   %a = load <9 x double>, ptr %Aptr
45   %veck = insertelement <9 x double> poison, double %k, i64 0
46   %splat = shufflevector <9 x double> %veck, <9 x double> poison, <9 x i32> zeroinitializer
47   %mul = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %splat, <9 x double> %a, i32 3, i32 3, i32 3)
48   %t = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %mul, i32 3, i32 3)
49   store <9 x double> %t, ptr %C
50   ret void
53 ; (k * A)^T -> A^T * k with fmul
54 define void @ka_t_fmul(ptr %Aptr, double %k, ptr %C) {
55 ; CHECK-LABEL: @ka_t_fmul(
56 ; CHECK-NEXT:  entry:
57 ; CHECK-NEXT:    [[A:%.*]] = load <9 x double>, ptr [[APTR:%.*]], align 128
58 ; CHECK-NEXT:    [[VECK:%.*]] = insertelement <9 x double> poison, double [[K:%.*]], i64 0
59 ; CHECK-NEXT:    [[SPLAT:%.*]] = shufflevector <9 x double> [[VECK]], <9 x double> poison, <9 x i32> zeroinitializer
60 ; CHECK-NEXT:    [[A_T:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[A]], i32 3, i32 3)
61 ; CHECK-NEXT:    [[MMUL:%.*]] = fmul <9 x double> [[SPLAT]], [[A_T]]
62 ; CHECK-NEXT:    store <9 x double> [[MMUL]], ptr [[C:%.*]], align 128
63 ; CHECK-NEXT:    ret void
65 entry:
66   %a = load <9 x double>, ptr %Aptr
67   %veck = insertelement <9 x double> poison, double %k, i64 0
68   %splat = shufflevector <9 x double> %veck, <9 x double> poison, <9 x i32> zeroinitializer
69   %mul = fmul <9 x double> %splat, %a
70   %t = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %mul, i32 3, i32 3)
71   store <9 x double> %t, ptr %C
72   ret void
75 ; (k * A)^T -> A^T * k with mul (non-fp types)
76 define void @ka_t_mul(ptr %Aptr, i32 %k, ptr %C) {
77 ; CHECK-LABEL: @ka_t_mul(
78 ; CHECK-NEXT:  entry:
79 ; CHECK-NEXT:    [[A:%.*]] = load <9 x i32>, ptr [[APTR:%.*]], align 64
80 ; CHECK-NEXT:    [[VECK:%.*]] = insertelement <9 x i32> poison, i32 [[K:%.*]], i64 0
81 ; CHECK-NEXT:    [[SPLAT:%.*]] = shufflevector <9 x i32> [[VECK]], <9 x i32> poison, <9 x i32> zeroinitializer
82 ; CHECK-NEXT:    [[A_T:%.*]] = call <9 x i32> @llvm.matrix.transpose.v9i32(<9 x i32> [[A]], i32 3, i32 3)
83 ; CHECK-NEXT:    [[MMUL:%.*]] = mul <9 x i32> [[SPLAT]], [[A_T]]
84 ; CHECK-NEXT:    store <9 x i32> [[MMUL]], ptr [[C:%.*]], align 64
85 ; CHECK-NEXT:    ret void
87 entry:
88   %a = load <9 x i32>, ptr %Aptr
89   %veck = insertelement <9 x i32> poison, i32 %k, i64 0
90   %splat = shufflevector <9 x i32> %veck, <9 x i32> poison, <9 x i32> zeroinitializer
91   %mul = mul <9 x i32> %splat, %a
92   %t = call <9 x i32> @llvm.matrix.transpose.v9i32.v9i32(<9 x i32> %mul, i32 3, i32 3)
93   store <9 x i32> %t, ptr %C
94   ret void
97 ; A^T + B^T -> (A + B)^T
98 define void @at_plus_bt(ptr %Aptr, ptr %Bptr, ptr %C) {
99 ; CHECK-LABEL: @at_plus_bt(
100 ; CHECK-NEXT:  entry:
101 ; CHECK-NEXT:    [[A:%.*]] = load <9 x double>, ptr [[APTR:%.*]], align 128
102 ; CHECK-NEXT:    [[B:%.*]] = load <9 x double>, ptr [[BPTR:%.*]], align 128
103 ; CHECK-NEXT:    [[MFADD:%.*]] = fadd <9 x double> [[A]], [[B]]
104 ; CHECK-NEXT:    [[MFADD_T:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[MFADD]], i32 3, i32 3)
105 ; CHECK-NEXT:    store <9 x double> [[MFADD_T]], ptr [[C:%.*]], align 128
106 ; CHECK-NEXT:    ret void
108 entry:
109   %a = load <9 x double>, ptr %Aptr
110   %b = load <9 x double>, ptr %Bptr
111   %at = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %a, i32 3, i32 3)
112   %bt = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %b, i32 3, i32 3)
113   %fadd = fadd <9 x double> %at, %bt
114   store <9 x double> %fadd, ptr %C
115   ret void
118 ; (A + B)^T -> A^T + B^T -> (A + B)^T
119 define void @a_plus_b_t(ptr %Aptr, ptr %Bptr, ptr %C) {
120 ; CHECK-LABEL: @a_plus_b_t(
121 ; CHECK-NEXT:  entry:
122 ; CHECK-NEXT:    [[A:%.*]] = load <9 x double>, ptr [[APTR:%.*]], align 128
123 ; CHECK-NEXT:    [[B:%.*]] = load <9 x double>, ptr [[BPTR:%.*]], align 128
124 ; CHECK-NEXT:    [[MFADD:%.*]] = fadd <9 x double> [[A]], [[B]]
125 ; CHECK-NEXT:    [[MFADD_T:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[MFADD]], i32 3, i32 3)
126 ; CHECK-NEXT:    store <9 x double> [[MFADD_T]], ptr [[C:%.*]], align 128
127 ; CHECK-NEXT:    ret void
129 entry:
130   %a = load <9 x double>, ptr %Aptr
131   %b = load <9 x double>, ptr %Bptr
132   %fadd = fadd <9 x double> %a, %b
133   %t = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %fadd, i32 3, i32 3)
134   store <9 x double> %t, ptr %C
135   ret void
138 ; A^T * B^T + C^T * D^T -> (B * A + D * C)^T
139 define void @atbt_plus_ctdt(ptr %Aptr, ptr %Bptr, ptr %Cptr, ptr %Dptr, ptr %E) {
140 ; CHECK-LABEL: @atbt_plus_ctdt(
141 ; CHECK-NEXT:  entry:
142 ; CHECK-NEXT:    [[A:%.*]] = load <9 x double>, ptr [[APTR:%.*]], align 128
143 ; CHECK-NEXT:    [[B:%.*]] = load <9 x double>, ptr [[BPTR:%.*]], align 128
144 ; CHECK-NEXT:    [[C:%.*]] = load <9 x double>, ptr [[CPTR:%.*]], align 128
145 ; CHECK-NEXT:    [[D:%.*]] = load <9 x double>, ptr [[DPTR:%.*]], align 128
146 ; CHECK-NEXT:    [[TMP0:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[B]], <9 x double> [[A]], i32 3, i32 3, i32 3)
147 ; CHECK-NEXT:    [[TMP1:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[D]], <9 x double> [[C]], i32 3, i32 3, i32 3)
148 ; CHECK-NEXT:    [[MFADD:%.*]] = fadd <9 x double> [[TMP0]], [[TMP1]]
149 ; CHECK-NEXT:    [[MFADD_T:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[MFADD]], i32 3, i32 3)
150 ; CHECK-NEXT:    store <9 x double> [[MFADD_T]], ptr [[E:%.*]], align 128
151 ; CHECK-NEXT:    ret void
153 entry:
154   %a = load <9 x double>, ptr %Aptr
155   %b = load <9 x double>, ptr %Bptr
156   %c = load <9 x double>, ptr %Cptr
157   %d = load <9 x double>, ptr %Dptr
158   %at = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %a, i32 3, i32 3)
159   %bt = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %b, i32 3, i32 3)
160   %ct = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %c, i32 3, i32 3)
161   %dt = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %d, i32 3, i32 3)
162   %atbt = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %at, <9 x double> %bt, i32 3, i32 3, i32 3)
163   %ctdt = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %ct, <9 x double> %dt, i32 3, i32 3, i32 3)
164   %fadd = fadd <9 x double> %atbt, %ctdt
165   store <9 x double> %fadd, ptr %E
166   ret void
169 ; -(A^T) + B^T
170 define void @negat_plus_bt(ptr %Aptr, ptr %Bptr, ptr %C) {
171 ; CHECK-LABEL: @negat_plus_bt(
172 ; CHECK-NEXT:  entry:
173 ; CHECK-NEXT:    [[A:%.*]] = load <9 x double>, ptr [[APTR:%.*]], align 128
174 ; CHECK-NEXT:    [[B:%.*]] = load <9 x double>, ptr [[BPTR:%.*]], align 128
175 ; CHECK-NEXT:    [[AT:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[A]], i32 3, i32 3)
176 ; CHECK-NEXT:    [[NEGAT:%.*]] = fneg <9 x double> [[AT]]
177 ; CHECK-NEXT:    [[BT:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[B]], i32 3, i32 3)
178 ; CHECK-NEXT:    [[FADD:%.*]] = fadd <9 x double> [[NEGAT]], [[BT]]
179 ; CHECK-NEXT:    store <9 x double> [[FADD]], ptr [[C:%.*]], align 128
180 ; CHECK-NEXT:    ret void
182 entry:
183   %a = load <9 x double>, ptr %Aptr
184   %b = load <9 x double>, ptr %Bptr
185   %at = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %a, i32 3, i32 3)
186   %negat = fneg <9 x double> %at
187   %bt = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %b, i32 3, i32 3)
188   %fadd = fadd <9 x double> %negat, %bt
189   store <9 x double> %fadd, ptr %C
190   ret void
193 ; (A^T * B^T + k * C^T * D^T)^T -> (B * A) + (D * C * k)
194 define void @atbt_plus_kctdt_t(ptr %Aptr, ptr %Bptr, ptr %Cptr, ptr %Dptr, double %k, ptr %E) {
195 ; CHECK-LABEL: @atbt_plus_kctdt_t(
196 ; CHECK-NEXT:  entry:
197 ; CHECK-NEXT:    [[A:%.*]] = load <9 x double>, ptr [[APTR:%.*]], align 128
198 ; CHECK-NEXT:    [[B:%.*]] = load <9 x double>, ptr [[BPTR:%.*]], align 128
199 ; CHECK-NEXT:    [[C:%.*]] = load <9 x double>, ptr [[CPTR:%.*]], align 128
200 ; CHECK-NEXT:    [[D:%.*]] = load <9 x double>, ptr [[DPTR:%.*]], align 128
201 ; CHECK-NEXT:    [[VECK:%.*]] = insertelement <9 x double> poison, double [[K:%.*]], i64 0
202 ; CHECK-NEXT:    [[SPLAT:%.*]] = shufflevector <9 x double> [[VECK]], <9 x double> poison, <9 x i32> zeroinitializer
203 ; CHECK-NEXT:    [[MMUL2:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[B]], <9 x double> [[A]], i32 3, i32 3, i32 3)
204 ; CHECK-NEXT:    [[MMUL1:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[C]], <9 x double> [[SPLAT]], i32 3, i32 3, i32 3)
205 ; CHECK-NEXT:    [[MMUL:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[D]], <9 x double> [[MMUL1]], i32 3, i32 3, i32 3)
206 ; CHECK-NEXT:    [[MADD:%.*]] = fadd <9 x double> [[MMUL2]], [[MMUL]]
207 ; CHECK-NEXT:    store <9 x double> [[MADD]], ptr [[E:%.*]], align 128
208 ; CHECK-NEXT:    ret void
210 entry:
211   %a = load <9 x double>, ptr %Aptr
212   %b = load <9 x double>, ptr %Bptr
213   %c = load <9 x double>, ptr %Cptr
214   %d = load <9 x double>, ptr %Dptr
215   %at = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %a, i32 3, i32 3)
216   %bt = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %b, i32 3, i32 3)
217   %ct = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %c, i32 3, i32 3)
218   %dt = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %d, i32 3, i32 3)
219   %atbt = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %at, <9 x double> %bt, i32 3, i32 3, i32 3)
220   %veck = insertelement <9 x double> poison, double %k, i64 0
221   %splat = shufflevector <9 x double> %veck, <9 x double> poison, <9 x i32> zeroinitializer
222   %kct = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %splat, <9 x double> %ct, i32 3, i32 3, i32 3)
223   %kctdt = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %kct, <9 x double> %dt, i32 3, i32 3, i32 3)
224   %fadd = fadd <9 x double> %atbt, %kctdt
225   %t = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %fadd, i32 3, i32 3)
226   store <9 x double> %t, ptr %E
227   ret void
230 ; (A^T * (k * B^T))^T => (B * k) * A
231 define void @atkbt_t(ptr %Aptr, ptr %Bptr, double %k, ptr %C) {
232 ; CHECK-LABEL: @atkbt_t(
233 ; CHECK-NEXT:  entry:
234 ; CHECK-NEXT:    [[A:%.*]] = load <9 x double>, ptr [[APTR:%.*]], align 128
235 ; CHECK-NEXT:    [[B:%.*]] = load <9 x double>, ptr [[BPTR:%.*]], align 128
236 ; CHECK-NEXT:    [[VECK:%.*]] = insertelement <9 x double> poison, double [[K:%.*]], i64 0
237 ; CHECK-NEXT:    [[SPLAT:%.*]] = shufflevector <9 x double> [[VECK]], <9 x double> poison, <9 x i32> zeroinitializer
238 ; CHECK-NEXT:    [[MMUL1:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[B]], <9 x double> [[SPLAT]], i32 3, i32 3, i32 3)
239 ; CHECK-NEXT:    [[MMUL:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[MMUL1]], <9 x double> [[A]], i32 3, i32 3, i32 3)
240 ; CHECK-NEXT:    store <9 x double> [[MMUL]], ptr [[C:%.*]], align 128
241 ; CHECK-NEXT:    ret void
243 entry:
244   %a = load <9 x double>, ptr %Aptr
245   %b = load <9 x double>, ptr %Bptr
246   %at = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %a, i32 3, i32 3)
247   %bt = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %b, i32 3, i32 3)
248   %veck = insertelement <9 x double> poison, double %k, i64 0
249   %splat = shufflevector <9 x double> %veck, <9 x double> poison, <9 x i32> zeroinitializer
250   %kbt = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %splat, <9 x double> %bt, i32 3, i32 3, i32 3)
251   %atkbt = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %at, <9 x double> %kbt, i32 3, i32 3, i32 3)
252   %t = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %atkbt, i32 3, i32 3)
253   store <9 x double> %t, ptr %C
254   ret void
257 declare <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double>, <9 x double>, i32, i32, i32)
258 declare <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double>, i32, i32)
259 declare <9 x i32> @llvm.matrix.transpose.v9i32.v9i32(<9 x i32>, i32, i32)
262 ; (a * b + c)^T -> (a * b)^T + b^T with integer types.
263 define noundef <4 x i32> @mul_add_transpose_int(<4 x i32> noundef %a, <4 x i32> noundef %b, <4 x i32> noundef %c) {
264 ; CHECK-LABEL: @mul_add_transpose_int(
265 ; CHECK-NEXT:  entry:
266 ; CHECK-NEXT:    [[TMP0:%.*]] = call <4 x i32> @llvm.matrix.multiply.v4i32.v4i32.v4i32(<4 x i32> [[A:%.*]], <4 x i32> [[B:%.*]], i32 2, i32 2, i32 2)
267 ; CHECK-NEXT:    [[TMP1:%.*]] = call <4 x i32> @llvm.matrix.transpose.v4i32(<4 x i32> [[TMP0]], i32 2, i32 2)
268 ; CHECK-NEXT:    [[C_T:%.*]] = call <4 x i32> @llvm.matrix.transpose.v4i32(<4 x i32> [[C:%.*]], i32 2, i32 2)
269 ; CHECK-NEXT:    [[MADD:%.*]] = add <4 x i32> [[TMP1]], [[C_T]]
270 ; CHECK-NEXT:    ret <4 x i32> [[MADD]]
272 entry:
273   %mul = tail call <4 x i32> @llvm.matrix.multiply.v4i32.v4i32.v4i32(<4 x i32> %a, <4 x i32> %b, i32 2, i32 2, i32 2)
274   %add = add <4 x i32> %mul, %c
275   %t = tail call <4 x i32> @llvm.matrix.transpose.v4i32(<4 x i32> %add, i32 2, i32 2)
276   ret <4 x i32> %t
279 declare <4 x i32> @llvm.matrix.multiply.v4i32.v4i32.v4i32(<4 x i32>, <4 x i32>, i32 immarg, i32 immarg, i32 immarg)
281 declare <4 x i32> @llvm.matrix.transpose.v4i32(<4 x i32>, i32 immarg, i32 immarg)