[clang-tidy][NFC]remove deps of clang in clang tidy test (#116588)
[llvm-project.git] / mlir / test / Examples / NVGPU / Ch1.py
blobcfb48d56f8d49499b97ddfe365ffa728afcfa682
1 # RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
2 # RUN: %PYTHON %s | FileCheck %s
4 # ===----------------------------------------------------------------------===//
5 # Chapter 1 : 2D Saxpy
6 # ===----------------------------------------------------------------------===//
8 # This program demonstrates 2D Saxpy:
9 # 1. Use GPU dialect to allocate and copy memory host to gpu and vice versa
10 # 2. Computes 2D SAXPY kernel using operator overloading
11 # 3. Pass numpy arrays to MLIR as memref arguments
12 # 4. Verify MLIR program with reference computation in python
14 # ===----------------------------------------------------------------------===//
17 from mlir import ir
18 from mlir.dialects import gpu, memref
19 from tools.nvdsl import *
20 import numpy as np
23 @NVDSL.mlir_func
24 def saxpy(x, y, alpha):
25 # 1. Use MLIR GPU dialect to allocate and copy memory
26 token_ty = gpu.AsyncTokenType.get()
27 t1 = gpu.wait(token_ty, [])
28 x_dev, t2 = gpu.alloc(x.type, token_ty, [t1], [], [])
29 y_dev, t3 = gpu.alloc(y.type, token_ty, [t2], [], [])
30 t4 = gpu.memcpy(token_ty, [t3], x_dev, x)
31 t5 = gpu.memcpy(token_ty, [t4], y_dev, y)
32 t6 = gpu.wait(token_ty, [t5])
34 # 2. Compute 2D SAXPY kernel
35 @NVDSL.mlir_gpu_launch(grid=(M, 1, 1), block=(N, 1, 1))
36 def saxpy_kernel():
37 bidx = gpu.block_id(gpu.Dimension.x)
38 tidx = gpu.thread_id(gpu.Dimension.x)
39 x_val = memref.load(x_dev, [bidx, tidx])
40 y_val = memref.load(y_dev, [bidx, tidx])
42 # SAXPY: y[i] += a * x[i];
43 y_val += x_val * alpha
45 memref.store(y_val, y_dev, [bidx, tidx])
47 saxpy_kernel()
49 t7 = gpu.memcpy(token_ty, [t6], y, y_dev)
50 gpu.wait(token_ty, [t7])
53 # 3. Pass numpy arrays to MLIR
54 M = 256
55 N = 32
56 alpha = 2.0
57 x = np.random.randn(M, N).astype(np.float32)
58 y = np.ones((M, N), np.float32)
59 saxpy(x, y, alpha)
61 # 4. Verify MLIR with reference computation
62 ref = np.ones((M, N), np.float32)
63 ref += x * alpha
64 np.testing.assert_allclose(y, ref, rtol=5e-03, atol=1e-01)
65 print("PASS")
66 # CHECK-NOT: Mismatched elements