2 // RUN: --linalg-generalize-named-ops --linalg-fuse-elementwise-ops \
3 // RUN: --sparsification --sparse-tensor-conversion \
4 // RUN: --linalg-bufferize --convert-linalg-to-loops \
5 // RUN: --convert-vector-to-scf --convert-scf-to-std \
6 // RUN: --func-bufferize --tensor-constant-bufferize --tensor-bufferize \
7 // RUN: --std-bufferize --finalizing-bufferize --lower-affine \
8 // RUN: --convert-vector-to-llvm --convert-memref-to-llvm \
9 // RUN: --convert-std-to-llvm --reconcile-unrealized-casts | \
10 // RUN: mlir-cpu-runner \
11 // RUN: -e entry -entry-point-result=void \
12 // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
16 #CSR = #sparse_tensor.encoding<{
17 dimLevelType = [ "dense", "compressed" ],
18 dimOrdering = affine_map<(i,j) -> (i,j)>
21 #DCSR = #sparse_tensor.encoding<{
22 dimLevelType = [ "compressed", "compressed" ],
23 dimOrdering = affine_map<(i,j) -> (i,j)>
28 // Computes C = A x B with all matrices dense.
30 func @matmul1(%A: tensor<4x8xf64>,
31 %B: tensor<8x4xf64>) -> tensor<4x4xf64> {
32 %C = arith.constant dense<0.0> : tensor<4x4xf64>
34 ins(%A, %B: tensor<4x8xf64>, tensor<8x4xf64>)
35 outs(%C: tensor<4x4xf64>) -> tensor<4x4xf64>
36 return %D: tensor<4x4xf64>
40 // Computes C = A x B with all matrices sparse (SpMSpM) in CSR.
42 func @matmul2(%A: tensor<4x8xf64, #CSR>,
43 %B: tensor<8x4xf64, #CSR>) -> tensor<4x4xf64, #CSR> {
44 %c4 = arith.constant 4 : index
45 %C = sparse_tensor.init [%c4, %c4] : tensor<4x4xf64, #CSR>
47 ins(%A, %B: tensor<4x8xf64, #CSR>, tensor<8x4xf64, #CSR>)
48 outs(%C: tensor<4x4xf64, #CSR>) -> tensor<4x4xf64, #CSR>
49 return %D: tensor<4x4xf64, #CSR>
53 // Computes C = A x B with all matrices sparse (SpMSpM) in DCSR.
55 func @matmul3(%A: tensor<4x8xf64, #DCSR>,
56 %B: tensor<8x4xf64, #DCSR>) -> tensor<4x4xf64, #DCSR> {
57 %c4 = arith.constant 4 : index
58 %C = sparse_tensor.init [%c4, %c4] : tensor<4x4xf64, #DCSR>
60 ins(%A, %B: tensor<4x8xf64, #DCSR>, tensor<8x4xf64, #DCSR>)
61 outs(%C: tensor<4x4xf64, #DCSR>) -> tensor<4x4xf64, #DCSR>
62 return %D: tensor<4x4xf64, #DCSR>
69 %c0 = arith.constant 0 : index
70 %d1 = arith.constant -1.0 : f64
72 // Initialize various matrices, dense for stress testing,
73 // and sparse to verify correct nonzero structure.
74 %da = arith.constant dense<[
75 [ 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1, 8.1 ],
76 [ 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2, 8.2 ],
77 [ 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3, 8.3 ],
78 [ 1.4, 2.4, 3.4, 4.4, 5.4, 6.4, 7.4, 8.4 ]
80 %db = arith.constant dense<[
81 [ 10.1, 11.1, 12.1, 13.1 ],
82 [ 10.2, 11.2, 12.2, 13.2 ],
83 [ 10.3, 11.3, 12.3, 13.3 ],
84 [ 10.4, 11.4, 12.4, 13.4 ],
85 [ 10.5, 11.5, 12.5, 13.5 ],
86 [ 10.6, 11.6, 12.6, 13.6 ],
87 [ 10.7, 11.7, 12.7, 13.7 ],
88 [ 10.8, 11.8, 12.8, 13.8 ]
90 %sa = arith.constant dense<[
91 [ 0.0, 2.1, 0.0, 0.0, 0.0, 6.1, 0.0, 0.0 ],
92 [ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 ],
93 [ 0.0, 2.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 ],
94 [ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0 ]
96 %sb = arith.constant dense<[
97 [ 0.0, 0.0, 0.0, 1.0 ],
98 [ 0.0, 0.0, 2.0, 0.0 ],
99 [ 0.0, 3.0, 0.0, 0.0 ],
100 [ 4.0, 0.0, 0.0, 0.0 ],
101 [ 0.0, 0.0, 0.0, 0.0 ],
102 [ 0.0, 5.0, 0.0, 0.0 ],
103 [ 0.0, 0.0, 6.0, 0.0 ],
104 [ 0.0, 0.0, 7.0, 8.0 ]
107 // Convert all these matrices to sparse format.
108 %a1 = sparse_tensor.convert %da : tensor<4x8xf64> to tensor<4x8xf64, #CSR>
109 %a2 = sparse_tensor.convert %da : tensor<4x8xf64> to tensor<4x8xf64, #DCSR>
110 %a3 = sparse_tensor.convert %sa : tensor<4x8xf64> to tensor<4x8xf64, #CSR>
111 %a4 = sparse_tensor.convert %sa : tensor<4x8xf64> to tensor<4x8xf64, #DCSR>
112 %b1 = sparse_tensor.convert %db : tensor<8x4xf64> to tensor<8x4xf64, #CSR>
113 %b2 = sparse_tensor.convert %db : tensor<8x4xf64> to tensor<8x4xf64, #DCSR>
114 %b3 = sparse_tensor.convert %sb : tensor<8x4xf64> to tensor<8x4xf64, #CSR>
115 %b4 = sparse_tensor.convert %sb : tensor<8x4xf64> to tensor<8x4xf64, #DCSR>
117 // Call kernels with dense.
118 %0 = call @matmul1(%da, %db)
119 : (tensor<4x8xf64>, tensor<8x4xf64>) -> tensor<4x4xf64>
120 %1 = call @matmul2(%a1, %b1)
121 : (tensor<4x8xf64, #CSR>,
122 tensor<8x4xf64, #CSR>) -> tensor<4x4xf64, #CSR>
123 %2 = call @matmul3(%a2, %b2)
124 : (tensor<4x8xf64, #DCSR>,
125 tensor<8x4xf64, #DCSR>) -> tensor<4x4xf64, #DCSR>
127 // Call kernels with one sparse.
128 %3 = call @matmul1(%sa, %db)
129 : (tensor<4x8xf64>, tensor<8x4xf64>) -> tensor<4x4xf64>
130 %4 = call @matmul2(%a3, %b1)
131 : (tensor<4x8xf64, #CSR>,
132 tensor<8x4xf64, #CSR>) -> tensor<4x4xf64, #CSR>
133 %5 = call @matmul3(%a4, %b2)
134 : (tensor<4x8xf64, #DCSR>,
135 tensor<8x4xf64, #DCSR>) -> tensor<4x4xf64, #DCSR>
137 // Call kernels with sparse.
138 %6 = call @matmul1(%sa, %sb)
139 : (tensor<4x8xf64>, tensor<8x4xf64>) -> tensor<4x4xf64>
140 %7 = call @matmul2(%a3, %b3)
141 : (tensor<4x8xf64, #CSR>,
142 tensor<8x4xf64, #CSR>) -> tensor<4x4xf64, #CSR>
143 %8 = call @matmul3(%a4, %b4)
144 : (tensor<4x8xf64, #DCSR>,
145 tensor<8x4xf64, #DCSR>) -> tensor<4x4xf64, #DCSR>
148 // CHECK: ( ( 388.76, 425.56, 462.36, 499.16 ),
149 // CHECK-SAME: ( 397.12, 434.72, 472.32, 509.92 ),
150 // CHECK-SAME: ( 405.48, 443.88, 482.28, 520.68 ),
151 // CHECK-SAME: ( 413.84, 453.04, 492.24, 531.44 ) )
153 %m0 = bufferization.to_memref %0 : memref<4x4xf64>
154 %v0 = vector.transfer_read %m0[%c0, %c0], %d1 : memref<4x4xf64>, vector<4x4xf64>
155 vector.print %v0 : vector<4x4xf64>
158 // CHECK: ( ( 388.76, 425.56, 462.36, 499.16 ),
159 // CHECK-SAME: ( 397.12, 434.72, 472.32, 509.92 ),
160 // CHECK-SAME: ( 405.48, 443.88, 482.28, 520.68 ),
161 // CHECK-SAME: ( 413.84, 453.04, 492.24, 531.44 ) )
163 %c1 = sparse_tensor.convert %1 : tensor<4x4xf64, #CSR> to tensor<4x4xf64>
164 %m1 = bufferization.to_memref %c1 : memref<4x4xf64>
165 %v1 = vector.transfer_read %m1[%c0, %c0], %d1 : memref<4x4xf64>, vector<4x4xf64>
166 vector.print %v1 : vector<4x4xf64>
169 // CHECK: ( ( 388.76, 425.56, 462.36, 499.16 ),
170 // CHECK-SAME: ( 397.12, 434.72, 472.32, 509.92 ),
171 // CHECK-SAME: ( 405.48, 443.88, 482.28, 520.68 ),
172 // CHECK-SAME: ( 413.84, 453.04, 492.24, 531.44 ) )
174 %c2 = sparse_tensor.convert %2 : tensor<4x4xf64, #DCSR> to tensor<4x4xf64>
175 %m2 = bufferization.to_memref %c2 : memref<4x4xf64>
176 %v2 = vector.transfer_read %m2[%c0, %c0], %d1 : memref<4x4xf64>, vector<4x4xf64>
177 vector.print %v2 : vector<4x4xf64>
180 // CHECK: ( ( 86.08, 94.28, 102.48, 110.68 ),
181 // CHECK-SAME: ( 0, 0, 0, 0 ),
182 // CHECK-SAME: ( 23.46, 25.76, 28.06, 30.36 ),
183 // CHECK-SAME: ( 10.8, 11.8, 12.8, 13.8 ) )
185 %m3 = bufferization.to_memref %3 : memref<4x4xf64>
186 %v3 = vector.transfer_read %m3[%c0, %c0], %d1 : memref<4x4xf64>, vector<4x4xf64>
187 vector.print %v3 : vector<4x4xf64>
190 // CHECK: ( ( 86.08, 94.28, 102.48, 110.68 ),
191 // CHECK-SAME: ( 0, 0, 0, 0 ),
192 // CHECK-SAME: ( 23.46, 25.76, 28.06, 30.36 ),
193 // CHECK-SAME: ( 10.8, 11.8, 12.8, 13.8 ) )
195 %c4 = sparse_tensor.convert %4 : tensor<4x4xf64, #CSR> to tensor<4x4xf64>
196 %m4 = bufferization.to_memref %c4 : memref<4x4xf64>
197 %v4 = vector.transfer_read %m4[%c0, %c0], %d1 : memref<4x4xf64>, vector<4x4xf64>
198 vector.print %v4 : vector<4x4xf64>
201 // CHECK: ( ( 86.08, 94.28, 102.48, 110.68 ),
202 // CHECK-SAME: ( 0, 0, 0, 0 ),
203 // CHECK-SAME: ( 23.46, 25.76, 28.06, 30.36 ),
204 // CHECK-SAME: ( 10.8, 11.8, 12.8, 13.8 ) )
206 %c5 = sparse_tensor.convert %5 : tensor<4x4xf64, #DCSR> to tensor<4x4xf64>
207 %m5 = bufferization.to_memref %c5 : memref<4x4xf64>
208 %v5 = vector.transfer_read %m5[%c0, %c0], %d1 : memref<4x4xf64>, vector<4x4xf64>
209 vector.print %v5 : vector<4x4xf64>
212 // CHECK: ( ( 0, 30.5, 4.2, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 4.6, 0 ), ( 0, 0, 7, 8 ) )
214 %m6 = bufferization.to_memref %6 : memref<4x4xf64>
215 %v6 = vector.transfer_read %m6[%c0, %c0], %d1 : memref<4x4xf64>, vector<4x4xf64>
216 vector.print %v6 : vector<4x4xf64>
219 // CHECK: ( ( 0, 30.5, 4.2, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 4.6, 0 ), ( 0, 0, 7, 8 ) )
221 %c7 = sparse_tensor.convert %7 : tensor<4x4xf64, #CSR> to tensor<4x4xf64>
222 %m7 = bufferization.to_memref %c7 : memref<4x4xf64>
223 %v7 = vector.transfer_read %m7[%c0, %c0], %d1 : memref<4x4xf64>, vector<4x4xf64>
224 vector.print %v7 : vector<4x4xf64>
227 // CHECK: ( ( 0, 30.5, 4.2, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 4.6, 0 ), ( 0, 0, 7, 8 ) )
229 %c8 = sparse_tensor.convert %8 : tensor<4x4xf64, #DCSR> to tensor<4x4xf64>
230 %m8 = bufferization.to_memref %c8 : memref<4x4xf64>
231 %v8 = vector.transfer_read %m8[%c0, %c0], %d1 : memref<4x4xf64>, vector<4x4xf64>
232 vector.print %v8 : vector<4x4xf64>
235 // Sanity check on nonzeros.
237 // CHECK: ( 30.5, 4.2, 4.6, 7, 8, -1, -1, -1 )
238 // CHECK: ( 30.5, 4.2, 4.6, 7, 8, -1, -1, -1 )
240 %val7 = sparse_tensor.values %7 : tensor<4x4xf64, #CSR> to memref<?xf64>
241 %val8 = sparse_tensor.values %8 : tensor<4x4xf64, #DCSR> to memref<?xf64>
242 %nz7 = vector.transfer_read %val7[%c0], %d1 : memref<?xf64>, vector<8xf64>
243 %nz8 = vector.transfer_read %val8[%c0], %d1 : memref<?xf64>, vector<8xf64>
244 vector.print %nz7 : vector<8xf64>
245 vector.print %nz8 : vector<8xf64>
247 // Release the resources.
248 sparse_tensor.release %a1 : tensor<4x8xf64, #CSR>
249 sparse_tensor.release %a2 : tensor<4x8xf64, #DCSR>
250 sparse_tensor.release %a3 : tensor<4x8xf64, #CSR>
251 sparse_tensor.release %a4 : tensor<4x8xf64, #DCSR>
252 sparse_tensor.release %b1 : tensor<8x4xf64, #CSR>
253 sparse_tensor.release %b2 : tensor<8x4xf64, #DCSR>
254 sparse_tensor.release %b3 : tensor<8x4xf64, #CSR>
255 sparse_tensor.release %b4 : tensor<8x4xf64, #DCSR>
256 sparse_tensor.release %1 : tensor<4x4xf64, #CSR>
257 sparse_tensor.release %2 : tensor<4x4xf64, #DCSR>
258 sparse_tensor.release %4 : tensor<4x4xf64, #CSR>
259 sparse_tensor.release %5 : tensor<4x4xf64, #DCSR>
260 sparse_tensor.release %7 : tensor<4x4xf64, #CSR>
261 sparse_tensor.release %8 : tensor<4x4xf64, #DCSR>
262 memref.dealloc %m0 : memref<4x4xf64>
263 memref.dealloc %m1 : memref<4x4xf64>
264 memref.dealloc %m2 : memref<4x4xf64>
265 memref.dealloc %m3 : memref<4x4xf64>
266 memref.dealloc %m4 : memref<4x4xf64>
267 memref.dealloc %m5 : memref<4x4xf64>
268 memref.dealloc %m6 : memref<4x4xf64>
269 memref.dealloc %m7 : memref<4x4xf64>
270 memref.dealloc %m8 : memref<4x4xf64>