[AArch64][NFC] NFC for const vector as Instruction operand (#116790)
[llvm-project.git] / llvm / test / Transforms / LowerMatrixIntrinsics / transpose-opts-lifting.ll
blobfcf83b03bc3d2362b2e9bcab9c65c819d6a6175b
1 ; RUN: opt -p lower-matrix-intrinsics -matrix-print-after-transpose-opt -disable-output -S %s 2>&1 | FileCheck %s
3 ; REQUIRES: asserts
5 target datalayout = "e-m:o-i64:64-f80:128-n8:16:32:64-S128"
7 ; FIXME: Lifted transpose dimensions are incorrect.
8 define <6 x double> @lift_through_add_matching_transpose_dimensions(<6 x double> %a, <6 x double> %b) {
9 ; CHECK-LABEL:  define <6 x double> @lift_through_add_matching_transpose_dimensions(<6 x double> %a, <6 x double> %b) {
10 ; CHECK-NEXT:  entry:
11 ; CHECK-NEXT:    [[A:%.+]] = fadd <6 x double> %a, %b
12 ; CHECK-NEXT:    [[T:%.+]] = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> [[A]], i32 3, i32 2)
13 ; CHECK-NEXT:    ret <6 x double> [[T]]
15 entry:
16   %a.t = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> %a, i32 3, i32 2)
17   %b.t = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> %b, i32 3, i32 2)
18   %add = fadd <6 x double> %a.t, %b.t
19   ret <6 x double> %add
22 define <6 x double> @lift_through_add_matching_transpose_dimensions_ops_also_have_shape_info(ptr %a.ptr, ptr %b.ptr) {
23 ; CHECK-LABEL: define <6 x double> @lift_through_add_matching_transpose_dimensions_ops_also_have_shape_info(ptr %a.ptr, ptr %b.ptr)
24 ; CHECK-NEXT:  entry:
25 ; CHECK-NEXT:    [[A:%.+]] = load <6 x double>, ptr %a.ptr
26 ; CHECK-NEXT:    [[B:%.+]] = load <6 x double>, ptr %b.ptr
27 ; CHECK-NEXT:    [[ADD:%.+]] = fadd <6 x double> [[A]], [[B]]
28 ; CHECK-NEXT:    [[T:%.+]] = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> [[ADD]], i32 3, i32 2)
29 ; CHECK-NEXT:    ret <6 x double> [[T]]
31 entry:
32   %a = load <6 x double>, ptr %a.ptr
33   %b = load <6 x double>, ptr %b.ptr
34   %a.t = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> %a, i32 3, i32 2)
35   %b.t = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> %b, i32 3, i32 2)
36   %add = fadd <6 x double> %a.t, %b.t
37   ret <6 x double> %add
40 define <6 x double> @lift_through_add_mismatching_dimensions_1(<6 x double> %a, <6 x double> %b) {
41 ; CHECK-LABEL:  define <6 x double> @lift_through_add_mismatching_dimensions_1(<6 x double> %a, <6 x double> %b) {
42 ; CHECK-NEXT:  entry:
43 ; CHECK-NEXT:    [[A:%.+]] = fadd <6 x double> %a, %b
44 ; CHECK-NEXT:    [[T:%.+]] = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> [[A]], i32 1, i32 6)
45 ; CHECK-NEXT:    ret <6 x double> [[T]]
47 entry:
48   %a.t = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> %a, i32 1, i32 6)
49   %b.t = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> %b, i32 3, i32 2)
50   %add = fadd <6 x double> %a.t, %b.t
51   ret <6 x double> %add
54 define <6 x double> @lift_through_add_mismatching_dimensions_1_transpose_dimensions_ops_also_have_shape_info(ptr %a.ptr, ptr %b.ptr) {
55 ; CHECK-LABEL: define <6 x double> @lift_through_add_mismatching_dimensions_1_transpose_dimensions_ops_also_have_shape_info(ptr %a.ptr, ptr %b.ptr)
56 ; CHECK-NEXT:  entry:
57 ; CHECK-NEXT:    [[A:%.+]] = load <6 x double>, ptr %a.ptr
58 ; CHECK-NEXT:    [[B:%.+]] = load <6 x double>, ptr %b.ptr
59 ; CHECK-NEXT:    [[ADD:%.+]] = fadd <6 x double> [[A]], [[B]]
60 ; CHECK-NEXT:    [[T:%.+]] = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> [[ADD]], i32 1, i32 6)
61 ; CHECK-NEXT:    ret <6 x double> [[T]]
63 entry:
64   %a = load <6 x double>, ptr %a.ptr
65   %b = load <6 x double>, ptr %b.ptr
66   %a.t = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> %a, i32 1, i32 6)
67   %b.t = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> %b, i32 3, i32 2)
68   %add = fadd <6 x double> %a.t, %b.t
69   ret <6 x double> %add
72 define <6 x double> @lift_through_add_mismatching_dimensions_2(<6 x double> %a, <6 x double> %b) {
73 ; CHECK-LABEL:  define <6 x double> @lift_through_add_mismatching_dimensions_2(<6 x double> %a, <6 x double> %b) {
74 ; CHECK-NEXT:  entry:
75 ; CHECK-NEXT:    [[A:%.+]] = fadd <6 x double> %a, %b
76 ; CHECK-NEXT:    [[T:%.+]] = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> [[A]], i32 3, i32 2)
77 ; CHECK-NEXT:    ret <6 x double> [[T]]
80 entry:
81   %a.t = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> %a, i32 3, i32 2)
82   %b.t = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> %b, i32 6, i32 1)
83   %add = fadd <6 x double> %a.t, %b.t
84   ret <6 x double> %add
87 define <6 x double> @lift_through_add_mismatching_dimensions_2_transpose_dimensions_ops_also_have_shape_info(ptr %a.ptr, ptr %b.ptr) {
88 ; CHECK-LABEL: define <6 x double> @lift_through_add_mismatching_dimensions_2_transpose_dimensions_ops_also_have_shape_info(ptr %a.ptr, ptr %b.ptr)
89 ; CHECK-NEXT:  entry:
90 ; CHECK-NEXT:    [[A:%.+]] = load <6 x double>, ptr %a.ptr
91 ; CHECK-NEXT:    [[B:%.+]] = load <6 x double>, ptr %b.ptr
92 ; CHECK-NEXT:    [[ADD:%.+]] = fadd <6 x double> [[A]], [[B]]
93 ; CHECK-NEXT:    [[T:%.+]] = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> [[ADD]], i32 3, i32 2)
94 ; CHECK-NEXT:    ret <6 x double> [[T]]
96 entry:
97   %a = load <6 x double>, ptr %a.ptr
98   %b = load <6 x double>, ptr %b.ptr
99   %a.t = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> %a, i32 3, i32 2)
100   %b.t = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> %b, i32 6, i32 1)
101   %add = fadd <6 x double> %a.t, %b.t
102   ret <6 x double> %add
105 define <9 x double> @lift_through_multiply(<6 x double> %a, <6 x double> %b) {
106 ; CHECK-LABEL: define <9 x double> @lift_through_multiply(<6 x double> %a, <6 x double> %b) {
107 ; CHECK-NEXT:  entry:
108 ; CHECK-NEXT:    [[MUL:%.+]] = call <9 x double> @llvm.matrix.multiply.v9f64.v6f64.v6f64(<6 x double> %b, <6 x double> %a, i32 3, i32 2, i32 3)
109 ; CHECK-NEXT:    [[T:%.+]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[MUL]], i32 3, i32 3)
110 ; CHECK-NEXT:   ret <9 x double> [[T]]
112 entry:
113   %a.t = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> %a, i32 3, i32 2)
114   %b.t = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> %b, i32 2, i32 3)
115   %mul = call <9 x double> @llvm.matrix.multiply.v9f64.v6f64(<6 x double> %a.t, <6 x double> %b.t, i32 3, i32 2 , i32 3)
116   ret <9 x double> %mul
119 define <6 x double> @lift_through_multiply_2(<6 x double> %a, <4 x double> %b) {
120 ; CHECK-LABEL: define <6 x double> @lift_through_multiply_2(<6 x double> %a, <4 x double> %b) {
121 ; CHECK-NEXT:  entry:
122 ; CHECK-NEXT:    [[MUL:%.+]] = call <6 x double> @llvm.matrix.multiply.v6f64.v4f64.v6f64(<4 x double> %b, <6 x double> %a, i32 2, i32 2, i32 3)
123 ; CHECK-NEXT:    [[T:%.+]] = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> [[MUL]], i32 2, i32 3)
124 ; CHECK-NEXT:    ret <6 x double> [[T]]
126 entry:
127   %a.t = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> %a, i32 3, i32 2)
128   %b.t = call <4 x double> @llvm.matrix.transpose.v4f64(<4 x double> %b, i32 2, i32 2)
129   %mul = call <6 x double> @llvm.matrix.multiply.v6f64.v6f64.v4f64(<6 x double> %a.t, <4 x double> %b.t, i32 3, i32 2 , i32 2)
130   ret <6 x double> %mul
133 define <6 x double> @lift_through_multiply_3(<4 x double> %a, <6 x double> %b) {
134 ; CHECK-LABEL: define <6 x double> @lift_through_multiply_3(<4 x double> %a, <6 x double> %b) {
135 ; CHECK-NEXT:  entry:
136 ; CHECK-NEXT:    [[MUL:%.+]] = call <6 x double> @llvm.matrix.multiply.v6f64.v6f64.v4f64(<6 x double> %b, <4 x double> %a, i32 3, i32 2, i32 2)
137 ; CHECK-NEXT:    [[T:%.+]] = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> [[MUL]], i32 3, i32 2)
138 ; CHECK-NEXT:    ret <6 x double> [[T]]
140 entry:
141   %a.t = call <4 x double> @llvm.matrix.transpose.v4f64(<4 x double> %a, i32 2, i32 2)
142   %b.t = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> %b, i32 2, i32 3)
143   %mul = call <6 x double> @llvm.matrix.multiply.v6f64.v4f64.v6f64(<4 x double> %a.t, <6 x double> %b.t, i32 2, i32 2 , i32 3)
144   ret <6 x double> %mul
147 declare <6 x double> @llvm.matrix.transpose.v6f64.v6f64(<6 x double>, i32, i32)
148 declare <4 x double> @llvm.matrix.transpose.v4f64.v4f64(<4 x double>, i32, i32)
149 declare <9 x double> @llvm.matrix.multiply.v9f64.v6f64(<6 x double>, <6 x double>, i32, i32, i32)
150 declare <6 x double> @llvm.matrix.multiply.v6f64.v6f64.v4f64(<6 x double>, <4 x double>, i32, i32, i32)
151 declare <6 x double> @llvm.matrix.multiply.v6f64.v6f64.v6f64(<6 x double>, <4 x double>, i32, i32, i32)