1 // This file is part of Eigen, a lightweight C++ template library
4 // Copyright (C) 2009-2010 Gael Guennebaud <gael.guennebaud@inria.fr>
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
12 int EIGEN_BLAS_FUNC(gemm
)(char *opa
, char *opb
, int *m
, int *n
, int *k
, RealScalar
*palpha
, RealScalar
*pa
, int *lda
, RealScalar
*pb
, int *ldb
, RealScalar
*pbeta
, RealScalar
*pc
, int *ldc
)
14 // std::cerr << "in gemm " << *opa << " " << *opb << " " << *m << " " << *n << " " << *k << " " << *lda << " " << *ldb << " " << *ldc << " " << *palpha << " " << *pbeta << "\n";
15 typedef void (*functype
)(DenseIndex
, DenseIndex
, DenseIndex
, const Scalar
*, DenseIndex
, const Scalar
*, DenseIndex
, Scalar
*, DenseIndex
, Scalar
, internal::level3_blocking
<Scalar
,Scalar
>&, Eigen::internal::GemmParallelInfo
<DenseIndex
>*);
16 static functype func
[12];
18 static bool init
= false;
21 for(int k
=0; k
<12; ++k
)
23 func
[NOTR
| (NOTR
<< 2)] = (internal::general_matrix_matrix_product
<DenseIndex
,Scalar
,ColMajor
,false,Scalar
,ColMajor
,false,ColMajor
>::run
);
24 func
[TR
| (NOTR
<< 2)] = (internal::general_matrix_matrix_product
<DenseIndex
,Scalar
,RowMajor
,false,Scalar
,ColMajor
,false,ColMajor
>::run
);
25 func
[ADJ
| (NOTR
<< 2)] = (internal::general_matrix_matrix_product
<DenseIndex
,Scalar
,RowMajor
,Conj
, Scalar
,ColMajor
,false,ColMajor
>::run
);
26 func
[NOTR
| (TR
<< 2)] = (internal::general_matrix_matrix_product
<DenseIndex
,Scalar
,ColMajor
,false,Scalar
,RowMajor
,false,ColMajor
>::run
);
27 func
[TR
| (TR
<< 2)] = (internal::general_matrix_matrix_product
<DenseIndex
,Scalar
,RowMajor
,false,Scalar
,RowMajor
,false,ColMajor
>::run
);
28 func
[ADJ
| (TR
<< 2)] = (internal::general_matrix_matrix_product
<DenseIndex
,Scalar
,RowMajor
,Conj
, Scalar
,RowMajor
,false,ColMajor
>::run
);
29 func
[NOTR
| (ADJ
<< 2)] = (internal::general_matrix_matrix_product
<DenseIndex
,Scalar
,ColMajor
,false,Scalar
,RowMajor
,Conj
, ColMajor
>::run
);
30 func
[TR
| (ADJ
<< 2)] = (internal::general_matrix_matrix_product
<DenseIndex
,Scalar
,RowMajor
,false,Scalar
,RowMajor
,Conj
, ColMajor
>::run
);
31 func
[ADJ
| (ADJ
<< 2)] = (internal::general_matrix_matrix_product
<DenseIndex
,Scalar
,RowMajor
,Conj
, Scalar
,RowMajor
,Conj
, ColMajor
>::run
);
35 Scalar
* a
= reinterpret_cast<Scalar
*>(pa
);
36 Scalar
* b
= reinterpret_cast<Scalar
*>(pb
);
37 Scalar
* c
= reinterpret_cast<Scalar
*>(pc
);
38 Scalar alpha
= *reinterpret_cast<Scalar
*>(palpha
);
39 Scalar beta
= *reinterpret_cast<Scalar
*>(pbeta
);
42 if(OP(*opa
)==INVALID
) info
= 1;
43 else if(OP(*opb
)==INVALID
) info
= 2;
44 else if(*m
<0) info
= 3;
45 else if(*n
<0) info
= 4;
46 else if(*k
<0) info
= 5;
47 else if(*lda
<std::max(1,(OP(*opa
)==NOTR
)?*m
:*k
)) info
= 8;
48 else if(*ldb
<std::max(1,(OP(*opb
)==NOTR
)?*k
:*n
)) info
= 10;
49 else if(*ldc
<std::max(1,*m
)) info
= 13;
51 return xerbla_(SCALAR_SUFFIX_UP
"GEMM ",&info
,6);
55 if(beta
==Scalar(0)) matrix(c
, *m
, *n
, *ldc
).setZero();
56 else matrix(c
, *m
, *n
, *ldc
) *= beta
;
59 internal::gemm_blocking_space
<ColMajor
,Scalar
,Scalar
,Dynamic
,Dynamic
,Dynamic
> blocking(*m
,*n
,*k
);
61 int code
= OP(*opa
) | (OP(*opb
) << 2);
62 func
[code
](*m
, *n
, *k
, a
, *lda
, b
, *ldb
, c
, *ldc
, alpha
, blocking
, 0);
66 int EIGEN_BLAS_FUNC(trsm
)(char *side
, char *uplo
, char *opa
, char *diag
, int *m
, int *n
, RealScalar
*palpha
, RealScalar
*pa
, int *lda
, RealScalar
*pb
, int *ldb
)
68 // std::cerr << "in trsm " << *side << " " << *uplo << " " << *opa << " " << *diag << " " << *m << "," << *n << " " << *palpha << " " << *lda << " " << *ldb<< "\n";
69 typedef void (*functype
)(DenseIndex
, DenseIndex
, const Scalar
*, DenseIndex
, Scalar
*, DenseIndex
, internal::level3_blocking
<Scalar
,Scalar
>&);
70 static functype func
[32];
72 static bool init
= false;
75 for(int k
=0; k
<32; ++k
)
78 func
[NOTR
| (LEFT
<< 2) | (UP
<< 3) | (NUNIT
<< 4)] = (internal::triangular_solve_matrix
<Scalar
,DenseIndex
,OnTheLeft
, Upper
|0, false,ColMajor
,ColMajor
>::run
);
79 func
[TR
| (LEFT
<< 2) | (UP
<< 3) | (NUNIT
<< 4)] = (internal::triangular_solve_matrix
<Scalar
,DenseIndex
,OnTheLeft
, Lower
|0, false,RowMajor
,ColMajor
>::run
);
80 func
[ADJ
| (LEFT
<< 2) | (UP
<< 3) | (NUNIT
<< 4)] = (internal::triangular_solve_matrix
<Scalar
,DenseIndex
,OnTheLeft
, Lower
|0, Conj
, RowMajor
,ColMajor
>::run
);
82 func
[NOTR
| (RIGHT
<< 2) | (UP
<< 3) | (NUNIT
<< 4)] = (internal::triangular_solve_matrix
<Scalar
,DenseIndex
,OnTheRight
,Upper
|0, false,ColMajor
,ColMajor
>::run
);
83 func
[TR
| (RIGHT
<< 2) | (UP
<< 3) | (NUNIT
<< 4)] = (internal::triangular_solve_matrix
<Scalar
,DenseIndex
,OnTheRight
,Lower
|0, false,RowMajor
,ColMajor
>::run
);
84 func
[ADJ
| (RIGHT
<< 2) | (UP
<< 3) | (NUNIT
<< 4)] = (internal::triangular_solve_matrix
<Scalar
,DenseIndex
,OnTheRight
,Lower
|0, Conj
, RowMajor
,ColMajor
>::run
);
86 func
[NOTR
| (LEFT
<< 2) | (LO
<< 3) | (NUNIT
<< 4)] = (internal::triangular_solve_matrix
<Scalar
,DenseIndex
,OnTheLeft
, Lower
|0, false,ColMajor
,ColMajor
>::run
);
87 func
[TR
| (LEFT
<< 2) | (LO
<< 3) | (NUNIT
<< 4)] = (internal::triangular_solve_matrix
<Scalar
,DenseIndex
,OnTheLeft
, Upper
|0, false,RowMajor
,ColMajor
>::run
);
88 func
[ADJ
| (LEFT
<< 2) | (LO
<< 3) | (NUNIT
<< 4)] = (internal::triangular_solve_matrix
<Scalar
,DenseIndex
,OnTheLeft
, Upper
|0, Conj
, RowMajor
,ColMajor
>::run
);
90 func
[NOTR
| (RIGHT
<< 2) | (LO
<< 3) | (NUNIT
<< 4)] = (internal::triangular_solve_matrix
<Scalar
,DenseIndex
,OnTheRight
,Lower
|0, false,ColMajor
,ColMajor
>::run
);
91 func
[TR
| (RIGHT
<< 2) | (LO
<< 3) | (NUNIT
<< 4)] = (internal::triangular_solve_matrix
<Scalar
,DenseIndex
,OnTheRight
,Upper
|0, false,RowMajor
,ColMajor
>::run
);
92 func
[ADJ
| (RIGHT
<< 2) | (LO
<< 3) | (NUNIT
<< 4)] = (internal::triangular_solve_matrix
<Scalar
,DenseIndex
,OnTheRight
,Upper
|0, Conj
, RowMajor
,ColMajor
>::run
);
95 func
[NOTR
| (LEFT
<< 2) | (UP
<< 3) | (UNIT
<< 4)] = (internal::triangular_solve_matrix
<Scalar
,DenseIndex
,OnTheLeft
, Upper
|UnitDiag
,false,ColMajor
,ColMajor
>::run
);
96 func
[TR
| (LEFT
<< 2) | (UP
<< 3) | (UNIT
<< 4)] = (internal::triangular_solve_matrix
<Scalar
,DenseIndex
,OnTheLeft
, Lower
|UnitDiag
,false,RowMajor
,ColMajor
>::run
);
97 func
[ADJ
| (LEFT
<< 2) | (UP
<< 3) | (UNIT
<< 4)] = (internal::triangular_solve_matrix
<Scalar
,DenseIndex
,OnTheLeft
, Lower
|UnitDiag
,Conj
, RowMajor
,ColMajor
>::run
);
99 func
[NOTR
| (RIGHT
<< 2) | (UP
<< 3) | (UNIT
<< 4)] = (internal::triangular_solve_matrix
<Scalar
,DenseIndex
,OnTheRight
,Upper
|UnitDiag
,false,ColMajor
,ColMajor
>::run
);
100 func
[TR
| (RIGHT
<< 2) | (UP
<< 3) | (UNIT
<< 4)] = (internal::triangular_solve_matrix
<Scalar
,DenseIndex
,OnTheRight
,Lower
|UnitDiag
,false,RowMajor
,ColMajor
>::run
);
101 func
[ADJ
| (RIGHT
<< 2) | (UP
<< 3) | (UNIT
<< 4)] = (internal::triangular_solve_matrix
<Scalar
,DenseIndex
,OnTheRight
,Lower
|UnitDiag
,Conj
, RowMajor
,ColMajor
>::run
);
103 func
[NOTR
| (LEFT
<< 2) | (LO
<< 3) | (UNIT
<< 4)] = (internal::triangular_solve_matrix
<Scalar
,DenseIndex
,OnTheLeft
, Lower
|UnitDiag
,false,ColMajor
,ColMajor
>::run
);
104 func
[TR
| (LEFT
<< 2) | (LO
<< 3) | (UNIT
<< 4)] = (internal::triangular_solve_matrix
<Scalar
,DenseIndex
,OnTheLeft
, Upper
|UnitDiag
,false,RowMajor
,ColMajor
>::run
);
105 func
[ADJ
| (LEFT
<< 2) | (LO
<< 3) | (UNIT
<< 4)] = (internal::triangular_solve_matrix
<Scalar
,DenseIndex
,OnTheLeft
, Upper
|UnitDiag
,Conj
, RowMajor
,ColMajor
>::run
);
107 func
[NOTR
| (RIGHT
<< 2) | (LO
<< 3) | (UNIT
<< 4)] = (internal::triangular_solve_matrix
<Scalar
,DenseIndex
,OnTheRight
,Lower
|UnitDiag
,false,ColMajor
,ColMajor
>::run
);
108 func
[TR
| (RIGHT
<< 2) | (LO
<< 3) | (UNIT
<< 4)] = (internal::triangular_solve_matrix
<Scalar
,DenseIndex
,OnTheRight
,Upper
|UnitDiag
,false,RowMajor
,ColMajor
>::run
);
109 func
[ADJ
| (RIGHT
<< 2) | (LO
<< 3) | (UNIT
<< 4)] = (internal::triangular_solve_matrix
<Scalar
,DenseIndex
,OnTheRight
,Upper
|UnitDiag
,Conj
, RowMajor
,ColMajor
>::run
);
114 Scalar
* a
= reinterpret_cast<Scalar
*>(pa
);
115 Scalar
* b
= reinterpret_cast<Scalar
*>(pb
);
116 Scalar alpha
= *reinterpret_cast<Scalar
*>(palpha
);
119 if(SIDE(*side
)==INVALID
) info
= 1;
120 else if(UPLO(*uplo
)==INVALID
) info
= 2;
121 else if(OP(*opa
)==INVALID
) info
= 3;
122 else if(DIAG(*diag
)==INVALID
) info
= 4;
123 else if(*m
<0) info
= 5;
124 else if(*n
<0) info
= 6;
125 else if(*lda
<std::max(1,(SIDE(*side
)==LEFT
)?*m
:*n
)) info
= 9;
126 else if(*ldb
<std::max(1,*m
)) info
= 11;
128 return xerbla_(SCALAR_SUFFIX_UP
"TRSM ",&info
,6);
130 int code
= OP(*opa
) | (SIDE(*side
) << 2) | (UPLO(*uplo
) << 3) | (DIAG(*diag
) << 4);
132 if(SIDE(*side
)==LEFT
)
134 internal::gemm_blocking_space
<ColMajor
,Scalar
,Scalar
,Dynamic
,Dynamic
,Dynamic
,4> blocking(*m
,*n
,*m
);
135 func
[code
](*m
, *n
, a
, *lda
, b
, *ldb
, blocking
);
139 internal::gemm_blocking_space
<ColMajor
,Scalar
,Scalar
,Dynamic
,Dynamic
,Dynamic
,4> blocking(*m
,*n
,*n
);
140 func
[code
](*n
, *m
, a
, *lda
, b
, *ldb
, blocking
);
144 matrix(b
,*m
,*n
,*ldb
) *= alpha
;
150 // b = alpha*op(a)*b for side = 'L'or'l'
151 // b = alpha*b*op(a) for side = 'R'or'r'
152 int EIGEN_BLAS_FUNC(trmm
)(char *side
, char *uplo
, char *opa
, char *diag
, int *m
, int *n
, RealScalar
*palpha
, RealScalar
*pa
, int *lda
, RealScalar
*pb
, int *ldb
)
154 // std::cerr << "in trmm " << *side << " " << *uplo << " " << *opa << " " << *diag << " " << *m << " " << *n << " " << *lda << " " << *ldb << " " << *palpha << "\n";
155 typedef void (*functype
)(DenseIndex
, DenseIndex
, DenseIndex
, const Scalar
*, DenseIndex
, const Scalar
*, DenseIndex
, Scalar
*, DenseIndex
, const Scalar
&, internal::level3_blocking
<Scalar
,Scalar
>&);
156 static functype func
[32];
157 static bool init
= false;
160 for(int k
=0; k
<32; ++k
)
163 func
[NOTR
| (LEFT
<< 2) | (UP
<< 3) | (NUNIT
<< 4)] = (internal::product_triangular_matrix_matrix
<Scalar
,DenseIndex
,Upper
|0, true, ColMajor
,false,ColMajor
,false,ColMajor
>::run
);
164 func
[TR
| (LEFT
<< 2) | (UP
<< 3) | (NUNIT
<< 4)] = (internal::product_triangular_matrix_matrix
<Scalar
,DenseIndex
,Lower
|0, true, RowMajor
,false,ColMajor
,false,ColMajor
>::run
);
165 func
[ADJ
| (LEFT
<< 2) | (UP
<< 3) | (NUNIT
<< 4)] = (internal::product_triangular_matrix_matrix
<Scalar
,DenseIndex
,Lower
|0, true, RowMajor
,Conj
, ColMajor
,false,ColMajor
>::run
);
167 func
[NOTR
| (RIGHT
<< 2) | (UP
<< 3) | (NUNIT
<< 4)] = (internal::product_triangular_matrix_matrix
<Scalar
,DenseIndex
,Upper
|0, false,ColMajor
,false,ColMajor
,false,ColMajor
>::run
);
168 func
[TR
| (RIGHT
<< 2) | (UP
<< 3) | (NUNIT
<< 4)] = (internal::product_triangular_matrix_matrix
<Scalar
,DenseIndex
,Lower
|0, false,ColMajor
,false,RowMajor
,false,ColMajor
>::run
);
169 func
[ADJ
| (RIGHT
<< 2) | (UP
<< 3) | (NUNIT
<< 4)] = (internal::product_triangular_matrix_matrix
<Scalar
,DenseIndex
,Lower
|0, false,ColMajor
,false,RowMajor
,Conj
, ColMajor
>::run
);
171 func
[NOTR
| (LEFT
<< 2) | (LO
<< 3) | (NUNIT
<< 4)] = (internal::product_triangular_matrix_matrix
<Scalar
,DenseIndex
,Lower
|0, true, ColMajor
,false,ColMajor
,false,ColMajor
>::run
);
172 func
[TR
| (LEFT
<< 2) | (LO
<< 3) | (NUNIT
<< 4)] = (internal::product_triangular_matrix_matrix
<Scalar
,DenseIndex
,Upper
|0, true, RowMajor
,false,ColMajor
,false,ColMajor
>::run
);
173 func
[ADJ
| (LEFT
<< 2) | (LO
<< 3) | (NUNIT
<< 4)] = (internal::product_triangular_matrix_matrix
<Scalar
,DenseIndex
,Upper
|0, true, RowMajor
,Conj
, ColMajor
,false,ColMajor
>::run
);
175 func
[NOTR
| (RIGHT
<< 2) | (LO
<< 3) | (NUNIT
<< 4)] = (internal::product_triangular_matrix_matrix
<Scalar
,DenseIndex
,Lower
|0, false,ColMajor
,false,ColMajor
,false,ColMajor
>::run
);
176 func
[TR
| (RIGHT
<< 2) | (LO
<< 3) | (NUNIT
<< 4)] = (internal::product_triangular_matrix_matrix
<Scalar
,DenseIndex
,Upper
|0, false,ColMajor
,false,RowMajor
,false,ColMajor
>::run
);
177 func
[ADJ
| (RIGHT
<< 2) | (LO
<< 3) | (NUNIT
<< 4)] = (internal::product_triangular_matrix_matrix
<Scalar
,DenseIndex
,Upper
|0, false,ColMajor
,false,RowMajor
,Conj
, ColMajor
>::run
);
179 func
[NOTR
| (LEFT
<< 2) | (UP
<< 3) | (UNIT
<< 4)] = (internal::product_triangular_matrix_matrix
<Scalar
,DenseIndex
,Upper
|UnitDiag
,true, ColMajor
,false,ColMajor
,false,ColMajor
>::run
);
180 func
[TR
| (LEFT
<< 2) | (UP
<< 3) | (UNIT
<< 4)] = (internal::product_triangular_matrix_matrix
<Scalar
,DenseIndex
,Lower
|UnitDiag
,true, RowMajor
,false,ColMajor
,false,ColMajor
>::run
);
181 func
[ADJ
| (LEFT
<< 2) | (UP
<< 3) | (UNIT
<< 4)] = (internal::product_triangular_matrix_matrix
<Scalar
,DenseIndex
,Lower
|UnitDiag
,true, RowMajor
,Conj
, ColMajor
,false,ColMajor
>::run
);
183 func
[NOTR
| (RIGHT
<< 2) | (UP
<< 3) | (UNIT
<< 4)] = (internal::product_triangular_matrix_matrix
<Scalar
,DenseIndex
,Upper
|UnitDiag
,false,ColMajor
,false,ColMajor
,false,ColMajor
>::run
);
184 func
[TR
| (RIGHT
<< 2) | (UP
<< 3) | (UNIT
<< 4)] = (internal::product_triangular_matrix_matrix
<Scalar
,DenseIndex
,Lower
|UnitDiag
,false,ColMajor
,false,RowMajor
,false,ColMajor
>::run
);
185 func
[ADJ
| (RIGHT
<< 2) | (UP
<< 3) | (UNIT
<< 4)] = (internal::product_triangular_matrix_matrix
<Scalar
,DenseIndex
,Lower
|UnitDiag
,false,ColMajor
,false,RowMajor
,Conj
, ColMajor
>::run
);
187 func
[NOTR
| (LEFT
<< 2) | (LO
<< 3) | (UNIT
<< 4)] = (internal::product_triangular_matrix_matrix
<Scalar
,DenseIndex
,Lower
|UnitDiag
,true, ColMajor
,false,ColMajor
,false,ColMajor
>::run
);
188 func
[TR
| (LEFT
<< 2) | (LO
<< 3) | (UNIT
<< 4)] = (internal::product_triangular_matrix_matrix
<Scalar
,DenseIndex
,Upper
|UnitDiag
,true, RowMajor
,false,ColMajor
,false,ColMajor
>::run
);
189 func
[ADJ
| (LEFT
<< 2) | (LO
<< 3) | (UNIT
<< 4)] = (internal::product_triangular_matrix_matrix
<Scalar
,DenseIndex
,Upper
|UnitDiag
,true, RowMajor
,Conj
, ColMajor
,false,ColMajor
>::run
);
191 func
[NOTR
| (RIGHT
<< 2) | (LO
<< 3) | (UNIT
<< 4)] = (internal::product_triangular_matrix_matrix
<Scalar
,DenseIndex
,Lower
|UnitDiag
,false,ColMajor
,false,ColMajor
,false,ColMajor
>::run
);
192 func
[TR
| (RIGHT
<< 2) | (LO
<< 3) | (UNIT
<< 4)] = (internal::product_triangular_matrix_matrix
<Scalar
,DenseIndex
,Upper
|UnitDiag
,false,ColMajor
,false,RowMajor
,false,ColMajor
>::run
);
193 func
[ADJ
| (RIGHT
<< 2) | (LO
<< 3) | (UNIT
<< 4)] = (internal::product_triangular_matrix_matrix
<Scalar
,DenseIndex
,Upper
|UnitDiag
,false,ColMajor
,false,RowMajor
,Conj
, ColMajor
>::run
);
198 Scalar
* a
= reinterpret_cast<Scalar
*>(pa
);
199 Scalar
* b
= reinterpret_cast<Scalar
*>(pb
);
200 Scalar alpha
= *reinterpret_cast<Scalar
*>(palpha
);
203 if(SIDE(*side
)==INVALID
) info
= 1;
204 else if(UPLO(*uplo
)==INVALID
) info
= 2;
205 else if(OP(*opa
)==INVALID
) info
= 3;
206 else if(DIAG(*diag
)==INVALID
) info
= 4;
207 else if(*m
<0) info
= 5;
208 else if(*n
<0) info
= 6;
209 else if(*lda
<std::max(1,(SIDE(*side
)==LEFT
)?*m
:*n
)) info
= 9;
210 else if(*ldb
<std::max(1,*m
)) info
= 11;
212 return xerbla_(SCALAR_SUFFIX_UP
"TRMM ",&info
,6);
214 int code
= OP(*opa
) | (SIDE(*side
) << 2) | (UPLO(*uplo
) << 3) | (DIAG(*diag
) << 4);
219 // FIXME find a way to avoid this copy
220 Matrix
<Scalar
,Dynamic
,Dynamic
,ColMajor
> tmp
= matrix(b
,*m
,*n
,*ldb
);
221 matrix(b
,*m
,*n
,*ldb
).setZero();
223 if(SIDE(*side
)==LEFT
)
225 internal::gemm_blocking_space
<ColMajor
,Scalar
,Scalar
,Dynamic
,Dynamic
,Dynamic
,4> blocking(*m
,*n
,*m
);
226 func
[code
](*m
, *n
, *m
, a
, *lda
, tmp
.data(), tmp
.outerStride(), b
, *ldb
, alpha
, blocking
);
230 internal::gemm_blocking_space
<ColMajor
,Scalar
,Scalar
,Dynamic
,Dynamic
,Dynamic
,4> blocking(*m
,*n
,*n
);
231 func
[code
](*m
, *n
, *n
, tmp
.data(), tmp
.outerStride(), a
, *lda
, b
, *ldb
, alpha
, blocking
);
236 // c = alpha*a*b + beta*c for side = 'L'or'l'
237 // c = alpha*b*a + beta*c for side = 'R'or'r
238 int EIGEN_BLAS_FUNC(symm
)(char *side
, char *uplo
, int *m
, int *n
, RealScalar
*palpha
, RealScalar
*pa
, int *lda
, RealScalar
*pb
, int *ldb
, RealScalar
*pbeta
, RealScalar
*pc
, int *ldc
)
240 // std::cerr << "in symm " << *side << " " << *uplo << " " << *m << "x" << *n << " lda:" << *lda << " ldb:" << *ldb << " ldc:" << *ldc << " alpha:" << *palpha << " beta:" << *pbeta << "\n";
241 Scalar
* a
= reinterpret_cast<Scalar
*>(pa
);
242 Scalar
* b
= reinterpret_cast<Scalar
*>(pb
);
243 Scalar
* c
= reinterpret_cast<Scalar
*>(pc
);
244 Scalar alpha
= *reinterpret_cast<Scalar
*>(palpha
);
245 Scalar beta
= *reinterpret_cast<Scalar
*>(pbeta
);
248 if(SIDE(*side
)==INVALID
) info
= 1;
249 else if(UPLO(*uplo
)==INVALID
) info
= 2;
250 else if(*m
<0) info
= 3;
251 else if(*n
<0) info
= 4;
252 else if(*lda
<std::max(1,(SIDE(*side
)==LEFT
)?*m
:*n
)) info
= 7;
253 else if(*ldb
<std::max(1,*m
)) info
= 9;
254 else if(*ldc
<std::max(1,*m
)) info
= 12;
256 return xerbla_(SCALAR_SUFFIX_UP
"SYMM ",&info
,6);
260 if(beta
==Scalar(0)) matrix(c
, *m
, *n
, *ldc
).setZero();
261 else matrix(c
, *m
, *n
, *ldc
) *= beta
;
270 // FIXME add support for symmetric complex matrix
271 int size
= (SIDE(*side
)==LEFT
) ? (*m
) : (*n
);
272 Matrix
<Scalar
,Dynamic
,Dynamic
,ColMajor
> matA(size
,size
);
275 matA
.triangularView
<Upper
>() = matrix(a
,size
,size
,*lda
);
276 matA
.triangularView
<Lower
>() = matrix(a
,size
,size
,*lda
).transpose();
278 else if(UPLO(*uplo
)==LO
)
280 matA
.triangularView
<Lower
>() = matrix(a
,size
,size
,*lda
);
281 matA
.triangularView
<Upper
>() = matrix(a
,size
,size
,*lda
).transpose();
283 if(SIDE(*side
)==LEFT
)
284 matrix(c
, *m
, *n
, *ldc
) += alpha
* matA
* matrix(b
, *m
, *n
, *ldb
);
285 else if(SIDE(*side
)==RIGHT
)
286 matrix(c
, *m
, *n
, *ldc
) += alpha
* matrix(b
, *m
, *n
, *ldb
) * matA
;
288 if(SIDE(*side
)==LEFT
)
289 if(UPLO(*uplo
)==UP
) internal::product_selfadjoint_matrix
<Scalar
, DenseIndex
, RowMajor
,true,false, ColMajor
,false,false, ColMajor
>::run(*m
, *n
, a
, *lda
, b
, *ldb
, c
, *ldc
, alpha
);
290 else if(UPLO(*uplo
)==LO
) internal::product_selfadjoint_matrix
<Scalar
, DenseIndex
, ColMajor
,true,false, ColMajor
,false,false, ColMajor
>::run(*m
, *n
, a
, *lda
, b
, *ldb
, c
, *ldc
, alpha
);
292 else if(SIDE(*side
)==RIGHT
)
293 if(UPLO(*uplo
)==UP
) internal::product_selfadjoint_matrix
<Scalar
, DenseIndex
, ColMajor
,false,false, RowMajor
,true,false, ColMajor
>::run(*m
, *n
, b
, *ldb
, a
, *lda
, c
, *ldc
, alpha
);
294 else if(UPLO(*uplo
)==LO
) internal::product_selfadjoint_matrix
<Scalar
, DenseIndex
, ColMajor
,false,false, ColMajor
,true,false, ColMajor
>::run(*m
, *n
, b
, *ldb
, a
, *lda
, c
, *ldc
, alpha
);
303 // c = alpha*a*a' + beta*c for op = 'N'or'n'
304 // c = alpha*a'*a + beta*c for op = 'T'or't','C'or'c'
305 int EIGEN_BLAS_FUNC(syrk
)(char *uplo
, char *op
, int *n
, int *k
, RealScalar
*palpha
, RealScalar
*pa
, int *lda
, RealScalar
*pbeta
, RealScalar
*pc
, int *ldc
)
307 // std::cerr << "in syrk " << *uplo << " " << *op << " " << *n << " " << *k << " " << *palpha << " " << *lda << " " << *pbeta << " " << *ldc << "\n";
309 typedef void (*functype
)(DenseIndex
, DenseIndex
, const Scalar
*, DenseIndex
, const Scalar
*, DenseIndex
, Scalar
*, DenseIndex
, const Scalar
&);
310 static functype func
[8];
312 static bool init
= false;
315 for(int k
=0; k
<8; ++k
)
318 func
[NOTR
| (UP
<< 2)] = (internal::general_matrix_matrix_triangular_product
<DenseIndex
,Scalar
,ColMajor
,false,Scalar
,RowMajor
,ColMajor
,Conj
, Upper
>::run
);
319 func
[TR
| (UP
<< 2)] = (internal::general_matrix_matrix_triangular_product
<DenseIndex
,Scalar
,RowMajor
,false,Scalar
,ColMajor
,ColMajor
,Conj
, Upper
>::run
);
320 func
[ADJ
| (UP
<< 2)] = (internal::general_matrix_matrix_triangular_product
<DenseIndex
,Scalar
,RowMajor
,Conj
, Scalar
,ColMajor
,ColMajor
,false,Upper
>::run
);
322 func
[NOTR
| (LO
<< 2)] = (internal::general_matrix_matrix_triangular_product
<DenseIndex
,Scalar
,ColMajor
,false,Scalar
,RowMajor
,ColMajor
,Conj
, Lower
>::run
);
323 func
[TR
| (LO
<< 2)] = (internal::general_matrix_matrix_triangular_product
<DenseIndex
,Scalar
,RowMajor
,false,Scalar
,ColMajor
,ColMajor
,Conj
, Lower
>::run
);
324 func
[ADJ
| (LO
<< 2)] = (internal::general_matrix_matrix_triangular_product
<DenseIndex
,Scalar
,RowMajor
,Conj
, Scalar
,ColMajor
,ColMajor
,false,Lower
>::run
);
330 Scalar
* a
= reinterpret_cast<Scalar
*>(pa
);
331 Scalar
* c
= reinterpret_cast<Scalar
*>(pc
);
332 Scalar alpha
= *reinterpret_cast<Scalar
*>(palpha
);
333 Scalar beta
= *reinterpret_cast<Scalar
*>(pbeta
);
336 if(UPLO(*uplo
)==INVALID
) info
= 1;
337 else if(OP(*op
)==INVALID
) info
= 2;
338 else if(*n
<0) info
= 3;
339 else if(*k
<0) info
= 4;
340 else if(*lda
<std::max(1,(OP(*op
)==NOTR
)?*n
:*k
)) info
= 7;
341 else if(*ldc
<std::max(1,*n
)) info
= 10;
343 return xerbla_(SCALAR_SUFFIX_UP
"SYRK ",&info
,6);
348 if(beta
==Scalar(0)) matrix(c
, *n
, *n
, *ldc
).triangularView
<Upper
>().setZero();
349 else matrix(c
, *n
, *n
, *ldc
).triangularView
<Upper
>() *= beta
;
351 if(beta
==Scalar(0)) matrix(c
, *n
, *n
, *ldc
).triangularView
<Lower
>().setZero();
352 else matrix(c
, *n
, *n
, *ldc
).triangularView
<Lower
>() *= beta
;
356 // FIXME add support for symmetric complex matrix
360 matrix(c
, *n
, *n
, *ldc
).triangularView
<Upper
>() += alpha
* matrix(a
,*n
,*k
,*lda
) * matrix(a
,*n
,*k
,*lda
).transpose();
362 matrix(c
, *n
, *n
, *ldc
).triangularView
<Upper
>() += alpha
* matrix(a
,*k
,*n
,*lda
).transpose() * matrix(a
,*k
,*n
,*lda
);
367 matrix(c
, *n
, *n
, *ldc
).triangularView
<Lower
>() += alpha
* matrix(a
,*n
,*k
,*lda
) * matrix(a
,*n
,*k
,*lda
).transpose();
369 matrix(c
, *n
, *n
, *ldc
).triangularView
<Lower
>() += alpha
* matrix(a
,*k
,*n
,*lda
).transpose() * matrix(a
,*k
,*n
,*lda
);
372 int code
= OP(*op
) | (UPLO(*uplo
) << 2);
373 func
[code
](*n
, *k
, a
, *lda
, a
, *lda
, c
, *ldc
, alpha
);
379 // c = alpha*a*b' + alpha*b*a' + beta*c for op = 'N'or'n'
380 // c = alpha*a'*b + alpha*b'*a + beta*c for op = 'T'or't'
381 int EIGEN_BLAS_FUNC(syr2k
)(char *uplo
, char *op
, int *n
, int *k
, RealScalar
*palpha
, RealScalar
*pa
, int *lda
, RealScalar
*pb
, int *ldb
, RealScalar
*pbeta
, RealScalar
*pc
, int *ldc
)
383 Scalar
* a
= reinterpret_cast<Scalar
*>(pa
);
384 Scalar
* b
= reinterpret_cast<Scalar
*>(pb
);
385 Scalar
* c
= reinterpret_cast<Scalar
*>(pc
);
386 Scalar alpha
= *reinterpret_cast<Scalar
*>(palpha
);
387 Scalar beta
= *reinterpret_cast<Scalar
*>(pbeta
);
390 if(UPLO(*uplo
)==INVALID
) info
= 1;
391 else if(OP(*op
)==INVALID
) info
= 2;
392 else if(*n
<0) info
= 3;
393 else if(*k
<0) info
= 4;
394 else if(*lda
<std::max(1,(OP(*op
)==NOTR
)?*n
:*k
)) info
= 7;
395 else if(*ldb
<std::max(1,(OP(*op
)==NOTR
)?*n
:*k
)) info
= 9;
396 else if(*ldc
<std::max(1,*n
)) info
= 12;
398 return xerbla_(SCALAR_SUFFIX_UP
"SYR2K",&info
,6);
403 if(beta
==Scalar(0)) matrix(c
, *n
, *n
, *ldc
).triangularView
<Upper
>().setZero();
404 else matrix(c
, *n
, *n
, *ldc
).triangularView
<Upper
>() *= beta
;
406 if(beta
==Scalar(0)) matrix(c
, *n
, *n
, *ldc
).triangularView
<Lower
>().setZero();
407 else matrix(c
, *n
, *n
, *ldc
).triangularView
<Lower
>() *= beta
;
417 matrix(c
, *n
, *n
, *ldc
).triangularView
<Upper
>()
418 += alpha
*matrix(a
, *n
, *k
, *lda
)*matrix(b
, *n
, *k
, *ldb
).transpose()
419 + alpha
*matrix(b
, *n
, *k
, *ldb
)*matrix(a
, *n
, *k
, *lda
).transpose();
421 else if(UPLO(*uplo
)==LO
)
422 matrix(c
, *n
, *n
, *ldc
).triangularView
<Lower
>()
423 += alpha
*matrix(a
, *n
, *k
, *lda
)*matrix(b
, *n
, *k
, *ldb
).transpose()
424 + alpha
*matrix(b
, *n
, *k
, *ldb
)*matrix(a
, *n
, *k
, *lda
).transpose();
426 else if(OP(*op
)==TR
|| OP(*op
)==ADJ
)
429 matrix(c
, *n
, *n
, *ldc
).triangularView
<Upper
>()
430 += alpha
*matrix(a
, *k
, *n
, *lda
).transpose()*matrix(b
, *k
, *n
, *ldb
)
431 + alpha
*matrix(b
, *k
, *n
, *ldb
).transpose()*matrix(a
, *k
, *n
, *lda
);
432 else if(UPLO(*uplo
)==LO
)
433 matrix(c
, *n
, *n
, *ldc
).triangularView
<Lower
>()
434 += alpha
*matrix(a
, *k
, *n
, *lda
).transpose()*matrix(b
, *k
, *n
, *ldb
)
435 + alpha
*matrix(b
, *k
, *n
, *ldb
).transpose()*matrix(a
, *k
, *n
, *lda
);
444 // c = alpha*a*b + beta*c for side = 'L'or'l'
445 // c = alpha*b*a + beta*c for side = 'R'or'r
446 int EIGEN_BLAS_FUNC(hemm
)(char *side
, char *uplo
, int *m
, int *n
, RealScalar
*palpha
, RealScalar
*pa
, int *lda
, RealScalar
*pb
, int *ldb
, RealScalar
*pbeta
, RealScalar
*pc
, int *ldc
)
448 Scalar
* a
= reinterpret_cast<Scalar
*>(pa
);
449 Scalar
* b
= reinterpret_cast<Scalar
*>(pb
);
450 Scalar
* c
= reinterpret_cast<Scalar
*>(pc
);
451 Scalar alpha
= *reinterpret_cast<Scalar
*>(palpha
);
452 Scalar beta
= *reinterpret_cast<Scalar
*>(pbeta
);
454 // std::cerr << "in hemm " << *side << " " << *uplo << " " << *m << " " << *n << " " << alpha << " " << *lda << " " << beta << " " << *ldc << "\n";
457 if(SIDE(*side
)==INVALID
) info
= 1;
458 else if(UPLO(*uplo
)==INVALID
) info
= 2;
459 else if(*m
<0) info
= 3;
460 else if(*n
<0) info
= 4;
461 else if(*lda
<std::max(1,(SIDE(*side
)==LEFT
)?*m
:*n
)) info
= 7;
462 else if(*ldb
<std::max(1,*m
)) info
= 9;
463 else if(*ldc
<std::max(1,*m
)) info
= 12;
465 return xerbla_(SCALAR_SUFFIX_UP
"HEMM ",&info
,6);
467 if(beta
==Scalar(0)) matrix(c
, *m
, *n
, *ldc
).setZero();
468 else if(beta
!=Scalar(1)) matrix(c
, *m
, *n
, *ldc
) *= beta
;
475 if(SIDE(*side
)==LEFT
)
477 if(UPLO(*uplo
)==UP
) internal::product_selfadjoint_matrix
<Scalar
,DenseIndex
,RowMajor
,true,Conj
, ColMajor
,false,false, ColMajor
>
478 ::run(*m
, *n
, a
, *lda
, b
, *ldb
, c
, *ldc
, alpha
);
479 else if(UPLO(*uplo
)==LO
) internal::product_selfadjoint_matrix
<Scalar
,DenseIndex
,ColMajor
,true,false, ColMajor
,false,false, ColMajor
>
480 ::run(*m
, *n
, a
, *lda
, b
, *ldb
, c
, *ldc
, alpha
);
483 else if(SIDE(*side
)==RIGHT
)
485 if(UPLO(*uplo
)==UP
) matrix(c
,*m
,*n
,*ldc
) += alpha
* matrix(b
,*m
,*n
,*ldb
) * matrix(a
,*n
,*n
,*lda
).selfadjointView
<Upper
>();/*internal::product_selfadjoint_matrix<Scalar,DenseIndex,ColMajor,false,false, RowMajor,true,Conj, ColMajor>
486 ::run(*m, *n, b, *ldb, a, *lda, c, *ldc, alpha);*/
487 else if(UPLO(*uplo
)==LO
) internal::product_selfadjoint_matrix
<Scalar
,DenseIndex
,ColMajor
,false,false, ColMajor
,true,false, ColMajor
>
488 ::run(*m
, *n
, b
, *ldb
, a
, *lda
, c
, *ldc
, alpha
);
499 // c = alpha*a*conj(a') + beta*c for op = 'N'or'n'
500 // c = alpha*conj(a')*a + beta*c for op = 'C'or'c'
501 int EIGEN_BLAS_FUNC(herk
)(char *uplo
, char *op
, int *n
, int *k
, RealScalar
*palpha
, RealScalar
*pa
, int *lda
, RealScalar
*pbeta
, RealScalar
*pc
, int *ldc
)
503 typedef void (*functype
)(DenseIndex
, DenseIndex
, const Scalar
*, DenseIndex
, const Scalar
*, DenseIndex
, Scalar
*, DenseIndex
, const Scalar
&);
504 static functype func
[8];
506 static bool init
= false;
509 for(int k
=0; k
<8; ++k
)
512 func
[NOTR
| (UP
<< 2)] = (internal::general_matrix_matrix_triangular_product
<DenseIndex
,Scalar
,ColMajor
,false,Scalar
,RowMajor
,Conj
, ColMajor
,Upper
>::run
);
513 func
[ADJ
| (UP
<< 2)] = (internal::general_matrix_matrix_triangular_product
<DenseIndex
,Scalar
,RowMajor
,Conj
, Scalar
,ColMajor
,false,ColMajor
,Upper
>::run
);
515 func
[NOTR
| (LO
<< 2)] = (internal::general_matrix_matrix_triangular_product
<DenseIndex
,Scalar
,ColMajor
,false,Scalar
,RowMajor
,Conj
, ColMajor
,Lower
>::run
);
516 func
[ADJ
| (LO
<< 2)] = (internal::general_matrix_matrix_triangular_product
<DenseIndex
,Scalar
,RowMajor
,Conj
, Scalar
,ColMajor
,false,ColMajor
,Lower
>::run
);
521 Scalar
* a
= reinterpret_cast<Scalar
*>(pa
);
522 Scalar
* c
= reinterpret_cast<Scalar
*>(pc
);
523 RealScalar alpha
= *palpha
;
524 RealScalar beta
= *pbeta
;
526 // std::cerr << "in herk " << *uplo << " " << *op << " " << *n << " " << *k << " " << alpha << " " << *lda << " " << beta << " " << *ldc << "\n";
529 if(UPLO(*uplo
)==INVALID
) info
= 1;
530 else if((OP(*op
)==INVALID
) || (OP(*op
)==TR
)) info
= 2;
531 else if(*n
<0) info
= 3;
532 else if(*k
<0) info
= 4;
533 else if(*lda
<std::max(1,(OP(*op
)==NOTR
)?*n
:*k
)) info
= 7;
534 else if(*ldc
<std::max(1,*n
)) info
= 10;
536 return xerbla_(SCALAR_SUFFIX_UP
"HERK ",&info
,6);
538 int code
= OP(*op
) | (UPLO(*uplo
) << 2);
540 if(beta
!=RealScalar(1))
543 if(beta
==Scalar(0)) matrix(c
, *n
, *n
, *ldc
).triangularView
<Upper
>().setZero();
544 else matrix(c
, *n
, *n
, *ldc
).triangularView
<StrictlyUpper
>() *= beta
;
546 if(beta
==Scalar(0)) matrix(c
, *n
, *n
, *ldc
).triangularView
<Lower
>().setZero();
547 else matrix(c
, *n
, *n
, *ldc
).triangularView
<StrictlyLower
>() *= beta
;
551 matrix(c
, *n
, *n
, *ldc
).diagonal().real() *= beta
;
552 matrix(c
, *n
, *n
, *ldc
).diagonal().imag().setZero();
556 if(*k
>0 && alpha
!=RealScalar(0))
558 func
[code
](*n
, *k
, a
, *lda
, a
, *lda
, c
, *ldc
, alpha
);
559 matrix(c
, *n
, *n
, *ldc
).diagonal().imag().setZero();
564 // c = alpha*a*conj(b') + conj(alpha)*b*conj(a') + beta*c, for op = 'N'or'n'
565 // c = alpha*conj(a')*b + conj(alpha)*conj(b')*a + beta*c, for op = 'C'or'c'
566 int EIGEN_BLAS_FUNC(her2k
)(char *uplo
, char *op
, int *n
, int *k
, RealScalar
*palpha
, RealScalar
*pa
, int *lda
, RealScalar
*pb
, int *ldb
, RealScalar
*pbeta
, RealScalar
*pc
, int *ldc
)
568 Scalar
* a
= reinterpret_cast<Scalar
*>(pa
);
569 Scalar
* b
= reinterpret_cast<Scalar
*>(pb
);
570 Scalar
* c
= reinterpret_cast<Scalar
*>(pc
);
571 Scalar alpha
= *reinterpret_cast<Scalar
*>(palpha
);
572 RealScalar beta
= *pbeta
;
575 if(UPLO(*uplo
)==INVALID
) info
= 1;
576 else if((OP(*op
)==INVALID
) || (OP(*op
)==TR
)) info
= 2;
577 else if(*n
<0) info
= 3;
578 else if(*k
<0) info
= 4;
579 else if(*lda
<std::max(1,(OP(*op
)==NOTR
)?*n
:*k
)) info
= 7;
580 else if(*lda
<std::max(1,(OP(*op
)==NOTR
)?*n
:*k
)) info
= 9;
581 else if(*ldc
<std::max(1,*n
)) info
= 12;
583 return xerbla_(SCALAR_SUFFIX_UP
"HER2K",&info
,6);
585 if(beta
!=RealScalar(1))
588 if(beta
==Scalar(0)) matrix(c
, *n
, *n
, *ldc
).triangularView
<Upper
>().setZero();
589 else matrix(c
, *n
, *n
, *ldc
).triangularView
<StrictlyUpper
>() *= beta
;
591 if(beta
==Scalar(0)) matrix(c
, *n
, *n
, *ldc
).triangularView
<Lower
>().setZero();
592 else matrix(c
, *n
, *n
, *ldc
).triangularView
<StrictlyLower
>() *= beta
;
596 matrix(c
, *n
, *n
, *ldc
).diagonal().real() *= beta
;
597 matrix(c
, *n
, *n
, *ldc
).diagonal().imag().setZero();
600 else if(*k
>0 && alpha
!=Scalar(0))
601 matrix(c
, *n
, *n
, *ldc
).diagonal().imag().setZero();
610 matrix(c
, *n
, *n
, *ldc
).triangularView
<Upper
>()
611 += alpha
*matrix(a
, *n
, *k
, *lda
)*matrix(b
, *n
, *k
, *ldb
).adjoint()
612 + numext::conj(alpha
)*matrix(b
, *n
, *k
, *ldb
)*matrix(a
, *n
, *k
, *lda
).adjoint();
614 else if(UPLO(*uplo
)==LO
)
615 matrix(c
, *n
, *n
, *ldc
).triangularView
<Lower
>()
616 += alpha
*matrix(a
, *n
, *k
, *lda
)*matrix(b
, *n
, *k
, *ldb
).adjoint()
617 + numext::conj(alpha
)*matrix(b
, *n
, *k
, *ldb
)*matrix(a
, *n
, *k
, *lda
).adjoint();
619 else if(OP(*op
)==ADJ
)
622 matrix(c
, *n
, *n
, *ldc
).triangularView
<Upper
>()
623 += alpha
*matrix(a
, *k
, *n
, *lda
).adjoint()*matrix(b
, *k
, *n
, *ldb
)
624 + numext::conj(alpha
)*matrix(b
, *k
, *n
, *ldb
).adjoint()*matrix(a
, *k
, *n
, *lda
);
625 else if(UPLO(*uplo
)==LO
)
626 matrix(c
, *n
, *n
, *ldc
).triangularView
<Lower
>()
627 += alpha
*matrix(a
, *k
, *n
, *lda
).adjoint()*matrix(b
, *k
, *n
, *ldb
)
628 + numext::conj(alpha
)*matrix(b
, *k
, *n
, *ldb
).adjoint()*matrix(a
, *k
, *n
, *lda
);