1 /* IMPORTANT! This code is partially taken from GSL, so everything must be GPL'd
2 * or this has to be rewritten (or removed; the only reason to use this are problems
3 * with OpenBLAS) when distributed.
12 cblas_xerbla (int p
, const char *rout
, const char *form
, ...)
20 fprintf (stderr
, "Parameter %d to routine %s was incorrect\n", p
, rout
);
23 vfprintf (stderr
, form
, ap
);
32 #define INDEX QPMS_BLAS_INDEX_T
33 #define OFFSET(N, incX) ((incX) > 0 ? 0 : ((N) - 1) * (-(incX)))
34 #define BLAS_ERROR(x) cblas_xerbla(0, __FILE__, x);
36 #define MAX(x,y) (((x) < (y)) ? (y) : (x))
38 #define CONJUGATE(x) ((x) == CblasConjTrans)
39 #define TRANSPOSE(x) ((x) == CblasTrans || (x) == CblasConjTrans)
40 #define UPPER(x) ((x) == CblasUpper)
41 #define LOWER(x) ((x) == CblasLower)
43 /* Handling of packed complex types... */
45 #define REAL(a,i) (((BASE *) a)[2*(i)])
46 #define IMAG(a,i) (((BASE *) a)[2*(i)+1])
48 #define REAL0(a) (((BASE *)a)[0])
49 #define IMAG0(a) (((BASE *)a)[1])
51 #define CONST_REAL(a,i) (((const BASE *) a)[2*(i)])
52 #define CONST_IMAG(a,i) (((const BASE *) a)[2*(i)+1])
54 #define CONST_REAL0(a) (((const BASE *)a)[0])
55 #define CONST_IMAG0(a) (((const BASE *)a)[1])
58 #define GB(KU,KL,lda,i,j) ((KU+1+(i-j))*lda + j)
60 #define TRCOUNT(N,i) ((((i)+1)*(2*(N)-(i)))/2)
62 /* #define TBUP(N,i,j) */
63 /* #define TBLO(N,i,j) */
65 #define TPUP(N,i,j) (TRCOUNT(N,(i)-1)+(j)-(i))
66 #define TPLO(N,i,j) (((i)*((i)+1))/2 + (j))
69 /* check if CBLAS_ORDER is correct */
70 #define CHECK_ORDER(pos,posIfError,order) \
71 if(((order)!=CblasRowMajor)&&((order)!=CblasColMajor)) \
74 /* check if CBLAS_TRANSPOSE is correct */
75 #define CHECK_TRANSPOSE(pos,posIfError,Trans) \
76 if(((Trans)!=CblasNoTrans)&&((Trans)!=CblasTrans)&&((Trans)!=CblasConjTrans)) \
79 /* check if a dimension argument is correct */
80 #define CHECK_DIM(pos,posIfError,dim) \
85 #define CBLAS_ERROR_GEMM(pos,Order,TransA,TransB,M,N,K,alpha,A,lda,B,ldb,beta,C,ldc) \
87 CBLAS_TRANSPOSE __transF=CblasNoTrans,__transG=CblasNoTrans; \
88 if((Order)==CblasRowMajor) { \
89 __transF = ((TransA)!=CblasConjTrans) ? (TransA) : CblasTrans; \
90 __transG = ((TransB)!=CblasConjTrans) ? (TransB) : CblasTrans; \
92 __transF = ((TransB)!=CblasConjTrans) ? (TransB) : CblasTrans; \
93 __transG = ((TransA)!=CblasConjTrans) ? (TransA) : CblasTrans; \
95 CHECK_ORDER(pos,1,Order); \
96 CHECK_TRANSPOSE(pos,2,TransA); \
97 CHECK_TRANSPOSE(pos,3,TransB); \
100 CHECK_DIM(pos,6,K); \
101 if((Order)==CblasRowMajor) { \
102 if(__transF==CblasNoTrans) { \
103 if((lda)<MAX(1,(K))) { \
107 if((lda)<MAX(1,(M))) { \
111 if(__transG==CblasNoTrans) { \
112 if((ldb)<MAX(1,(N))) { \
116 if((ldb)<MAX(1,(K))) { \
120 if((ldc)<MAX(1,(N))) { \
123 } else if((Order)==CblasColMajor) { \
124 if(__transF==CblasNoTrans) { \
125 if((ldb)<MAX(1,(K))) { \
129 if((ldb)<MAX(1,(N))) { \
133 if(__transG==CblasNoTrans) { \
134 if((lda)<MAX(1,(M))) { \
138 if((lda)<MAX(1,(K))) { \
142 if((ldc)<MAX(1,(M))) { \
150 #define CHECK_ARGS_X(FUNCTION,VAR,ARGS) do { int VAR = 0 ; \
151 CBLAS_ERROR_##FUNCTION ARGS ; \
152 if (VAR) cblas_xerbla(pos,__FILE__,""); } while (0)
154 #define CHECK_ARGS14(FUNCTION,A1,A2,A3,A4,A5,A6,A7,A8,A9,A10,A11,A12,A13,A14) \
155 CHECK_ARGS_X(FUNCTION,pos,(pos,A1,A2,A3,A4,A5,A6,A7,A8,A9,A10,A11,A12,A13,A14))
157 void qpms_zgemm(CBLAS_LAYOUT Order
, CBLAS_TRANSPOSE TransA
, CBLAS_TRANSPOSE TransB
,
158 const INDEX M
, const INDEX N
, const INDEX K
,
159 const _Complex
double *alpha
, const _Complex
double *A
, const INDEX lda
,
160 const _Complex
double *B
, const INDEX ldb
,
161 const _Complex
double *beta
, _Complex
double *C
, const INDEX ldc
)
166 int conjF
, conjG
, TransF
, TransG
;
169 CHECK_ARGS14(GEMM
,Order
,TransA
,TransB
,M
,N
,K
,alpha
,A
,lda
,B
,ldb
,beta
,C
,ldc
);
172 const BASE alpha_real
= CONST_REAL0(alpha
);
173 const BASE alpha_imag
= CONST_IMAG0(alpha
);
175 const BASE beta_real
= CONST_REAL0(beta
);
176 const BASE beta_imag
= CONST_IMAG0(beta
);
178 if ((alpha_real
== 0.0 && alpha_imag
== 0.0)
179 && (beta_real
== 1.0 && beta_imag
== 0.0))
182 if (Order
== CblasRowMajor
) {
187 conjF
= (TransA
== CblasConjTrans
) ? -1 : 1;
188 TransF
= (TransA
== CblasNoTrans
) ? CblasNoTrans
: CblasTrans
;
191 conjG
= (TransB
== CblasConjTrans
) ? -1 : 1;
192 TransG
= (TransB
== CblasNoTrans
) ? CblasNoTrans
: CblasTrans
;
198 conjF
= (TransB
== CblasConjTrans
) ? -1 : 1;
199 TransF
= (TransB
== CblasNoTrans
) ? CblasNoTrans
: CblasTrans
;
202 conjG
= (TransA
== CblasConjTrans
) ? -1 : 1;
203 TransG
= (TransA
== CblasNoTrans
) ? CblasNoTrans
: CblasTrans
;
206 /* form y := beta*y */
207 if (beta_real
== 0.0 && beta_imag
== 0.0) {
208 for (i
= 0; i
< n1
; i
++) {
209 for (j
= 0; j
< n2
; j
++) {
210 REAL(C
, ldc
* i
+ j
) = 0.0;
211 IMAG(C
, ldc
* i
+ j
) = 0.0;
214 } else if (!(beta_real
== 1.0 && beta_imag
== 0.0)) {
215 for (i
= 0; i
< n1
; i
++) {
216 for (j
= 0; j
< n2
; j
++) {
217 const BASE Cij_real
= REAL(C
, ldc
* i
+ j
);
218 const BASE Cij_imag
= IMAG(C
, ldc
* i
+ j
);
219 REAL(C
, ldc
* i
+ j
) = beta_real
* Cij_real
- beta_imag
* Cij_imag
;
220 IMAG(C
, ldc
* i
+ j
) = beta_real
* Cij_imag
+ beta_imag
* Cij_real
;
225 if (alpha_real
== 0.0 && alpha_imag
== 0.0)
228 if (TransF
== CblasNoTrans
&& TransG
== CblasNoTrans
) {
230 /* form C := alpha*A*B + C */
232 for (k
= 0; k
< K
; k
++) {
233 for (i
= 0; i
< n1
; i
++) {
234 const BASE Fik_real
= CONST_REAL(F
, ldf
* i
+ k
);
235 const BASE Fik_imag
= conjF
* CONST_IMAG(F
, ldf
* i
+ k
);
236 const BASE temp_real
= alpha_real
* Fik_real
- alpha_imag
* Fik_imag
;
237 const BASE temp_imag
= alpha_real
* Fik_imag
+ alpha_imag
* Fik_real
;
238 if (!(temp_real
== 0.0 && temp_imag
== 0.0)) {
239 for (j
= 0; j
< n2
; j
++) {
240 const BASE Gkj_real
= CONST_REAL(G
, ldg
* k
+ j
);
241 const BASE Gkj_imag
= conjG
* CONST_IMAG(G
, ldg
* k
+ j
);
242 REAL(C
, ldc
* i
+ j
) += temp_real
* Gkj_real
- temp_imag
* Gkj_imag
;
243 IMAG(C
, ldc
* i
+ j
) += temp_real
* Gkj_imag
+ temp_imag
* Gkj_real
;
249 } else if (TransF
== CblasNoTrans
&& TransG
== CblasTrans
) {
251 /* form C := alpha*A*B' + C */
253 for (i
= 0; i
< n1
; i
++) {
254 for (j
= 0; j
< n2
; j
++) {
255 BASE temp_real
= 0.0;
256 BASE temp_imag
= 0.0;
257 for (k
= 0; k
< K
; k
++) {
258 const BASE Fik_real
= CONST_REAL(F
, ldf
* i
+ k
);
259 const BASE Fik_imag
= conjF
* CONST_IMAG(F
, ldf
* i
+ k
);
260 const BASE Gjk_real
= CONST_REAL(G
, ldg
* j
+ k
);
261 const BASE Gjk_imag
= conjG
* CONST_IMAG(G
, ldg
* j
+ k
);
262 temp_real
+= Fik_real
* Gjk_real
- Fik_imag
* Gjk_imag
;
263 temp_imag
+= Fik_real
* Gjk_imag
+ Fik_imag
* Gjk_real
;
265 REAL(C
, ldc
* i
+ j
) += alpha_real
* temp_real
- alpha_imag
* temp_imag
;
266 IMAG(C
, ldc
* i
+ j
) += alpha_real
* temp_imag
+ alpha_imag
* temp_real
;
270 } else if (TransF
== CblasTrans
&& TransG
== CblasNoTrans
) {
272 for (k
= 0; k
< K
; k
++) {
273 for (i
= 0; i
< n1
; i
++) {
274 const BASE Fki_real
= CONST_REAL(F
, ldf
* k
+ i
);
275 const BASE Fki_imag
= conjF
* CONST_IMAG(F
, ldf
* k
+ i
);
276 const BASE temp_real
= alpha_real
* Fki_real
- alpha_imag
* Fki_imag
;
277 const BASE temp_imag
= alpha_real
* Fki_imag
+ alpha_imag
* Fki_real
;
278 if (!(temp_real
== 0.0 && temp_imag
== 0.0)) {
279 for (j
= 0; j
< n2
; j
++) {
280 const BASE Gkj_real
= CONST_REAL(G
, ldg
* k
+ j
);
281 const BASE Gkj_imag
= conjG
* CONST_IMAG(G
, ldg
* k
+ j
);
282 REAL(C
, ldc
* i
+ j
) += temp_real
* Gkj_real
- temp_imag
* Gkj_imag
;
283 IMAG(C
, ldc
* i
+ j
) += temp_real
* Gkj_imag
+ temp_imag
* Gkj_real
;
289 } else if (TransF
== CblasTrans
&& TransG
== CblasTrans
) {
291 for (i
= 0; i
< n1
; i
++) {
292 for (j
= 0; j
< n2
; j
++) {
293 BASE temp_real
= 0.0;
294 BASE temp_imag
= 0.0;
295 for (k
= 0; k
< K
; k
++) {
296 const BASE Fki_real
= CONST_REAL(F
, ldf
* k
+ i
);
297 const BASE Fki_imag
= conjF
* CONST_IMAG(F
, ldf
* k
+ i
);
298 const BASE Gjk_real
= CONST_REAL(G
, ldg
* j
+ k
);
299 const BASE Gjk_imag
= conjG
* CONST_IMAG(G
, ldg
* j
+ k
);
301 temp_real
+= Fki_real
* Gjk_real
- Fki_imag
* Gjk_imag
;
302 temp_imag
+= Fki_real
* Gjk_imag
+ Fki_imag
* Gjk_real
;
304 REAL(C
, ldc
* i
+ j
) += alpha_real
* temp_real
- alpha_imag
* temp_imag
;
305 IMAG(C
, ldc
* i
+ j
) += alpha_real
* temp_imag
+ alpha_imag
* temp_real
;
310 BLAS_ERROR("unrecognized operation");