Fix saving lists of arrays with recent versions of numpy
[qpms.git] / qpms / own_zgemm.c
blob7de5baa61a8ab2469901d48c73173bfa01620fc5
1 /* IMPORTANT! This code is partially taken from GSL, so everything must be GPL'd
2 * or this has to be rewritten (or removed; the only reason to use this are problems
3 * with OpenBLAS) when distributed.
4 */
6 #include "qpmsblas.h"
7 #include <stdlib.h>
8 #include <stdarg.h>
9 #include <stdio.h>
11 void
12 cblas_xerbla (int p, const char *rout, const char *form, ...)
14 va_list ap;
16 va_start (ap, form);
18 if (p)
20 fprintf (stderr, "Parameter %d to routine %s was incorrect\n", p, rout);
23 vfprintf (stderr, form, ap);
24 va_end (ap);
26 abort ();
30 #define BASE double
32 #define INDEX QPMS_BLAS_INDEX_T
33 #define OFFSET(N, incX) ((incX) > 0 ? 0 : ((N) - 1) * (-(incX)))
34 #define BLAS_ERROR(x) cblas_xerbla(0, __FILE__, x);
36 #define MAX(x,y) (((x) < (y)) ? (y) : (x))
38 #define CONJUGATE(x) ((x) == CblasConjTrans)
39 #define TRANSPOSE(x) ((x) == CblasTrans || (x) == CblasConjTrans)
40 #define UPPER(x) ((x) == CblasUpper)
41 #define LOWER(x) ((x) == CblasLower)
43 /* Handling of packed complex types... */
45 #define REAL(a,i) (((BASE *) a)[2*(i)])
46 #define IMAG(a,i) (((BASE *) a)[2*(i)+1])
48 #define REAL0(a) (((BASE *)a)[0])
49 #define IMAG0(a) (((BASE *)a)[1])
51 #define CONST_REAL(a,i) (((const BASE *) a)[2*(i)])
52 #define CONST_IMAG(a,i) (((const BASE *) a)[2*(i)+1])
54 #define CONST_REAL0(a) (((const BASE *)a)[0])
55 #define CONST_IMAG0(a) (((const BASE *)a)[1])
58 #define GB(KU,KL,lda,i,j) ((KU+1+(i-j))*lda + j)
60 #define TRCOUNT(N,i) ((((i)+1)*(2*(N)-(i)))/2)
62 /* #define TBUP(N,i,j) */
63 /* #define TBLO(N,i,j) */
65 #define TPUP(N,i,j) (TRCOUNT(N,(i)-1)+(j)-(i))
66 #define TPLO(N,i,j) (((i)*((i)+1))/2 + (j))
69 /* check if CBLAS_ORDER is correct */
70 #define CHECK_ORDER(pos,posIfError,order) \
71 if(((order)!=CblasRowMajor)&&((order)!=CblasColMajor)) \
72 pos = posIfError;
74 /* check if CBLAS_TRANSPOSE is correct */
75 #define CHECK_TRANSPOSE(pos,posIfError,Trans) \
76 if(((Trans)!=CblasNoTrans)&&((Trans)!=CblasTrans)&&((Trans)!=CblasConjTrans)) \
77 pos = posIfError;
79 /* check if a dimension argument is correct */
80 #define CHECK_DIM(pos,posIfError,dim) \
81 if((dim)<0) \
82 pos = posIfError;
84 /* cblas_xgemm() */
85 #define CBLAS_ERROR_GEMM(pos,Order,TransA,TransB,M,N,K,alpha,A,lda,B,ldb,beta,C,ldc) \
86 { \
87 CBLAS_TRANSPOSE __transF=CblasNoTrans,__transG=CblasNoTrans; \
88 if((Order)==CblasRowMajor) { \
89 __transF = ((TransA)!=CblasConjTrans) ? (TransA) : CblasTrans; \
90 __transG = ((TransB)!=CblasConjTrans) ? (TransB) : CblasTrans; \
91 } else { \
92 __transF = ((TransB)!=CblasConjTrans) ? (TransB) : CblasTrans; \
93 __transG = ((TransA)!=CblasConjTrans) ? (TransA) : CblasTrans; \
94 } \
95 CHECK_ORDER(pos,1,Order); \
96 CHECK_TRANSPOSE(pos,2,TransA); \
97 CHECK_TRANSPOSE(pos,3,TransB); \
98 CHECK_DIM(pos,4,M); \
99 CHECK_DIM(pos,5,N); \
100 CHECK_DIM(pos,6,K); \
101 if((Order)==CblasRowMajor) { \
102 if(__transF==CblasNoTrans) { \
103 if((lda)<MAX(1,(K))) { \
104 (pos) = 9; \
106 } else { \
107 if((lda)<MAX(1,(M))) { \
108 (pos) = 9; \
111 if(__transG==CblasNoTrans) { \
112 if((ldb)<MAX(1,(N))) { \
113 (pos) = 11; \
115 } else { \
116 if((ldb)<MAX(1,(K))) { \
117 (pos) = 11; \
120 if((ldc)<MAX(1,(N))) { \
121 (pos) = 14; \
123 } else if((Order)==CblasColMajor) { \
124 if(__transF==CblasNoTrans) { \
125 if((ldb)<MAX(1,(K))) { \
126 (pos) = 11; \
128 } else { \
129 if((ldb)<MAX(1,(N))) { \
130 (pos) = 11; \
133 if(__transG==CblasNoTrans) { \
134 if((lda)<MAX(1,(M))) { \
135 (pos) = 9; \
137 } else { \
138 if((lda)<MAX(1,(K))) { \
139 (pos) = 9; \
142 if((ldc)<MAX(1,(M))) { \
143 (pos) = 14; \
150 #define CHECK_ARGS_X(FUNCTION,VAR,ARGS) do { int VAR = 0 ; \
151 CBLAS_ERROR_##FUNCTION ARGS ; \
152 if (VAR) cblas_xerbla(pos,__FILE__,""); } while (0)
154 #define CHECK_ARGS14(FUNCTION,A1,A2,A3,A4,A5,A6,A7,A8,A9,A10,A11,A12,A13,A14) \
155 CHECK_ARGS_X(FUNCTION,pos,(pos,A1,A2,A3,A4,A5,A6,A7,A8,A9,A10,A11,A12,A13,A14))
157 void qpms_zgemm(CBLAS_LAYOUT Order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB,
158 const INDEX M, const INDEX N, const INDEX K,
159 const _Complex double *alpha, const _Complex double *A, const INDEX lda,
160 const _Complex double *B, const INDEX ldb,
161 const _Complex double *beta, _Complex double *C, const INDEX ldc)
163 INDEX i, j, k;
164 INDEX n1, n2;
165 INDEX ldf, ldg;
166 int conjF, conjG, TransF, TransG;
167 const BASE *F, *G;
169 CHECK_ARGS14(GEMM,Order,TransA,TransB,M,N,K,alpha,A,lda,B,ldb,beta,C,ldc);
172 const BASE alpha_real = CONST_REAL0(alpha);
173 const BASE alpha_imag = CONST_IMAG0(alpha);
175 const BASE beta_real = CONST_REAL0(beta);
176 const BASE beta_imag = CONST_IMAG0(beta);
178 if ((alpha_real == 0.0 && alpha_imag == 0.0)
179 && (beta_real == 1.0 && beta_imag == 0.0))
180 return;
182 if (Order == CblasRowMajor) {
183 n1 = M;
184 n2 = N;
185 F = (const BASE *)A;
186 ldf = lda;
187 conjF = (TransA == CblasConjTrans) ? -1 : 1;
188 TransF = (TransA == CblasNoTrans) ? CblasNoTrans : CblasTrans;
189 G = (const BASE *)B;
190 ldg = ldb;
191 conjG = (TransB == CblasConjTrans) ? -1 : 1;
192 TransG = (TransB == CblasNoTrans) ? CblasNoTrans : CblasTrans;
193 } else {
194 n1 = N;
195 n2 = M;
196 F = (const BASE *)B;
197 ldf = ldb;
198 conjF = (TransB == CblasConjTrans) ? -1 : 1;
199 TransF = (TransB == CblasNoTrans) ? CblasNoTrans : CblasTrans;
200 G = (const BASE *)A;
201 ldg = lda;
202 conjG = (TransA == CblasConjTrans) ? -1 : 1;
203 TransG = (TransA == CblasNoTrans) ? CblasNoTrans : CblasTrans;
206 /* form y := beta*y */
207 if (beta_real == 0.0 && beta_imag == 0.0) {
208 for (i = 0; i < n1; i++) {
209 for (j = 0; j < n2; j++) {
210 REAL(C, ldc * i + j) = 0.0;
211 IMAG(C, ldc * i + j) = 0.0;
214 } else if (!(beta_real == 1.0 && beta_imag == 0.0)) {
215 for (i = 0; i < n1; i++) {
216 for (j = 0; j < n2; j++) {
217 const BASE Cij_real = REAL(C, ldc * i + j);
218 const BASE Cij_imag = IMAG(C, ldc * i + j);
219 REAL(C, ldc * i + j) = beta_real * Cij_real - beta_imag * Cij_imag;
220 IMAG(C, ldc * i + j) = beta_real * Cij_imag + beta_imag * Cij_real;
225 if (alpha_real == 0.0 && alpha_imag == 0.0)
226 return;
228 if (TransF == CblasNoTrans && TransG == CblasNoTrans) {
230 /* form C := alpha*A*B + C */
232 for (k = 0; k < K; k++) {
233 for (i = 0; i < n1; i++) {
234 const BASE Fik_real = CONST_REAL(F, ldf * i + k);
235 const BASE Fik_imag = conjF * CONST_IMAG(F, ldf * i + k);
236 const BASE temp_real = alpha_real * Fik_real - alpha_imag * Fik_imag;
237 const BASE temp_imag = alpha_real * Fik_imag + alpha_imag * Fik_real;
238 if (!(temp_real == 0.0 && temp_imag == 0.0)) {
239 for (j = 0; j < n2; j++) {
240 const BASE Gkj_real = CONST_REAL(G, ldg * k + j);
241 const BASE Gkj_imag = conjG * CONST_IMAG(G, ldg * k + j);
242 REAL(C, ldc * i + j) += temp_real * Gkj_real - temp_imag * Gkj_imag;
243 IMAG(C, ldc * i + j) += temp_real * Gkj_imag + temp_imag * Gkj_real;
249 } else if (TransF == CblasNoTrans && TransG == CblasTrans) {
251 /* form C := alpha*A*B' + C */
253 for (i = 0; i < n1; i++) {
254 for (j = 0; j < n2; j++) {
255 BASE temp_real = 0.0;
256 BASE temp_imag = 0.0;
257 for (k = 0; k < K; k++) {
258 const BASE Fik_real = CONST_REAL(F, ldf * i + k);
259 const BASE Fik_imag = conjF * CONST_IMAG(F, ldf * i + k);
260 const BASE Gjk_real = CONST_REAL(G, ldg * j + k);
261 const BASE Gjk_imag = conjG * CONST_IMAG(G, ldg * j + k);
262 temp_real += Fik_real * Gjk_real - Fik_imag * Gjk_imag;
263 temp_imag += Fik_real * Gjk_imag + Fik_imag * Gjk_real;
265 REAL(C, ldc * i + j) += alpha_real * temp_real - alpha_imag * temp_imag;
266 IMAG(C, ldc * i + j) += alpha_real * temp_imag + alpha_imag * temp_real;
270 } else if (TransF == CblasTrans && TransG == CblasNoTrans) {
272 for (k = 0; k < K; k++) {
273 for (i = 0; i < n1; i++) {
274 const BASE Fki_real = CONST_REAL(F, ldf * k + i);
275 const BASE Fki_imag = conjF * CONST_IMAG(F, ldf * k + i);
276 const BASE temp_real = alpha_real * Fki_real - alpha_imag * Fki_imag;
277 const BASE temp_imag = alpha_real * Fki_imag + alpha_imag * Fki_real;
278 if (!(temp_real == 0.0 && temp_imag == 0.0)) {
279 for (j = 0; j < n2; j++) {
280 const BASE Gkj_real = CONST_REAL(G, ldg * k + j);
281 const BASE Gkj_imag = conjG * CONST_IMAG(G, ldg * k + j);
282 REAL(C, ldc * i + j) += temp_real * Gkj_real - temp_imag * Gkj_imag;
283 IMAG(C, ldc * i + j) += temp_real * Gkj_imag + temp_imag * Gkj_real;
289 } else if (TransF == CblasTrans && TransG == CblasTrans) {
291 for (i = 0; i < n1; i++) {
292 for (j = 0; j < n2; j++) {
293 BASE temp_real = 0.0;
294 BASE temp_imag = 0.0;
295 for (k = 0; k < K; k++) {
296 const BASE Fki_real = CONST_REAL(F, ldf * k + i);
297 const BASE Fki_imag = conjF * CONST_IMAG(F, ldf * k + i);
298 const BASE Gjk_real = CONST_REAL(G, ldg * j + k);
299 const BASE Gjk_imag = conjG * CONST_IMAG(G, ldg * j + k);
301 temp_real += Fki_real * Gjk_real - Fki_imag * Gjk_imag;
302 temp_imag += Fki_real * Gjk_imag + Fki_imag * Gjk_real;
304 REAL(C, ldc * i + j) += alpha_real * temp_real - alpha_imag * temp_imag;
305 IMAG(C, ldc * i + j) += alpha_real * temp_imag + alpha_imag * temp_real;
309 } else {
310 BLAS_ERROR("unrecognized operation");