2 import functools
, sys
, ctypes
, os
, errno
4 from functools
import partialmethod
6 from mlir
.dialects
import arith
, func
, gpu
, memref
, nvgpu
, scf
, nvvm
7 from mlir
.extras
import types
as T
8 from mlir
import runtime
as rt
9 from tools
import nvgpucompiler
11 MLIR_DYNAMIC
= -9223372036854775808
14 def const(value
: int, ty
=None):
15 ty
= T
.index() if ty
is None else ty
16 if isinstance(value
, ir
.Value
) and (
17 value
.type.isinstance(value
.type) or T
.bool().isinstance(value
.type)
20 return arith
.constant(ty
, value
)
23 def get_type_size(ty
):
24 if ir
.MemRefType
.isinstance(ty
):
25 size
= get_type_size(ty
.element_type
)
29 if ir
.FloatType
.isinstance(ty
):
30 return ir
.FloatType(ty
).width
// 8
31 if ir
.IntegerType
.isinstance(ty
):
32 return ir
.IntegerType(ty
).width
// 8
33 raise NotImplementedError(ty
)
36 def get_mlir_func_obj_ty(inputArgs
):
38 c_int_p
= ctypes
.c_int
* 1
39 c_float_p
= ctypes
.c_float
* 1
40 c_bool_p
= ctypes
.c_bool
* 1
42 if isinstance(arg
, bool):
43 args
.append(c_bool_p(arg
))
44 elif isinstance(arg
, int):
45 args
.append(c_int_p(arg
))
46 elif isinstance(arg
, float):
47 args
.append(c_float_p(arg
))
48 elif isinstance(arg
, np
.ndarray
):
50 ctypes
.pointer(ctypes
.pointer(rt
.get_ranked_memref_descriptor(arg
)))
53 raise NotImplementedError(arg
)
58 def __init__(self
, number_of_barriers
=1):
59 self
.mbar_ty
= ir
.Type
.parse(
60 "!nvgpu.mbarrier.group<memorySpace=#gpu.address_space<workgroup>, num_barriers = "
61 + str(number_of_barriers
)
64 self
.mbar_group_op
= nvgpu
.mbarrier_create(self
.mbar_ty
)
65 self
.number_of_barriers
= number_of_barriers
67 def __getitem__(self
, key
):
68 self
.id_op
= const(key
)
71 def init(self
, count
: int, predicate
=None):
72 count_op
= const(count
)
74 nvgpu
.mbarrier_init(self
.mbar_group_op
, count_op
, self
.id_op
)
77 self
.mbar_group_op
, count_op
, self
.id_op
, predicate
=predicate
80 def arrive(self
, txcount
: int = 0, predicate
=None):
82 txcount_op
= const(txcount
)
83 nvgpu
.mbarrier_arrive_expect_tx(
84 self
.mbar_group_op
, txcount_op
, self
.id_op
, predicate
=predicate
87 nvgpu
.mbarrier_arrive(
88 ir
.Type
.parse("!nvgpu.mbarrier.token"), self
.mbar_group_op
, self
.id_op
91 def try_wait(self
, phase
: bool = False, ticks
: int = 10000000):
92 ticks_op
= const(ticks
)
93 phase_op
= const(phase
, T
.bool())
94 nvgpu
.MBarrierTryWaitParityOp(
103 """A class that builds a TMA descriptor."""
109 swizzle
=nvgpu
.TensorMapSwizzleKind
.SWIZZLE_NONE
,
110 l2promo
=nvgpu
.TensorMapL2PromoKind
.L2PROMO_NONE
,
111 oob
=nvgpu
.TensorMapOOBKind
.OOB_ZERO
,
112 interleave
=nvgpu
.TensorMapInterleaveKind
.INTERLEAVE_NONE
,
114 self
.swizzle
= swizzle
# mlir.nvgpu.TensorMapSwizzleKind
115 self
.l2promo
= l2promo
# mlir.nvgpu.TensorMapL2PromoKind
116 self
.oob
= oob
# mlir.nvgpu.TensorMapOOBKind
117 self
.interleave
= interleave
# mlir.nvgpu.TensorMapInterleaveKind
118 self
.tma_box_shape
= tma_box_shape
119 self
.memref_ty
= memref_ty
# MemRefType
120 self
.tma_memref
= ir
.MemRefType
.get(tma_box_shape
, memref_ty
.element_type
)
123 def tensormap_descriptor_ty(self
):
124 """Returns a tensormap descriptor type."""
125 tensorMemrefType
= ir
.MemRefType
.get(
127 self
.memref_ty
.element_type
,
128 memory_space
=ir
.Attribute
.parse("3"),
130 return nvgpu
.TensorMapDescriptorType
.get(
138 def create_descriptor(self
, device_ptr
):
139 tma_descriptor_ty
= self
.tensormap_descriptor_ty
140 device_unranked_memref
= memref
.CastOp(
141 ir
.UnrankedMemRefType
.get(
142 self
.memref_ty
.element_type
, self
.memref_ty
.memory_space
146 self
.tma_descriptor
= nvgpu
.TmaCreateDescriptorOp(
147 tma_descriptor_ty
, device_unranked_memref
, map(const
, self
.tma_box_shape
)
149 return self
.tma_descriptor
.result
151 def prefetch(self
, predicate
=None):
152 nvgpu
.tma_prefetch_descriptor(self
.tma_descriptor
, predicate
=predicate
)
154 def load(self
, dest
, mbarrier
: Mbarriers
, coords
=[0], predicate
=None):
155 nvgpu
.TmaAsyncLoadOp(
157 mbarrier
.mbar_group_op
,
159 coordinates
=map(const
, coords
),
160 mbarId
=mbarrier
.id_op
,
165 WARP_GROUP_SIZE
= 128 # Number of threads in a warpgroup
169 def __init__(self
, primary_thread
, register_size
):
170 assert (primary_thread
% WARP_GROUP_SIZE
) == 0
171 tidx
= gpu
.thread_id(gpu
.Dimension
.x
)
172 self
.primary_thread
= primary_thread
173 self
.register_size
= register_size
174 self
.is_wg_primary
= (tidx
% WARP_GROUP_SIZE
) == 0
175 self
.wg_id
= tidx
/ WARP_GROUP_SIZE
176 self
.is_me
= self
.wg_id
== (primary_thread
// WARP_GROUP_SIZE
)
179 if_op
= scf
.IfOp(self
.is_me
)
180 self
.ipoint_op
= ir
.InsertionPoint(if_op
.then_block
)
181 self
.ipoint_op
.__enter
__()
182 if self
.register_size
< 64:
183 nvvm
.setmaxregister(self
.register_size
, nvvm
.SetMaxRegisterAction
.decrease
)
185 nvvm
.setmaxregister(self
.register_size
, nvvm
.SetMaxRegisterAction
.increase
)
187 def __exit__(self
, exc_type
, exc_value
, traceback
):
189 self
.ipoint_op
.__exit
__(exc_type
, exc_value
, traceback
)
193 class WGMMAType(Enum
):
201 matrix_type
: WGMMAType
,
212 self
.matrix_type
= matrix_type
215 if matrix_type
is WGMMAType
.Accumulator
:
216 self
.acc_op
= nvgpu
.warpgroup_mma_init_accumulator(self
.acc_ty
)
219 self
.matrix_type
= WGMMAType
.Accumulator
223 parse_str
= f
"!nvgpu.warpgroup.accumulator<fragmented=vector<{self.M}x{self.N}x{self.ty}>>"
224 return ir
.Type
.parse(parse_str
)
228 parse_str
= f
"!nvgpu.warpgroup.descriptor<tensor=memref<{self.M}x{self.N}x{self.desc.memref_ty.element_type}, #gpu.address_space<workgroup>>>"
229 return ir
.Type
.parse(parse_str
)
231 def store_accumulator(self
, dest
):
232 assert self
.matrix_type
== WGMMAType
.Accumulator
233 nvgpu
.warpgroup_mma_store(self
.acc_op
, dest
)
235 def update_smem(self
, smem
):
238 def update_accumulator(self
, acc_op
):
241 def __matmul__(self
, rhs
):
242 lhs
= nvgpu
.warpgroup_generate_descriptor(
243 self
.wgmma_ty
, self
.smem
, self
.desc
.tma_descriptor
245 rhs
= nvgpu
.warpgroup_generate_descriptor(
246 rhs
.wgmma_ty
, rhs
.smem
, rhs
.desc
.tma_descriptor
250 def __iadd__(self
, matmulResult
):
251 lhs
= matmulResult
[0]
252 rhs
= matmulResult
[1]
253 acc_op
= nvgpu
.WarpgroupMmaOp(
254 self
.acc_op
.type, lhs
, rhs
, self
.acc_op
, transposeB
=True
256 return WGMMAMatrix(WGMMAType
.Accumulator
, acc_op
=acc_op
)
259 def get_dynamic_shared_memory(shape
=None, ty
=None, offset
: int = 0):
260 smem_space_str
= "#gpu.address_space<workgroup>"
261 smem_space
= ir
.Attribute
.parse(smem_space_str
)
262 dynamic_smem
= gpu
.dynamic_shared_memory(
263 ir
.MemRefType
.get((MLIR_DYNAMIC
,), T
.i8(), memory_space
=smem_space
)
267 memref_ty
= ir
.MemRefType
.get(shape
, ty
, memory_space
=smem_space
)
270 memref_ty
.shape
, memref_ty
.element_type
, memory_space
=smem_space
278 def get_mlir_ty(arg
):
279 def get_mlir_ty_from_np(dtype
):
280 if dtype
== np
.float16
:
282 if dtype
== np
.float32
:
284 if dtype
== np
.float64
:
286 if dtype
== np
.int32
:
288 if dtype
== np
.int64
:
290 raise NotImplementedError(dtype
)
292 if isinstance(arg
, bool):
294 elif isinstance(arg
, int):
296 elif isinstance(arg
, float):
298 elif isinstance(arg
, np
.ndarray
):
299 descriptor
= rt
.get_ranked_memref_descriptor(arg
)
300 dtype
= get_mlir_ty_from_np(arg
.dtype
)
301 shape
= descriptor
.shape
302 return memref
.MemRefType
.get(shape
, dtype
)
303 raise NotImplementedError(arg
)
308 def mlir_gpu_launch(grid
=(1, 1, 1), block
=(1, 1, 1), smem
=0):
310 @functools.wraps(func
)
311 def wrapper(*args
, **kwargs
):
312 launch_op
= gpu
.LaunchOp(
317 dynamicSharedMemorySize
=arith
.constant(T
.i32(), smem
),
319 launch_op
.body
.blocks
.append(*([T
.index()] * 12))
320 with ir
.InsertionPoint(launch_op
.body
.blocks
[0]):
321 result
= func(*args
, **kwargs
)
330 def mlir_func(funcBody
):
331 @functools.wraps(funcBody
)
332 def wrapper(*args
, **kwargs
):
333 function_name
= funcBody
.__name
__
336 """Save generated IR"""
337 if True: # self.saveIR:
338 # print(mlir_nvgpu_module)
339 original_stdout
= sys
.stdout
340 with
open("nvdsl.mlir", "w") as f
:
343 sys
.stdout
= original_stdout
345 def _binary_op(lhs
, rhs
, op
: str, predAtt
="") -> "ArithValue":
346 """Generate MLIR's Arith dialects binary operations."""
348 if arith
._is
_float
_type
(lhs
.type) and arith
._is
_float
_type
(rhs
.type):
350 if op
.startswith("Cmp"):
351 predicateAttr
= getattr(arith
, f
"CmpFPredicate").__dict
__[
354 elif arith
._is
_integer
_like
_type
(
356 ) and arith
._is
_integer
_like
_type
(lhs
.type):
357 if op
== "Div" or op
== "Rem":
360 if op
.startswith("Cmp"):
361 predicateAttr
= getattr(arith
, f
"CmpIPredicate").__dict
__[
365 raise NotImplementedError(
366 f
"Unsupported '{op}' operands: {lhs}, {rhs}"
369 if op
.startswith("Cmp"):
370 op
= getattr(arith
, f
"{op}Op")
372 return op(predicateAttr
, lhs
, rhs
).result
374 op
= getattr(arith
, f
"{op}Op")
375 return op(lhs
, rhs
).result
377 @ir.register_value_caster(ir
.IndexType
.static_typeid
)
378 @ir.register_value_caster(ir
.F32Type
.static_typeid
)
379 @ir.register_value_caster(ir
.F16Type
.static_typeid
)
380 @ir.register_value_caster(ir
.F64Type
.static_typeid
)
381 @ir.register_value_caster(ir
.IntegerType
.static_typeid
)
382 class ArithValue(ir
.Value
):
383 """Overloads operators for MLIR's Arith dialects binary operations."""
385 def __init__(self
, v
):
388 __add__
= partialmethod(_binary_op
, op
="Add")
389 __sub__
= partialmethod(_binary_op
, op
="Sub")
390 __mul__
= partialmethod(_binary_op
, op
="Mul")
391 __truediv__
= partialmethod(_binary_op
, op
="Div")
392 __mod__
= partialmethod(_binary_op
, op
="Rem")
393 __xor__
= partialmethod(_binary_op
, op
="XOr")
394 __lt__
= partialmethod(_binary_op
, op
="Cmp", predAtt
="ult")
395 __le__
= partialmethod(_binary_op
, op
="Cmp", predAtt
="ule")
396 __eq__
= partialmethod(_binary_op
, op
="Cmp", predAtt
="eq")
397 __ne__
= partialmethod(_binary_op
, op
="Cmp", predAtt
="ne")
398 __gt__
= partialmethod(_binary_op
, op
="Cmp", predAtt
="ugt")
399 __ge__
= partialmethod(_binary_op
, op
="Cmp", predAtt
="uge")
400 __and__
= partialmethod(_binary_op
, op
="And")
401 __or__
= partialmethod(_binary_op
, op
="Or")
407 .replace(ir
.Value
.__name
__, ArithValue
.__name
__)
410 # Generate MLIR Context and start generating IR
411 with ir
.Context(), ir
.Location
.unknown():
414 types
.append(get_mlir_ty(arg
))
417 module
= ir
.Module
.create()
418 with ir
.InsertionPoint(module
.body
):
419 fop
= func
.FuncOp(function_name
, (types
, []))
420 fop
.attributes
["llvm.emit_c_interface"] = ir
.UnitAttr
.get()
421 with ir
.InsertionPoint(fop
.add_entry_block()):
423 for i
, a
in enumerate(types
):
424 fargs
.append(fop
.arguments
[i
])
426 # Call user function body
427 result
= funcBody(*fargs
, **kwargs
)
434 module
.operation
.verify()
436 # Compile and JIT MLIR module
437 options
= f
"cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3"
438 support_lib
= os
.getenv("SUPPORT_LIB")
439 if not os
.path
.exists(support_lib
):
440 raise FileNotFoundError(
441 errno
.ENOENT
, os
.strerror(errno
.ENOENT
), support_lib
443 compiler
= nvgpucompiler
.NvgpuCompiler(
444 options
, opt_level
=3, shared_libs
=[support_lib
]
446 engine
= compiler
.compile_and_jit(module
)
448 # Convert input arguments to MLIR arguments
449 newArgs
= get_mlir_func_obj_ty(args
)
451 # Run the compiled program
452 engine
.invoke(function_name
, *newArgs
)