1 USING: accessors alien alien.c-types arrays byte-arrays combinators
2 combinators.short-circuit fry kernel locals macros
3 math math.blas.cblas math.blas.vectors math.blas.vectors.private
4 math.complex math.functions math.order functors words
5 sequences sequences.merged sequences.private shuffle
6 specialized-arrays.direct.float specialized-arrays.direct.double
7 specialized-arrays.float specialized-arrays.double ;
10 TUPLE: blas-matrix-base underlying ld rows cols transpose ;
12 : Mtransposed? ( matrix -- ? )
14 : Mwidth ( matrix -- width )
15 dup Mtransposed? [ rows>> ] [ cols>> ] if ; inline
16 : Mheight ( matrix -- height )
17 dup Mtransposed? [ cols>> ] [ rows>> ] if ; inline
19 GENERIC: n*M.V+n*V! ( alpha A x beta y -- y=alpha*A.x+b*y )
20 GENERIC: n*V(*)V+M! ( alpha x y A -- A=alpha*x(*)y+A )
21 GENERIC: n*V(*)Vconj+M! ( alpha x y A -- A=alpha*x(*)yconj+A )
22 GENERIC: n*M.M+n*M! ( alpha A B beta C -- C=alpha*A.B+beta*C )
26 : (blas-transpose) ( matrix -- integer )
27 transpose>> [ CblasTrans ] [ CblasNoTrans ] if ;
29 GENERIC: (blas-matrix-like) ( data ld rows cols transpose exemplar -- matrix )
31 : (validate-gemv) ( A x y -- )
33 [ drop [ Mwidth ] [ length>> ] bi* = ]
34 [ nip [ Mheight ] [ length>> ] bi* = ]
36 [ "Mismatched matrix and vectors in matrix-vector multiplication" throw ]
40 ( alpha A x beta y >c-arg -- order A-trans m n alpha A-data A-ld x-data x-inc beta y-data y-inc
57 : (validate-ger) ( x y A -- )
59 [ nip [ length>> ] [ Mheight ] bi* = ]
60 [ nipd [ length>> ] [ Mwidth ] bi* = ]
62 [ "Mismatched vertices and matrix in vector outer product" throw ]
66 ( alpha x y A >c-arg -- order m n alpha x-data x-inc y-data y-inc A-data A-ld
79 A f >>transpose ; inline
81 : (validate-gemm) ( A B C -- )
83 [ drop [ Mwidth ] [ Mheight ] bi* = ]
84 [ nip [ Mheight ] bi@ = ]
85 [ nipd [ Mwidth ] bi@ = ]
87 [ "Mismatched matrices in matrix multiplication" throw ]
91 ( alpha A B beta C >c-arg -- order A-trans B-trans m n k alpha A-data A-ld B-data B-ld beta C-data C-ld
108 C f >>transpose ; inline
110 : (>matrix) ( arrays >c-array -- c-array ld rows cols transpose )
111 '[ <merged> @ ] [ length dup ] [ first length ] tri f ; inline
115 ! XXX should do a dense clone
116 M: blas-matrix-base clone
122 [ element-type heap-size ]
123 } cleave * * memory>byte-array ]
131 ] keep (blas-matrix-like) ;
133 ! XXX try rounding stride to next 128 bit bound for better vectorizin'
134 : <empty-matrix> ( rows cols exemplar -- matrix )
135 [ element-type [ * ] dip <c-array> ]
137 [ f swap (blas-matrix-like) ] 3tri ;
139 : n*M.V+n*V ( alpha A x beta y -- alpha*A.x+b*y )
141 : n*V(*)V+M ( alpha x y A -- alpha*x(*)y+A )
143 : n*V(*)Vconj+M ( alpha x y A -- alpha*x(*)yconj+A )
144 clone n*V(*)Vconj+M! ;
145 : n*M.M+n*M ( alpha A B beta C -- alpha*A.B+beta*C )
148 : n*M.V ( alpha A x -- alpha*A.x )
149 1.0 2over [ Mheight ] dip <empty-vector>
153 1.0 -rot n*M.V ; inline
155 : n*V(*)V ( alpha x y -- alpha*x(*)y )
156 2dup [ length>> ] bi@ pick <empty-matrix>
158 : n*V(*)Vconj ( alpha x y -- alpha*x(*)yconj )
159 2dup [ length>> ] bi@ pick <empty-matrix>
162 : V(*) ( x y -- x(*)y )
163 1.0 -rot n*V(*)V ; inline
164 : V(*)conj ( x y -- x(*)yconj )
165 1.0 -rot n*V(*)Vconj ; inline
167 : n*M.M ( alpha A B -- alpha*A.B )
168 2dup [ Mheight ] [ Mwidth ] bi* pick <empty-matrix>
169 1.0 swap n*M.M+n*M! ;
172 1.0 -rot n*M.M ; inline
174 :: (Msub) ( matrix row col height width -- data ld rows cols )
175 matrix ld>> col * row + matrix element-type heap-size *
176 matrix underlying>> <displaced-alien>
181 :: Msub ( matrix row col height width -- sub )
182 matrix dup transpose>>
183 [ col row width height ]
184 [ row col height width ] if (Msub)
185 matrix transpose>> matrix (blas-matrix-like) ;
187 TUPLE: blas-matrix-rowcol-sequence
188 parent inc rowcol-length rowcol-jump length ;
189 C: <blas-matrix-rowcol-sequence> blas-matrix-rowcol-sequence
191 INSTANCE: blas-matrix-rowcol-sequence sequence
193 M: blas-matrix-rowcol-sequence length
195 M: blas-matrix-rowcol-sequence nth-unsafe
199 [ parent>> element-type heap-size ]
200 [ parent>> underlying>> ] tri
201 [ * * ] dip <displaced-alien>
206 } cleave (blas-vector-like) ;
208 : (Mcols) ( A -- columns )
209 { [ ] [ drop 1 ] [ rows>> ] [ ld>> ] [ cols>> ] }
210 cleave <blas-matrix-rowcol-sequence> ;
211 : (Mrows) ( A -- rows )
212 { [ ] [ ld>> ] [ cols>> ] [ drop 1 ] [ rows>> ] }
213 cleave <blas-matrix-rowcol-sequence> ;
215 : Mrows ( A -- rows )
216 dup transpose>> [ (Mcols) ] [ (Mrows) ] if ;
217 : Mcols ( A -- cols )
218 dup transpose>> [ (Mrows) ] [ (Mcols) ] if ;
220 : n*M! ( n A -- A=n*A )
221 [ (Mcols) [ n*V! drop ] with each ] keep ;
229 recip swap n*M ; inline
231 : Mtranspose ( matrix -- matrix^T )
237 } cleave ] keep (blas-matrix-like) ;
239 M: blas-matrix-base equal?
242 [ [ Mcols ] bi@ [ = ] 2all? ]
247 FUNCTOR: (define-blas-matrix) ( TYPE T U C -- )
249 VECTOR IS ${TYPE}-blas-vector
250 <VECTOR> IS <${TYPE}-blas-vector>
251 >ARRAY IS >${TYPE}-array
252 TYPE>ARG IS ${TYPE}>arg
253 XGEMV IS cblas_${T}gemv
254 XGEMM IS cblas_${T}gemm
255 XGERU IS cblas_${T}ger${U}
256 XGERC IS cblas_${T}ger${C}
258 MATRIX DEFINES ${TYPE}-blas-matrix
259 <MATRIX> DEFINES <${TYPE}-blas-matrix>
260 >MATRIX DEFINES >${TYPE}-blas-matrix
264 TUPLE: MATRIX < blas-matrix-base ;
265 : <MATRIX> ( underlying ld rows cols transpose -- matrix )
268 M: MATRIX element-type
270 M: MATRIX (blas-matrix-like)
271 drop <MATRIX> execute ;
272 M: VECTOR (blas-matrix-like)
273 drop <MATRIX> execute ;
274 M: MATRIX (blas-vector-like)
275 drop <VECTOR> execute ;
277 : >MATRIX ( arrays -- matrix )
278 [ >ARRAY execute underlying>> ] (>matrix)
282 [ TYPE>ARG execute ] (prepare-gemv)
283 [ XGEMV execute ] dip ;
285 [ TYPE>ARG execute ] (prepare-gemm)
286 [ XGEMM execute ] dip ;
288 [ TYPE>ARG execute ] (prepare-ger)
289 [ XGERU execute ] dip ;
290 M: MATRIX n*V(*)Vconj+M!
291 [ TYPE>ARG execute ] (prepare-ger)
292 [ XGERC execute ] dip ;
297 : define-real-blas-matrix ( TYPE T -- )
298 "" "" (define-blas-matrix) ;
299 : define-complex-blas-matrix ( TYPE T -- )
300 "u" "c" (define-blas-matrix) ;
302 "float" "s" define-real-blas-matrix
303 "double" "d" define-real-blas-matrix
304 "float-complex" "c" define-complex-blas-matrix
305 "double-complex" "z" define-complex-blas-matrix