archrelease: copy trunk to community-any
[ArchLinux/community.git] / python-pytorch / repos / community-x86_64 / rocblas-batched.patch
blob1eef5458bc443efe7fc92a667aa73dcd1527ef3e
1 diff --git a/caffe2/utils/math_gpu.cu b/caffe2/utils/math_gpu.cu
2 index 2906d0acd9..33610c65f7 100644
3 --- a/caffe2/utils/math_gpu.cu
4 +++ b/caffe2/utils/math_gpu.cu
5 @@ -838,6 +838,24 @@ CAFFE2_CUDA_EXPORT void GemmBatched<at::Half, CUDAContext>(
6 at::Half** C,
7 CUDAContext* context,
8 TensorProto::DataType math_type) {
9 +#if defined(USE_ROCM)
10 + // loop over matrices in the batch
11 + for (int i = 0; i < batch_size; ++i) {
12 + Gemm<at::Half, CUDAContext>(
13 + trans_A,
14 + trans_B,
15 + M,
16 + N,
17 + K,
18 + alpha,
19 + A[i],
20 + B[i],
21 + beta,
22 + C[i],
23 + context,
24 + math_type);
25 + }
26 +#else
27 // Note that cublas follows fortran order, so the order is different from
28 // the cblas convention.
29 const int lda = (trans_A == CblasNoTrans) ? K : M;
30 @@ -912,6 +930,7 @@ CAFFE2_CUDA_EXPORT void GemmBatched<at::Half, CUDAContext>(
31 } else {
32 CAFFE_THROW("Unsupported math type");
34 +#endif