1 # RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
2 # RUN: %PYTHON %s | FileCheck %s
4 # ===----------------------------------------------------------------------===//
5 # Chapter 4 : Multistage GEMM with Tensor Core
6 # ===----------------------------------------------------------------------===//
8 # This program exemplifies a GEMM operation for `f32+=f16*f16`, utilizing the
9 # Multistage method with a tile size of 128x128x64. The code completely
10 # parallelizes the two outermost loops into thread blocks. It launches one Warp
11 # Groups (128 threads in total) and allocates multiple slots/stage in the
12 # shared memory. The program consists of three main parts: prologue, mainloop,
13 # and epilogue. In the prologue, thread0 requests for TMA to load data into
14 # shared memory slots. The mainloop executes MMA while simultaneously loading
15 # TMA for the utilized slots. This overlap of TMA and MMA operations enhances
16 # performance by maximizing computational throughput.
20 # for s in range(num_stages):
21 # TMA_128x64_64x128...
22 # for ti in range(M//128): # -> blockIdx.x
23 # for tj in range(N//128): # -> blockIdx.y
24 # for tk in range(K//64):
26 # TMA_128x64_64x128...
29 # This chapter introduces demonstrates:
30 # 1. Partition shape based on block IDs
32 # 2.1 Execute TMA Load for two input matrices for each stage
34 # 3.1 Wait for completion of TMA load with mbarrier
35 # 3.2 Performs Tensor Core GEMM 64x128x64 by warpgroup
36 # 3.3 Load next stage if needed
38 # 4.1 Store fragmented registers to shared memory
39 # 4.2 Store shared memory to global
41 # ===----------------------------------------------------------------------===//
45 from mlir
.dialects
import gpu
, scf
, nvgpu
, nvvm
46 from mlir
.extras
import types
as T
47 from tools
.nvdsl
import *
51 def partition_shape():
53 Calculate the partition shape based on the block IDs.
55 It partitions the shape like below:
56 for(.. i < M ...) --> blockIdx.x
57 for(.. j < N ...) --> blockIdx.y
61 dimX (int): Dimension along the x-axis.
62 dimY (int): Dimension along the y-axis.
64 bidx
= gpu
.block_id(gpu
.Dimension
.x
)
65 bidy
= gpu
.block_id(gpu
.Dimension
.y
)
72 mbar_group
: Mbarriers
,
81 TMA loads two input matrices from global memory to shared memory. It performs the following operations:
83 - tma.load a_shared_memory[off_x] at coordinate [x, z] (Loads 128x64)
84 - tma.load b_shared_memory[off_y1] at coordinate [y, x] (Loads 64x64)
85 - tma.load b_shared_memory[off_y2] at coordinate [y + 64, x] (Loads 64x64)
87 mbarrier.arrive ta_count = 128x64x2x4
89 dimX
, dimY
= partition_shape()
91 tidx
= gpu
.thread_id(gpu
.Dimension
.x
)
92 begin_b
= num_stages
* get_type_size(a_tma
.tma_memref
)
93 size_tma_a
= get_type_size(a_tma
.tma_memref
)
94 size_tma_b
= get_type_size(b_tma
.tma_memref
)
95 ta_count
= size_tma_a
+ (size_tma_b
* 2)
96 tidx
= gpu
.thread_id(gpu
.Dimension
.x
)
98 p
= tidx
== 0 if p
is None else p
100 off_a
= slot
* size_tma_a
101 off_b
= (slot
* size_tma_a
) + begin_b
102 off_b2
= off_b
+ size_tma_b
103 a_elem_ty
= a_tma
.tma_memref
.element_type
104 b_elem_ty
= b_tma
.tma_memref
.element_type
105 a
= get_dynamic_shared_memory(a_tma
.tma_memref
.shape
, a_elem_ty
, off_a
)
106 b1
= get_dynamic_shared_memory(b_tma
.tma_memref
.shape
, b_elem_ty
, off_b
)
107 b2
= get_dynamic_shared_memory(b_tma
.tma_memref
.shape
, b_elem_ty
, off_b2
)
109 mbar_group
[slot
].arrive(ta_count
, predicate
=p
)
112 a_tma
.load(a
, mbar_group
[slot
], coords
=[c1
, dimX
], predicate
=p
)
113 b_tma
.load(b1
, mbar_group
[slot
], coords
=[dimY
, c1
], predicate
=p
)
114 b_tma
.load(b2
, mbar_group
[slot
], coords
=[dimY
+ 64, c1
], predicate
=p
)
117 def initialize(a_tma
: TMA
, b_tma
: TMA
, num_stages
):
119 Initialize mbarriers and prefetch TMA descriptors.
121 tidx
= gpu
.thread_id(gpu
.Dimension
.x
)
122 mbar_group
= Mbarriers(number_of_barriers
=num_stages
)
123 isThread0
= tidx
== const(0)
124 with ir
.InsertionPoint(scf
.IfOp(isThread0
).then_block
):
125 for i
in scf
.for_(0, num_stages
, 1):
126 mbar_group
[i
].init(1)
135 def prologue(mbar_group
: Mbarriers
, a_tma
: TMA
, b_tma
: TMA
, num_stages
):
137 Prologue of the GEMM kernel. It loads 2 input matrices for each stage in loop like below:
139 for stage in range(NUM_STAGES):
143 ns
= num_stages
if num_stages
== 1 else num_stages
- 1
144 for iv
in scf
.for_(0, ns
, 1):
145 tma_load(mbar_group
, a_tma
, b_tma
, iv
, iv
, num_stages
)
149 def mainloop(mbar_group
: Mbarriers
, a_tma
: TMA
, b_tma
: TMA
, num_stages
):
151 Main loop of the Multistage GEMM kernel. It iterates through
152 stages and performs matrix multiplication, loading data by TMA to shared memory. It like following
155 for k in range(K // TILE_K):
157 try_wait(stage, ...) # Wait TMA load
159 Matrix A(stage, ...) # Find shared memory slot
160 Matrix B(stage, ...) # Find shared memory slot
161 D += A @ B # Multiply and accumulate
163 if(needLoad) # Load next stage if needed
164 tma_load(x, y, nextSlot, nextStage)
167 ns
= num_stages
if num_stages
== 1 else num_stages
- 1
169 tidx
= gpu
.thread_id(gpu
.Dimension
.x
)
170 begin_b
= num_stages
* get_type_size(a_tma
.tma_memref
)
172 size_a
= TILE_M
* TILE_K
* get_type_size(T
.f16())
174 # Initialize A and B (input matrices) and C (accumulator)
175 A
= WGMMAMatrix(WGMMAType
.Descriptor
, [TILE_M
, TILE_K
], desc
=a_tma
)
176 B
= WGMMAMatrix(WGMMAType
.Descriptor
, [TILE_K
, TILE_N
], desc
=b_tma
)
177 D
= WGMMAMatrix(WGMMAType
.Accumulator
, shape
=[TILE_M
, TILE_N
], ty
=T
.f32())
179 phase
= const(False, ty
=T
.bool())
182 for_op
= scf
.ForOp(const(0), const(K
// TILE_K
), const(1), [D
.acc_op
, phase
])
183 with ir
.InsertionPoint(for_op
.body
):
184 phase
= for_op
.inner_iter_args
[1]
185 iv
= for_op
.induction_variable
186 stage
= iv
% num_stages
188 # Wait for current stage
189 mbar_group
[stage
].try_wait(phase
=phase
)
191 # Find shared memory slot
192 offset_a
= stage
* size_a
193 offset_b
= offset_a
+ begin_b
194 a_smem
= get_dynamic_shared_memory([TILE_M
, TILE_K
], T
.f16(), offset_a
)
195 b_smem
= get_dynamic_shared_memory([TILE_K
, TILE_N
], T
.f16(), offset_b
)
197 # Iterate input matrices, update accumulator
198 A
.update_smem(a_smem
)
199 B
.update_smem(b_smem
)
200 D
.update_accumulator(for_op
.inner_iter_args
[0])
205 # Wait Tensor Core for single stage
207 nvvm
.WgmmaWaitGroupSyncOp(0)
210 pred
= ((iv
+ ns
) < const(K
// TILE_K
)) & (tidx
== 0)
212 nextSlot
= nextStage
% num_stages
213 tma_load(mbar_group
, a_tma
, b_tma
, nextSlot
, nextStage
, num_stages
, pred
)
215 # Switch phase parity for the mbarrier
216 newPhase
= arith
.select(
217 stage
== (num_stages
- 1),
218 (phase ^
const(True, ty
=T
.bool())),
221 scf
.yield_([D
.acc_op
, newPhase
])
223 nvvm
.WgmmaWaitGroupSyncOp(0)
225 D
.update_accumulator(for_op
.results
[0])
229 def epilogue(D
: WGMMAMatrix
, d_dev
):
231 Epilogue of the GEMM kernel. It stores the fragmented registers to global memory.
233 MatrixAccumulator D # Fragmented results
234 store D -> Shared Memory # Store Shared Memory
235 Shared Memory -> Z[dimX][dimY] # Store Shared Memory to Global Memory
238 tidx
= gpu
.thread_id(gpu
.Dimension
.x
)
239 dimX
, dimY
= partition_shape()
241 d_smem
= get_dynamic_shared_memory([TILE_M
, TILE_N
], T
.f32())
242 d_gmem
= memref
.subview(d_dev
, [dimX
, dimY
], [TILE_M
, TILE_N
], [1, 1])
244 # Store (registers -> shared memory)
245 D
.store_accumulator(d_smem
)
248 # Store (shared memory --> global memory)
249 for i
in scf
.for_(0, TILE_M
, 1):
250 val
= memref
.load(d_smem
, [i
, tidx
])
251 memref
.store(val
, d_gmem
, [i
, tidx
])
255 # The decorator generates
256 # a -> memref<MxKxf16>
257 # b -> memref<NxKf16>
258 # d -> memref<MxNxf32>
260 def gemm_multistage(a
, b
, d
, num_stages
):
261 token_ty
= gpu
.AsyncTokenType
.get()
262 t1
= gpu
.wait(token_ty
, [])
263 a_dev
, t2
= gpu
.alloc(a
.type, token_ty
, [t1
], [], [])
264 b_dev
, t3
= gpu
.alloc(b
.type, token_ty
, [t2
], [], [])
265 d_dev
, t4
= gpu
.alloc(d
.type, token_ty
, [t3
], [], [])
266 t5
= gpu
.memcpy(token_ty
, [t4
], a_dev
, a
)
267 t6
= gpu
.memcpy(token_ty
, [t5
], b_dev
, b
)
268 t7
= gpu
.wait(token_ty
, [t6
])
270 sw
= nvgpu
.TensorMapSwizzleKind
.SWIZZLE_128B
271 a_tma
= TMA([128, 64], a
.type, swizzle
=sw
)
272 b_tma
= TMA([64, 64], b
.type, swizzle
=sw
)
273 a_tma
.create_descriptor(a_dev
)
274 b_tma
.create_descriptor(b_dev
)
276 grid
= [(M
// TILE_M
), (N
// TILE_N
), 1]
279 size_a
= get_type_size(a
.type.element_type
) * TILE_M
* TILE_K
280 size_b
= get_type_size(b
.type.element_type
) * TILE_N
* TILE_K
281 smem_size_in_bytes
= (size_a
+ size_b
) * num_stages
283 @NVDSL.mlir_gpu_launch(grid
=grid
, block
=block
, smem
=smem_size_in_bytes
)
284 def gemm_multistage_kernel():
285 # Initialize mbarriers and prefetch TMA descriptors
286 mbar_group
= initialize(a_tma
, b_tma
, num_stages
)
288 # Fill the pipeline stages
289 prologue(mbar_group
, a_tma
, b_tma
, num_stages
)
292 D
= mainloop(mbar_group
, a_tma
, b_tma
, num_stages
)
294 # Store registers to global memory
297 gemm_multistage_kernel()
299 t8
= gpu
.memcpy(token_ty
, [t7
], d
, d_dev
)
303 # Python pass arguments to MLIR
310 a
= np
.random
.randn(M
, K
).astype(np
.float16
)
311 b
= np
.random
.randn(K
, N
).astype(np
.float16
)
312 d
= np
.zeros((M
, N
), np
.float32
)
314 gemm_multistage(a
, b
, d
, num_stages
=7)
317 # Verify MLIR with reference computation
318 ref_d
= a
.astype(np
.float16
) @ b
.astype(np
.float16
)
319 np
.testing
.assert_allclose(d
, ref_d
, rtol
=5e-03, atol
=1e-01)
323 # CHECK-NOT: Mismatched elements