archrelease: copy trunk to community-any
[ArchLinux/community.git] / rocalution / trunk / test.cpp
blob5431de84aecb77f3d3385a2b1048f13272bbc90a
1 #include <rocalution/rocalution.hpp>
2 #include <vector>
3 #include <iostream>
5 using namespace rocalution;
7 int main()
9 init_rocalution();
10 info_rocalution();
11 size_t n = 128;
14 float *data = new float[3 * n];
15 int *row_ptr = new int[n + 1];
16 int *col = new int[3 * n];
17 row_ptr[0] = 0;
18 int off;
19 for(int i = 0; i < n; i++){
20 off = row_ptr[i];
21 if(i > 0){
22 data[off] = -1.0;
23 col[off++] = i - 1;
25 data[off] = 2.0;
26 col[off++] = i;
27 if(i < n - 1){
28 data[off] = -1.0;
29 col[off++] = i + 1;
31 row_ptr[i + 1] = off;
35 LocalVector<float> x;
36 LocalVector<float> b;
37 LocalVector<float> r;
38 LocalMatrix<float> A;
40 A.SetDataPtrCSR(&row_ptr, &col, &data,
41 "matrix", row_ptr[n], n, n);
42 A.Check();
44 A.MoveToAccelerator();
45 x.MoveToAccelerator();
46 b.MoveToAccelerator();
47 r.MoveToAccelerator();
49 x.Allocate("x", n);
50 b.Allocate("b", n);
51 r.Allocate("r", n);
53 CG<LocalMatrix<float>, LocalVector<float>, float> ls;
55 b.SetRandomUniform(2342359);
56 x.Zeros();
57 r.CopyFrom(b);
59 A.Info();
61 ls.InitTol(1e-6, 5e-4, 1e3);
62 ls.SetOperator(A);
64 ls.Build();
65 ls.Verbose(1);
67 ls.Solve(b, &x);
69 A.Apply(x, &r);
71 r.ScaleAdd(-1.0, b);
73 float nrm = r.Norm();
74 float tol = 0.001f;
75 if(nrm > tol){
76 std::cout << "Solver failed with tolerance " << tol << std::endl;
77 return 1;
80 std::cout << "TESTS PASSED!" << std::endl;
82 stop_rocalution();