[clang-tidy][NFC]remove deps of clang in clang tidy test (#116588)
[llvm-project.git] / mlir / test / Examples / NVGPU / Ch5.py
blobf98cfd758a75f0b3ffda2c54d8c550fca076164d
1 # RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
2 # RUN: %PYTHON %s | FileCheck %s
4 # ===----------------------------------------------------------------------===//
5 # Chapter 5 : Warp Specialized GEMM with Tensor Core
6 # ===----------------------------------------------------------------------===//
8 # This program demonstrates a GEMM operation for `f32+=f16*f16`, utilizing the
9 # Warp Specialized method with a tile size of 128x128x64. The code completely
10 # parallelizes the two outermost loops into thread blocks. It launches two Warp
11 # Groups (256 threads in total): one for the producer and the other for the consumer.
12 # Each group takes a different control-flow. The producer thread group is responsible
13 # for loading data into shared memory, while the consumer group executes the Tensor
14 # Core GEMM operation and epilogue.
16 # for ti in range(M//128): # -> blockIdx.x
17 # for tj in range(N//128): # -> blockIdx.y
18 # with wg_producer:
19 # for tk in range(K//64):
20 # TMA_128x64_64x128...
21 # with wg_consumer:
22 # for tk in range(K//64):
23 # MMA_128x128x64...
24 # Epilogue..
26 # This chapter demonstrates:
27 # 2 WG (warpgroups)
28 # Producer:
29 # 2.1.1 Wait MMA Barrier
30 # 2.1.1 Load TMA with TMA barrier
31 # 2.1.1 Arrive TMA barrier with txcount
32 # Consumer:
33 # Loop
34 # Wait TMA barrier
35 # Performs Tensor Core GEMM 64x128x64 by warpgroup
36 # Arrive MMA Barrier
37 # Epilogue
38 # Store fragmented registers to shared memory
39 # Store shared memory to global
41 # ===----------------------------------------------------------------------===//
44 from mlir import ir
45 from mlir.dialects import gpu, scf, nvgpu, nvvm
46 from mlir.extras import types as T
47 from tools.nvdsl import *
48 import numpy as np
51 def partition_shape():
52 """
53 Calculate the partition shape based on the block IDs.
55 It parallelizes the two outermost loops into thread blocks.
56 for ti in range(M//128): # -> blockIdx.x
57 for tj in range(N//128): # -> blockIdx.y
58 D = 0
59 for tk in range(K//64):
60 for i in range(128):
61 for j in range(128):
62 for k in range(64):
63 FMA
65 Returns:
66 dimX (int): Dimension along the x-axis.
67 dimY (int): Dimension along the y-axis.
68 """
69 bidx = gpu.block_id(gpu.Dimension.x)
70 bidy = gpu.block_id(gpu.Dimension.y)
71 dimX = bidx * TILE_M
72 dimY = bidy * TILE_N
73 return dimX, dimY
76 def tma_load(
77 mbar_group: Mbarriers,
78 a_tma: TMA,
79 b_tma: TMA,
80 slot,
81 stage,
82 num_stages,
83 p=None,
85 """
86 TMA loads two input matrices from global memory to shared memory. It performs the following operations:
88 - tma.load a_shared_memory[off_x] at coordinate [x, z] (Loads 128x64)
89 - tma.load b_shared_memory[off_y1] at coordinate [y, x] (Loads 64x64)
90 - tma.load b_shared_memory[off_y2] at coordinate [y + 64, x] (Loads 64x64)
92 mbarrier.arrive ta_count = 128x64x2x4
93 """
94 dimX, dimY = partition_shape()
96 tidx = gpu.thread_id(gpu.Dimension.x)
97 begin_b = num_stages * get_type_size(a_tma.tma_memref)
98 size_tma_a = get_type_size(a_tma.tma_memref)
99 size_tma_b = get_type_size(b_tma.tma_memref)
100 ta_count = size_tma_a + (size_tma_b * 2)
102 off_a = slot * size_tma_a
103 off_b = (slot * size_tma_a) + begin_b
104 off_b2 = off_b + size_tma_b
105 a_elem_ty = a_tma.tma_memref.element_type
106 b_elem_ty = b_tma.tma_memref.element_type
107 a = get_dynamic_shared_memory(a_tma.tma_memref.shape, a_elem_ty, off_a)
108 b1 = get_dynamic_shared_memory(b_tma.tma_memref.shape, b_elem_ty, off_b)
109 b2 = get_dynamic_shared_memory(b_tma.tma_memref.shape, b_elem_ty, off_b2)
111 mbar_group[slot].arrive(ta_count, predicate=p)
112 p = (tidx % WARP_GROUP_SIZE) == 0
113 c1 = stage * 64
114 a_tma.load(a, mbar_group[slot], coords=[c1, dimX], predicate=p)
115 b_tma.load(b1, mbar_group[slot], coords=[dimY, c1], predicate=p)
116 b_tma.load(b2, mbar_group[slot], coords=[dimY + 64, c1], predicate=p)
119 def initialize(a_tma: TMA, b_tma: TMA, num_stages):
121 Initialize mbarriers and prefetch TMA descriptors.
123 tidx = gpu.thread_id(gpu.Dimension.x)
124 mbar_group_tma = Mbarriers(number_of_barriers=num_stages)
125 mbar_group_mma = Mbarriers(number_of_barriers=num_stages)
126 isThread0 = tidx == const(0)
127 with ir.InsertionPoint(scf.IfOp(isThread0).then_block):
128 for i in scf.for_(0, num_stages, 1):
129 mbar_group_tma[i].init(1)
130 mbar_group_mma[i].init(1)
131 scf.yield_([])
132 a_tma.prefetch()
133 b_tma.prefetch()
134 scf.yield_([])
136 return mbar_group_tma, mbar_group_mma
139 def switch_phase(stage, phase, num_stages):
140 p = stage == (num_stages - 1)
141 phase = arith.select(
143 (phase ^ const(True, ty=T.bool())),
144 phase,
146 return phase
149 def producer_loop(
150 mbar_tma: Mbarriers,
151 mbar_mma: Mbarriers,
152 a_tma: TMA,
153 b_tma: TMA,
154 wg_me: Warpgroup,
155 num_stages,
157 phase = const(True, ty=T.bool())
159 for iv, phase in scf.for_(0, (K // TILE_K), 1, [phase]):
160 stage = iv % num_stages
161 # Wait MMA to be done
162 mbar_mma[stage].try_wait(phase)
163 # New phase for mbarrier
164 phase = switch_phase(stage, phase, num_stages)
165 # TMA Load
166 tma_load(mbar_tma, a_tma, b_tma, stage, iv, num_stages, wg_me.is_wg_primary)
167 scf.yield_([phase])
170 def consumer_loop(
171 mbar_tma: Mbarriers,
172 mbar_mma: Mbarriers,
173 a_tma: TMA,
174 b_tma: TMA,
175 wg_me: Warpgroup,
176 num_stages,
178 begin_b = num_stages * get_type_size(a_tma.tma_memref)
180 size_a = TILE_M * TILE_K * get_type_size(T.f16())
182 phase = const(False, ty=T.bool())
183 A = WGMMAMatrix(WGMMAType.Descriptor, [TILE_M, TILE_K], desc=a_tma)
184 B = WGMMAMatrix(WGMMAType.Descriptor, [TILE_K, TILE_N], desc=b_tma)
185 D = WGMMAMatrix(WGMMAType.Accumulator, shape=[TILE_M, TILE_N], ty=T.f32())
187 for_op = scf.ForOp(const(0), const(K // TILE_K), const(1), [D.acc_op, phase])
188 with ir.InsertionPoint(for_op.body):
189 phase = for_op.inner_iter_args[1]
190 iv = for_op.induction_variable
191 stage = iv % num_stages
193 # Wait TMA for current stage
194 mbar_tma[stage].try_wait(phase)
196 # Find shared memory slot
197 offset_a = stage * size_a
198 offset_b = offset_a + begin_b
199 a_smem = get_dynamic_shared_memory([TILE_M, TILE_K], T.f16(), offset_a)
200 b_smem = get_dynamic_shared_memory([TILE_K, TILE_N], T.f16(), offset_b)
202 # Iterate input matrices, update accumulator
203 A.update_smem(a_smem)
204 B.update_smem(b_smem)
205 D.update_accumulator(for_op.inner_iter_args[0])
207 # Matrix Multiply
208 D += A @ B
210 # MMA Barrier Arrive
211 p_arrive = (iv > 0) & wg_me.is_wg_primary
212 with ir.InsertionPoint(scf.IfOp(p_arrive).then_block):
213 barId = arith.select((stage == 0), const(num_stages - 1), (stage - 1))
214 mbar_mma[barId].arrive()
215 scf.yield_([])
217 phase = switch_phase(stage, phase, num_stages)
218 scf.yield_([D.acc_op, phase])
220 nvvm.WgmmaWaitGroupSyncOp(0)
221 D.update_accumulator(for_op.results[0])
222 return D
225 def epilogue(D: WGMMAMatrix, d_dev):
227 Epilogue of the GEMM kernel. It stores the fragmented registers to global memory.
229 MatrixAccumulator D # Fragmented results
230 store D -> Shared Memory # Store Shared Memory
231 Shared Memory -> Z[dimX][dimY] # Store Shared Memory to Global Memory
234 tidx = gpu.thread_id(gpu.Dimension.x)
235 dimX, dimY = partition_shape()
236 # s = tidx - WARP_GROUP_SIZE
237 # debug_print("[Epilogue] store to global memory @ s={}", s)
239 d_smem = get_dynamic_shared_memory([TILE_M, TILE_N], T.f32())
240 d_gmem = memref.subview(d_dev, [dimX, dimY], [TILE_M, TILE_N], [1, 1])
242 # Store (registers -> shared memory)
243 D.store_accumulator(d_smem)
244 gpu.barrier()
246 # Store (shared memory --> global memory)
247 for i in scf.for_(0, TILE_M, 1):
248 val = memref.load(d_smem, [i, tidx])
249 memref.store(val, d_gmem, [i, tidx])
250 scf.yield_([])
253 @NVDSL.mlir_func
254 def gemm_warp_specialized(a, b, d, num_stages):
255 token_ty = gpu.AsyncTokenType.get()
256 t1 = gpu.wait(token_ty, [])
257 a_dev, t2 = gpu.alloc(a.type, token_ty, [t1], [], [])
258 b_dev, t3 = gpu.alloc(b.type, token_ty, [t2], [], [])
259 d_dev, t4 = gpu.alloc(d.type, token_ty, [t3], [], [])
260 t5 = gpu.memcpy(token_ty, [t4], a_dev, a)
261 t6 = gpu.memcpy(token_ty, [t5], b_dev, b)
262 t7 = gpu.wait(token_ty, [t6])
264 sw = nvgpu.TensorMapSwizzleKind.SWIZZLE_128B
265 a_tma = TMA([128, 64], a.type, swizzle=sw)
266 b_tma = TMA([64, 64], b.type, swizzle=sw)
267 a_tma.create_descriptor(a_dev)
268 b_tma.create_descriptor(b_dev)
270 grid = [(M // TILE_M), (N // TILE_N), 1]
271 block = [256, 1, 1]
273 size_a = get_type_size(a.type.element_type) * TILE_M * TILE_K
274 size_b = get_type_size(b.type.element_type) * TILE_N * TILE_K
275 smem_size_in_bytes = (size_a + size_b) * num_stages
277 @NVDSL.mlir_gpu_launch(grid=grid, block=block, smem=smem_size_in_bytes)
278 def gemm_warp_specialized_kernel():
279 # Init Warpgroups
280 wg_producer = Warpgroup(primary_thread=128, register_size=40)
281 wg_consumer = Warpgroup(primary_thread=0, register_size=232)
283 # Initialize mbarriers and prefetch TMA descriptors
284 mbar_mma, mbar_tma = initialize(a_tma, b_tma, num_stages)
286 # Producer performs TMA
287 with wg_producer:
288 producer_loop(mbar_tma, mbar_mma, a_tma, b_tma, wg_producer, num_stages)
290 # Consumer performs MMA/Tensor Core
291 with wg_consumer:
292 D = consumer_loop(mbar_tma, mbar_mma, a_tma, b_tma, wg_consumer, num_stages)
293 epilogue(D, d_dev)
295 gemm_warp_specialized_kernel()
297 t8 = gpu.memcpy(token_ty, [t7], d, d_dev)
298 gpu.wait(None, [t8])
301 # Python pass arguments to MLIR
302 N = 256
303 M = 512
304 K = 1024
305 TILE_M = 128
306 TILE_N = 128
307 TILE_K = 64
308 a = np.random.randn(M, K).astype(np.float16)
309 b = np.random.randn(K, N).astype(np.float16)
310 d = np.zeros((M, N), np.float32)
312 gemm_warp_specialized(a, b, d, num_stages=7)
315 # Verify MLIR with reference computation
316 ref_d = a.astype(np.float16) @ b.astype(np.float16)
317 np.testing.assert_allclose(d, ref_d, rtol=5e-03, atol=1e-01)
320 print("PASS")
321 # CHECK-NOT: Mismatched elements