create a_linalg_mul
[liba.git] / include / a / linalg.h
blobdc989c7cbd00749b26518c70224341b08f207ad9
1 /*!
2 @file linalg.h
3 @brief linear algebra functions
4 */
6 #ifndef LIBA_LINALG_H
7 #define LIBA_LINALG_H
9 #include "a.h"
11 /*!
12 @ingroup liba
13 @addtogroup a_linalg linear algebra functions
17 #if defined(__cplusplus)
18 extern "C" {
19 #endif /* __cplusplus */
21 /*!
22 @brief transpose an n x n square matrix in-place.
23 @param[in,out] A an n x n square matrix
24 @param[in] n order of square matrix A
26 A_EXTERN void a_linalg_T1(a_float *A, a_uint n);
28 /*!
29 @brief transpose a given m x n matrix A into an n x m matrix T.
30 @param[in] A the input matrix A (m x n), stored in row-major order.
31 @param[in] m rows in the input matrix A.
32 @param[in] n columns in the input matrix A.
33 @param[out] T the output matrix where the transposed matrix T (n x m) will be stored.
35 A_EXTERN void a_linalg_T2(a_float const *__restrict A, a_uint m, a_uint n, a_float *__restrict T);
37 /*!
38 @brief multiply two matrices X and Y, storing the result in Z.
39 \f{aligned}{
40 \pmb Z_{rc}&=\pmb X_{rn}\pmb Y_{nc}
41 \\&=
42 \begin{bmatrix}
43 x_{11} & \cdots & x_{1n} \\
44 \vdots & \ddots & \vdots \\
45 x_{r1} & \cdots & x_{rn} \\
46 \end{bmatrix}
47 \begin{bmatrix}
48 y_{11} & \cdots & y_{1c} \\
49 \vdots & \ddots & \vdots \\
50 y_{n1} & \cdots & y_{nc} \\
51 \end{bmatrix}
52 \\&=
53 \begin{bmatrix}
54 (x_{11}y_{11}+\ldots+x_{1n}y_{n1}) & \cdots & (x_{11}y_{1c}+\ldots+x_{1n}y_{nc}) \\
55 \vdots & \ddots & \vdots \\
56 (x_{r1}y_{11}+\ldots+x_{rn}y_{n1}) & \cdots & (x_{r1}y_{1c}+\ldots+x_{rn}y_{nc}) \\
57 \end{bmatrix}
58 \f}
59 @param[out] Z the output matrix where the result will be stored.
60 @param[in] X the first input matrix.
61 @param[in] Y the second input matrix.
62 @param[in] row rows matrix Z and rows in matrix X.
63 @param[in] c_r columns in matrix X and rows in matrix Y.
64 @param[in] col columns in matrix Z and columns in matrix Y.
66 A_EXTERN void a_linalg_mulmm(a_float *Z, a_float const *X, a_float const *Y, a_uint row, a_uint c_r, a_uint col);
68 /*!
69 @brief multiply the transpose of matrix X with matrix Y, storing the result in Z.
70 \f{aligned}{
71 \pmb Z_{rc}&=\pmb X_{nr}^{T}\pmb Y_{nc}
72 \\&=
73 \begin{bmatrix}
74 x_{11} & \cdots & x_{1r} \\
75 \vdots & \ddots & \vdots \\
76 x_{n1} & \cdots & x_{nr} \\
77 \end{bmatrix}^T
78 \begin{bmatrix}
79 y_{11} & \cdots & y_{1c} \\
80 \vdots & \ddots & \vdots \\
81 y_{n1} & \cdots & y_{nc} \\
82 \end{bmatrix}
83 \\&=
84 \begin{bmatrix}
85 (x_{11}y_{11}+\ldots+x_{n1}y_{n1}) & \cdots & (x_{11}y_{1c}+\ldots+x_{n1}y_{nc}) \\
86 \vdots & \ddots & \vdots \\
87 (x_{1r}y_{11}+\ldots+x_{nr}y_{n1}) & \cdots & (x_{1r}y_{1c}+\ldots+x_{nr}y_{nc}) \\
88 \end{bmatrix}
89 \\&=
90 \begin{bmatrix}
91 x_{11}y_{11} & \cdots & x_{11}y_{1c} \\
92 \vdots & \ddots & \vdots \\
93 x_{1r}y_{11} & \cdots & x_{1r}y_{1c} \\
94 \end{bmatrix}+\cdots+
95 \begin{bmatrix}
96 x_{n1}y_{n1} & \cdots & x_{n1}y_{nc} \\
97 \vdots & \ddots & \vdots \\
98 x_{nr}y_{n1} & \cdots & x_{nr}y_{nc} \\
99 \end{bmatrix}
101 @param[out] Z the output matrix where the result will be stored.
102 @param[in] X the first input matrix that will be transposed during multiplication.
103 @param[in] Y the second input matrix.
104 @param[in] c_r rows in matrix X and rows in matrix Y.
105 @param[in] row rows in matrix Z and columns in matrix X.
106 @param[in] col columns in matrix Z and columns in matrix Y.
108 A_EXTERN void a_linalg_mulTm(a_float *Z, a_float const *X, a_float const *Y, a_uint c_r, a_uint row, a_uint col);
111 @brief multiply matrix X with the transpose of matrix Y, storing the result in Z.
112 \f{aligned}{
113 \pmb Z_{rc}&=\pmb X_{rn}\pmb Y_{cn}^T
114 \\&=
115 \begin{bmatrix}
116 x_{11} & \cdots & x_{1n} \\
117 \vdots & \ddots & \vdots \\
118 x_{r1} & \cdots & x_{rn} \\
119 \end{bmatrix}
120 \begin{bmatrix}
121 y_{11} & \cdots & y_{1n} \\
122 \vdots & \ddots & \vdots \\
123 y_{c1} & \cdots & y_{cn} \\
124 \end{bmatrix}^T
125 \\&=
126 \begin{bmatrix}
127 (x_{11}y_{11}+\ldots+x_{1n}y_{1n}) & \cdots & (x_{11}y_{c1}+\ldots+x_{1n}y_{cn}) \\
128 \vdots & \ddots & \vdots \\
129 (x_{r1}y_{11}+\ldots+x_{rn}y_{1n}) & \cdots & (x_{r1}y_{c1}+\ldots+x_{rn}y_{cn}) \\
130 \end{bmatrix}
132 @param[out] Z the output matrix where the result will be stored.
133 @param[in] X the first input matrix.
134 @param[in] Y the second input matrix that will be transposed during multiplication.
135 @param[in] row rows matrix Z and rows in matrix X.
136 @param[in] col columns in matrix Z and rows in matrix Y.
137 @param[in] c_r columns in matrix X and columns in matrix Y.
139 A_EXTERN void a_linalg_mulmT(a_float *Z, a_float const *X, a_float const *Y, a_uint row, a_uint col, a_uint c_r);
142 @brief multiply the transpose of matrix X with the transpose of matrix Y, storing the result in Z.
143 \f{aligned}{
144 \pmb Z_{rc}&=\pmb X_{nr}^T\pmb Y_{cn}^T
145 \\&=
146 \begin{bmatrix}
147 x_{11} & \cdots & x_{1r} \\
148 \vdots & \ddots & \vdots \\
149 x_{n1} & \cdots & x_{nr} \\
150 \end{bmatrix}^T
151 \begin{bmatrix}
152 y_{11} & \cdots & y_{1n} \\
153 \vdots & \ddots & \vdots \\
154 y_{c1} & \cdots & y_{cn} \\
155 \end{bmatrix}^T
156 \\&=
157 \begin{bmatrix}
158 (x_{11}y_{11}+\ldots+x_{n1}y_{1n}) & \cdots & (x_{11}y_{c1}+\ldots+x_{n1}y_{cn}) \\
159 \vdots & \ddots & \vdots \\
160 (x_{1r}y_{11}+\ldots+x_{nr}y_{1n}) & \cdots & (x_{1r}y_{c1}+\ldots+x_{nr}y_{cn}) \\
161 \end{bmatrix}
163 @param[out] Z the output matrix where the result will be stored.
164 @param[in] X the first input matrix that will be transposed during multiplication.
165 @param[in] Y the second input matrix that will be transposed during multiplication.
166 @param[in] row rows matrix Z and columns in matrix X.
167 @param[in] c_r rows in matrix X and columns in matrix Y.
168 @param[in] col columns in matrix Z and rows in matrix Y.
170 A_EXTERN void a_linalg_mulTT(a_float *Z, a_float const *X, a_float const *Y, a_uint row, a_uint c_r, a_uint col);
173 @brief compute LU decomposition of a square matrix with partial pivoting.
174 @details This function performs an LU decomposition on the given square matrix A,
175 where L is a lower triangular matrix, and U is an upper triangular matrix.
176 Partial pivoting is used to improve numerical stability during the decomposition process.
177 The result is stored in the original matrix A, with L stored below, and U stored in the diagonal and above.
178 Additionally, it calculates a permutation matrix P that records the row exchanges made during partial pivoting,
179 and determines the sign of the permutation (which can be used to find the determinant's sign).
180 @param[in,out] A an n x n square matrix.
181 on input, contains the matrix to decompose. on output, contains the L and U matrices.
182 @param[in] n the order of the square matrix A (number of rows and columns).
183 @param[out] p the row permutation indices after partial pivoting.
184 @param[out] sign store the sign of the permutation (+1 or -1).
185 @return 0 on success, or a non-zero error code if the decomposition fails.
186 @retval -1 on failure, A is a singular matrix.
188 A_EXTERN int a_linalg_plu(a_float *A, a_uint n, a_uint *p, int *sign);
191 @brief construct the permutation matrix P from a permutation vector p.
192 @param[in] p the row permutation indices after partial pivoting.
193 @param[in] n the order of the square matrix that was decomposed.
194 @param[out] P the output matrix where the permutation matrix will be stored.
196 A_EXTERN void a_linalg_plu_get_P(a_uint const *p, a_uint n, a_float *P);
199 @brief extract the lower triangular matrix L from matrix A.
200 @param[in] A the matrix containing L and U in a compact form after LU decomposition.
201 @param[in] n the order of the square matrix that was decomposed.
202 @param[out] L the output matrix where the lower triangular matrix will be stored.
204 A_EXTERN void a_linalg_plu_get_L(a_float const *A, a_uint n, a_float *L);
207 @brief extract the upper triangular matrix U from matrix A.
208 @param[in] A the matrix containing L and U in a compact form after LU decomposition.
209 @param[in] n the order of the square matrix that was decomposed.
210 @param[out] U the output matrix where the upper triangular matrix will be stored.
212 A_EXTERN void a_linalg_plu_get_U(a_float const *A, a_uint n, a_float *U);
215 @brief apply the permutation P to the vector b, producing Pb.
216 @param[in] p the row permutation indices after partial pivoting.
217 @param[in] n the order of the square matrix that was decomposed.
218 @param[in] b the input vector of size n that will be permuted.
219 @param[out] Pb the output vector where the permuted result will be stored.
221 A_EXTERN void a_linalg_plu_apply(a_uint const *p, a_uint n, a_float const *b, a_float *Pb);
224 @brief solve the lower triangular system Ly = Pb for y.
225 @param[in] L the lower triangular matrix L, stored in row-major order.
226 @param[in] n the order of the square matrix L (number of rows and columns).
227 @param[in,out] y on input, contains the permuted vector Pb. on output, contains the solution vector y.
229 A_EXTERN void a_linalg_plu_lower(a_float const *L, a_uint n, a_float *y);
232 @brief solve the upper triangular system Ux = y for x.
233 @param[in] U the upper triangular matrix U, stored in row-major order.
234 @param[in] n the order of the square matrix U (number of rows and columns).
235 @param[in,out] x on input, contains the vector y. on output, contains the solution vector x.
237 A_EXTERN void a_linalg_plu_upper(a_float const *U, a_uint n, a_float *x);
240 @brief solve the linear system Ax = b using LU decomposition with partial pivoting.
241 @param[in] A the matrix containing L and U in a compact form after LU decomposition.
242 @param[in] n the order of the square matrix A (number of rows and columns).
243 @param[in] p the permutation indices obtained during LU decomposition.
244 @param[in] b the input vector b of the linear system.
245 @param[out] x the output vector x where the solution will be stored.
247 A_EXTERN void a_linalg_plu_solve(a_float const *A, a_uint n, a_uint const *p, a_float const *b, a_float *x);
250 @brief compute the inverse of a matrix using its LU decomposition and permutation matrix.
251 @param[in] A the matrix containing L and U in a compact form after LU decomposition.
252 @param[in] n the order of the square matrix A (number of rows and columns).
253 @param[in] p the permutation indices obtained during LU decomposition.
254 @param[in] b a pre-allocated temporary buffer of size n for intermediate computations.
255 @param[out] I the output matrix where the inverse of A will be stored.
257 A_EXTERN void a_linalg_plu_inv(a_float const *A, a_uint n, a_uint const *p, a_float *b, a_float *I);
260 @brief compute the determinant of a matrix using its LU decomposition.
261 @param[in] A the matrix containing L and U in a compact form after LU decomposition.
262 @param[in] n the order of the square matrix A (number of rows and columns).
263 @param[in] sign the sign of the permutation matrix P (+1 or -1).
264 @return the determinant of matrix A.
266 A_EXTERN a_float a_linalg_plu_det(a_float const *A, a_uint n, int sign);
269 @brief compute the natural logarithm of the absolute value of the determinant of a matrix using its LU decomposition.
270 @param[in] A the matrix containing L and U in a compact form after LU decomposition.
271 @param[in] n the order of the square matrix A (number of rows and columns).
272 @return the natural logarithm of the absolute value of the determinant.
274 A_EXTERN a_float a_linalg_plu_lndet(a_float const *A, a_uint n);
277 @brief compute the sign of the determinant of a matrix using its LU decomposition.
278 @param[in] A the matrix containing L and U in a compact form after LU decomposition.
279 @param[in] n the order of the square matrix A (number of rows and columns).
280 @param[in] sign the sign of the permutation matrix P (+1 or -1).
281 @return the sign of the determinant: -1, 0, +1.
283 A_EXTERN int a_linalg_plu_sgndet(a_float const *A, a_uint n, int sign);
285 #if defined(__cplusplus)
286 } /* extern "C" */
287 #endif /* __cplusplus */
289 /*! @} a_linalg */
291 #endif /* a/linalg.h */