1 # RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
2 # RUN: %PYTHON %s | FileCheck %s
4 # ===----------------------------------------------------------------------===//
5 # Chapter 5 : Warp Specialized GEMM with Tensor Core
6 # ===----------------------------------------------------------------------===//
8 # This program demonstrates a GEMM operation for `f32+=f16*f16`, utilizing the
9 # Warp Specialized method with a tile size of 128x128x64. The code completely
10 # parallelizes the two outermost loops into thread blocks. It launches two Warp
11 # Groups (256 threads in total): one for the producer and the other for the consumer.
12 # Each group takes a different control-flow. The producer thread group is responsible
13 # for loading data into shared memory, while the consumer group executes the Tensor
14 # Core GEMM operation and epilogue.
16 # for ti in range(M//128): # -> blockIdx.x
17 # for tj in range(N//128): # -> blockIdx.y
19 # for tk in range(K//64):
20 # TMA_128x64_64x128...
22 # for tk in range(K//64):
26 # This chapter demonstrates:
29 # 2.1.1 Wait MMA Barrier
30 # 2.1.1 Load TMA with TMA barrier
31 # 2.1.1 Arrive TMA barrier with txcount
35 # Performs Tensor Core GEMM 64x128x64 by warpgroup
38 # Store fragmented registers to shared memory
39 # Store shared memory to global
41 # ===----------------------------------------------------------------------===//
45 from mlir
.dialects
import gpu
, scf
, nvgpu
, nvvm
46 from mlir
.extras
import types
as T
47 from tools
.nvdsl
import *
51 def partition_shape():
53 Calculate the partition shape based on the block IDs.
55 It parallelizes the two outermost loops into thread blocks.
56 for ti in range(M//128): # -> blockIdx.x
57 for tj in range(N//128): # -> blockIdx.y
59 for tk in range(K//64):
66 dimX (int): Dimension along the x-axis.
67 dimY (int): Dimension along the y-axis.
69 bidx
= gpu
.block_id(gpu
.Dimension
.x
)
70 bidy
= gpu
.block_id(gpu
.Dimension
.y
)
77 mbar_group
: Mbarriers
,
86 TMA loads two input matrices from global memory to shared memory. It performs the following operations:
88 - tma.load a_shared_memory[off_x] at coordinate [x, z] (Loads 128x64)
89 - tma.load b_shared_memory[off_y1] at coordinate [y, x] (Loads 64x64)
90 - tma.load b_shared_memory[off_y2] at coordinate [y + 64, x] (Loads 64x64)
92 mbarrier.arrive ta_count = 128x64x2x4
94 dimX
, dimY
= partition_shape()
96 tidx
= gpu
.thread_id(gpu
.Dimension
.x
)
97 begin_b
= num_stages
* get_type_size(a_tma
.tma_memref
)
98 size_tma_a
= get_type_size(a_tma
.tma_memref
)
99 size_tma_b
= get_type_size(b_tma
.tma_memref
)
100 ta_count
= size_tma_a
+ (size_tma_b
* 2)
102 off_a
= slot
* size_tma_a
103 off_b
= (slot
* size_tma_a
) + begin_b
104 off_b2
= off_b
+ size_tma_b
105 a_elem_ty
= a_tma
.tma_memref
.element_type
106 b_elem_ty
= b_tma
.tma_memref
.element_type
107 a
= get_dynamic_shared_memory(a_tma
.tma_memref
.shape
, a_elem_ty
, off_a
)
108 b1
= get_dynamic_shared_memory(b_tma
.tma_memref
.shape
, b_elem_ty
, off_b
)
109 b2
= get_dynamic_shared_memory(b_tma
.tma_memref
.shape
, b_elem_ty
, off_b2
)
111 mbar_group
[slot
].arrive(ta_count
, predicate
=p
)
112 p
= (tidx
% WARP_GROUP_SIZE
) == 0
114 a_tma
.load(a
, mbar_group
[slot
], coords
=[c1
, dimX
], predicate
=p
)
115 b_tma
.load(b1
, mbar_group
[slot
], coords
=[dimY
, c1
], predicate
=p
)
116 b_tma
.load(b2
, mbar_group
[slot
], coords
=[dimY
+ 64, c1
], predicate
=p
)
119 def initialize(a_tma
: TMA
, b_tma
: TMA
, num_stages
):
121 Initialize mbarriers and prefetch TMA descriptors.
123 tidx
= gpu
.thread_id(gpu
.Dimension
.x
)
124 mbar_group_tma
= Mbarriers(number_of_barriers
=num_stages
)
125 mbar_group_mma
= Mbarriers(number_of_barriers
=num_stages
)
126 isThread0
= tidx
== const(0)
127 with ir
.InsertionPoint(scf
.IfOp(isThread0
).then_block
):
128 for i
in scf
.for_(0, num_stages
, 1):
129 mbar_group_tma
[i
].init(1)
130 mbar_group_mma
[i
].init(1)
136 return mbar_group_tma
, mbar_group_mma
139 def switch_phase(stage
, phase
, num_stages
):
140 p
= stage
== (num_stages
- 1)
141 phase
= arith
.select(
143 (phase ^
const(True, ty
=T
.bool())),
157 phase
= const(True, ty
=T
.bool())
159 for iv
, phase
in scf
.for_(0, (K
// TILE_K
), 1, [phase
]):
160 stage
= iv
% num_stages
161 # Wait MMA to be done
162 mbar_mma
[stage
].try_wait(phase
)
163 # New phase for mbarrier
164 phase
= switch_phase(stage
, phase
, num_stages
)
166 tma_load(mbar_tma
, a_tma
, b_tma
, stage
, iv
, num_stages
, wg_me
.is_wg_primary
)
178 begin_b
= num_stages
* get_type_size(a_tma
.tma_memref
)
180 size_a
= TILE_M
* TILE_K
* get_type_size(T
.f16())
182 phase
= const(False, ty
=T
.bool())
183 A
= WGMMAMatrix(WGMMAType
.Descriptor
, [TILE_M
, TILE_K
], desc
=a_tma
)
184 B
= WGMMAMatrix(WGMMAType
.Descriptor
, [TILE_K
, TILE_N
], desc
=b_tma
)
185 D
= WGMMAMatrix(WGMMAType
.Accumulator
, shape
=[TILE_M
, TILE_N
], ty
=T
.f32())
187 for_op
= scf
.ForOp(const(0), const(K
// TILE_K
), const(1), [D
.acc_op
, phase
])
188 with ir
.InsertionPoint(for_op
.body
):
189 phase
= for_op
.inner_iter_args
[1]
190 iv
= for_op
.induction_variable
191 stage
= iv
% num_stages
193 # Wait TMA for current stage
194 mbar_tma
[stage
].try_wait(phase
)
196 # Find shared memory slot
197 offset_a
= stage
* size_a
198 offset_b
= offset_a
+ begin_b
199 a_smem
= get_dynamic_shared_memory([TILE_M
, TILE_K
], T
.f16(), offset_a
)
200 b_smem
= get_dynamic_shared_memory([TILE_K
, TILE_N
], T
.f16(), offset_b
)
202 # Iterate input matrices, update accumulator
203 A
.update_smem(a_smem
)
204 B
.update_smem(b_smem
)
205 D
.update_accumulator(for_op
.inner_iter_args
[0])
211 p_arrive
= (iv
> 0) & wg_me
.is_wg_primary
212 with ir
.InsertionPoint(scf
.IfOp(p_arrive
).then_block
):
213 barId
= arith
.select((stage
== 0), const(num_stages
- 1), (stage
- 1))
214 mbar_mma
[barId
].arrive()
217 phase
= switch_phase(stage
, phase
, num_stages
)
218 scf
.yield_([D
.acc_op
, phase
])
220 nvvm
.WgmmaWaitGroupSyncOp(0)
221 D
.update_accumulator(for_op
.results
[0])
225 def epilogue(D
: WGMMAMatrix
, d_dev
):
227 Epilogue of the GEMM kernel. It stores the fragmented registers to global memory.
229 MatrixAccumulator D # Fragmented results
230 store D -> Shared Memory # Store Shared Memory
231 Shared Memory -> Z[dimX][dimY] # Store Shared Memory to Global Memory
234 tidx
= gpu
.thread_id(gpu
.Dimension
.x
)
235 dimX
, dimY
= partition_shape()
236 # s = tidx - WARP_GROUP_SIZE
237 # debug_print("[Epilogue] store to global memory @ s={}", s)
239 d_smem
= get_dynamic_shared_memory([TILE_M
, TILE_N
], T
.f32())
240 d_gmem
= memref
.subview(d_dev
, [dimX
, dimY
], [TILE_M
, TILE_N
], [1, 1])
242 # Store (registers -> shared memory)
243 D
.store_accumulator(d_smem
)
246 # Store (shared memory --> global memory)
247 for i
in scf
.for_(0, TILE_M
, 1):
248 val
= memref
.load(d_smem
, [i
, tidx
])
249 memref
.store(val
, d_gmem
, [i
, tidx
])
254 def gemm_warp_specialized(a
, b
, d
, num_stages
):
255 token_ty
= gpu
.AsyncTokenType
.get()
256 t1
= gpu
.wait(token_ty
, [])
257 a_dev
, t2
= gpu
.alloc(a
.type, token_ty
, [t1
], [], [])
258 b_dev
, t3
= gpu
.alloc(b
.type, token_ty
, [t2
], [], [])
259 d_dev
, t4
= gpu
.alloc(d
.type, token_ty
, [t3
], [], [])
260 t5
= gpu
.memcpy(token_ty
, [t4
], a_dev
, a
)
261 t6
= gpu
.memcpy(token_ty
, [t5
], b_dev
, b
)
262 t7
= gpu
.wait(token_ty
, [t6
])
264 sw
= nvgpu
.TensorMapSwizzleKind
.SWIZZLE_128B
265 a_tma
= TMA([128, 64], a
.type, swizzle
=sw
)
266 b_tma
= TMA([64, 64], b
.type, swizzle
=sw
)
267 a_tma
.create_descriptor(a_dev
)
268 b_tma
.create_descriptor(b_dev
)
270 grid
= [(M
// TILE_M
), (N
// TILE_N
), 1]
273 size_a
= get_type_size(a
.type.element_type
) * TILE_M
* TILE_K
274 size_b
= get_type_size(b
.type.element_type
) * TILE_N
* TILE_K
275 smem_size_in_bytes
= (size_a
+ size_b
) * num_stages
277 @NVDSL.mlir_gpu_launch(grid
=grid
, block
=block
, smem
=smem_size_in_bytes
)
278 def gemm_warp_specialized_kernel():
280 wg_producer
= Warpgroup(primary_thread
=128, register_size
=40)
281 wg_consumer
= Warpgroup(primary_thread
=0, register_size
=232)
283 # Initialize mbarriers and prefetch TMA descriptors
284 mbar_mma
, mbar_tma
= initialize(a_tma
, b_tma
, num_stages
)
286 # Producer performs TMA
288 producer_loop(mbar_tma
, mbar_mma
, a_tma
, b_tma
, wg_producer
, num_stages
)
290 # Consumer performs MMA/Tensor Core
292 D
= consumer_loop(mbar_tma
, mbar_mma
, a_tma
, b_tma
, wg_consumer
, num_stages
)
295 gemm_warp_specialized_kernel()
297 t8
= gpu
.memcpy(token_ty
, [t7
], d
, d_dev
)
301 # Python pass arguments to MLIR
308 a
= np
.random
.randn(M
, K
).astype(np
.float16
)
309 b
= np
.random
.randn(K
, N
).astype(np
.float16
)
310 d
= np
.zeros((M
, N
), np
.float32
)
312 gemm_warp_specialized(a
, b
, d
, num_stages
=7)
315 # Verify MLIR with reference computation
316 ref_d
= a
.astype(np
.float16
) @ b
.astype(np
.float16
)
317 np
.testing
.assert_allclose(d
, ref_d
, rtol
=5e-03, atol
=1e-01)
321 # CHECK-NOT: Mismatched elements