[AMDGPU] Make v8i16/v8f16 legal
[llvm-project.git] / mlir / test / Integration / Dialect / SparseTensor / python / test_stress.py
blob55e64668d4955666473e2b534aa8444c47d6f0ab
1 # RUN: SUPPORT_LIB=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \
2 # RUN: %PYTHON %s | FileCheck %s
4 import ctypes
5 import errno
6 import itertools
7 import os
8 import sys
9 from typing import List, Callable
11 import numpy as np
13 import mlir.all_passes_registration
15 from mlir import ir
16 from mlir import runtime as rt
17 from mlir.execution_engine import ExecutionEngine
18 from mlir.passmanager import PassManager
20 from mlir.dialects import builtin
21 from mlir.dialects import std
22 from mlir.dialects import sparse_tensor as st
24 # ===----------------------------------------------------------------------=== #
26 # TODO: move this boilerplate to its own module, so it can be used by
27 # other tests and programs.
28 class TypeConverter:
29 """Converter between NumPy types and MLIR types."""
31 def __init__(self, context: ir.Context):
32 # Note 1: these are numpy "scalar types" (i.e., the values of
33 # np.sctypeDict) not numpy "dtypes" (i.e., the np.dtype class).
35 # Note 2: we must construct the MLIR types in the same context as the
36 # types that'll be passed to irtype_to_sctype() or irtype_to_dtype();
37 # otherwise, those methods will raise a KeyError.
38 types_list = [
39 (np.float64, ir.F64Type.get(context=context)),
40 (np.float32, ir.F32Type.get(context=context)),
41 (np.int64, ir.IntegerType.get_signless(64, context=context)),
42 (np.int32, ir.IntegerType.get_signless(32, context=context)),
43 (np.int16, ir.IntegerType.get_signless(16, context=context)),
44 (np.int8, ir.IntegerType.get_signless(8, context=context)),
46 self._sc2ir = dict(types_list)
47 self._ir2sc = dict(( (ir,sc) for sc,ir in types_list ))
49 def dtype_to_irtype(self, dtype: np.dtype) -> ir.Type:
50 """Returns the MLIR equivalent of a NumPy dtype."""
51 try:
52 return self.sctype_to_irtype(dtype.type)
53 except KeyError as e:
54 raise KeyError(f'Unknown dtype: {dtype}') from e
56 def sctype_to_irtype(self, sctype) -> ir.Type:
57 """Returns the MLIR equivalent of a NumPy scalar type."""
58 if sctype in self._sc2ir:
59 return self._sc2ir[sctype]
60 else:
61 raise KeyError(f'Unknown sctype: {sctype}')
63 def irtype_to_dtype(self, tp: ir.Type) -> np.dtype:
64 """Returns the NumPy dtype equivalent of an MLIR type."""
65 return np.dtype(self.irtype_to_sctype(tp))
67 def irtype_to_sctype(self, tp: ir.Type):
68 """Returns the NumPy scalar-type equivalent of an MLIR type."""
69 if tp in self._ir2sc:
70 return self._ir2sc[tp]
71 else:
72 raise KeyError(f'Unknown ir.Type: {tp}')
74 def get_RankedTensorType_of_nparray(self, nparray: np.ndarray) -> ir.RankedTensorType:
75 """Returns the ir.RankedTensorType of a NumPy array. Note that NumPy
76 arrays can only be converted to/from dense tensors, not sparse tensors."""
77 # TODO: handle strides as well?
78 return ir.RankedTensorType.get(nparray.shape,
79 self.dtype_to_irtype(nparray.dtype))
81 # ===----------------------------------------------------------------------=== #
83 class StressTest:
84 def __init__(self, tyconv: TypeConverter):
85 self._tyconv = tyconv
86 self._roundtripTp = None
87 self._module = None
88 self._engine = None
90 def _assertEqualsRoundtripTp(self, tp: ir.RankedTensorType):
91 assert self._roundtripTp is not None, \
92 'StressTest: uninitialized roundtrip type'
93 if tp != self._roundtripTp:
94 raise AssertionError(
95 f"Type is not equal to the roundtrip type.\n"
96 f"\tExpected: {self._roundtripTp}\n"
97 f"\tFound: {tp}\n")
99 def build(self, types: List[ir.Type]):
100 """Builds the ir.Module. The module has only the @main function,
101 which will convert the input through the list of types and then back
102 to the initial type. The roundtrip type must be a dense tensor."""
103 assert self._module is None, 'StressTest: must not call build() repeatedly'
104 self._module = ir.Module.create()
105 with ir.InsertionPoint(self._module.body):
106 tp0 = types.pop(0)
107 self._roundtripTp = tp0
108 # TODO: assert dense? assert element type is recognised by the TypeConverter?
109 types.append(tp0)
110 funcTp = ir.FunctionType.get(inputs=[tp0], results=[tp0])
111 funcOp = builtin.FuncOp(name='main', type=funcTp)
112 funcOp.attributes['llvm.emit_c_interface'] = ir.UnitAttr.get()
113 with ir.InsertionPoint(funcOp.add_entry_block()):
114 arg0 = funcOp.entry_block.arguments[0]
115 self._assertEqualsRoundtripTp(arg0.type)
116 v = st.ConvertOp(types.pop(0), arg0)
117 for tp in types:
118 w = st.ConvertOp(tp, v)
119 # Release intermediate tensors before they fall out of scope.
120 st.ReleaseOp(v.result)
121 v = w
122 self._assertEqualsRoundtripTp(v.result.type)
123 std.ReturnOp(v)
124 return self
126 def writeTo(self, filename):
127 """Write the ir.Module to the given file. If the file already exists,
128 then raises an error. If the filename is None, then is a no-op."""
129 assert self._module is not None, \
130 'StressTest: must call build() before writeTo()'
131 if filename is None:
132 # Silent no-op, for convenience.
133 return self
134 if os.path.exists(filename):
135 raise FileExistsError(errno.EEXIST, os.strerror(errno.EEXIST), filename)
136 with open(filename, 'w') as f:
137 f.write(str(self._module))
138 return self
140 def compile(self, compiler: Callable[[ir.Module], ExecutionEngine]):
141 """Compile the ir.Module."""
142 assert self._module is not None, \
143 'StressTest: must call build() before compile()'
144 assert self._engine is None, \
145 'StressTest: must not call compile() repeatedly'
146 self._engine = compiler(self._module)
147 return self
149 def run(self, np_arg0: np.ndarray) -> np.ndarray:
150 """Runs the test on the given numpy array, and returns the resulting
151 numpy array."""
152 assert self._engine is not None, \
153 'StressTest: must call compile() before run()'
154 self._assertEqualsRoundtripTp(
155 self._tyconv.get_RankedTensorType_of_nparray(np_arg0))
156 np_out = np.zeros(np_arg0.shape, dtype=np_arg0.dtype)
157 self._assertEqualsRoundtripTp(
158 self._tyconv.get_RankedTensorType_of_nparray(np_out))
159 mem_arg0 = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(np_arg0)))
160 mem_out = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(np_out)))
161 self._engine.invoke('main', mem_out, mem_arg0)
162 return rt.ranked_memref_to_numpy(mem_out[0])
164 # ===----------------------------------------------------------------------=== #
166 # TODO: move this boilerplate to its own module, so it can be used by
167 # other tests and programs.
168 class SparseCompiler:
169 """Sparse compiler passes."""
171 def __init__(self, sparsification_options: str, support_lib: str):
172 self._support_lib = support_lib
173 self._pipeline = (
174 f'builtin.func(linalg-generalize-named-ops,linalg-fuse-elementwise-ops),'
175 f'sparsification{{{sparsification_options}}},'
176 f'sparse-tensor-conversion,'
177 f'builtin.func(linalg-bufferize,convert-linalg-to-loops,convert-vector-to-scf),'
178 f'convert-scf-to-std,'
179 f'func-bufferize,'
180 f'tensor-constant-bufferize,'
181 f'builtin.func(tensor-bufferize,std-bufferize,finalizing-bufferize),'
182 f'convert-vector-to-llvm{{reassociate-fp-reductions=1 enable-index-optimizations=1}},'
183 f'lower-affine,'
184 f'convert-memref-to-llvm,'
185 f'convert-std-to-llvm,'
186 f'reconcile-unrealized-casts')
187 # Must be in the scope of a `with ir.Context():`
188 self._passmanager = PassManager.parse(self._pipeline)
190 def __call__(self, module: ir.Module) -> ExecutionEngine:
191 self._passmanager.run(module)
192 return ExecutionEngine(module, opt_level=0, shared_libs=[self._support_lib])
194 # ===----------------------------------------------------------------------=== #
196 def main():
198 USAGE: python3 test_stress.py [raw_module.mlir [compiled_module.mlir]]
200 The environment variable SUPPORT_LIB must be set to point to the
201 libmlir_c_runner_utils shared library. There are two optional
202 arguments, for debugging purposes. The first argument specifies where
203 to write out the raw/generated ir.Module. The second argument specifies
204 where to write out the compiled version of that ir.Module.
206 support_lib = os.getenv('SUPPORT_LIB')
207 assert support_lib is not None, 'SUPPORT_LIB is undefined'
208 if not os.path.exists(support_lib):
209 raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib)
211 # CHECK-LABEL: TEST: test_stress
212 print("\nTEST: test_stress")
213 with ir.Context() as ctx, ir.Location.unknown():
214 par = 0
215 vec = 0
216 vl = 1
217 e = False
218 sparsification_options = (
219 f'parallelization-strategy={par} '
220 f'vectorization-strategy={vec} '
221 f'vl={vl} '
222 f'enable-simd-index32={e}')
223 compiler = SparseCompiler(sparsification_options, support_lib)
224 f64 = ir.F64Type.get()
225 # Be careful about increasing this because
226 # len(types) = 1 + 2^rank * rank! * len(bitwidths)^2
227 shape = range(2, 6)
228 rank = len(shape)
229 # All combinations.
230 levels = list(itertools.product(*itertools.repeat(
231 [st.DimLevelType.dense, st.DimLevelType.compressed], rank)))
232 # All permutations.
233 orderings = list(map(ir.AffineMap.get_permutation,
234 itertools.permutations(range(rank))))
235 bitwidths = [0]
236 # The first type must be a dense tensor for numpy conversion to work.
237 types = [ir.RankedTensorType.get(shape, f64)]
238 for level in levels:
239 for ordering in orderings:
240 for pwidth in bitwidths:
241 for iwidth in bitwidths:
242 attr = st.EncodingAttr.get(level, ordering, pwidth, iwidth)
243 types.append(ir.RankedTensorType.get(shape, f64, attr))
245 # For exhaustiveness we should have one or more StressTest, such
246 # that their paths cover all 2*n*(n-1) directed pairwise combinations
247 # of the `types` set. However, since n is already superexponential,
248 # such exhaustiveness would be prohibitive for a test that runs on
249 # every commit. So for now we'll just pick one particular path that
250 # at least hits all n elements of the `types` set.
252 tyconv = TypeConverter(ctx)
253 size = 1
254 for d in shape:
255 size *= d
256 np_arg0 = np.arange(size, dtype=tyconv.irtype_to_dtype(f64)).reshape(*shape)
257 np_out = (
258 StressTest(tyconv)
259 .build(types)
260 .writeTo(sys.argv[1] if len(sys.argv) > 1 else None)
261 .compile(compiler)
262 .writeTo(sys.argv[2] if len(sys.argv) > 2 else None)
263 .run(np_arg0))
264 # CHECK: Passed
265 if np.allclose(np_out, np_arg0):
266 print('Passed')
267 else:
268 sys.exit('FAILURE')
270 if __name__ == '__main__':
271 main()