1 #include <rocprim/rocprim.hpp>
10 auto xpy
= [] __device__(float x
, float y
) -> float{
15 std::vector
<float> xin(size
);
16 std::vector
<float> yin(size
);
18 std::random_device rd
;
19 std::mt19937
gen(rd());
20 std::uniform_real_distribution
<float> dist(-1.0, 1.0);
22 auto myrand
= [&]() -> float {return dist(gen
);};
24 std::generate(xin
.begin(), xin
.end(), myrand
);
25 std::generate(yin
.begin(), yin
.end(), myrand
);
27 std::vector
<float> zref(size
);
28 for(size_t i
= 0; i
< size
; i
++){
29 zref
[i
] = xin
[i
] + yin
[i
];
35 hipMalloc((void**)&x
, sizeof *x
* size
);
36 hipMalloc((void**)&y
, sizeof *y
* size
);
37 hipMalloc((void**)&z
, sizeof *z
* size
);
39 hipMemcpy(x
, xin
.data(), sizeof *x
* size
, hipMemcpyHostToDevice
);
40 hipMemcpy(y
, yin
.data(), sizeof *y
* size
, hipMemcpyHostToDevice
);
42 rocprim::transform(x
, y
, z
, size
, xpy
);
44 std::vector
<float> zout(size
);
45 hipMemcpy(zout
.data(), z
, sizeof *z
* size
, hipMemcpyDeviceToHost
);
47 for(size_t i
= 0; i
< size
; i
++){
48 if(std::abs(zout
[i
] - zref
[i
]) > 0.001f
){
49 std::cout
<< "Element mismatch at index " << i
<< "\n";
50 std::cout
<< "Got " << zout
[i
] << " but expected " << zref
[i
] << "\n";
54 std::cout
<< "TESTS PASSED!" << std::endl
;