1 # RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
2 # RUN: %PYTHON %s | FileCheck %s
4 # ===----------------------------------------------------------------------===//
5 # Chapter 3 : GEMM 128x128x64 with Tensor Core
6 # ===----------------------------------------------------------------------===//
8 # This program demonstrates a GEMM operation with 128x128x64 matrix multiplication
10 # This chapter introduces demonstrates:
11 # 1. Execute TMA Load for two input matrices
12 # 2. Performs Tensor Core GEMM 128x128x64 by warpgroup
13 # 3. Stores fragmented registers to global memory by warpgroup
15 # ===----------------------------------------------------------------------===//
19 from mlir
.dialects
import nvgpu
, scf
, arith
, memref
, vector
, gpu
20 from tools
.nvdsl
import *
21 from mlir
.extras
import types
as T
26 mbar_group
: Mbarriers
,
32 TMA loads two input matrices from global memory to shared memory. It performs the following operations:
34 - tma.load a_shared_memory[0] at coordinate [0, 0] (Loads 128x64)
35 - tma.load b_shared_memory[0] at coordinate [0, 0] (Loads 64x64)
36 - tma.load b_shared_memory[0] at coordinate [64, 0] (Loads 64x64)
38 mbarrier.arrive ta_count = 128x64xf16 + 64x128xf16
41 size_tma_a
= get_type_size(a_tma
.tma_memref
)
42 size_tma_b
= get_type_size(b_tma
.tma_memref
)
43 ta_count
= size_tma_a
+ (size_tma_b
* 2)
46 off_b2
= off_b
+ size_tma_b
47 a_elem_ty
= a_tma
.tma_memref
.element_type
48 b_elem_ty
= b_tma
.tma_memref
.element_type
49 a
= get_dynamic_shared_memory(a_tma
.tma_memref
.shape
, a_elem_ty
)
50 b1
= get_dynamic_shared_memory(b_tma
.tma_memref
.shape
, b_elem_ty
, off_b
)
51 b2
= get_dynamic_shared_memory(b_tma
.tma_memref
.shape
, b_elem_ty
, off_b2
)
53 mbar_group
[0].arrive(ta_count
, predicate
=p
)
55 a_tma
.load(a
, mbar_group
[0], coords
=[0, 0], predicate
=p
)
56 b_tma
.load(b1
, mbar_group
[0], coords
=[0, 0], predicate
=p
)
57 b_tma
.load(b2
, mbar_group
[0], coords
=[64, 0], predicate
=p
)
61 def gemm_128_128_64(a
, b
, d
):
62 token_ty
= gpu
.AsyncTokenType
.get()
63 t1
= gpu
.wait(token_ty
, [])
64 a_dev
, t2
= gpu
.alloc(a
.type, token_ty
, [t1
], [], [])
65 b_dev
, t3
= gpu
.alloc(b
.type, token_ty
, [t2
], [], [])
66 d_dev
, t4
= gpu
.alloc(d
.type, token_ty
, [t3
], [], [])
67 t5
= gpu
.memcpy(token_ty
, [t4
], a_dev
, a
)
68 t6
= gpu
.memcpy(token_ty
, [t5
], b_dev
, b
)
69 t7
= gpu
.wait(token_ty
, [t6
])
71 sw
= nvgpu
.TensorMapSwizzleKind
.SWIZZLE_128B
72 a_tma
= TMA([128, 64], a
.type, swizzle
=sw
)
73 b_tma
= TMA([64, 64], b
.type, swizzle
=sw
)
74 a_tma
.create_descriptor(a_dev
)
75 b_tma
.create_descriptor(b_dev
)
76 a_size
= get_type_size(a
.type)
77 b_size
= get_type_size(b
.type)
78 smem_size_in_bytes
= a_size
+ b_size
80 @NVDSL.mlir_gpu_launch(grid
=(1, 1, 1), block
=(128, 1, 1), smem
=smem_size_in_bytes
)
81 def gemm_tma_kernel():
82 tidx
= gpu
.thread_id(gpu
.Dimension
.x
)
84 mbar_group
= Mbarriers(number_of_barriers
=1)
87 mbar_group
[0].init(1, predicate
=isThread0
)
88 a_tma
.prefetch(predicate
=isThread0
)
89 b_tma
.prefetch(predicate
=isThread0
)
91 a_smem
= get_dynamic_shared_memory((M
, K
), T
.f16())
92 b_smem
= get_dynamic_shared_memory((K
, N
), T
.f16(), offset
=a_size
)
94 # 1. TMA Load for two input matrices
95 tma_load(mbar_group
, a_tma
, b_tma
, isThread0
)
97 # 2. All threads wait TMA load completion
98 mbar_group
[0].try_wait()
100 # 3. Performs Tensor Core GEMM 128x128x64 by warpgroup
101 A
= WGMMAMatrix(WGMMAType
.Descriptor
, [M
, K
], desc
=a_tma
, smem
=a_smem
)
102 B
= WGMMAMatrix(WGMMAType
.Descriptor
, [K
, N
], desc
=b_tma
, smem
=b_smem
)
103 D
= WGMMAMatrix(WGMMAType
.Accumulator
, shape
=[M
, N
], ty
=T
.f32())
108 # 4. Stores fragmented registers to global memory by warpgroup
109 D
.store_accumulator(d_dev
)
113 t8
= gpu
.memcpy(token_ty
, [t7
], d
, d_dev
)
117 # Python pass arguments to MLIR
121 a
= np
.random
.randn(M
, K
).astype(np
.float16
)
122 b
= np
.random
.randn(K
, N
).astype(np
.float16
)
123 d
= np
.zeros((M
, N
), np
.float32
)
124 gemm_128_128_64(a
, b
, d
)
126 ref_d
= a
.astype(np
.float16
) @ b
.astype(np
.float16
)
127 np
.testing
.assert_allclose(d
, ref_d
, rtol
=5e-03, atol
=1e-01)
129 # CHECK-NOT: Mismatched elements