[clang-tidy][NFC]remove deps of clang in clang tidy test (#116588)
[llvm-project.git] / mlir / test / Examples / NVGPU / Ch3.py
blobeb96b11c634165cf4f9d0af4863aba309c127203
1 # RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
2 # RUN: %PYTHON %s | FileCheck %s
4 # ===----------------------------------------------------------------------===//
5 # Chapter 3 : GEMM 128x128x64 with Tensor Core
6 # ===----------------------------------------------------------------------===//
8 # This program demonstrates a GEMM operation with 128x128x64 matrix multiplication
10 # This chapter introduces demonstrates:
11 # 1. Execute TMA Load for two input matrices
12 # 2. Performs Tensor Core GEMM 128x128x64 by warpgroup
13 # 3. Stores fragmented registers to global memory by warpgroup
15 # ===----------------------------------------------------------------------===//
18 from mlir import ir
19 from mlir.dialects import nvgpu, scf, arith, memref, vector, gpu
20 from tools.nvdsl import *
21 from mlir.extras import types as T
22 import numpy as np
25 def tma_load(
26 mbar_group: Mbarriers,
27 a_tma: TMA,
28 b_tma: TMA,
31 """
32 TMA loads two input matrices from global memory to shared memory. It performs the following operations:
34 - tma.load a_shared_memory[0] at coordinate [0, 0] (Loads 128x64)
35 - tma.load b_shared_memory[0] at coordinate [0, 0] (Loads 64x64)
36 - tma.load b_shared_memory[0] at coordinate [64, 0] (Loads 64x64)
38 mbarrier.arrive ta_count = 128x64xf16 + 64x128xf16
39 """
41 size_tma_a = get_type_size(a_tma.tma_memref)
42 size_tma_b = get_type_size(b_tma.tma_memref)
43 ta_count = size_tma_a + (size_tma_b * 2)
45 off_b = size_tma_a
46 off_b2 = off_b + size_tma_b
47 a_elem_ty = a_tma.tma_memref.element_type
48 b_elem_ty = b_tma.tma_memref.element_type
49 a = get_dynamic_shared_memory(a_tma.tma_memref.shape, a_elem_ty)
50 b1 = get_dynamic_shared_memory(b_tma.tma_memref.shape, b_elem_ty, off_b)
51 b2 = get_dynamic_shared_memory(b_tma.tma_memref.shape, b_elem_ty, off_b2)
53 mbar_group[0].arrive(ta_count, predicate=p)
55 a_tma.load(a, mbar_group[0], coords=[0, 0], predicate=p)
56 b_tma.load(b1, mbar_group[0], coords=[0, 0], predicate=p)
57 b_tma.load(b2, mbar_group[0], coords=[64, 0], predicate=p)
60 @NVDSL.mlir_func
61 def gemm_128_128_64(a, b, d):
62 token_ty = gpu.AsyncTokenType.get()
63 t1 = gpu.wait(token_ty, [])
64 a_dev, t2 = gpu.alloc(a.type, token_ty, [t1], [], [])
65 b_dev, t3 = gpu.alloc(b.type, token_ty, [t2], [], [])
66 d_dev, t4 = gpu.alloc(d.type, token_ty, [t3], [], [])
67 t5 = gpu.memcpy(token_ty, [t4], a_dev, a)
68 t6 = gpu.memcpy(token_ty, [t5], b_dev, b)
69 t7 = gpu.wait(token_ty, [t6])
71 sw = nvgpu.TensorMapSwizzleKind.SWIZZLE_128B
72 a_tma = TMA([128, 64], a.type, swizzle=sw)
73 b_tma = TMA([64, 64], b.type, swizzle=sw)
74 a_tma.create_descriptor(a_dev)
75 b_tma.create_descriptor(b_dev)
76 a_size = get_type_size(a.type)
77 b_size = get_type_size(b.type)
78 smem_size_in_bytes = a_size + b_size
80 @NVDSL.mlir_gpu_launch(grid=(1, 1, 1), block=(128, 1, 1), smem=smem_size_in_bytes)
81 def gemm_tma_kernel():
82 tidx = gpu.thread_id(gpu.Dimension.x)
84 mbar_group = Mbarriers(number_of_barriers=1)
85 isThread0 = tidx == 0
87 mbar_group[0].init(1, predicate=isThread0)
88 a_tma.prefetch(predicate=isThread0)
89 b_tma.prefetch(predicate=isThread0)
91 a_smem = get_dynamic_shared_memory((M, K), T.f16())
92 b_smem = get_dynamic_shared_memory((K, N), T.f16(), offset=a_size)
94 # 1. TMA Load for two input matrices
95 tma_load(mbar_group, a_tma, b_tma, isThread0)
97 # 2. All threads wait TMA load completion
98 mbar_group[0].try_wait()
100 # 3. Performs Tensor Core GEMM 128x128x64 by warpgroup
101 A = WGMMAMatrix(WGMMAType.Descriptor, [M, K], desc=a_tma, smem=a_smem)
102 B = WGMMAMatrix(WGMMAType.Descriptor, [K, N], desc=b_tma, smem=b_smem)
103 D = WGMMAMatrix(WGMMAType.Accumulator, shape=[M, N], ty=T.f32())
105 # Matrix Multiply
106 D += A @ B
108 # 4. Stores fragmented registers to global memory by warpgroup
109 D.store_accumulator(d_dev)
111 gemm_tma_kernel()
113 t8 = gpu.memcpy(token_ty, [t7], d, d_dev)
114 gpu.wait(None, [t8])
117 # Python pass arguments to MLIR
118 M = 128
119 N = 128
120 K = 64
121 a = np.random.randn(M, K).astype(np.float16)
122 b = np.random.randn(K, N).astype(np.float16)
123 d = np.zeros((M, N), np.float32)
124 gemm_128_128_64(a, b, d)
126 ref_d = a.astype(np.float16) @ b.astype(np.float16)
127 np.testing.assert_allclose(d, ref_d, rtol=5e-03, atol=1e-01)
128 print("PASS")
129 # CHECK-NOT: Mismatched elements