Revert "[lldb][test] Remove compiler version check and use regex" (#124101)
[llvm-project.git] / llvm / test / CodeGen / X86 / AMX / amx-combine.ll
blob07f489c633c55824c509b0cbcfaf08679c84474d
1 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
2 ; RUN: opt --codegen-opt-level=2 -mtriple=x86_64 -lower-amx-type %s -S | FileCheck %s
4 define void @combine_store(ptr%p) {
5 ; CHECK-LABEL: @combine_store(
6 ; CHECK-NEXT:    [[T1:%.*]] = call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 64)
7 ; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 16, i16 64, ptr [[P:%.*]], i64 64, x86_amx [[T1]])
8 ; CHECK-NEXT:    ret void
10   %t1 = call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 64)
11   %t2 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %t1)
12   store <256 x i32> %t2, ptr %p, align 64
13   ret void
16 define <256 x i32> @combine_store_2user(ptr%p) {
17 ; CHECK-LABEL: @combine_store_2user(
18 ; CHECK-NEXT:    [[TMP1:%.*]] = alloca <256 x i32>, align 64
19 ; CHECK-NEXT:    [[T1:%.*]] = call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 64)
20 ; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 16, i16 64, ptr [[TMP1]], i64 64, x86_amx [[T1]])
21 ; CHECK-NEXT:    [[TMP2:%.*]] = load <256 x i32>, ptr [[TMP1]], align 1024
22 ; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 16, i16 64, ptr [[P:%.*]], i64 64, x86_amx [[T1]])
23 ; CHECK-NEXT:    ret <256 x i32> [[TMP2]]
25   %t1 = call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 64)
26   %t2 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %t1)
27   store <256 x i32> %t2, ptr %p, align 64
28   ret <256 x i32> %t2
31 define void @combine_load(ptr%p, ptr%p2) {
32 ; CHECK-LABEL: @combine_load(
33 ; CHECK-NEXT:    [[TMP1:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, ptr [[P:%.*]], i64 64)
34 ; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 16, i16 64, ptr [[P2:%.*]], i64 64, x86_amx [[TMP1]])
35 ; CHECK-NEXT:    ret void
37   %t1 = load <256 x i32>, ptr %p, align 64
38   %t2 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %t1)
39   call void @llvm.x86.tilestored64.internal(i16 16, i16 64, ptr %p2, i64 64, x86_amx %t2)
40   ret void
43 define void @combine_cast_across_store(ptr%p, ptr%p2) {
44 ; CHECK-LABEL: @combine_cast_across_store(
45 ; CHECK-NEXT:    [[TMP1:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, ptr [[P:%.*]], i64 64)
46 ; CHECK-NEXT:    store <256 x i32> zeroinitializer, ptr [[P]], align 64
47 ; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 16, i16 64, ptr [[P2:%.*]], i64 64, x86_amx [[TMP1]])
48 ; CHECK-NEXT:    ret void
50   %t1 = load <256 x i32>, ptr %p, align 64
51   store <256 x i32> zeroinitializer, ptr %p, align 64
52   %t2 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %t1)
53   call void @llvm.x86.tilestored64.internal(i16 16, i16 64, ptr %p2, i64 64, x86_amx %t2)
54   ret void
57 define <256 x i32> @combine_load_2user(ptr%p, ptr%p2) {
58 ; CHECK-LABEL: @combine_load_2user(
59 ; CHECK-NEXT:    [[TMP1:%.*]] = alloca <256 x i32>, align 64
60 ; CHECK-NEXT:    [[T1:%.*]] = load <256 x i32>, ptr [[P:%.*]], align 64
61 ; CHECK-NEXT:    store <256 x i32> [[T1]], ptr [[TMP1]], align 1024
62 ; CHECK-NEXT:    [[TMP2:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, ptr [[TMP1]], i64 64)
63 ; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 16, i16 64, ptr [[P2:%.*]], i64 64, x86_amx [[TMP2]])
64 ; CHECK-NEXT:    ret <256 x i32> [[T1]]
66   %t1 = load <256 x i32>, ptr %p, align 64
67   %t2 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %t1)
68   call void @llvm.x86.tilestored64.internal(i16 16, i16 64, ptr %p2, i64 64, x86_amx %t2)
69   %t3 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %t2)
70   ret <256 x i32> %t3
73 define <256 x i32> @combine_load_3user(ptr%p, ptr%p2) {
74 ; CHECK-LABEL: @combine_load_3user(
75 ; CHECK-NEXT:    [[TMP1:%.*]] = alloca <256 x i32>, align 64
76 ; CHECK-NEXT:    [[T1:%.*]] = load <256 x i32>, ptr [[P:%.*]], align 64
77 ; CHECK-NEXT:    store <256 x i32> [[T1]], ptr [[TMP1]], align 1024
78 ; CHECK-NEXT:    [[TMP2:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 16, ptr [[TMP1]], i64 16)
79 ; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 16, i16 64, ptr [[P2:%.*]], i64 64, x86_amx [[TMP2]])
80 ; CHECK-NEXT:    [[TMP3:%.*]] = call x86_amx @llvm.x86.tdpbssd.internal(i16 16, i16 16, i16 64, x86_amx [[TMP2]], x86_amx [[TMP2]], x86_amx [[TMP2]])
81 ; CHECK-NEXT:    ret <256 x i32> [[T1]]
83   %t1 = load <256 x i32>, ptr %p, align 64
84   %t2 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %t1)
85   call void @llvm.x86.tilestored64.internal(i16 16, i16 64, ptr %p2, i64 64, x86_amx %t2)
86   %t3 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %t2)
87   call x86_amx @llvm.x86.tdpbssd.internal(i16 16, i16 16, i16 64, x86_amx %t2, x86_amx %t2, x86_amx %t2)
88   ret <256 x i32> %t3
91 ; the shape is loaded after tile.
92 %struct.__tile1024i_str = type <{ i16, i16, [60 x i8], <256 x i32> }>
93 define void @test_tile_dpbssd(ptr byval(%struct.__tile1024i_str) align 64 %a, ptr byval(%struct.__tile1024i_str) align 64 %b, ptr byval(%struct.__tile1024i_str) align 64 %c) {
94 ; CHECK-LABEL: @test_tile_dpbssd(
95 ; CHECK-NEXT:  entry:
96 ; CHECK-NEXT:    [[TMP0:%.*]] = alloca <256 x i32>, align 64
97 ; CHECK-NEXT:    [[B_ROW_PTR:%.*]] = getelementptr inbounds i8, ptr [[B:%.*]], i64 2
98 ; CHECK-NEXT:    [[B_ROW:%.*]] = load i16, ptr [[B_ROW_PTR]], align 2
99 ; CHECK-NEXT:    [[B_TILE_PTR:%.*]] = getelementptr inbounds i8, ptr [[B]], i64 64
100 ; CHECK-NEXT:    [[TMP1:%.*]] = sext i16 [[B_ROW]] to i64
101 ; CHECK-NEXT:    [[B_TILE:%.*]] = load <256 x i32>, ptr [[B_TILE_PTR]], align 64
102 ; CHECK-NEXT:    store <256 x i32> [[B_TILE]], ptr [[TMP0]], align 1024
103 ; CHECK-NEXT:    [[A_ROW:%.*]] = load i16, ptr [[A:%.*]], align 64
104 ; CHECK-NEXT:    [[A_COL_PTR:%.*]] = getelementptr inbounds i8, ptr [[A]], i64 2
105 ; CHECK-NEXT:    [[A_COL:%.*]] = load i16, ptr [[A_COL_PTR]], align 2
106 ; CHECK-NEXT:    [[TMP2:%.*]] = udiv i16 [[A_COL]], 4
107 ; CHECK-NEXT:    [[A_TILE_PTR:%.*]] = getelementptr inbounds i8, ptr [[A]], i64 64
108 ; CHECK-NEXT:    [[TMP3:%.*]] = sext i16 [[A_COL]] to i64
109 ; CHECK-NEXT:    [[TMP4:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[A_ROW]], i16 [[A_COL]], ptr [[A_TILE_PTR]], i64 [[TMP3]])
110 ; CHECK-NEXT:    [[C_TILE_PTR:%.*]] = getelementptr inbounds [[STRUCT___TILE1024I_STR:%.*]], ptr [[C:%.*]], i64 0, i32 3
111 ; CHECK-NEXT:    [[TMP5:%.*]] = sext i16 [[B_ROW]] to i64
112 ; CHECK-NEXT:    [[TMP6:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[A_ROW]], i16 [[B_ROW]], ptr [[C_TILE_PTR]], i64 [[TMP5]])
113 ; CHECK-NEXT:    [[TMP7:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP2]], i16 [[B_ROW]], ptr [[TMP0]], i64 [[TMP1]])
114 ; CHECK-NEXT:    [[RES:%.*]] = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 [[A_ROW]], i16 [[B_ROW]], i16 [[A_COL]], x86_amx [[TMP6]], x86_amx [[TMP4]], x86_amx [[TMP7]])
115 ; CHECK-NEXT:    ret void
117 entry:
118   %b.row.ptr= getelementptr inbounds i8, ptr %b, i64 2
119   %b.row = load i16, ptr %b.row.ptr, align 2
120   %b.tile.ptr = getelementptr inbounds i8, ptr %b, i64 64
121   %b.tile = load <256 x i32>, ptr %b.tile.ptr, align 64
122   %a.row = load i16, ptr %a, align 64
123   %a.col.ptr = getelementptr inbounds i8, ptr %a, i64 2
124   %a.col = load i16, ptr %a.col.ptr, align 2
125   %a.tile.ptr = getelementptr inbounds i8, ptr %a, i64 64
126   %a.tile = load <256 x i32>, ptr %a.tile.ptr, align 64
127   %c.tile.ptr = getelementptr inbounds %struct.__tile1024i_str, ptr %c, i64 0, i32 3
128   %c.tile = load <256 x i32>, ptr %c.tile.ptr, align 64
129   %c.amx = tail call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %c.tile)
130   %a.amx = tail call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %a.tile)
131   %b.amx = tail call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %b.tile)
132   %res = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 %a.row, i16 %b.row, i16 %a.col, x86_amx %c.amx, x86_amx %a.amx, x86_amx %b.amx)
133   ret void
136 define void @combine_v256i8amcast_with_store(ptr %src_ptr, ptr %dst_ptr) {
137 ; CHECK-LABEL: @combine_v256i8amcast_with_store(
138 ; CHECK-NEXT:  entry:
139 ; CHECK-NEXT:    [[TILE:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 8, i16 32, ptr [[SRC_PTR:%.*]], i64 64)
140 ; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 8, i16 32, ptr [[DST_PTR:%.*]], i64 32, x86_amx [[TILE]])
141 ; CHECK-NEXT:    ret void
143 entry:
144   %tile = call x86_amx @llvm.x86.tileloadd64.internal(i16 8, i16 32, ptr %src_ptr, i64 64)
145   %vec = call <256 x i8> @llvm.x86.cast.tile.to.vector.v256i8(x86_amx %tile)
146   store <256 x i8> %vec, ptr %dst_ptr, align 256
147   ret void
150 define void @combine_v256i8amcast_with_load(ptr %src_ptr, ptr %dst_ptr) {
151 ; CHECK-LABEL: @combine_v256i8amcast_with_load(
152 ; CHECK-NEXT:  entry:
153 ; CHECK-NEXT:    [[TMP0:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 8, i16 32, ptr [[SRC_PTR:%.*]], i64 32)
154 ; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 8, i16 32, ptr [[DST_PTR:%.*]], i64 32, x86_amx [[TMP0]])
155 ; CHECK-NEXT:    ret void
157 entry:
158   %vec = load <256 x i8>, ptr %src_ptr, align 256
159   %tile = call x86_amx @llvm.x86.cast.vector.to.tile.v256i8(<256 x i8> %vec)
160   call void @llvm.x86.tilestored64.internal(i16 8, i16 32, ptr %dst_ptr, i64 32, x86_amx %tile)
161   ret void
164 declare x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32>)
165 declare <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx)
166 declare x86_amx @llvm.x86.cast.vector.to.tile.v256i8(<256 x i8>)
167 declare <256 x i8> @llvm.x86.cast.tile.to.vector.v256i8(x86_amx)
168 declare x86_amx @llvm.x86.tilezero.internal(i16, i16)
169 declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, ptr, i64)
170 declare void @llvm.x86.tilestored64.internal(i16, i16, ptr, i64, x86_amx)
171 declare x86_amx @llvm.x86.tdpbssd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx)