1 # RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
2 # RUN: %PYTHON %s | FileCheck %s
4 # ===----------------------------------------------------------------------===//
5 # Chapter 2 : 2D Saxpy with TMA
6 # ===----------------------------------------------------------------------===//
8 # This program demonstrates 2D Saxpy. It is same as Chapter 1,
9 # but it loads data using TMA (Tensor Memory Accelerator)
11 # This chapter introduces demonstrates:
12 # 1. Computes 2D SAXPY in the same way as Ch1.py but loads data using TMA
13 # 2. Create and initialize 1 asynchronous transactional barrier (mbarrier)
14 # 3. Thread-0 Load request data load from TMA for each thread block
15 # 4. Each thread block loads <1x32xf32> for x and y.
16 # 5. Wait for completion of TMA load with mbarrier
18 # ===----------------------------------------------------------------------===//
21 from mlir
.dialects
import nvgpu
, scf
, arith
, memref
, vector
, gpu
22 from tools
.nvdsl
import *
23 from mlir
import runtime
as rt
24 from mlir
.extras
import types
as T
29 def saxpy(x
, y
, alpha
):
30 token_ty
= gpu
.AsyncTokenType
.get()
31 t1
= gpu
.wait(token_ty
, [])
32 x_dev
, t2
= gpu
.alloc(x
.type, token_ty
, [t1
], [], [])
33 y_dev
, t3
= gpu
.alloc(y
.type, token_ty
, [t2
], [], [])
34 t4
= gpu
.memcpy(token_ty
, [t3
], x_dev
, x
)
35 t5
= gpu
.memcpy(token_ty
, [t4
], y_dev
, y
)
36 t6
= gpu
.wait(token_ty
, [t5
])
38 x_tma
= TMA([1, N
], x
.type)
39 y_tma
= TMA([1, N
], y
.type)
40 x_tma
.create_descriptor(x_dev
)
41 y_tma
.create_descriptor(y_dev
)
42 sz_x
= get_type_size(x_tma
.tma_memref
)
43 sz_y
= get_type_size(x_tma
.tma_memref
)
46 @NVDSL.mlir_gpu_launch(grid
=(M
, 1, 1), block
=(N
, 1, 1), smem
=sz
)
47 def saxpy_tma_kernel():
48 bidx
= gpu
.block_id(gpu
.Dimension
.x
)
49 tidx
= gpu
.thread_id(gpu
.Dimension
.x
)
52 # 1. Create and initialize asynchronous transactional barrier (mbarrier)
53 mbar_group
= Mbarriers(number_of_barriers
=1)
54 mbar_group
[0].init(1, predicate
=isThread0
)
56 # 2. Execute Tensor Memory Accelerator (TMA) Load
57 x_smem
= get_dynamic_shared_memory([1, N
], T
.f32())
58 y_smem
= get_dynamic_shared_memory([1, N
], T
.f32(), offset
=sz_x
)
59 x_tma
.load(x_smem
, mbar_group
[0], coords
=[0, bidx
], predicate
=isThread0
)
60 y_tma
.load(y_smem
, mbar_group
[0], coords
=[0, bidx
], predicate
=isThread0
)
61 mbar_group
[0].arrive(txcount
=sz
, predicate
=isThread0
)
63 # 3. Wait for completion of TMA load with mbarrier
64 mbar_group
[0].try_wait()
66 x_val
= memref
.load(x_smem
, [const(0), tidx
])
67 y_val
= memref
.load(y_smem
, [const(0), tidx
])
69 # SAXPY: y[i] += a * x[i];
70 y_val
+= x_val
* alpha
72 memref
.store(y_val
, y_dev
, [bidx
, tidx
])
76 t7
= gpu
.memcpy(token_ty
, [t6
], y
, y_dev
)
77 gpu
.wait(token_ty
, [t7
])
80 # 3. Pass numpy arrays to MLIR
84 x
= np
.random
.randn(M
, N
).astype(np
.float32
)
85 y
= np
.ones((M
, N
), np
.float32
)
88 # 4. Verify MLIR with reference computation
89 ref
= np
.ones((M
, N
), np
.float32
)
91 np
.testing
.assert_allclose(y
, ref
, rtol
=5e-03, atol
=1e-01)
93 # CHECK-NOT: Mismatched elements