[clang-tidy][NFC]remove deps of clang in clang tidy test (#116588)
[llvm-project.git] / mlir / test / Examples / NVGPU / Ch4.py
blob0e3460ff8d63b269d9a9b812f9ac64da9b78b542
1 # RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
2 # RUN: %PYTHON %s | FileCheck %s
4 # ===----------------------------------------------------------------------===//
5 # Chapter 4 : Multistage GEMM with Tensor Core
6 # ===----------------------------------------------------------------------===//
8 # This program exemplifies a GEMM operation for `f32+=f16*f16`, utilizing the
9 # Multistage method with a tile size of 128x128x64. The code completely
10 # parallelizes the two outermost loops into thread blocks. It launches one Warp
11 # Groups (128 threads in total) and allocates multiple slots/stage in the
12 # shared memory. The program consists of three main parts: prologue, mainloop,
13 # and epilogue. In the prologue, thread0 requests for TMA to load data into
14 # shared memory slots. The mainloop executes MMA while simultaneously loading
15 # TMA for the utilized slots. This overlap of TMA and MMA operations enhances
16 # performance by maximizing computational throughput.
18 # Loops illustration:
20 # for s in range(num_stages):
21 # TMA_128x64_64x128...
22 # for ti in range(M//128): # -> blockIdx.x
23 # for tj in range(N//128): # -> blockIdx.y
24 # for tk in range(K//64):
25 # MMA_128x128x64...
26 # TMA_128x64_64x128...
27 # Epilogue...
29 # This chapter introduces demonstrates:
30 # 1. Partition shape based on block IDs
31 # 2. Prologue
32 # 2.1 Execute TMA Load for two input matrices for each stage
33 # 3. Main loop
34 # 3.1 Wait for completion of TMA load with mbarrier
35 # 3.2 Performs Tensor Core GEMM 64x128x64 by warpgroup
36 # 3.3 Load next stage if needed
37 # 4. Epilogue
38 # 4.1 Store fragmented registers to shared memory
39 # 4.2 Store shared memory to global
41 # ===----------------------------------------------------------------------===//
44 from mlir import ir
45 from mlir.dialects import gpu, scf, nvgpu, nvvm
46 from mlir.extras import types as T
47 from tools.nvdsl import *
48 import numpy as np
51 def partition_shape():
52 """
53 Calculate the partition shape based on the block IDs.
55 It partitions the shape like below:
56 for(.. i < M ...) --> blockIdx.x
57 for(.. j < N ...) --> blockIdx.y
58 for(.. k < K ...)
60 Returns:
61 dimX (int): Dimension along the x-axis.
62 dimY (int): Dimension along the y-axis.
63 """
64 bidx = gpu.block_id(gpu.Dimension.x)
65 bidy = gpu.block_id(gpu.Dimension.y)
66 dimX = bidx * TILE_M
67 dimY = bidy * TILE_N
68 return dimX, dimY
71 def tma_load(
72 mbar_group: Mbarriers,
73 a_tma: TMA,
74 b_tma: TMA,
75 slot,
76 stage,
77 num_stages,
78 p=None,
80 """
81 TMA loads two input matrices from global memory to shared memory. It performs the following operations:
83 - tma.load a_shared_memory[off_x] at coordinate [x, z] (Loads 128x64)
84 - tma.load b_shared_memory[off_y1] at coordinate [y, x] (Loads 64x64)
85 - tma.load b_shared_memory[off_y2] at coordinate [y + 64, x] (Loads 64x64)
87 mbarrier.arrive ta_count = 128x64x2x4
88 """
89 dimX, dimY = partition_shape()
91 tidx = gpu.thread_id(gpu.Dimension.x)
92 begin_b = num_stages * get_type_size(a_tma.tma_memref)
93 size_tma_a = get_type_size(a_tma.tma_memref)
94 size_tma_b = get_type_size(b_tma.tma_memref)
95 ta_count = size_tma_a + (size_tma_b * 2)
96 tidx = gpu.thread_id(gpu.Dimension.x)
98 p = tidx == 0 if p is None else p
100 off_a = slot * size_tma_a
101 off_b = (slot * size_tma_a) + begin_b
102 off_b2 = off_b + size_tma_b
103 a_elem_ty = a_tma.tma_memref.element_type
104 b_elem_ty = b_tma.tma_memref.element_type
105 a = get_dynamic_shared_memory(a_tma.tma_memref.shape, a_elem_ty, off_a)
106 b1 = get_dynamic_shared_memory(b_tma.tma_memref.shape, b_elem_ty, off_b)
107 b2 = get_dynamic_shared_memory(b_tma.tma_memref.shape, b_elem_ty, off_b2)
109 mbar_group[slot].arrive(ta_count, predicate=p)
111 c1 = stage * 64
112 a_tma.load(a, mbar_group[slot], coords=[c1, dimX], predicate=p)
113 b_tma.load(b1, mbar_group[slot], coords=[dimY, c1], predicate=p)
114 b_tma.load(b2, mbar_group[slot], coords=[dimY + 64, c1], predicate=p)
117 def initialize(a_tma: TMA, b_tma: TMA, num_stages):
119 Initialize mbarriers and prefetch TMA descriptors.
121 tidx = gpu.thread_id(gpu.Dimension.x)
122 mbar_group = Mbarriers(number_of_barriers=num_stages)
123 isThread0 = tidx == const(0)
124 with ir.InsertionPoint(scf.IfOp(isThread0).then_block):
125 for i in scf.for_(0, num_stages, 1):
126 mbar_group[i].init(1)
127 scf.yield_([])
128 a_tma.prefetch()
129 b_tma.prefetch()
130 scf.yield_([])
132 return mbar_group
135 def prologue(mbar_group: Mbarriers, a_tma: TMA, b_tma: TMA, num_stages):
137 Prologue of the GEMM kernel. It loads 2 input matrices for each stage in loop like below:
139 for stage in range(NUM_STAGES):
140 tma_load x, y, stage
143 ns = num_stages if num_stages == 1 else num_stages - 1
144 for iv in scf.for_(0, ns, 1):
145 tma_load(mbar_group, a_tma, b_tma, iv, iv, num_stages)
146 scf.yield_([])
149 def mainloop(mbar_group: Mbarriers, a_tma: TMA, b_tma: TMA, num_stages):
151 Main loop of the Multistage GEMM kernel. It iterates through
152 stages and performs matrix multiplication, loading data by TMA to shared memory. It like following
154 MatrixAccumulator D
155 for k in range(K // TILE_K):
157 try_wait(stage, ...) # Wait TMA load
159 Matrix A(stage, ...) # Find shared memory slot
160 Matrix B(stage, ...) # Find shared memory slot
161 D += A @ B # Multiply and accumulate
163 if(needLoad) # Load next stage if needed
164 tma_load(x, y, nextSlot, nextStage)
167 ns = num_stages if num_stages == 1 else num_stages - 1
169 tidx = gpu.thread_id(gpu.Dimension.x)
170 begin_b = num_stages * get_type_size(a_tma.tma_memref)
172 size_a = TILE_M * TILE_K * get_type_size(T.f16())
174 # Initialize A and B (input matrices) and C (accumulator)
175 A = WGMMAMatrix(WGMMAType.Descriptor, [TILE_M, TILE_K], desc=a_tma)
176 B = WGMMAMatrix(WGMMAType.Descriptor, [TILE_K, TILE_N], desc=b_tma)
177 D = WGMMAMatrix(WGMMAType.Accumulator, shape=[TILE_M, TILE_N], ty=T.f32())
179 phase = const(False, ty=T.bool())
181 # Main Loop
182 for_op = scf.ForOp(const(0), const(K // TILE_K), const(1), [D.acc_op, phase])
183 with ir.InsertionPoint(for_op.body):
184 phase = for_op.inner_iter_args[1]
185 iv = for_op.induction_variable
186 stage = iv % num_stages
188 # Wait for current stage
189 mbar_group[stage].try_wait(phase=phase)
191 # Find shared memory slot
192 offset_a = stage * size_a
193 offset_b = offset_a + begin_b
194 a_smem = get_dynamic_shared_memory([TILE_M, TILE_K], T.f16(), offset_a)
195 b_smem = get_dynamic_shared_memory([TILE_K, TILE_N], T.f16(), offset_b)
197 # Iterate input matrices, update accumulator
198 A.update_smem(a_smem)
199 B.update_smem(b_smem)
200 D.update_accumulator(for_op.inner_iter_args[0])
202 # Matrix Multiply
203 D += A @ B
205 # Wait Tensor Core for single stage
206 if num_stages == 1:
207 nvvm.WgmmaWaitGroupSyncOp(0)
209 # Load next stage
210 pred = ((iv + ns) < const(K // TILE_K)) & (tidx == 0)
211 nextStage = iv + ns
212 nextSlot = nextStage % num_stages
213 tma_load(mbar_group, a_tma, b_tma, nextSlot, nextStage, num_stages, pred)
215 # Switch phase parity for the mbarrier
216 newPhase = arith.select(
217 stage == (num_stages - 1),
218 (phase ^ const(True, ty=T.bool())),
219 phase,
221 scf.yield_([D.acc_op, newPhase])
223 nvvm.WgmmaWaitGroupSyncOp(0)
225 D.update_accumulator(for_op.results[0])
226 return D
229 def epilogue(D: WGMMAMatrix, d_dev):
231 Epilogue of the GEMM kernel. It stores the fragmented registers to global memory.
233 MatrixAccumulator D # Fragmented results
234 store D -> Shared Memory # Store Shared Memory
235 Shared Memory -> Z[dimX][dimY] # Store Shared Memory to Global Memory
238 tidx = gpu.thread_id(gpu.Dimension.x)
239 dimX, dimY = partition_shape()
241 d_smem = get_dynamic_shared_memory([TILE_M, TILE_N], T.f32())
242 d_gmem = memref.subview(d_dev, [dimX, dimY], [TILE_M, TILE_N], [1, 1])
244 # Store (registers -> shared memory)
245 D.store_accumulator(d_smem)
246 gpu.barrier()
248 # Store (shared memory --> global memory)
249 for i in scf.for_(0, TILE_M, 1):
250 val = memref.load(d_smem, [i, tidx])
251 memref.store(val, d_gmem, [i, tidx])
252 scf.yield_([])
255 # The decorator generates
256 # a -> memref<MxKxf16>
257 # b -> memref<NxKf16>
258 # d -> memref<MxNxf32>
259 @NVDSL.mlir_func
260 def gemm_multistage(a, b, d, num_stages):
261 token_ty = gpu.AsyncTokenType.get()
262 t1 = gpu.wait(token_ty, [])
263 a_dev, t2 = gpu.alloc(a.type, token_ty, [t1], [], [])
264 b_dev, t3 = gpu.alloc(b.type, token_ty, [t2], [], [])
265 d_dev, t4 = gpu.alloc(d.type, token_ty, [t3], [], [])
266 t5 = gpu.memcpy(token_ty, [t4], a_dev, a)
267 t6 = gpu.memcpy(token_ty, [t5], b_dev, b)
268 t7 = gpu.wait(token_ty, [t6])
270 sw = nvgpu.TensorMapSwizzleKind.SWIZZLE_128B
271 a_tma = TMA([128, 64], a.type, swizzle=sw)
272 b_tma = TMA([64, 64], b.type, swizzle=sw)
273 a_tma.create_descriptor(a_dev)
274 b_tma.create_descriptor(b_dev)
276 grid = [(M // TILE_M), (N // TILE_N), 1]
277 block = [128, 1, 1]
279 size_a = get_type_size(a.type.element_type) * TILE_M * TILE_K
280 size_b = get_type_size(b.type.element_type) * TILE_N * TILE_K
281 smem_size_in_bytes = (size_a + size_b) * num_stages
283 @NVDSL.mlir_gpu_launch(grid=grid, block=block, smem=smem_size_in_bytes)
284 def gemm_multistage_kernel():
285 # Initialize mbarriers and prefetch TMA descriptors
286 mbar_group = initialize(a_tma, b_tma, num_stages)
288 # Fill the pipeline stages
289 prologue(mbar_group, a_tma, b_tma, num_stages)
291 # Main loop
292 D = mainloop(mbar_group, a_tma, b_tma, num_stages)
294 # Store registers to global memory
295 epilogue(D, d_dev)
297 gemm_multistage_kernel()
299 t8 = gpu.memcpy(token_ty, [t7], d, d_dev)
300 gpu.wait(None, [t8])
303 # Python pass arguments to MLIR
304 N = 256
305 M = 512
306 K = 1024
307 TILE_M = 128
308 TILE_N = 128
309 TILE_K = 64
310 a = np.random.randn(M, K).astype(np.float16)
311 b = np.random.randn(K, N).astype(np.float16)
312 d = np.zeros((M, N), np.float32)
314 gemm_multistage(a, b, d, num_stages=7)
317 # Verify MLIR with reference computation
318 ref_d = a.astype(np.float16) @ b.astype(np.float16)
319 np.testing.assert_allclose(d, ref_d, rtol=5e-03, atol=1e-01)
322 print("PASS")
323 # CHECK-NOT: Mismatched elements