1 # RUN: SUPPORT_LIB=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \
2 # RUN: %PYTHON %s | FileCheck %s
8 import mlir
.all_passes_registration
11 from mlir
import runtime
as rt
12 from mlir
import execution_engine
13 from mlir
import passmanager
15 from mlir
.dialects
import sparse_tensor
as st
16 from mlir
.dialects
import builtin
17 from mlir
.dialects
.linalg
.opdsl
import lang
as dsl
20 @dsl.linalg_structured_op
22 A
=dsl
.TensorDef(dsl
.T
, dsl
.S
.M
, dsl
.S
.K
),
23 B
=dsl
.TensorDef(dsl
.T
, dsl
.S
.K
, dsl
.S
.N
),
24 S
=dsl
.TensorDef(dsl
.T
, dsl
.S
.M
, dsl
.S
.N
),
25 C
=dsl
.TensorDef(dsl
.T
, dsl
.S
.M
, dsl
.S
.N
, output
=True)):
27 dsl
.D
.n
] += S
[dsl
.D
.m
, dsl
.D
.n
] * A
[dsl
.D
.m
, dsl
.D
.k
] * B
[dsl
.D
.k
, dsl
.D
.n
]
30 def build_SDDMM(attr
: st
.EncodingAttr
):
31 """Build SDDMM kernel.
33 This method generates a linalg op with for matrix multiplication using
34 just the Python API. Effectively, a generic linalg op is constructed
35 that computes C(i,j) += S(i,j) SUM_k A(i,k) B(k,j) for sparse S.
37 module
= ir
.Module
.create()
38 f64
= ir
.F64Type
.get()
39 a
= ir
.RankedTensorType
.get([8, 8], f64
)
40 b
= ir
.RankedTensorType
.get([8, 8], f64
)
41 c
= ir
.RankedTensorType
.get([8, 8], f64
)
42 s
= ir
.RankedTensorType
.get([8, 8], f64
, attr
)
43 arguments
= [a
, b
, s
, c
]
44 with ir
.InsertionPoint(module
.body
):
46 @builtin.FuncOp
.from_py_func(*arguments
)
48 return sddmm_dsl(args
[0], args
[1], args
[2], outs
=[args
[3]])
53 def boilerplate(attr
: st
.EncodingAttr
):
54 """Returns boilerplate code for main driver."""
56 func @main(%a: tensor<8x8xf64>,
58 %c: tensor<8x8xf64>) -> tensor<8x8xf64> attributes {{ llvm.emit_c_interface }} {{
59 %t = arith.constant sparse<[[0,0], [0,2], [4,1]], [1.0, 2.0, 3.0]> : tensor<8x8xf64>
60 %s = sparse_tensor.convert %t : tensor<8x8xf64> to tensor<8x8xf64, {attr}>
61 %0 = call @sddmm(%a, %b, %s, %c) : (tensor<8x8xf64>,
63 tensor<8x8xf64, {attr}>,
64 tensor<8x8xf64>) -> tensor<8x8xf64>
65 return %0 : tensor<8x8xf64>
70 def build_compile_and_run_SDDMMM(attr
: st
.EncodingAttr
, opt
: str,
71 support_lib
: str, compiler
):
73 module
= build_SDDMM(attr
)
74 func
= str(module
.operation
.regions
[0].blocks
[0].operations
[0].operation
)
75 module
= ir
.Module
.parse(func
+ boilerplate(attr
))
79 engine
= execution_engine
.ExecutionEngine(
80 module
, opt_level
=0, shared_libs
=[support_lib
])
82 # Set up numpy input and buffer for output.
83 a
= np
.array([[1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1, 8.1],
84 [1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2, 8.2],
85 [1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3, 8.3],
86 [1.4, 2.4, 3.4, 4.4, 5.4, 6.4, 7.4, 8.4],
87 [1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5],
88 [1.6, 2.6, 3.6, 4.6, 5.6, 6.6, 7.6, 8.6],
89 [1.7, 2.7, 3.7, 4.7, 5.7, 6.7, 7.7, 8.7],
90 [1.8, 2.8, 3.8, 4.8, 5.8, 6.8, 7.8, 8.8]], np
.float64
)
91 b
= np
.ones((8, 8), np
.float64
)
92 c
= np
.zeros((8, 8), np
.float64
)
94 mem_a
= ctypes
.pointer(ctypes
.pointer(rt
.get_ranked_memref_descriptor(a
)))
95 mem_b
= ctypes
.pointer(ctypes
.pointer(rt
.get_ranked_memref_descriptor(b
)))
96 mem_c
= ctypes
.pointer(ctypes
.pointer(rt
.get_ranked_memref_descriptor(c
)))
98 # Allocate a MemRefDescriptor to receive the output tensor.
99 # The buffer itself is allocated inside the MLIR code generation.
100 ref_out
= rt
.make_nd_memref_descriptor(2, ctypes
.c_double
)()
101 mem_out
= ctypes
.pointer(ctypes
.pointer(ref_out
))
103 # Invoke the kernel and get numpy output.
104 # Built-in bufferization uses in-out buffers.
105 # TODO: replace with inplace comprehensive bufferization.
106 engine
.invoke('main', mem_out
, mem_a
, mem_b
, mem_c
)
108 # Sanity check on computed result. Only a few elements
109 # are sampled from the full dense matrix multiplication.
110 full_matmul
= np
.matmul(a
, b
)
111 expected
= np
.zeros((8, 8), np
.float64
)
112 expected
[0, 0] = 1.0 * full_matmul
[0, 0]
113 expected
[0, 2] = 2.0 * full_matmul
[0, 2]
114 expected
[4, 1] = 3.0 * full_matmul
[4, 1]
115 c
= rt
.ranked_memref_to_numpy(mem_out
[0])
116 if np
.allclose(c
, expected
):
122 class SparseCompiler
:
123 """Sparse compiler passes."""
125 def __init__(self
, options
: str):
127 f
'sparsification{{{options}}},'
128 f
'sparse-tensor-conversion,'
129 f
'builtin.func(linalg-bufferize,convert-linalg-to-loops,convert-vector-to-scf),'
130 f
'convert-scf-to-std,'
132 f
'tensor-constant-bufferize,'
133 f
'builtin.func(tensor-bufferize,std-bufferize,finalizing-bufferize),'
134 f
'convert-vector-to-llvm{{reassociate-fp-reductions=1 enable-index-optimizations=1}},'
136 f
'convert-memref-to-llvm,'
137 f
'convert-std-to-llvm,'
138 f
'reconcile-unrealized-casts')
139 self
.pipeline
= pipeline
141 def __call__(self
, module
: ir
.Module
):
142 passmanager
.PassManager
.parse(self
.pipeline
).run(module
)
146 support_lib
= os
.getenv('SUPPORT_LIB')
147 assert support_lib
is not None, 'SUPPORT_LIB is undefined'
148 if not os
.path
.exists(support_lib
):
149 raise FileNotFoundError(errno
.ENOENT
, os
.strerror(errno
.ENOENT
),
152 # CHECK-LABEL: TEST: testSDDMMM
153 print('\nTEST: testSDDMMM')
154 with ir
.Context() as ctx
, ir
.Location
.unknown():
156 # Loop over various ways to compile and annotate the SDDMM kernel with
157 # a *single* sparse tensor. Note that we deliberate do not exhaustively
158 # search the full state space to reduce runtime of the test. It is
159 # straightforward to adapt the code below to explore more combinations.
160 levels
= [[st
.DimLevelType
.dense
, st
.DimLevelType
.dense
],
161 [st
.DimLevelType
.dense
, st
.DimLevelType
.compressed
],
162 [st
.DimLevelType
.compressed
, st
.DimLevelType
.dense
],
163 [st
.DimLevelType
.compressed
, st
.DimLevelType
.compressed
]]
165 ir
.AffineMap
.get_permutation([0, 1]),
166 ir
.AffineMap
.get_permutation([1, 0])
169 for ordering
in orderings
:
175 vl
= 1 if vec
== 0 else 16
176 attr
= st
.EncodingAttr
.get(level
, ordering
, pwidth
, iwidth
)
177 opt
= (f
'parallelization-strategy={par} '
178 f
'vectorization-strategy={vec} '
179 f
'vl={vl} enable-simd-index32={e}')
180 compiler
= SparseCompiler(options
=opt
)
181 build_compile_and_run_SDDMMM(attr
, opt
, support_lib
, compiler
)
183 # CHECK: Passed 16 tests
184 print('Passed ', count
, 'tests')
187 if __name__
== '__main__':