2 from ..utils
import MATRIX
, MATRIX_DICT
3 from .algebra
import Matrix
, MatrixBase
, MatrixDict
5 from ..core
import init_module
6 init_module
.import_lowlevel_operations()
8 def MATRIX_DICT_iadd(self
, other
):
9 """ Inplace matrix add.
12 if t
is list or t
is tuple:
20 head1
, data1
= ret
.pair
21 head2
, data2
= other
.pair
22 assert head1
.shape
==head2
.shape
,`head1
, head2`
23 if head1
.is_transpose
:
24 if head2
.is_transpose
:
25 iadd_MATRIX_MATRIX_TT(data1
, data2
)
27 iadd_MATRIX_MATRIX_TA(data1
, data2
)
28 elif head2
.is_transpose
:
29 iadd_MATRIX_MATRIX_AT(data1
, data2
)
31 iadd_MATRIX_MATRIX_AA(data1
, data2
)
35 elif isinstance(other
, MatrixBase
):
36 raise NotImplementedError(`
type(other
)`
)
44 rows
, cols
= head
.shape
46 iadd_MATRIX_T_SCALAR(rows
, cols
, data
, other
)
48 iadd_MATRIX_SCALAR(rows
, cols
, data
, other
)
53 def MATRIX_DICT_imul(self
, other
):
54 """ Inplace matrix multiplication.
57 if t
is list or t
is tuple:
61 head1
, data1
= self
.pair
62 head2
, data2
= other
.pair
63 if head1
.is_array
or head2
.is_array
:
64 assert head1
.shape
==head2
.shape
,`head1
, head2`
70 if head1
.is_transpose
:
71 if head2
.is_transpose
:
72 imul_MATRIX_MATRIX_ATT(data1
, data2
)
74 imul_MATRIX_MATRIX_TA(data1
, data2
)
75 elif head2
.is_transpose
:
76 imul_MATRIX_MATRIX_AT(data1
, data2
)
78 imul_MATRIX_MATRIX_AA(data1
, data2
)
83 assert head1
.cols
==head2
.rows
,`head1
, head2`
84 args
= data1
, data2
, head1
.rows
, head2
.cols
, head1
.cols
85 if head1
.is_transpose
:
86 if head2
.is_transpose
:
87 ret
= mul_MATRIX_MATRIX_MTT(*args
)
89 ret
= mul_MATRIX_MATRIX_TM(*args
)
90 elif head2
.is_transpose
:
91 ret
= mul_MATRIX_MATRIX_MT(*args
)
93 ret
= mul_MATRIX_MATRIX_MM(*args
)
95 elif isinstance(other
, MatrixBase
):
96 raise NotImplementedError(`
type(other
)`
)
103 head
, data
= ret
.pair
107 head
, data
= ret
.pair
111 def iadd_MATRIX_SCALAR(rows
, cols
, data
, value
):
112 col_indices
= range(cols
)
113 for i
in xrange(rows
):
114 for j
in col_indices
:
126 def iadd_MATRIX_T_SCALAR(rows
, cols
, data
, value
):
127 col_indices
= range(cols
)
128 for i
in xrange(rows
):
129 for j
in col_indices
:
141 def iadd_MATRIX_MATRIX_AA(data1
, data2
):
142 for key
,x
in data2
.items():
153 def iadd_MATRIX_MATRIX_AT(data1
, data2
):
154 for (j
,i
),x
in data2
.items():
166 def iadd_MATRIX_MATRIX_TA(data1
, data2
):
167 for (i
,j
),x
in data2
.items():
179 iadd_MATRIX_MATRIX_TT
= iadd_MATRIX_MATRIX_AA
181 def imul_MATRIX_MATRIX_AA(data1
, data2
):
189 def imul_MATRIX_MATRIX_AT(data1
, data2
):
198 imul_MATRIX_MATRIX_TA
= imul_MATRIX_MATRIX_AT
199 imul_MATRIX_MATRIX_ATT
= imul_MATRIX_MATRIX_AA
201 def mul_MATRIX_MATRIX_AA(data1
, data2
, rows
, cols
):
204 data1_get
= data1
.get
205 data2_get
= data2
.get
206 for i
in xrange(rows
):
207 for j
in xrange(cols
):
209 a_ij
= data1_get(key
)
212 b_ij
= data2_get(key
)
216 return MatrixDict(MATRIX(rows
, cols
, MATRIX_DICT
), d
)
218 def mul_MATRIX_MATRIX_AT(data1
, data2
, rows
, cols
):
221 data1_get
= data1
.get
222 data2_get
= data2
.get
223 for i
in xrange(rows
):
224 for j
in xrange(cols
):
226 a_ij
= data1_get(key
)
229 b_ij
= data2_get((j
,i
))
233 return MatrixDict(MATRIX(rows
, cols
, MATRIX_DICT
), d
)
235 def mul_MATRIX_MATRIX_TA(data1
, data2
, rows
, cols
):
238 data1_get
= data1
.get
239 data2_get
= data2
.get
240 for i
in xrange(rows
):
241 for j
in xrange(cols
):
243 a_ij
= data1_get((j
,i
))
246 b_ij
= data2_get(key
)
250 return MatrixDict(MATRIX(rows
, cols
, MATRIX_DICT
), d
)
252 def mul_MATRIX_MATRIX_ATT(data1
, data2
, rows
, cols
):
255 data1_get
= data1
.get
256 data2_get
= data2
.get
257 for i
in xrange(rows
):
258 for j
in xrange(cols
):
260 a_ij
= data1_get(ikey
)
263 b_ij
= data2_get(ikey
)
267 return MatrixDict(MATRIX(rows
, cols
, MATRIX_DICT
), d
)
269 def mul_MATRIX_MATRIX_MM(data1
, data2
, rows
, cols
, n
):
273 for i
,j
in data1
.keys():
274 s
= left_indices
.get(i
)
276 s
= left_indices
[i
] = set()
278 for j
,k
in data2
.keys():
279 s
= right_indices
.get(k
)
281 s
= right_indices
[k
] = set()
284 for i
,ji
in left_indices
.items ():
285 for k
,jk
in right_indices
.items ():
286 for j
in ji
.intersection(jk
):
287 dict_add_item(None, d
, (i
,k
), data1
[(i
,j
)] * data2
[(j
,k
)])
288 return MatrixDict(MATRIX(rows
, cols
, MATRIX_DICT
), d
)
290 def mul_MATRIX_MATRIX_TM(data1
, data2
, rows
, cols
, n
):
294 for j
,i
in data1
.keys():
295 s
= left_indices
.get(i
)
297 s
= left_indices
[i
] = set()
299 for j
,k
in data2
.keys():
300 s
= right_indices
.get(k
)
302 s
= right_indices
[k
] = set()
305 for i
,ji
in left_indices
.items ():
306 for k
,jk
in right_indices
.items ():
307 for j
in ji
.intersection(jk
):
308 dict_add_item(None, d
, (i
,k
), data1
[(j
,i
)] * data2
[(j
,k
)])
309 return MatrixDict(MATRIX(rows
, cols
, MATRIX_DICT
), d
)
311 def mul_MATRIX_MATRIX_MT(data1
, data2
, rows
, cols
, n
):
315 for i
,j
in data1
.keys():
316 s
= left_indices
.get(i
)
318 s
= left_indices
[i
] = set()
320 for k
,j
in data2
.keys():
321 s
= right_indices
.get(k
)
323 s
= right_indices
[k
] = set()
326 for i
,ji
in left_indices
.items ():
327 for k
,jk
in right_indices
.items ():
328 for j
in ji
.intersection(jk
):
329 dict_add_item(None, d
, (i
,k
), data1
[(i
,j
)] * data2
[(k
,j
)])
330 return MatrixDict(MATRIX(rows
, cols
, MATRIX_DICT
), d
)
332 def mul_MATRIX_MATRIX_MTT(data1
, data2
, rows
, cols
, n
):
336 for j
,i
in data1
.keys():
337 s
= left_indices
.get(i
)
339 s
= left_indices
[i
] = set()
341 for k
,j
in data2
.keys():
342 s
= right_indices
.get(k
)
344 s
= right_indices
[k
] = set()
347 for i
,ji
in left_indices
.items ():
348 for k
,jk
in right_indices
.items ():
349 for j
in ji
.intersection(jk
):
350 dict_add_item(None, d
, (i
,k
), data1
[(j
,i
)] * data2
[(k
,j
)])
351 return MatrixDict(MATRIX(rows
, cols
, MATRIX_DICT
), d
)