1 // RUN: mlir-opt %s -convert-vector-to-xegpu -split-input-file | FileCheck %s
3 func.func @load_1D_vector(%source: memref<8x16x32xf32>, %offset: index) -> vector<8xf32> {
4 %c0 = arith.constant 0.0 : f32
5 %0 = vector.transfer_read %source[%offset, %offset, %offset], %c0
6 {in_bounds = [true]} : memref<8x16x32xf32>, vector<8xf32>
7 return %0 : vector<8xf32>
10 // CHECK-LABEL: @load_1D_vector(
11 // CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
12 // CHECK-SAME: %[[OFFSET:.+]]: index
13 // CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
14 // CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
15 // CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32,
16 // CHECK-SAME: boundary_check = false
17 // CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8xf32>
18 // CHECK: return %[[VEC]]
22 func.func @load_2D_vector(%source: memref<8x16x32xf32>,
23 %offset: index) -> vector<8x16xf32> {
24 %c0 = arith.constant 0.0 : f32
25 %0 = vector.transfer_read %source[%offset, %offset, %offset], %c0
26 {in_bounds = [true, true]} : memref<8x16x32xf32>, vector<8x16xf32>
27 return %0 : vector<8x16xf32>
30 // CHECK-LABEL: @load_2D_vector(
31 // CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
32 // CHECK-SAME: %[[OFFSET:.+]]: index
33 // CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
34 // CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
35 // CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32,
36 // CHECK-SAME: boundary_check = false
37 // CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
38 // CHECK: return %[[VEC]]
42 func.func @load_zero_pad_out_of_bounds(%source: memref<32x64xf32>,
43 %offset: index) -> vector<8x16xf32> {
44 %c0 = arith.constant 0.0 : f32
45 %0 = vector.transfer_read %source[%offset, %offset], %c0
46 {in_bounds = [false, true]} : memref<32x64xf32>, vector<8x16xf32>
47 return %0 : vector<8x16xf32>
50 // CHECK-LABEL: @load_zero_pad_out_of_bounds(
51 // CHECK-SAME: %[[SRC:.+]]: memref<32x64xf32>,
52 // CHECK-SAME: %[[OFFSET:.+]]: index
53 // CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
54 // CHECK-SAME: memref<32x64xf32> -> !xegpu.tensor_desc<8x16xf32,
55 // CHECK-SAME: boundary_check = true
56 // CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
57 // CHECK: return %[[VEC]]
61 func.func @load_transposed(%source: memref<32x64xf32>,
62 %offset: index) -> vector<8x16xf32> {
63 %c0 = arith.constant 0.0 : f32
64 %0 = vector.transfer_read %source[%offset, %offset], %c0
65 {permutation_map = affine_map<(d0, d1) -> (d1, d0)>,
66 in_bounds = [true, true]} : memref<32x64xf32>, vector<8x16xf32>
67 return %0 : vector<8x16xf32>
70 // CHECK-LABEL: @load_transposed(
71 // CHECK-SAME: %[[SRC:.+]]: memref<32x64xf32>,
72 // CHECK-SAME: %[[OFFSET:.+]]: index
73 // CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
74 // CHECK-SAME: memref<32x64xf32> -> !xegpu.tensor_desc<16x8xf32
75 // CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]] <{transpose = array<i64: 1, 0>}>
76 // CHECK-SAME: -> vector<8x16xf32>
77 // CHECK: return %[[VEC]]
81 func.func @load_dynamic_source(%source: memref<?x?x?xf32>,
82 %offset: index) -> vector<8x16xf32> {
83 %c0 = arith.constant 0.0 : f32
84 %0 = vector.transfer_read %source[%offset, %offset, %offset], %c0
85 {in_bounds = [true, true]} : memref<?x?x?xf32>, vector<8x16xf32>
86 return %0 : vector<8x16xf32>
89 // CHECK-LABEL: @load_dynamic_source(
90 // CHECK-SAME: %[[SRC:.+]]: memref<?x?x?xf32>,
91 // CHECK-SAME: %[[OFFSET:.+]]: index
92 // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
93 // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
94 // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
95 // CHECK-DAG: %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
96 // CHECK-DAG: %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
97 // CHECK-DAG: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
98 // CHECK: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
99 // CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
100 // CHECK-SAME: [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
101 // CHECK-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32
102 // CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
103 // CHECK: return %[[VEC]]
107 func.func @no_load_out_of_bounds_non_zero_pad(%source: memref<32x64xf32>,
108 %offset: index, %arg2: index, %pad: f32) -> (vector<8x16xf32>, vector<8x16xf32>) {
109 %c1 = arith.constant 1.0 : f32
110 %0 = vector.transfer_read %source[%offset, %arg2], %c1
111 {in_bounds = [true, false]} : memref<32x64xf32>, vector<8x16xf32>
112 %1 = vector.transfer_read %source[%arg2, %offset], %pad
113 {in_bounds = [false, true]} : memref<32x64xf32>, vector<8x16xf32>
114 return %0, %1 : vector<8x16xf32>, vector<8x16xf32>
117 // CHECK-LABEL: @no_load_out_of_bounds_non_zero_pad(
118 // CHECK-COUNT-2: vector.transfer_read
122 func.func @no_load_masked(%source : memref<4xf32>,
123 %offset : index) -> vector<4xf32> {
124 %c0 = arith.constant 0.0 : f32
125 %mask = arith.constant dense<[0, 1, 0, 1]> : vector<4xi1>
126 %0 = vector.transfer_read %source[%offset], %c0, %mask
127 {in_bounds = [true]} : memref<4xf32>, vector<4xf32>
128 return %0 : vector<4xf32>
131 // CHECK-LABEL: @no_load_masked(
132 // CHECK: vector.transfer_read
136 func.func @no_load_tensor(%source: tensor<32x64xf32>,
137 %offset: index, %arg2: index) -> vector<8x16xf32> {
138 %c0 = arith.constant 0.0 : f32
139 %0 = vector.transfer_read %source[%offset, %arg2], %c0
140 {in_bounds = [true, true]} : tensor<32x64xf32>, vector<8x16xf32>
141 return %0 : vector<8x16xf32>
144 // CHECK-LABEL: @no_load_tensor(
145 // CHECK: vector.transfer_read
149 func.func @no_load_high_dim_vector(%source: memref<16x32x64xf32>,
150 %offset: index, %arg2: index) -> vector<8x16x32xf32> {
151 %c0 = arith.constant 0.0 : f32
152 %0 = vector.transfer_read %source[%offset, %arg2, %offset], %c0
153 {in_bounds = [true, true, true]} : memref<16x32x64xf32>, vector<8x16x32xf32>
154 return %0 : vector<8x16x32xf32>
157 // CHECK-LABEL: @no_load_high_dim_vector(
158 // CHECK: vector.transfer_read
162 func.func @no_load_non_unit_inner_stride(
163 %source: memref<32xf32, strided<[?], offset: ?>>,
164 %offset: index) -> vector<8xf32> {
165 %c0 = arith.constant 0.0 : f32
166 %0 = vector.transfer_read %source[%offset], %c0 {in_bounds = [true]}
167 : memref<32xf32, strided<[?], offset: ?>>, vector<8xf32>
168 return %0 : vector<8xf32>
171 // CHECK-LABEL: @no_load_non_unit_inner_stride(
172 // CHECK: vector.transfer_read
176 func.func @no_load_unsupported_map(%source: memref<16x32x64xf32>,
177 %offset: index) -> vector<8x16xf32> {
178 %c0 = arith.constant 0.0 : f32
179 %0 = vector.transfer_read %source[%offset, %offset, %offset], %c0
180 {permutation_map = affine_map<(d0, d1, d2) -> (d0, d2)>,
181 in_bounds = [true, true]} : memref<16x32x64xf32>, vector<8x16xf32>
182 return %0 : vector<8x16xf32>
185 // CHECK-LABEL: @no_load_unsupported_map(
186 // CHECK: vector.transfer_read
190 func.func @no_load_transpose_unsupported_data_type(%source: memref<32x64xf16>,
191 %offset: index) -> vector<8x16xf16> {
192 %c0 = arith.constant 0.0 : f16
193 %0 = vector.transfer_read %source[%offset, %offset], %c0
194 {permutation_map = affine_map<(d0, d1) -> (d1, d0)>,
195 in_bounds = [true, true]} : memref<32x64xf16>, vector<8x16xf16>
196 return %0 : vector<8x16xf16>
199 // CHECK-LABEL: @no_load_transpose_unsupported_data_type(
200 // CHECK: vector.transfer_read