[mlir][scf]: Add value bound between scf for loop yield and result (#123200)
[llvm-project.git] / mlir / test / Examples / NVGPU / tools / nvdsl.py
blob90dbb2355e1c87739136e5b76db0215fa824ba07
1 from enum import Enum
2 import functools, sys, ctypes, os, errno
3 import numpy as np
4 from functools import partialmethod
5 from mlir import ir
6 from mlir.dialects import arith, func, gpu, memref, nvgpu, scf, nvvm
7 from mlir.extras import types as T
8 from mlir import runtime as rt
9 from tools import nvgpucompiler
11 MLIR_DYNAMIC = -9223372036854775808
14 def const(value: int, ty=None):
15 ty = T.index() if ty is None else ty
16 if isinstance(value, ir.Value) and (
17 value.type.isinstance(value.type) or T.bool().isinstance(value.type)
19 return value
20 return arith.constant(ty, value)
23 def get_type_size(ty):
24 if ir.MemRefType.isinstance(ty):
25 size = get_type_size(ty.element_type)
26 for sz in ty.shape:
27 size *= sz
28 return size
29 if ir.FloatType.isinstance(ty):
30 return ir.FloatType(ty).width // 8
31 if ir.IntegerType.isinstance(ty):
32 return ir.IntegerType(ty).width // 8
33 raise NotImplementedError(ty)
36 def get_mlir_func_obj_ty(inputArgs):
37 args = []
38 c_int_p = ctypes.c_int * 1
39 c_float_p = ctypes.c_float * 1
40 c_bool_p = ctypes.c_bool * 1
41 for arg in inputArgs:
42 if isinstance(arg, bool):
43 args.append(c_bool_p(arg))
44 elif isinstance(arg, int):
45 args.append(c_int_p(arg))
46 elif isinstance(arg, float):
47 args.append(c_float_p(arg))
48 elif isinstance(arg, np.ndarray):
49 args.append(
50 ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(arg)))
52 else:
53 raise NotImplementedError(arg)
54 return args
57 class Mbarriers:
58 def __init__(self, number_of_barriers=1):
59 self.mbar_ty = ir.Type.parse(
60 "!nvgpu.mbarrier.group<memorySpace=#gpu.address_space<workgroup>, num_barriers = "
61 + str(number_of_barriers)
62 + ">"
64 self.mbar_group_op = nvgpu.mbarrier_create(self.mbar_ty)
65 self.number_of_barriers = number_of_barriers
67 def __getitem__(self, key):
68 self.id_op = const(key)
69 return self
71 def init(self, count: int, predicate=None):
72 count_op = const(count)
73 if predicate is None:
74 nvgpu.mbarrier_init(self.mbar_group_op, count_op, self.id_op)
75 else:
76 nvgpu.mbarrier_init(
77 self.mbar_group_op, count_op, self.id_op, predicate=predicate
80 def arrive(self, txcount: int = 0, predicate=None):
81 if txcount != 0:
82 txcount_op = const(txcount)
83 nvgpu.mbarrier_arrive_expect_tx(
84 self.mbar_group_op, txcount_op, self.id_op, predicate=predicate
86 else:
87 nvgpu.mbarrier_arrive(
88 ir.Type.parse("!nvgpu.mbarrier.token"), self.mbar_group_op, self.id_op
91 def try_wait(self, phase: bool = False, ticks: int = 10000000):
92 ticks_op = const(ticks)
93 phase_op = const(phase, T.bool())
94 nvgpu.MBarrierTryWaitParityOp(
95 self.mbar_group_op,
96 phase_op,
97 ticks_op,
98 mbarId=self.id_op,
102 class TMA:
103 """A class that builds a TMA descriptor."""
105 def __init__(
106 self,
107 tma_box_shape,
108 memref_ty,
109 swizzle=nvgpu.TensorMapSwizzleKind.SWIZZLE_NONE,
110 l2promo=nvgpu.TensorMapL2PromoKind.L2PROMO_NONE,
111 oob=nvgpu.TensorMapOOBKind.OOB_ZERO,
112 interleave=nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE,
114 self.swizzle = swizzle # mlir.nvgpu.TensorMapSwizzleKind
115 self.l2promo = l2promo # mlir.nvgpu.TensorMapL2PromoKind
116 self.oob = oob # mlir.nvgpu.TensorMapOOBKind
117 self.interleave = interleave # mlir.nvgpu.TensorMapInterleaveKind
118 self.tma_box_shape = tma_box_shape
119 self.memref_ty = memref_ty # MemRefType
120 self.tma_memref = ir.MemRefType.get(tma_box_shape, memref_ty.element_type)
122 @property
123 def tensormap_descriptor_ty(self):
124 """Returns a tensormap descriptor type."""
125 tensorMemrefType = ir.MemRefType.get(
126 self.tma_box_shape,
127 self.memref_ty.element_type,
128 memory_space=ir.Attribute.parse("3"),
130 return nvgpu.TensorMapDescriptorType.get(
131 tensorMemrefType,
132 self.swizzle,
133 self.l2promo,
134 self.oob,
135 self.interleave,
138 def create_descriptor(self, device_ptr):
139 tma_descriptor_ty = self.tensormap_descriptor_ty
140 device_unranked_memref = memref.CastOp(
141 ir.UnrankedMemRefType.get(
142 self.memref_ty.element_type, self.memref_ty.memory_space
144 device_ptr,
146 self.tma_descriptor = nvgpu.TmaCreateDescriptorOp(
147 tma_descriptor_ty, device_unranked_memref, map(const, self.tma_box_shape)
149 return self.tma_descriptor.result
151 def prefetch(self, predicate=None):
152 nvgpu.tma_prefetch_descriptor(self.tma_descriptor, predicate=predicate)
154 def load(self, dest, mbarrier: Mbarriers, coords=[0], predicate=None):
155 nvgpu.TmaAsyncLoadOp(
156 dest,
157 mbarrier.mbar_group_op,
158 self.tma_descriptor,
159 coordinates=map(const, coords),
160 mbarId=mbarrier.id_op,
161 predicate=predicate,
165 WARP_GROUP_SIZE = 128 # Number of threads in a warpgroup
168 class Warpgroup:
169 def __init__(self, primary_thread, register_size):
170 assert (primary_thread % WARP_GROUP_SIZE) == 0
171 tidx = gpu.thread_id(gpu.Dimension.x)
172 self.primary_thread = primary_thread
173 self.register_size = register_size
174 self.is_wg_primary = (tidx % WARP_GROUP_SIZE) == 0
175 self.wg_id = tidx / WARP_GROUP_SIZE
176 self.is_me = self.wg_id == (primary_thread // WARP_GROUP_SIZE)
178 def __enter__(self):
179 if_op = scf.IfOp(self.is_me)
180 self.ipoint_op = ir.InsertionPoint(if_op.then_block)
181 self.ipoint_op.__enter__()
182 if self.register_size < 64:
183 nvvm.setmaxregister(self.register_size, nvvm.SetMaxRegisterAction.decrease)
184 else:
185 nvvm.setmaxregister(self.register_size, nvvm.SetMaxRegisterAction.increase)
187 def __exit__(self, exc_type, exc_value, traceback):
188 scf.yield_([])
189 self.ipoint_op.__exit__(exc_type, exc_value, traceback)
190 return True
193 class WGMMAType(Enum):
194 Accumulator = 1
195 Descriptor = 2
198 class WGMMAMatrix:
199 def __init__(
200 self,
201 matrix_type: WGMMAType,
202 shape: list = None,
203 desc: TMA = None,
204 smem=None,
205 ty=None,
206 acc_op=None,
208 if acc_op is None:
209 self.M = shape[0]
210 self.N = shape[1]
211 self.ty = ty
212 self.matrix_type = matrix_type
213 self.desc = desc
214 self.smem = smem
215 if matrix_type is WGMMAType.Accumulator:
216 self.acc_op = nvgpu.warpgroup_mma_init_accumulator(self.acc_ty)
217 elif acc_op:
218 self.acc_op = acc_op
219 self.matrix_type = WGMMAType.Accumulator
221 @property
222 def acc_ty(self):
223 parse_str = f"!nvgpu.warpgroup.accumulator<fragmented=vector<{self.M}x{self.N}x{self.ty}>>"
224 return ir.Type.parse(parse_str)
226 @property
227 def wgmma_ty(self):
228 parse_str = f"!nvgpu.warpgroup.descriptor<tensor=memref<{self.M}x{self.N}x{self.desc.memref_ty.element_type}, #gpu.address_space<workgroup>>>"
229 return ir.Type.parse(parse_str)
231 def store_accumulator(self, dest):
232 assert self.matrix_type == WGMMAType.Accumulator
233 nvgpu.warpgroup_mma_store(self.acc_op, dest)
235 def update_smem(self, smem):
236 self.smem = smem
238 def update_accumulator(self, acc_op):
239 self.acc_op = acc_op
241 def __matmul__(self, rhs):
242 lhs = nvgpu.warpgroup_generate_descriptor(
243 self.wgmma_ty, self.smem, self.desc.tma_descriptor
245 rhs = nvgpu.warpgroup_generate_descriptor(
246 rhs.wgmma_ty, rhs.smem, rhs.desc.tma_descriptor
248 return [lhs, rhs]
250 def __iadd__(self, matmulResult):
251 lhs = matmulResult[0]
252 rhs = matmulResult[1]
253 acc_op = nvgpu.WarpgroupMmaOp(
254 self.acc_op.type, lhs, rhs, self.acc_op, transposeB=True
256 return WGMMAMatrix(WGMMAType.Accumulator, acc_op=acc_op)
259 def get_dynamic_shared_memory(shape=None, ty=None, offset: int = 0):
260 smem_space_str = "#gpu.address_space<workgroup>"
261 smem_space = ir.Attribute.parse(smem_space_str)
262 dynamic_smem = gpu.dynamic_shared_memory(
263 ir.MemRefType.get((MLIR_DYNAMIC,), T.i8(), memory_space=smem_space)
265 if shape is None:
266 return dynamic_smem
267 memref_ty = ir.MemRefType.get(shape, ty, memory_space=smem_space)
268 return memref.view(
269 ir.MemRefType.get(
270 memref_ty.shape, memref_ty.element_type, memory_space=smem_space
272 dynamic_smem,
273 const(offset),
278 def get_mlir_ty(arg):
279 def get_mlir_ty_from_np(dtype):
280 if dtype == np.float16:
281 return T.f16()
282 if dtype == np.float32:
283 return T.f32()
284 if dtype == np.float64:
285 return T.f64()
286 if dtype == np.int32:
287 return T.i32()
288 if dtype == np.int64:
289 return T.i64()
290 raise NotImplementedError(dtype)
292 if isinstance(arg, bool):
293 return T.bool()
294 elif isinstance(arg, int):
295 return T.index()
296 elif isinstance(arg, float):
297 return T.f32()
298 elif isinstance(arg, np.ndarray):
299 descriptor = rt.get_ranked_memref_descriptor(arg)
300 dtype = get_mlir_ty_from_np(arg.dtype)
301 shape = descriptor.shape
302 return memref.MemRefType.get(shape, dtype)
303 raise NotImplementedError(arg)
306 class NVDSL:
307 @staticmethod
308 def mlir_gpu_launch(grid=(1, 1, 1), block=(1, 1, 1), smem=0):
309 def decorator(func):
310 @functools.wraps(func)
311 def wrapper(*args, **kwargs):
312 launch_op = gpu.LaunchOp(
313 None,
315 *map(const, grid),
316 *map(const, block),
317 dynamicSharedMemorySize=arith.constant(T.i32(), smem),
319 launch_op.body.blocks.append(*([T.index()] * 12))
320 with ir.InsertionPoint(launch_op.body.blocks[0]):
321 result = func(*args, **kwargs)
322 gpu.terminator()
323 return result
325 return wrapper
327 return decorator
329 @staticmethod
330 def mlir_func(funcBody):
331 @functools.wraps(funcBody)
332 def wrapper(*args, **kwargs):
333 function_name = funcBody.__name__
335 def saveIR(module):
336 """Save generated IR"""
337 if True: # self.saveIR:
338 # print(mlir_nvgpu_module)
339 original_stdout = sys.stdout
340 with open("nvdsl.mlir", "w") as f:
341 sys.stdout = f
342 print(module)
343 sys.stdout = original_stdout
345 def _binary_op(lhs, rhs, op: str, predAtt="") -> "ArithValue":
346 """Generate MLIR's Arith dialects binary operations."""
347 rhs = const(rhs)
348 if arith._is_float_type(lhs.type) and arith._is_float_type(rhs.type):
349 op += "F"
350 if op.startswith("Cmp"):
351 predicateAttr = getattr(arith, f"CmpFPredicate").__dict__[
352 predAtt
354 elif arith._is_integer_like_type(
355 lhs.type
356 ) and arith._is_integer_like_type(lhs.type):
357 if op == "Div" or op == "Rem":
358 op += "U"
359 op += "I"
360 if op.startswith("Cmp"):
361 predicateAttr = getattr(arith, f"CmpIPredicate").__dict__[
362 predAtt
364 else:
365 raise NotImplementedError(
366 f"Unsupported '{op}' operands: {lhs}, {rhs}"
369 if op.startswith("Cmp"):
370 op = getattr(arith, f"{op}Op")
372 return op(predicateAttr, lhs, rhs).result
373 else:
374 op = getattr(arith, f"{op}Op")
375 return op(lhs, rhs).result
377 @ir.register_value_caster(ir.IndexType.static_typeid)
378 @ir.register_value_caster(ir.F32Type.static_typeid)
379 @ir.register_value_caster(ir.F16Type.static_typeid)
380 @ir.register_value_caster(ir.F64Type.static_typeid)
381 @ir.register_value_caster(ir.IntegerType.static_typeid)
382 class ArithValue(ir.Value):
383 """Overloads operators for MLIR's Arith dialects binary operations."""
385 def __init__(self, v):
386 super().__init__(v)
388 __add__ = partialmethod(_binary_op, op="Add")
389 __sub__ = partialmethod(_binary_op, op="Sub")
390 __mul__ = partialmethod(_binary_op, op="Mul")
391 __truediv__ = partialmethod(_binary_op, op="Div")
392 __mod__ = partialmethod(_binary_op, op="Rem")
393 __xor__ = partialmethod(_binary_op, op="XOr")
394 __lt__ = partialmethod(_binary_op, op="Cmp", predAtt="ult")
395 __le__ = partialmethod(_binary_op, op="Cmp", predAtt="ule")
396 __eq__ = partialmethod(_binary_op, op="Cmp", predAtt="eq")
397 __ne__ = partialmethod(_binary_op, op="Cmp", predAtt="ne")
398 __gt__ = partialmethod(_binary_op, op="Cmp", predAtt="ugt")
399 __ge__ = partialmethod(_binary_op, op="Cmp", predAtt="uge")
400 __and__ = partialmethod(_binary_op, op="And")
401 __or__ = partialmethod(_binary_op, op="Or")
403 def __str__(self):
404 return (
405 super()
406 .__str__()
407 .replace(ir.Value.__name__, ArithValue.__name__)
410 # Generate MLIR Context and start generating IR
411 with ir.Context(), ir.Location.unknown():
412 types = []
413 for arg in args:
414 types.append(get_mlir_ty(arg))
416 # Build IR
417 module = ir.Module.create()
418 with ir.InsertionPoint(module.body):
419 fop = func.FuncOp(function_name, (types, []))
420 fop.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
421 with ir.InsertionPoint(fop.add_entry_block()):
422 fargs = []
423 for i, a in enumerate(types):
424 fargs.append(fop.arguments[i])
426 # Call user function body
427 result = funcBody(*fargs, **kwargs)
428 func.ReturnOp([])
430 # Save IR in a file
431 # saveIR(module)
433 # Verify the module
434 module.operation.verify()
436 # Compile and JIT MLIR module
437 options = f"cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3"
438 support_lib = os.getenv("SUPPORT_LIB")
439 if not os.path.exists(support_lib):
440 raise FileNotFoundError(
441 errno.ENOENT, os.strerror(errno.ENOENT), support_lib
443 compiler = nvgpucompiler.NvgpuCompiler(
444 options, opt_level=3, shared_libs=[support_lib]
446 engine = compiler.compile_and_jit(module)
448 # Convert input arguments to MLIR arguments
449 newArgs = get_mlir_func_obj_ty(args)
451 # Run the compiled program
452 engine.invoke(function_name, *newArgs)
454 return result
456 return wrapper