1 # RUN: env SUPPORT_LIB=%mlir_c_runner_utils \
2 # RUN: %PYTHON %s | FileCheck %s
10 from mlir
import runtime
as rt
12 from mlir
.dialects
import sparse_tensor
as st
13 from mlir
.dialects
import builtin
14 from mlir
.dialects
import func
15 from mlir
.dialects
.linalg
.opdsl
import lang
as dsl
17 _SCRIPT_PATH
= os
.path
.dirname(os
.path
.abspath(__file__
))
18 sys
.path
.append(_SCRIPT_PATH
)
19 from tools
import sparsifier
22 @dsl.linalg_structured_op
24 A
=dsl
.TensorDef(dsl
.T
, dsl
.S
.M
, dsl
.S
.K
),
25 B
=dsl
.TensorDef(dsl
.T
, dsl
.S
.K
, dsl
.S
.N
),
26 S
=dsl
.TensorDef(dsl
.T
, dsl
.S
.M
, dsl
.S
.N
),
27 C
=dsl
.TensorDef(dsl
.T
, dsl
.S
.M
, dsl
.S
.N
, output
=True),
29 C
[dsl
.D
.m
, dsl
.D
.n
] += (
30 S
[dsl
.D
.m
, dsl
.D
.n
] * A
[dsl
.D
.m
, dsl
.D
.k
] * B
[dsl
.D
.k
, dsl
.D
.n
]
34 def build_SDDMM(attr
: st
.EncodingAttr
):
35 """Build SDDMM kernel.
37 This method generates a linalg op with for matrix multiplication using
38 just the Python API. Effectively, a generic linalg op is constructed
39 that computes C(i,j) += S(i,j) SUM_k A(i,k) B(k,j) for sparse S.
41 module
= ir
.Module
.create()
42 f64
= ir
.F64Type
.get()
43 a
= ir
.RankedTensorType
.get([8, 8], f64
)
44 b
= ir
.RankedTensorType
.get([8, 8], f64
)
45 c
= ir
.RankedTensorType
.get([8, 8], f64
)
46 s
= ir
.RankedTensorType
.get([8, 8], f64
, attr
)
47 arguments
= [a
, b
, s
, c
]
48 with ir
.InsertionPoint(module
.body
):
50 @func.FuncOp
.from_py_func(*arguments
)
52 return sddmm_dsl(args
[0], args
[1], args
[2], outs
=[args
[3]])
57 def boilerplate(attr
: st
.EncodingAttr
):
58 """Returns boilerplate code for main driver."""
60 func.func @main(%a: tensor<8x8xf64>,
62 %c: tensor<8x8xf64>) -> tensor<8x8xf64> attributes {{ llvm.emit_c_interface }} {{
63 %t = arith.constant sparse<[[0,0], [0,2], [4,1]], [1.0, 2.0, 3.0]> : tensor<8x8xf64>
64 %s = sparse_tensor.convert %t : tensor<8x8xf64> to tensor<8x8xf64, {attr}>
65 %0 = call @sddmm(%a, %b, %s, %c) : (tensor<8x8xf64>,
67 tensor<8x8xf64, {attr}>,
68 tensor<8x8xf64>) -> tensor<8x8xf64>
69 return %0 : tensor<8x8xf64>
74 def build_compile_and_run_SDDMMM(attr
: st
.EncodingAttr
, compiler
):
76 module
= build_SDDMM(attr
)
77 func
= str(module
.operation
.regions
[0].blocks
[0].operations
[0].operation
)
78 module
= ir
.Module
.parse(func
+ boilerplate(attr
))
81 engine
= compiler
.compile_and_jit(module
)
83 # Set up numpy input and buffer for output.
86 [1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1, 8.1],
87 [1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2, 8.2],
88 [1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3, 8.3],
89 [1.4, 2.4, 3.4, 4.4, 5.4, 6.4, 7.4, 8.4],
90 [1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5],
91 [1.6, 2.6, 3.6, 4.6, 5.6, 6.6, 7.6, 8.6],
92 [1.7, 2.7, 3.7, 4.7, 5.7, 6.7, 7.7, 8.7],
93 [1.8, 2.8, 3.8, 4.8, 5.8, 6.8, 7.8, 8.8],
97 b
= np
.ones((8, 8), np
.float64
)
98 c
= np
.zeros((8, 8), np
.float64
)
100 mem_a
= ctypes
.pointer(ctypes
.pointer(rt
.get_ranked_memref_descriptor(a
)))
101 mem_b
= ctypes
.pointer(ctypes
.pointer(rt
.get_ranked_memref_descriptor(b
)))
102 mem_c
= ctypes
.pointer(ctypes
.pointer(rt
.get_ranked_memref_descriptor(c
)))
104 # Allocate a MemRefDescriptor to receive the output tensor.
105 # The buffer itself is allocated inside the MLIR code generation.
106 ref_out
= rt
.make_nd_memref_descriptor(2, ctypes
.c_double
)()
107 mem_out
= ctypes
.pointer(ctypes
.pointer(ref_out
))
109 # Invoke the kernel and get numpy output.
110 # Built-in bufferization uses in-out buffers.
111 engine
.invoke("main", mem_out
, mem_a
, mem_b
, mem_c
)
113 # Sanity check on computed result. Only a few elements
114 # are sampled from the full dense matrix multiplication.
115 full_matmul
= np
.matmul(a
, b
)
116 expected
= np
.zeros((8, 8), np
.float64
)
117 expected
[0, 0] = 1.0 * full_matmul
[0, 0]
118 expected
[0, 2] = 2.0 * full_matmul
[0, 2]
119 expected
[4, 1] = 3.0 * full_matmul
[4, 1]
120 c
= rt
.ranked_memref_to_numpy(mem_out
[0])
121 if np
.allclose(c
, expected
):
128 support_lib
= os
.getenv("SUPPORT_LIB")
129 assert support_lib
is not None, "SUPPORT_LIB is undefined"
130 if not os
.path
.exists(support_lib
):
131 raise FileNotFoundError(errno
.ENOENT
, os
.strerror(errno
.ENOENT
), support_lib
)
133 # CHECK-LABEL: TEST: testSDDMMM
134 print("\nTEST: testSDDMMM")
136 with ir
.Context() as ctx
, ir
.Location
.unknown():
137 # Loop over various ways to compile and annotate the SDDMM kernel with
138 # a *single* sparse tensor. Note that we deliberate do not exhaustively
139 # search the full state space to reduce runtime of the test. It is
140 # straightforward to adapt the code below to explore more combinations.
141 # For these simple orderings, dim2lvl and lvl2dim are the same.
142 builder
= st
.EncodingAttr
.build_level_type
144 prop
= st
.LevelProperty
146 [builder(fmt
.compressed
, [prop
.non_unique
]), builder(fmt
.singleton
)],
147 [builder(fmt
.dense
), builder(fmt
.dense
)],
148 [builder(fmt
.dense
), builder(fmt
.compressed
)],
149 [builder(fmt
.compressed
), builder(fmt
.dense
)],
150 [builder(fmt
.compressed
), builder(fmt
.compressed
)],
153 ir
.AffineMap
.get_permutation([0, 1]),
154 ir
.AffineMap
.get_permutation([1, 0]),
157 for ordering
in orderings
:
161 attr
= st
.EncodingAttr
.get(
162 level
, ordering
, ordering
, pwidth
, iwidth
164 opt
= f
"parallelization-strategy=none"
165 compiler
= sparsifier
.Sparsifier(
169 shared_libs
=[support_lib
],
171 build_compile_and_run_SDDMMM(attr
, compiler
)
173 # CHECK: Passed 10 tests
174 print("Passed ", count
, "tests")
177 if __name__
== "__main__":