[clang-tidy][NFC]remove deps of clang in clang tidy test (#116588)
[llvm-project.git] / mlir / test / Examples / NVGPU / Ch2.py
blob729913c6d5c4f86892c341c05e778cdf83b7af30
1 # RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
2 # RUN: %PYTHON %s | FileCheck %s
4 # ===----------------------------------------------------------------------===//
5 # Chapter 2 : 2D Saxpy with TMA
6 # ===----------------------------------------------------------------------===//
8 # This program demonstrates 2D Saxpy. It is same as Chapter 1,
9 # but it loads data using TMA (Tensor Memory Accelerator)
11 # This chapter introduces demonstrates:
12 # 1. Computes 2D SAXPY in the same way as Ch1.py but loads data using TMA
13 # 2. Create and initialize 1 asynchronous transactional barrier (mbarrier)
14 # 3. Thread-0 Load request data load from TMA for each thread block
15 # 4. Each thread block loads <1x32xf32> for x and y.
16 # 5. Wait for completion of TMA load with mbarrier
18 # ===----------------------------------------------------------------------===//
20 from mlir import ir
21 from mlir.dialects import nvgpu, scf, arith, memref, vector, gpu
22 from tools.nvdsl import *
23 from mlir import runtime as rt
24 from mlir.extras import types as T
25 import numpy as np
28 @NVDSL.mlir_func
29 def saxpy(x, y, alpha):
30 token_ty = gpu.AsyncTokenType.get()
31 t1 = gpu.wait(token_ty, [])
32 x_dev, t2 = gpu.alloc(x.type, token_ty, [t1], [], [])
33 y_dev, t3 = gpu.alloc(y.type, token_ty, [t2], [], [])
34 t4 = gpu.memcpy(token_ty, [t3], x_dev, x)
35 t5 = gpu.memcpy(token_ty, [t4], y_dev, y)
36 t6 = gpu.wait(token_ty, [t5])
38 x_tma = TMA([1, N], x.type)
39 y_tma = TMA([1, N], y.type)
40 x_tma.create_descriptor(x_dev)
41 y_tma.create_descriptor(y_dev)
42 sz_x = get_type_size(x_tma.tma_memref)
43 sz_y = get_type_size(x_tma.tma_memref)
44 sz = sz_x + sz_y
46 @NVDSL.mlir_gpu_launch(grid=(M, 1, 1), block=(N, 1, 1), smem=sz)
47 def saxpy_tma_kernel():
48 bidx = gpu.block_id(gpu.Dimension.x)
49 tidx = gpu.thread_id(gpu.Dimension.x)
50 isThread0 = tidx == 0
52 # 1. Create and initialize asynchronous transactional barrier (mbarrier)
53 mbar_group = Mbarriers(number_of_barriers=1)
54 mbar_group[0].init(1, predicate=isThread0)
56 # 2. Execute Tensor Memory Accelerator (TMA) Load
57 x_smem = get_dynamic_shared_memory([1, N], T.f32())
58 y_smem = get_dynamic_shared_memory([1, N], T.f32(), offset=sz_x)
59 x_tma.load(x_smem, mbar_group[0], coords=[0, bidx], predicate=isThread0)
60 y_tma.load(y_smem, mbar_group[0], coords=[0, bidx], predicate=isThread0)
61 mbar_group[0].arrive(txcount=sz, predicate=isThread0)
63 # 3. Wait for completion of TMA load with mbarrier
64 mbar_group[0].try_wait()
66 x_val = memref.load(x_smem, [const(0), tidx])
67 y_val = memref.load(y_smem, [const(0), tidx])
69 # SAXPY: y[i] += a * x[i];
70 y_val += x_val * alpha
72 memref.store(y_val, y_dev, [bidx, tidx])
74 saxpy_tma_kernel()
76 t7 = gpu.memcpy(token_ty, [t6], y, y_dev)
77 gpu.wait(token_ty, [t7])
80 # 3. Pass numpy arrays to MLIR
81 M = 256
82 N = 32
83 alpha = 2.0
84 x = np.random.randn(M, N).astype(np.float32)
85 y = np.ones((M, N), np.float32)
86 saxpy(x, y, alpha)
88 # 4. Verify MLIR with reference computation
89 ref = np.ones((M, N), np.float32)
90 ref += x * alpha
91 np.testing.assert_allclose(y, ref, rtol=5e-03, atol=1e-01)
92 print("PASS")
93 # CHECK-NOT: Mismatched elements