[AMDGPU] Make v8i16/v8f16 legal
[llvm-project.git] / mlir / test / Integration / Dialect / SparseTensor / CPU / sparse_matmul.mlir
blob8bf99f50da5b7efd4694511d654a5a39006c4ebf
1 // RUN: mlir-opt %s \
2 // RUN:   --linalg-generalize-named-ops --linalg-fuse-elementwise-ops \
3 // RUN:   --sparsification --sparse-tensor-conversion \
4 // RUN:   --linalg-bufferize --convert-linalg-to-loops \
5 // RUN:   --convert-vector-to-scf --convert-scf-to-std \
6 // RUN:   --func-bufferize --tensor-constant-bufferize --tensor-bufferize \
7 // RUN:   --std-bufferize --finalizing-bufferize --lower-affine \
8 // RUN:   --convert-vector-to-llvm --convert-memref-to-llvm \
9 // RUN:   --convert-std-to-llvm --reconcile-unrealized-casts | \
10 // RUN: mlir-cpu-runner \
11 // RUN:  -e entry -entry-point-result=void  \
12 // RUN:  -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
13 // RUN: FileCheck %s
16 #CSR = #sparse_tensor.encoding<{
17   dimLevelType = [ "dense", "compressed" ],
18   dimOrdering = affine_map<(i,j) -> (i,j)>
21 #DCSR = #sparse_tensor.encoding<{
22   dimLevelType = [ "compressed", "compressed" ],
23   dimOrdering = affine_map<(i,j) -> (i,j)>
26 module {
27   //
28   // Computes C = A x B with all matrices dense.
29   //
30   func @matmul1(%A: tensor<4x8xf64>,
31                 %B: tensor<8x4xf64>) -> tensor<4x4xf64> {
32     %C = arith.constant dense<0.0> : tensor<4x4xf64>
33     %D = linalg.matmul
34       ins(%A, %B: tensor<4x8xf64>, tensor<8x4xf64>)
35          outs(%C: tensor<4x4xf64>) -> tensor<4x4xf64>
36     return %D: tensor<4x4xf64>
37   }
39   //
40   // Computes C = A x B with all matrices sparse (SpMSpM) in CSR.
41   //
42   func @matmul2(%A: tensor<4x8xf64, #CSR>,
43                 %B: tensor<8x4xf64, #CSR>) -> tensor<4x4xf64, #CSR> {
44     %c4 = arith.constant 4 : index
45     %C = sparse_tensor.init [%c4, %c4] : tensor<4x4xf64, #CSR>
46     %D = linalg.matmul
47       ins(%A, %B: tensor<4x8xf64, #CSR>, tensor<8x4xf64, #CSR>)
48          outs(%C: tensor<4x4xf64, #CSR>) -> tensor<4x4xf64, #CSR>
49     return %D: tensor<4x4xf64, #CSR>
50   }
52   //
53   // Computes C = A x B with all matrices sparse (SpMSpM) in DCSR.
54   //
55   func @matmul3(%A: tensor<4x8xf64, #DCSR>,
56                 %B: tensor<8x4xf64, #DCSR>) -> tensor<4x4xf64, #DCSR> {
57     %c4 = arith.constant 4 : index
58     %C = sparse_tensor.init [%c4, %c4] : tensor<4x4xf64, #DCSR>
59     %D = linalg.matmul
60       ins(%A, %B: tensor<4x8xf64, #DCSR>, tensor<8x4xf64, #DCSR>)
61          outs(%C: tensor<4x4xf64, #DCSR>) -> tensor<4x4xf64, #DCSR>
62     return %D: tensor<4x4xf64, #DCSR>
63   }
65   //
66   // Main driver.
67   //
68   func @entry() {
69     %c0 = arith.constant 0 : index
70     %d1 = arith.constant -1.0 : f64
72     // Initialize various matrices, dense for stress testing,
73     // and sparse to verify correct nonzero structure.
74     %da = arith.constant dense<[
75         [ 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1, 8.1 ],
76         [ 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2, 8.2 ],
77         [ 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3, 8.3 ],
78         [ 1.4, 2.4, 3.4, 4.4, 5.4, 6.4, 7.4, 8.4 ]
79     ]> : tensor<4x8xf64>
80     %db = arith.constant dense<[
81         [ 10.1, 11.1, 12.1, 13.1 ],
82         [ 10.2, 11.2, 12.2, 13.2 ],
83         [ 10.3, 11.3, 12.3, 13.3 ],
84         [ 10.4, 11.4, 12.4, 13.4 ],
85         [ 10.5, 11.5, 12.5, 13.5 ],
86         [ 10.6, 11.6, 12.6, 13.6 ],
87         [ 10.7, 11.7, 12.7, 13.7 ],
88         [ 10.8, 11.8, 12.8, 13.8 ]
89     ]> : tensor<8x4xf64>
90     %sa = arith.constant dense<[
91         [ 0.0, 2.1, 0.0, 0.0, 0.0, 6.1, 0.0, 0.0 ],
92         [ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 ],
93         [ 0.0, 2.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 ],
94         [ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0 ]
95     ]> : tensor<4x8xf64>
96     %sb = arith.constant dense<[
97         [ 0.0, 0.0, 0.0, 1.0 ],
98         [ 0.0, 0.0, 2.0, 0.0 ],
99         [ 0.0, 3.0, 0.0, 0.0 ],
100         [ 4.0, 0.0, 0.0, 0.0 ],
101         [ 0.0, 0.0, 0.0, 0.0 ],
102         [ 0.0, 5.0, 0.0, 0.0 ],
103         [ 0.0, 0.0, 6.0, 0.0 ],
104         [ 0.0, 0.0, 7.0, 8.0 ]
105     ]> : tensor<8x4xf64>
107     // Convert all these matrices to sparse format.
108     %a1 = sparse_tensor.convert %da : tensor<4x8xf64> to tensor<4x8xf64, #CSR>
109     %a2 = sparse_tensor.convert %da : tensor<4x8xf64> to tensor<4x8xf64, #DCSR>
110     %a3 = sparse_tensor.convert %sa : tensor<4x8xf64> to tensor<4x8xf64, #CSR>
111     %a4 = sparse_tensor.convert %sa : tensor<4x8xf64> to tensor<4x8xf64, #DCSR>
112     %b1 = sparse_tensor.convert %db : tensor<8x4xf64> to tensor<8x4xf64, #CSR>
113     %b2 = sparse_tensor.convert %db : tensor<8x4xf64> to tensor<8x4xf64, #DCSR>
114     %b3 = sparse_tensor.convert %sb : tensor<8x4xf64> to tensor<8x4xf64, #CSR>
115     %b4 = sparse_tensor.convert %sb : tensor<8x4xf64> to tensor<8x4xf64, #DCSR>
117     // Call kernels with dense.
118     %0 = call @matmul1(%da, %db)
119        : (tensor<4x8xf64>, tensor<8x4xf64>) -> tensor<4x4xf64>
120     %1 = call @matmul2(%a1, %b1)
121        : (tensor<4x8xf64, #CSR>,
122           tensor<8x4xf64, #CSR>) -> tensor<4x4xf64, #CSR>
123     %2 = call @matmul3(%a2, %b2)
124        : (tensor<4x8xf64, #DCSR>,
125           tensor<8x4xf64, #DCSR>) -> tensor<4x4xf64, #DCSR>
127     // Call kernels with one sparse.
128     %3 = call @matmul1(%sa, %db)
129        : (tensor<4x8xf64>, tensor<8x4xf64>) -> tensor<4x4xf64>
130     %4 = call @matmul2(%a3, %b1)
131        : (tensor<4x8xf64, #CSR>,
132           tensor<8x4xf64, #CSR>) -> tensor<4x4xf64, #CSR>
133     %5 = call @matmul3(%a4, %b2)
134        : (tensor<4x8xf64, #DCSR>,
135           tensor<8x4xf64, #DCSR>) -> tensor<4x4xf64, #DCSR>
137     // Call kernels with sparse.
138     %6 = call @matmul1(%sa, %sb)
139        : (tensor<4x8xf64>, tensor<8x4xf64>) -> tensor<4x4xf64>
140     %7 = call @matmul2(%a3, %b3)
141        : (tensor<4x8xf64, #CSR>,
142           tensor<8x4xf64, #CSR>) -> tensor<4x4xf64, #CSR>
143     %8 = call @matmul3(%a4, %b4)
144        : (tensor<4x8xf64, #DCSR>,
145           tensor<8x4xf64, #DCSR>) -> tensor<4x4xf64, #DCSR>
147     //
148     // CHECK:    ( ( 388.76, 425.56, 462.36, 499.16 ),
149     // CHECK-SAME: ( 397.12, 434.72, 472.32, 509.92 ),
150     // CHECK-SAME: ( 405.48, 443.88, 482.28, 520.68 ),
151     // CHECK-SAME: ( 413.84, 453.04, 492.24, 531.44 ) )
152     //
153     %m0 = bufferization.to_memref %0 : memref<4x4xf64>
154     %v0 = vector.transfer_read %m0[%c0, %c0], %d1 : memref<4x4xf64>, vector<4x4xf64>
155     vector.print %v0 : vector<4x4xf64>
157     //
158     // CHECK:    ( ( 388.76, 425.56, 462.36, 499.16 ),
159     // CHECK-SAME: ( 397.12, 434.72, 472.32, 509.92 ),
160     // CHECK-SAME: ( 405.48, 443.88, 482.28, 520.68 ),
161     // CHECK-SAME: ( 413.84, 453.04, 492.24, 531.44 ) )
162     //
163     %c1 = sparse_tensor.convert %1 : tensor<4x4xf64, #CSR> to tensor<4x4xf64>
164     %m1 = bufferization.to_memref %c1 : memref<4x4xf64>
165     %v1 = vector.transfer_read %m1[%c0, %c0], %d1 : memref<4x4xf64>, vector<4x4xf64>
166     vector.print %v1 : vector<4x4xf64>
168     //
169     // CHECK:    ( ( 388.76, 425.56, 462.36, 499.16 ),
170     // CHECK-SAME: ( 397.12, 434.72, 472.32, 509.92 ),
171     // CHECK-SAME: ( 405.48, 443.88, 482.28, 520.68 ),
172     // CHECK-SAME: ( 413.84, 453.04, 492.24, 531.44 ) )
173     //
174     %c2 = sparse_tensor.convert %2 : tensor<4x4xf64, #DCSR> to tensor<4x4xf64>
175     %m2 = bufferization.to_memref %c2 : memref<4x4xf64>
176     %v2 = vector.transfer_read %m2[%c0, %c0], %d1 : memref<4x4xf64>, vector<4x4xf64>
177     vector.print %v2 : vector<4x4xf64>
179     //
180     // CHECK:    ( ( 86.08, 94.28, 102.48, 110.68 ),
181     // CHECK-SAME: ( 0, 0, 0, 0 ),
182     // CHECK-SAME: ( 23.46, 25.76, 28.06, 30.36 ),
183     // CHECK-SAME: ( 10.8, 11.8, 12.8, 13.8 ) )
184     //
185     %m3 = bufferization.to_memref %3 : memref<4x4xf64>
186     %v3 = vector.transfer_read %m3[%c0, %c0], %d1 : memref<4x4xf64>, vector<4x4xf64>
187     vector.print %v3 : vector<4x4xf64>
189     //
190     // CHECK:    ( ( 86.08, 94.28, 102.48, 110.68 ),
191     // CHECK-SAME: ( 0, 0, 0, 0 ),
192     // CHECK-SAME: ( 23.46, 25.76, 28.06, 30.36 ),
193     // CHECK-SAME: ( 10.8, 11.8, 12.8, 13.8 ) )
194     //
195     %c4 = sparse_tensor.convert %4 : tensor<4x4xf64, #CSR> to tensor<4x4xf64>
196     %m4 = bufferization.to_memref %c4 : memref<4x4xf64>
197     %v4 = vector.transfer_read %m4[%c0, %c0], %d1 : memref<4x4xf64>, vector<4x4xf64>
198     vector.print %v4 : vector<4x4xf64>
200     //
201     // CHECK:    ( ( 86.08, 94.28, 102.48, 110.68 ),
202     // CHECK-SAME: ( 0, 0, 0, 0 ),
203     // CHECK-SAME: ( 23.46, 25.76, 28.06, 30.36 ),
204     // CHECK-SAME: ( 10.8, 11.8, 12.8, 13.8 ) )
205     //
206     %c5 = sparse_tensor.convert %5 : tensor<4x4xf64, #DCSR> to tensor<4x4xf64>
207     %m5 = bufferization.to_memref %c5 : memref<4x4xf64>
208     %v5 = vector.transfer_read %m5[%c0, %c0], %d1 : memref<4x4xf64>, vector<4x4xf64>
209     vector.print %v5 : vector<4x4xf64>
211     //
212     // CHECK: ( ( 0, 30.5, 4.2, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 4.6, 0 ), ( 0, 0, 7, 8 ) )
213     //
214     %m6 = bufferization.to_memref %6 : memref<4x4xf64>
215     %v6 = vector.transfer_read %m6[%c0, %c0], %d1 : memref<4x4xf64>, vector<4x4xf64>
216     vector.print %v6 : vector<4x4xf64>
218     //
219     // CHECK: ( ( 0, 30.5, 4.2, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 4.6, 0 ), ( 0, 0, 7, 8 ) )
220     //
221     %c7 = sparse_tensor.convert %7 : tensor<4x4xf64, #CSR> to tensor<4x4xf64>
222     %m7 = bufferization.to_memref %c7 : memref<4x4xf64>
223     %v7 = vector.transfer_read %m7[%c0, %c0], %d1 : memref<4x4xf64>, vector<4x4xf64>
224     vector.print %v7 : vector<4x4xf64>
226     //
227     // CHECK: ( ( 0, 30.5, 4.2, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 4.6, 0 ), ( 0, 0, 7, 8 ) )
228     //
229     %c8 = sparse_tensor.convert %8 : tensor<4x4xf64, #DCSR> to tensor<4x4xf64>
230     %m8 = bufferization.to_memref %c8 : memref<4x4xf64>
231     %v8 = vector.transfer_read %m8[%c0, %c0], %d1 : memref<4x4xf64>, vector<4x4xf64>
232     vector.print %v8 : vector<4x4xf64>
234     //
235     // Sanity check on nonzeros.
236     //
237     // CHECK: ( 30.5, 4.2, 4.6, 7, 8, -1, -1, -1 )
238     // CHECK: ( 30.5, 4.2, 4.6, 7, 8, -1, -1, -1 )
239     //
240     %val7 = sparse_tensor.values %7 : tensor<4x4xf64, #CSR> to memref<?xf64>
241     %val8 = sparse_tensor.values %8 : tensor<4x4xf64, #DCSR> to memref<?xf64>
242     %nz7 = vector.transfer_read %val7[%c0], %d1 : memref<?xf64>, vector<8xf64>
243     %nz8 = vector.transfer_read %val8[%c0], %d1 : memref<?xf64>, vector<8xf64>
244     vector.print %nz7 : vector<8xf64>
245     vector.print %nz8 : vector<8xf64>
247     // Release the resources.
248     sparse_tensor.release %a1 : tensor<4x8xf64, #CSR>
249     sparse_tensor.release %a2 : tensor<4x8xf64, #DCSR>
250     sparse_tensor.release %a3 : tensor<4x8xf64, #CSR>
251     sparse_tensor.release %a4 : tensor<4x8xf64, #DCSR>
252     sparse_tensor.release %b1 : tensor<8x4xf64, #CSR>
253     sparse_tensor.release %b2 : tensor<8x4xf64, #DCSR>
254     sparse_tensor.release %b3 : tensor<8x4xf64, #CSR>
255     sparse_tensor.release %b4 : tensor<8x4xf64, #DCSR>
256     sparse_tensor.release %1 : tensor<4x4xf64, #CSR>
257     sparse_tensor.release %2 : tensor<4x4xf64, #DCSR>
258     sparse_tensor.release %4 : tensor<4x4xf64, #CSR>
259     sparse_tensor.release %5 : tensor<4x4xf64, #DCSR>
260     sparse_tensor.release %7 : tensor<4x4xf64, #CSR>
261     sparse_tensor.release %8 : tensor<4x4xf64, #DCSR>
262     memref.dealloc %m0 : memref<4x4xf64>
263     memref.dealloc %m1 : memref<4x4xf64>
264     memref.dealloc %m2 : memref<4x4xf64>
265     memref.dealloc %m3 : memref<4x4xf64>
266     memref.dealloc %m4 : memref<4x4xf64>
267     memref.dealloc %m5 : memref<4x4xf64>
268     memref.dealloc %m6 : memref<4x4xf64>
269     memref.dealloc %m7 : memref<4x4xf64>
270     memref.dealloc %m8 : memref<4x4xf64>
272     return
273   }