support for TMatrixInterpolator in TMatrixGenerator
[qpms.git] / qpms / qpms_c.pyx
blobf99095f8976cfc0560668b25634d0ef1f1d61734
1 """@package qpms_c
2 Cythonized parts of QPMS; mostly wrappers over the C data structures
3 to make them available in Python.
4 """
6 # Cythonized parts of QPMS here
7 # -----------------------------
9 import numpy as np
10 from .qpms_cdefs cimport *
11 from .cyquaternions cimport IRot3, CQuat
12 from .cybspec cimport BaseSpec
13 from .cycommon cimport make_c_string
14 from .cycommon import string_c2py, PointGroupClass
15 from .cytmatrices cimport CTMatrix, TMatrixFunction, TMatrixGenerator, TMatrixInterpolator
16 from .cymaterials cimport EpsMuGenerator
17 from libc.stdlib cimport malloc, free, calloc
18 import warnings
21 # Set custom GSL error handler. N.B. this is obviously not thread-safe.
22 cdef char *pgsl_err_reason
23 cdef char *pgsl_err_file
24 cdef int pgsl_err_line
25 cdef int pgsl_errno = 0
26 cdef int *pgsl_errno_ignorelist = NULL # list of ignored error codes, terminated by zero
28 # This error handler only sets the variables above
29 cdef void pgsl_error_handler(const char *reason, const char *_file, const int line, const int gsl_errno):
30 global pgsl_err_reason, pgsl_err_file, pgsl_err_line, pgsl_errno, pgsl_errno_ignorelist
31 cdef size_t i
32 if(pgsl_errno_ignorelist):
33 i = 0
34 while pgsl_errno_ignorelist[i] != 0:
35 if gsl_errno == pgsl_errno_ignorelist[i]:
36 return
37 i += 1
38 pgsl_err_file = _file
39 pgsl_err_reason = reason
40 pgsl_errno = gsl_errno
41 pgsl_err_line = line
42 return
44 cdef const int* pgsl_set_ignorelist(const int *new_ignorelist):
45 global pgsl_errno_ignorelist
46 cdef const int *oldlist = pgsl_errno_ignorelist
47 pgsl_errno_ignorelist = new_ignorelist
48 return oldlist
50 cdef class pgsl_ignore_error():
51 '''Context manager for setting a temporary list of errnos ignored by pgsl_error_handler.
52 Always sets pgsl_error_handler.
54 Performs pgsl_check_err() on exit unless
55 '''
56 cdef const int *ignorelist_old
57 cdef gsl_error_handler_t *old_handler
58 cdef bint final_check
59 cdef object ignorelist_python
61 cdef int *ignorelist
62 def __cinit__(self, *ignorelist, **kwargs):
63 self.ignorelist = <int*>calloc((len(ignorelist)+1), sizeof(int))
64 self.ignorelist_python = ignorelist
65 for i in range(len(ignorelist)):
66 self.ignorelist[i] = ignorelist[i]
67 if "final_check" in kwargs.keys() and not kwargs["final_check"]:
68 final_check = True
69 final_check = False
71 def __enter__(self):
72 global pgsl_error_handler
73 self.ignorelist_old = pgsl_set_ignorelist(self.ignorelist)
74 self.old_handler = gsl_set_error_handler(pgsl_error_handler)
75 return
77 def __exit__(self, type, value, traceback):
78 global pgsl_errno_ignorelist, pgsl_error_handler
79 pgsl_set_ignorelist(self.ignorelist_old)
80 gsl_set_error_handler(self.old_handler)
81 if self.final_check:
82 pgsl_check_err(retval = None, ignore = self.ignorelist_python)
84 def __dealloc__(self):
85 free(self.ignorelist)
87 def pgsl_check_err(retval = None, ignorelist = None):
88 global pgsl_err_reason, pgsl_err_file, pgsl_err_line, pgsl_errno
89 '''Check for possible errors encountered by pgsl_error_handler.
90 Takes return value of a function as an optional argument, which is now ignored.
91 '''
92 cdef int errno_was
93 if (pgsl_errno != 0):
94 errno_was = pgsl_errno
95 pgsl_errno = 0
96 raise RuntimeError("Error %d in GSL calculation in %s:%d: %s" % (errno_was,
97 string_c2py(pgsl_err_file), pgsl_err_line, string_c2py(pgsl_err_reason)))
98 if (retval is not None and retval != 0 and ignorelist is not None and retval not in ignorelist):
99 warnings.warn("Got non-zero return value %d" % retval)
100 if retval is not None:
101 return retval
102 else:
103 return 0
105 def set_gsl_pythonic_error_handling():
107 Sets pgsl_error_handler as the GSL error handler to avoid crashing.
109 gsl_set_error_handler(pgsl_error_handler)
111 cdef class PointGroup:
112 cdef readonly qpms_pointgroup_t G
114 def __init__(self, cls, qpms_gmi_t n = 0, IRot3 orientation = IRot3()):
115 cls = PointGroupClass(cls)
116 self.G.c = cls
117 if n <= 0 and qpms_pg_is_finite_axial(cls):
118 raise ValueError("For finite axial groups, n argument must be positive")
119 self.G.n = n
120 self.G.orientation = orientation.qd
122 def __len__(self):
123 return qpms_pg_order(self.G.c, self.G.n);
125 def __le__(PointGroup self, PointGroup other):
126 if qpms_pg_is_subgroup(self.G, other.G):
127 return True
128 else:
129 return False
130 def __ge__(PointGroup self, PointGroup other):
131 if qpms_pg_is_subgroup(other.G, self.G):
132 return True
133 else:
134 return False
135 def __lt__(PointGroup self, PointGroup other):
136 return qpms_pg_is_subgroup(self.G, other.G) and not qpms_pg_is_subgroup(other.G, self.G)
137 def __eq__(PointGroup self, PointGroup other):
138 return qpms_pg_is_subgroup(self.G, other.G) and qpms_pg_is_subgroup(other.G, self.G)
139 def __gt__(PointGroup self, PointGroup other):
140 return not qpms_pg_is_subgroup(self.G, other.G) and qpms_pg_is_subgroup(other.G, self.G)
142 def elems(self):
143 els = list()
144 cdef qpms_irot3_t *arr
145 arr = qpms_pg_elems(NULL, self.G)
146 cdef IRot3 q
147 for i in range(len(self)):
148 q = IRot3()
149 q.cset(arr[i])
150 els.append(q)
151 free(arr)
152 return els
154 cdef class FinitePointGroup:
156 Wrapper over the qpms_finite_group_t structure.
158 TODO more functionality to make it better usable in Python
159 (group element class at least)
162 def __cinit__(self, info):
163 '''Constructs a FinitePointGroup from PointGroupInfo'''
164 # TODO maybe I might use a try..finally statement to avoid leaks
165 # First, generate all basic data from info
166 permlist = info.deterministic_elemlist()
167 cdef int order = len(permlist)
168 permindices = {perm: i for i, perm in enumerate(permlist)} # 'invert' permlist
169 identity = info.permgroup.identity
170 # We use calloc to avoid calling free to unitialized pointers
171 self.G = <qpms_finite_group_t *>calloc(1,sizeof(qpms_finite_group_t))
172 if not self.G: raise MemoryError
173 self.G[0].name = make_c_string(info.name)
174 self.G[0].order = order
175 self.G[0].idi = permindices[identity]
176 self.G[0].mt = <qpms_gmi_t *>malloc(sizeof(qpms_gmi_t) * order * order)
177 if not self.G[0].mt: raise MemoryError
178 for i in range(order):
179 for j in range(order):
180 self.G[0].mt[i*order + j] = permindices[permlist[i] * permlist[j]]
181 self.G[0].invi = <qpms_gmi_t *>malloc(sizeof(qpms_gmi_t) * order)
182 if not self.G[0].invi: raise MemoryError
183 for i in range(order):
184 self.G[0].invi[i] = permindices[permlist[i]**-1]
185 self.G[0].ngens = len(info.permgroupgens)
186 self.G[0].gens = <qpms_gmi_t *>malloc(sizeof(qpms_gmi_t) * self.G[0].ngens)
187 if not self.G[0].gens: raise MemoryError
188 for i in range(self.G[0].ngens):
189 self.G[0].gens[i] = permindices[info.permgroupgens[i]]
190 self.G[0].permrep = <char **>calloc(order, sizeof(char *))
191 if not self.G[0].permrep: raise MemoryError
192 for i in range(order):
193 self.G[0].permrep[i] = make_c_string(str(permlist[i]))
194 if not self.G[0].permrep[i]: raise MemoryError
195 self.G[0].permrep_nelem = info.permgroup.degree
196 if info.rep3d is not None:
197 self.G[0].rep3d = <qpms_irot3_t *>malloc(order * sizeof(qpms_irot3_t))
198 for i in range(order):
199 self.G[0].rep3d[i] = info.rep3d[permlist[i]].qd
200 self.G[0].nirreps = len(info.irreps)
201 self.G[0].irreps = <qpms_finite_group_irrep_t *>calloc(self.G[0].nirreps, sizeof(qpms_finite_group_irrep_t))
202 if not self.G[0].irreps: raise MemoryError
203 cdef int dim
204 for iri, irname in enumerate(sorted(info.irreps.keys())):
205 irrep = info.irreps[irname]
206 is1d = isinstance(irrep[identity], (int, float, complex))
207 dim = 1 if is1d else irrep[identity].shape[0]
208 self.G[0].irreps[iri].dim = dim
209 self.G[0].irreps[iri].name = <char *>make_c_string(irname)
210 if not self.G[0].irreps[iri].name: raise MemoryError
211 self.G[0].irreps[iri].m = <cdouble *>malloc(dim*dim*sizeof(cdouble)*order)
212 if not self.G[0].irreps[iri].m: raise MemoryError
213 if is1d:
214 for i in range(order):
215 self.G[0].irreps[iri].m[i] = irrep[permlist[i]]
216 else:
217 for i in range(order):
218 for row in range(dim):
219 for col in range(dim):
220 self.G[0].irreps[iri].m[i*dim*dim + row*dim + col] = irrep[permlist[i]][row,col]
221 self.G[0].elemlabels = <char **> 0 # Elem labels not yet implemented
222 self.owns_data = True
224 def __dealloc__(self):
225 cdef qpms_gmi_t order
226 if self.owns_data:
227 if self.G:
228 order = self.G[0].order
229 free(self.G[0].name)
230 free(self.G[0].mt)
231 free(self.G[0].invi)
232 free(self.G[0].gens)
233 if self.G[0].permrep:
234 for i in range(order): free(self.G[0].permrep[i])
235 free(self.G[0].permrep)
236 if self.G[0].elemlabels: # this is not even contructed right now
237 for i in range(order): free(self.G[0].elemlabels[i])
238 if self.G[0].irreps:
239 for iri in range(self.G[0].nirreps):
240 free(self.G[0].irreps[iri].name)
241 free(self.G[0].irreps[iri].m)
242 free(self.G[0].irreps)
243 free(self.G)
244 self.G = <qpms_finite_group_t *>0
245 self.owns_data = False
247 cdef class FinitePointGroupElement:
248 '''TODO'''
249 cdef readonly FinitePointGroup G
250 cdef readonly qpms_gmi_t gmi
251 def __cinit__(self, FinitePointGroup G, qpms_gmi_t gmi):
252 self.G = G
253 self.gmi = gmi
255 cdef class Particle:
257 Wrapper over the qpms_particle_t structure.
259 cdef qpms_particle_t p
260 cdef readonly TMatrixFunction f # Reference to ensure correct reference counting
263 def __cinit__(Particle self, pos, t, bspec = None):
264 cdef TMatrixGenerator tgen
265 cdef BaseSpec spec
266 if(len(pos)>=2 and len(pos) < 4):
267 self.p.pos.x = pos[0]
268 self.p.pos.y = pos[1]
269 self.p.pos.z = pos[2] if len(pos)==3 else 0
270 else:
271 raise ValueError("Position argument has to contain 3 or 2 cartesian coordinates")
272 if isinstance(t, CTMatrix):
273 tgen = TMatrixGenerator(t)
274 elif isinstance(t, TMatrixInterpolator):
275 tgen = TMatrixGenerator(t)
276 warnings.warn("Initialising a particle with interpolated T-matrix values. Imaginary frequencies will be discarded and mode search algorithm will yield nonsense (just saying).")
277 elif isinstance(t, TMatrixGenerator):
278 tgen = <TMatrixGenerator>t
279 else: raise TypeError('t must be either CTMatrix or TMatrixGenerator, was %s' % str(type(t)))
280 if bspec is not None:
281 spec = bspec
282 else:
283 if isinstance(tgen.holder, CTMatrix):
284 spec = (<CTMatrix>tgen.holder).spec
285 else:
286 raise ValueError("bspec argument must be specified separately for %s" % str(type(t)))
287 self.f = TMatrixFunction(tgen, spec)
288 self.p.tmg = self.f.rawpointer()
289 # TODO non-trivial transformations later; if modified, do not forget to update ScatteringSystem constructor
290 self.p.op = qpms_tmatrix_operation_noop
292 def __dealloc__(self):
293 qpms_tmatrix_operation_clear(&self.p.op)
295 cdef qpms_particle_t *rawpointer(Particle self):
296 '''Pointer to the qpms_particle_p structure.
298 return &(self.p)
299 property rawpointer:
300 def __get__(self):
301 return <uintptr_t> &(self.p)
303 cdef qpms_particle_t cval(Particle self):
304 '''Provides a copy for assigning in cython code'''
305 return self.p
307 property x:
308 def __get__(self):
309 return self.p.pos.x
310 def __set__(self,x):
311 self.p.pos.x = x
312 property y:
313 def __get__(self):
314 return self.p.pos.y
315 def __set__(self,y):
316 self.p.pos.y = y
317 property z:
318 def __get__(self):
319 return self.p.pos.z
320 def __set__(self,z):
321 self.p.pos.z = z
322 property pos:
323 def __get__(self):
324 return (self.p.pos.x, self.p.pos.y, self.p.pos.z)
325 def __set__(self, pos):
326 if(len(pos)>=2 and len(pos) < 4):
327 self.p.pos.x = pos[0]
328 self.p.pos.y = pos[1]
329 self.p.pos.z = pos[2] if len(pos)==3 else 0
330 else:
331 raise ValueError("Position argument has to contain 3 or 2 cartesian coordinates")
333 cpdef void scatsystem_set_nthreads(long n):
334 qpms_scatsystem_set_nthreads(n)
335 return
338 cdef class ScatteringSystem:
340 Wrapper over the C qpms_scatsys_t structure.
342 cdef list tmgobjs # here we keep the references to occuring TMatrixFunctions (and hence BaseSpecs and TMatrixGenerators)
343 #cdef list Tmatrices # Here we keep the references to occuring T-matrices
344 cdef EpsMuGenerator medium_holder # Here we keep the reference to medium generator
345 cdef qpms_scatsys_t *s
346 cdef FinitePointGroup sym
348 cdef qpms_iri_t iri_py2c(self, iri, allow_None = True):
349 if iri is None and allow_None:
350 return QPMS_NO_IRREP
351 cdef qpms_iri_t nir = self.nirreps
352 cdef qpms_iri_t ciri = iri
353 if ciri < 0 or ciri > nir:
354 raise ValueError("Invalid irrep index %s (of %d irreps)", str(iri), self.nirreps)
355 return ciri
357 def check_s(self): # cdef instead?
358 if self.s == <qpms_scatsys_t *>NULL:
359 raise ValueError("ScatteringSystem's s-pointer not set. You must not use the default constructor; use the create() method instead")
360 #TODO is there a way to disable the constructor outside this module?
362 @staticmethod # We don't have any "standard" constructor for this right now
363 def create(particles, medium, FinitePointGroup sym, cdouble omega): # TODO tolerances
364 # These we are going to construct
365 cdef ScatteringSystem self
366 cdef _ScatteringSystemAtOmega pyssw
368 cdef qpms_scatsys_t orig # This should be automatically init'd to 0 (CHECKME)
369 cdef qpms_ss_pi_t pi, p_count = len(particles)
370 cdef qpms_ss_tmi_t tmi, tm_count = 0
371 cdef qpms_ss_tmgi_t tmgi, tmg_count = 0
373 cdef qpms_scatsys_at_omega_t *ssw
374 cdef qpms_scatsys_t *ss
376 cdef Particle p
378 tmgindices = dict()
379 tmgobjs = list()
380 tmindices = dict()
381 tmlist = list()
382 for p in particles: # find and enumerate unique t-matrix generators
383 if p.p.op.typ != QPMS_TMATRIX_OPERATION_NOOP:
384 raise NotImplementedError("currently, only no-op T-matrix operations are allowed in ScatteringSystem constructor")
385 #tmg_key = id(p.f) # This causes a different generator for each particle -> SUPER SLOW
386 tmg_key = (id(p.f.generator), id(p.f.spec))
387 if tmg_key not in tmgindices:
388 tmgindices[tmg_key] = tmg_count
389 tmgobjs.append(p.f) # Save the references on BaseSpecs and TMatrixGenerators (via TMatrixFunctions)
390 tmg_count += 1
391 # Following lines have to be adjusted when nontrivial operations allowed:
392 tm_derived_key = (tmg_key, None) # TODO unique representation of p.p.op instead of None
393 if tm_derived_key not in tmindices:
394 tmindices[tm_derived_key] = tm_count
395 tmlist.append(tm_derived_key)
396 tm_count += 1
397 cdef EpsMuGenerator mediumgen = EpsMuGenerator(medium)
398 orig.medium = mediumgen.g
399 orig.tmg_count = tmg_count
400 orig.tm_count = tm_count
401 orig.p_count = p_count
402 try:
403 orig.tmg = <qpms_tmatrix_function_t *>malloc(orig.tmg_count * sizeof(orig.tmg[0]))
404 if not orig.tmg: raise MemoryError
405 orig.tm = <qpms_ss_derived_tmatrix_t *>malloc(orig.tm_count * sizeof(orig.tm[0]))
406 if not orig.tm: raise MemoryError
407 orig.p = <qpms_particle_tid_t *>malloc(orig.p_count * sizeof(orig.p[0]))
408 if not orig.p: raise MemoryError
409 for tmgi in range(orig.tmg_count):
410 orig.tmg[tmgi] = (<TMatrixFunction?>tmgobjs[tmgi]).raw()
411 for tmi in range(tm_count):
412 tm_derived_key = tmlist[tmi]
413 tmgi = tmgindices[tm_derived_key[0]]
414 orig.tm[tmi].tmgi = tmgi
415 orig.tm[tmi].op = qpms_tmatrix_operation_noop # TODO adjust when notrivial operations allowed
416 for pi in range(p_count):
417 p = particles[pi]
418 tmg_key = (id(p.f.generator), id(p.f.spec))
419 tm_derived_key = (tmg_key, None) # TODO unique representation of p.p.op instead of None
420 orig.p[pi].pos = p.cval().pos
421 orig.p[pi].tmatrix_id = tmindices[tm_derived_key]
422 ssw = qpms_scatsys_apply_symmetry(&orig, sym.rawpointer(), omega, &QPMS_TOLERANCE_DEFAULT)
423 ss = ssw[0].ss
424 finally:
425 free(orig.tmg)
426 free(orig.tm)
427 free(orig.p)
428 self = ScatteringSystem()
429 self.medium_holder = mediumgen
430 self.s = ss
431 self.tmgobjs = tmgobjs
432 self.sym = sym
433 pyssw = _ScatteringSystemAtOmega()
434 pyssw.ssw = ssw
435 pyssw.ss_pyref = self
436 return self, pyssw
438 def __call__(self, cdouble omega):
439 self.check_s()
440 cdef _ScatteringSystemAtOmega pyssw = _ScatteringSystemAtOmega()
441 pyssw.ssw = qpms_scatsys_at_omega(self.s, omega)
442 pyssw.ss_pyref = self
443 return pyssw
445 def __dealloc__(self):
446 if(self.s):
447 qpms_scatsys_free(self.s)
449 property particles_tmi:
450 def __get__(self):
451 self.check_s()
452 r = list()
453 cdef qpms_ss_pi_t pi
454 for pi in range(self.s[0].p_count):
455 r.append(self.s[0].p[pi])
456 return r
458 property fecv_size:
459 def __get__(self):
460 self.check_s()
461 return self.s[0].fecv_size
462 property saecv_sizes:
463 def __get__(self):
464 self.check_s()
465 return [self.s[0].saecv_sizes[i]
466 for i in range(self.s[0].sym[0].nirreps)]
467 property irrep_names:
468 def __get__(self):
469 self.check_s()
470 return [string_c2py(self.s[0].sym[0].irreps[iri].name)
471 if (self.s[0].sym[0].irreps[iri].name) else None
472 for iri in range(self.s[0].sym[0].nirreps)]
473 property nirreps:
474 def __get__(self):
475 self.check_s()
476 return self.s[0].sym[0].nirreps
478 def pack_vector(self, vect, iri):
479 self.check_s()
480 if len(vect) != self.fecv_size:
481 raise ValueError("Length of a full vector has to be %d, not %d"
482 % (self.fecv_size, len(vect)))
483 vect = np.array(vect, dtype=complex, copy=False, order='C')
484 cdef cdouble[::1] vect_view = vect;
485 cdef np.ndarray[np.complex_t, ndim=1] target_np = np.empty(
486 (self.saecv_sizes[iri],), dtype=complex, order='C')
487 cdef cdouble[::1] target_view = target_np
488 qpms_scatsys_irrep_pack_vector(&target_view[0], &vect_view[0], self.s, iri)
489 return target_np
490 def unpack_vector(self, packed, iri):
491 self.check_s()
492 if len(packed) != self.saecv_sizes[iri]:
493 raise ValueError("Length of %d. irrep-packed vector has to be %d, not %d"
494 % (iri, self.saecv_sizes, len(packed)))
495 packed = np.array(packed, dtype=complex, copy=False, order='C')
496 cdef cdouble[::1] packed_view = packed
497 cdef np.ndarray[np.complex_t, ndim=1] target_np = np.empty(
498 (self.fecv_size,), dtype=complex)
499 cdef cdouble[::1] target_view = target_np
500 qpms_scatsys_irrep_unpack_vector(&target_view[0], &packed_view[0],
501 self.s, iri, 0)
502 return target_np
503 def pack_matrix(self, fullmatrix, iri):
504 self.check_s()
505 cdef size_t flen = self.s[0].fecv_size
506 cdef size_t rlen = self.saecv_sizes[iri]
507 fullmatrix = np.array(fullmatrix, dtype=complex, copy=False, order='C')
508 if fullmatrix.shape != (flen, flen):
509 raise ValueError("Full matrix shape should be (%d,%d), is %s."
510 % (flen, flen, repr(fullmatrix.shape)))
511 cdef cdouble[:,::1] fullmatrix_view = fullmatrix
512 cdef np.ndarray[np.complex_t, ndim=2] target_np = np.empty(
513 (rlen, rlen), dtype=complex, order='C')
514 cdef cdouble[:,::1] target_view = target_np
515 qpms_scatsys_irrep_pack_matrix(&target_view[0][0], &fullmatrix_view[0][0],
516 self.s, iri)
517 return target_np
518 def unpack_matrix(self, packedmatrix, iri):
519 self.check_s()
520 cdef size_t flen = self.s[0].fecv_size
521 cdef size_t rlen = self.saecv_sizes[iri]
522 packedmatrix = np.array(packedmatrix, dtype=complex, copy=False, order='C')
523 if packedmatrix.shape != (rlen, rlen):
524 raise ValueError("Packed matrix shape should be (%d,%d), is %s."
525 % (rlen, rlen, repr(packedmatrix.shape)))
526 cdef cdouble[:,::1] packedmatrix_view = packedmatrix
527 cdef np.ndarray[np.complex_t, ndim=2] target_np = np.empty(
528 (flen, flen), dtype=complex, order='C')
529 cdef cdouble[:,::1] target_view = target_np
530 qpms_scatsys_irrep_unpack_matrix(&target_view[0][0], &packedmatrix_view[0][0],
531 self.s, iri, 0)
532 return target_np
534 def translation_matrix_full(self, double k, J = QPMS_HANKEL_PLUS):
535 self.check_s()
536 cdef size_t flen = self.s[0].fecv_size
537 cdef np.ndarray[np.complex_t, ndim=2] target = np.empty(
538 (flen,flen),dtype=complex, order='C')
539 cdef cdouble[:,::1] target_view = target
540 qpms_scatsys_build_translation_matrix_e_full(&target_view[0][0], self.s, k, J)
541 return target
543 def translation_matrix_packed(self, double k, qpms_iri_t iri, J = QPMS_HANKEL_PLUS):
544 self.check_s()
545 cdef size_t rlen = self.saecv_sizes[iri]
546 cdef np.ndarray[np.complex_t, ndim=2] target = np.empty(
547 (rlen,rlen),dtype=complex, order='C')
548 cdef cdouble[:,::1] target_view = target
549 qpms_scatsys_build_translation_matrix_e_irrep_packed(&target_view[0][0],
550 self.s, iri, k, J)
551 return target
553 property fullvec_psizes:
554 def __get__(self):
555 self.check_s()
556 cdef np.ndarray[int32_t, ndim=1] ar = np.empty((self.s[0].p_count,), dtype=np.int32)
557 cdef int32_t[::1] ar_view = ar
558 for pi in range(self.s[0].p_count):
559 ar_view[pi] = self.s[0].tm[self.s[0].p[pi].tmatrix_id].spec[0].n
560 return ar
563 property fullvec_poffsets:
564 def __get__(self):
565 self.check_s()
566 cdef np.ndarray[intptr_t, ndim=1] ar = np.empty((self.s[0].p_count,), dtype=np.intp)
567 cdef intptr_t[::1] ar_view = ar
568 cdef intptr_t offset = 0
569 for pi in range(self.s[0].p_count):
570 ar_view[pi] = offset
571 offset += self.s[0].tm[self.s[0].p[pi].tmatrix_id].spec[0].n
572 return ar
574 property positions:
575 def __get__(self):
576 self.check_s()
577 cdef np.ndarray[np.double_t, ndim=2] ar = np.empty((self.s[0].p_count, 3), dtype=float)
578 cdef np.double_t[:,::1] ar_view = ar
579 for pi in range(self.s[0].p_count):
580 ar_view[pi,0] = self.s[0].p[pi].pos.x
581 ar_view[pi,1] = self.s[0].p[pi].pos.y
582 ar_view[pi,2] = self.s[0].p[pi].pos.z
583 return ar
585 def planewave_full(self, k_cart, E_cart):
586 self.check_s()
587 k_cart = np.array(k_cart)
588 E_cart = np.array(E_cart)
589 if k_cart.shape != (3,) or E_cart.shape != (3,):
590 raise ValueError("k_cart and E_cart must be ndarrays of shape (3,)")
591 cdef qpms_incfield_planewave_params_t p
592 p.use_cartesian = 1
593 p.k.cart.x = <cdouble>k_cart[0]
594 p.k.cart.y = <cdouble>k_cart[1]
595 p.k.cart.z = <cdouble>k_cart[2]
596 p.E.cart.x = <cdouble>E_cart[0]
597 p.E.cart.y = <cdouble>E_cart[1]
598 p.E.cart.z = <cdouble>E_cart[2]
599 cdef np.ndarray[np.complex_t, ndim=1] target_np = np.empty(
600 (self.fecv_size,), dtype=complex)
601 cdef cdouble[::1] target_view = target_np
602 qpms_scatsys_incident_field_vector_full(&target_view[0],
603 self.s, qpms_incfield_planewave, <void *>&p, 0)
604 return target_np
606 def find_modes(self, cdouble omega_centre, double omega_rr, double omega_ri, iri = None,
607 size_t contour_points = 20, double rank_tol = 1e-4, size_t rank_min_sel=1,
608 double res_tol = 0):
610 Attempts to find the eigenvalues and eigenvectors using Beyn's algorithm.
613 cdef beyn_result_t *res = qpms_scatsys_finite_find_eigenmodes(self.s,
614 self.iri_py2c(iri),
615 omega_centre, omega_rr, omega_ri, contour_points,
616 rank_tol, rank_min_sel, res_tol)
617 if res == NULL: raise RuntimeError
619 cdef size_t neig = res[0].neig
620 cdef size_t vlen = res[0].vlen # should be equal to self.s.fecv_size
622 cdef np.ndarray[complex, ndim=1] eigval = np.empty((neig,), dtype=complex)
623 cdef cdouble[::1] eigval_v = eigval
624 cdef np.ndarray[complex, ndim=1] eigval_err = np.empty((neig,), dtype=complex)
625 cdef cdouble[::1] eigval_err_v = eigval_err
626 cdef np.ndarray[double, ndim=1] residuals = np.empty((neig,), dtype=np.double)
627 cdef double[::1] residuals_v = residuals
628 cdef np.ndarray[complex, ndim=2] eigvec = np.empty((neig,vlen),dtype=complex)
629 cdef cdouble[:,::1] eigvec_v = eigvec
630 cdef np.ndarray[double, ndim=1] ranktest_SV = np.empty((vlen), dtype=np.double)
631 cdef double[::1] ranktest_SV_v = ranktest_SV
633 for i in range(neig):
634 eigval_v[i] = res[0].eigval[i]
635 eigval_err_v[i] = res[0].eigval_err[i]
636 residuals_v[i] = res[0].residuals[i]
637 for j in range(vlen):
638 eigvec_v[i,j] = res[0].eigvec[i*vlen + j]
639 for i in range(vlen):
640 ranktest_SV_v[i] = res[0].ranktest_SV[i]
642 zdist = eigval - omega_centre
643 eigval_inside_metric = np.hypot(zdist.real / omega_rr, zdist.imag / omega_ri)
645 beyn_result_free(res)
646 retdict = {
647 'eigval':eigval,
648 'eigval_inside_metric':eigval_inside_metric,
649 'eigvec':eigvec,
650 'residuals':residuals,
651 'eigval_err':eigval_err,
652 'ranktest_SV':ranktest_SV,
653 'iri': iri,
656 return retdict
658 cdef class _ScatteringSystemAtOmega:
660 Wrapper over the C qpms_scatsys_at_omega_t structure
661 that keeps the T-matrix and background data evaluated
662 at specific frequency.
664 cdef qpms_scatsys_at_omega_t *ssw
665 cdef ScatteringSystem ss_pyref
667 def check(self): # cdef instead?
668 if not self.ssw:
669 raise ValueError("_ScatteringSystemAtOmega's ssw-pointer not set. You must not use the default constructor; ScatteringSystem.create() instead")
670 self.ss_pyref.check_s()
671 #TODO is there a way to disable the constructor outside this module?
673 def __dealloc__(self):
674 if (self.ssw):
675 qpms_scatsys_at_omega_free(self.ssw)
677 def apply_Tmatrices_full(self, a):
678 self.check()
679 if len(a) != self.fecv_size:
680 raise ValueError("Length of a full vector has to be %d, not %d"
681 % (self.fecv_size, len(a)))
682 a = np.array(a, dtype=complex, copy=False, order='C')
683 cdef cdouble[::1] a_view = a;
684 cdef np.ndarray[np.complex_t, ndim=1] target_np = np.empty(
685 (self.fecv_size,), dtype=complex, order='C')
686 cdef cdouble[::1] target_view = target_np
687 qpms_scatsysw_apply_Tmatrices_full(&target_view[0], &a_view[0], self.ssw)
688 return target_np
690 cdef qpms_scatsys_at_omega_t *rawpointer(self):
691 return self.ssw
693 def scatter_solver(self, iri=None):
694 self.check()
695 return ScatteringMatrix(self, iri)
697 property fecv_size:
698 def __get__(self): return self.ss_pyref.fecv_size
699 property saecv_sizes:
700 def __get__(self): return self.ss_pyref.saecv_sizes
701 property irrep_names:
702 def __get__(self): return self.ss_pyref.irrep_names
703 property nirreps:
704 def __get__(self): return self.ss_pyref.nirreps
706 def modeproblem_matrix_full(self):
707 self.check()
708 cdef size_t flen = self.ss_pyref.s[0].fecv_size
709 cdef np.ndarray[np.complex_t, ndim=2] target = np.empty(
710 (flen,flen),dtype=complex, order='C')
711 cdef cdouble[:,::1] target_view = target
712 qpms_scatsysw_build_modeproblem_matrix_full(&target_view[0][0], self.ssw)
713 return target
715 def modeproblem_matrix_packed(self, qpms_iri_t iri, version='pR'):
716 self.check()
717 cdef size_t rlen = self.saecv_sizes[iri]
718 cdef np.ndarray[np.complex_t, ndim=2] target = np.empty(
719 (rlen,rlen),dtype=complex, order='C')
720 cdef cdouble[:,::1] target_view = target
721 if (version == 'R'):
722 qpms_scatsysw_build_modeproblem_matrix_irrep_packed_orbitorderR(&target_view[0][0], self.ssw, iri)
723 elif (version == 'pR'):
724 with nogil:
725 qpms_scatsysw_build_modeproblem_matrix_irrep_packed(&target_view[0][0], self.ssw, iri)
726 else:
727 qpms_scatsysw_build_modeproblem_matrix_irrep_packed_serial(&target_view[0][0], self.ssw, iri)
728 return target
731 cdef class ScatteringMatrix:
733 Wrapper over the C qpms_ss_LU structure that keeps the factorised mode problem matrix.
735 cdef _ScatteringSystemAtOmega ssw # Here we keep the reference to the parent scattering system
736 cdef qpms_ss_LU lu
738 def __cinit__(self, _ScatteringSystemAtOmega ssw, iri=None):
739 ssw.check()
740 self.ssw = ssw
741 # TODO? pre-allocate the matrix with numpy to make it transparent?
742 if iri is None:
743 self.lu = qpms_scatsysw_build_modeproblem_matrix_full_LU(
744 NULL, NULL, ssw.rawpointer())
745 else:
746 self.lu = qpms_scatsysw_build_modeproblem_matrix_irrep_packed_LU(
747 NULL, NULL, ssw.rawpointer(), iri)
749 def __dealloc__(self):
750 qpms_ss_LU_free(self.lu)
752 property iri:
753 def __get__(self):
754 return None if self.lu.full else self.lu.iri
756 def __call__(self, a_inc):
757 cdef size_t vlen
758 cdef qpms_iri_t iri = -1;
759 if self.lu.full:
760 vlen = self.lu.ssw[0].ss[0].fecv_size
761 if len(a_inc) != vlen:
762 raise ValueError("Length of a full coefficient vector has to be %d, not %d"
763 % (vlen, len(a_inc)))
764 else:
765 iri = self.lu.iri
766 vlen = self.lu.ssw[0].ss[0].saecv_sizes[iri]
767 if len(a_inc) != vlen:
768 raise ValueError("Length of a %d. irrep packed coefficient vector has to be %d, not %d"
769 % (iri, vlen, len(a_inc)))
770 a_inc = np.array(a_inc, dtype=complex, copy=False, order='C')
771 cdef const cdouble[::1] a_view = a_inc;
772 cdef np.ndarray f = np.empty((vlen,), dtype=complex, order='C')
773 cdef cdouble[::1] f_view = f
774 qpms_scatsys_scatter_solve(&f_view[0], &a_view[0], self.lu)
775 return f
777 def pitau(double theta, qpms_l_t lMax, double csphase = -1):
778 if(abs(csphase) != 1):
779 raise ValueError("csphase must be 1 or -1, is %g" % csphase)
780 cdef size_t nelem = qpms_lMax2nelem(lMax)
781 cdef np.ndarray[np.float_t, ndim=1] lega = np.empty((nelem,), dtype=float)
782 cdef np.ndarray[np.float_t, ndim=1] pia = np.empty((nelem,), dtype=float)
783 cdef np.ndarray[np.float_t, ndim=1] taua = np.empty((nelem,), dtype=float)
784 cdef double[::1] leg = lega
785 cdef double[::1] pi = pia
786 cdef double[::1] tau = taua
787 qpms_pitau_fill(&leg[0], &pi[0], &tau[0], theta, lMax, csphase)
788 return (lega, pia, taua)
790 def linton_gamma(cdouble x):
791 return clilgamma(x)
793 def linton_gamma_real(double x):
794 return lilgamma(x)
796 def gamma_inc(double a, cdouble x, int m = 0):
797 cdef qpms_csf_result res
798 with pgsl_ignore_error(15): #15 is underflow
799 complex_gamma_inc_e(a, x, m, &res)
800 return (res.val, res.err)
802 def gamma_inc_series(double a, cdouble x):
803 cdef qpms_csf_result res
804 with pgsl_ignore_error(15): #15 is underflow
805 cx_gamma_inc_series_e(a, x, &res)
806 return (res.val, res.err)
808 def gamma_inc_CF(double a, cdouble x):
809 cdef qpms_csf_result res
810 with pgsl_ignore_error(15): #15 is underflow
811 cx_gamma_inc_CF_e(a, x, &res)
812 return (res.val, res.err)
814 def lll_reduce(basis, double delta=0.75):
816 Lattice basis reduction with the Lenstra-Lenstra-Lovász algorithm.
818 basis is array_like with dimensions (n, d), where
819 n is the size of the basis (dimensionality of the lattice)
820 and d is the dimensionality of the space into which the lattice
821 is embedded.
823 basis = np.array(basis, copy=True, order='C', dtype=np.double)
824 if len(basis.shape) != 2:
825 raise ValueError("Expected two-dimensional array (got %d-dimensional)"
826 % len(basis.shape))
827 cdef size_t n, d
828 n, d = basis.shape
829 if n > d:
830 raise ValueError("Real space dimensionality (%d) cannot be smaller than"
831 "the dimensionality of the lattice (%d) embedded into it."
832 % (d, n))
833 cdef double [:,:] basis_view = basis
834 if 0 != qpms_reduce_lattice_basis(&basis_view[0,0], n, d, delta):
835 raise RuntimeError("Something weird happened")
836 return basis
839 cdef PGen get_PGen_direct(direct_basis, bint include_origin=False, double layers=30):
840 dba = np.array(direct_basis)
841 if not (dba.shape == (2,2)):
842 raise NotImplementedError
843 cdef cart2_t b1, b2
844 b1.x = dba[0,0]
845 b1.y = dba[0,1]
846 b2.x = dba[1,0]
847 b2.y = dba[0,1]
848 cdef double maxR = layers*max(cart2norm(b1), cart2norm(b2))
849 return PGen_xyWeb_new(b1, b2, BASIS_RTOL, CART2_ZERO, 0, include_origin, maxR, False)
851 cdef double get_unitcell_volume(direct_basis):
852 dba = np.array(direct_basis)
853 if not (dba.shape == (2,2)):
854 raise NotImplementedError
855 cdef cart2_t b1, b2
856 b1.x = dba[0,0]
857 b1.y = dba[0,1]
858 b2.x = dba[1,0]
859 b2.y = dba[0,1]
860 return l2d_unitcell_area(b1, b2)
862 cdef PGen get_PGen_reciprocal2pi(direct_basis, double layers = 30):
863 dba = np.array(direct_basis)
864 if not (dba.shape == (2,2)):
865 raise NotImplementedError
866 cdef cart2_t b1, b2, rb1, rb2
867 b1.x = dba[0,0]
868 b1.y = dba[0,1]
869 b2.x = dba[1,0]
870 b2.y = dba[0,1]
871 if(l2d_reciprocalBasis2pi(b1, b2, &rb1, &rb2) != 0):
872 raise RuntimeError
873 cdef double maxK = layers*max(cart2norm(rb1), cart2norm(rb2))
874 return PGen_xyWeb_new(rb1, rb2, BASIS_RTOL, CART2_ZERO,
875 0, True, maxK, False)
877 cdef class Ewald3Calculator:
878 '''Wrapper class over qpms_ewald3_constants_t.
880 Mainly for testing low-level scalar Ewald summation functionality.'''
881 cdef qpms_ewald3_constants_t *c
883 def __cinit__(self, qpms_l_t lMax, int csphase = -1):
884 if (csphase != -1 and csphase != 1):
885 raise ValueError("csphase must be +1 or -1, not %d" % csphase)
886 self.c = qpms_ewald3_constants_init(lMax, csphase)
888 def __dealloc__(self):
889 qpms_ewald3_constants_free(self.c)
891 def sigma0(self, double eta, cdouble wavenumber, do_err = False):
892 cdef int retval
893 cdef double err
894 cdef cdouble result
895 retval = ewald3_sigma0(&result, &err, self.c, eta, wavenumber)
896 if retval:
897 raise RuntimeError("ewald3_sigma0 returned non-zero value (%d)" % retval)
898 if do_err:
899 return (result, err)
900 else:
901 return result
903 def sigma_short(self, double eta, cdouble wavenumber, direct_basis, wavevector, particle_shift, do_err=False):
904 # FIXME now only 2d XY lattice in 3D is implemented here, we don't even do proper dimensionality checks.
905 cdef cart3_t beta, pshift
906 beta.x = wavevector[0]
907 beta.y = wavevector[1]
908 beta.z = 0
909 pshift.x = particle_shift[0]
910 pshift.y = particle_shift[1]
911 pshift.z = 0
912 cdef qpms_l_t n = self.c[0].nelem_sc
913 cdef np.ndarray[complex, ndim=1] result = np.empty((n,), dtype=complex)
914 cdef cdouble[::1] result_v = result
915 cdef np.ndarray[double, ndim=1] err
916 cdef double[::1] err_v
917 if do_err:
918 err = np.empty((n,), dtype=np.double)
919 err_v = err
920 cdef bint include_origin = not (particle_shift[0] == 0 and particle_shift[1] == 0)
921 cdef PGen rgen = get_PGen_direct(direct_basis, include_origin)
922 cdef int retval = ewald3_sigma_short(&result_v[0], &err_v[0] if do_err else NULL,
923 self.c, eta, wavenumber, LAT_2D_IN_3D_XYONLY, &rgen, False, beta, pshift)
924 if rgen.stateData: PGen_destroy(&rgen)
925 if retval: raise RuntimeError("ewald3_sigma_short returned %d" % retval)
926 if do_err:
927 return (result, err)
928 else:
929 return result
931 def sigma_long(self, double eta, cdouble wavenumber, direct_basis, wavevector, particle_shift, do_err=False):
932 # FIXME now only 2d XY lattice in 3D is implemented here, we don't even do proper dimensionality checks.
933 cdef cart3_t beta, pshift
934 beta.x = wavevector[0]
935 beta.y = wavevector[1]
936 beta.z = 0
937 pshift.x = particle_shift[0]
938 pshift.y = particle_shift[1]
939 pshift.z = 0
940 cdef qpms_l_t n = self.c[0].nelem_sc
941 cdef np.ndarray[complex, ndim=1] result = np.empty((n,), dtype=complex)
942 cdef cdouble[::1] result_v = result
943 cdef np.ndarray[double, ndim=1] err
944 cdef double[::1] err_v
945 if do_err:
946 err = np.empty((n,), dtype=np.double)
947 err_v = err
948 cdef PGen kgen = get_PGen_reciprocal2pi(direct_basis)
949 cdef double unitcell_volume = get_unitcell_volume(direct_basis)
950 cdef int retval = ewald3_sigma_long(&result_v[0], &err_v[0] if do_err else NULL,
951 self.c, eta, wavenumber, unitcell_volume, LAT_2D_IN_3D_XYONLY, &kgen, False, beta, pshift)
953 if kgen.stateData: PGen_destroy(&kgen)
954 if retval: raise RuntimeError("ewald3_sigma_long returned %d" % retval)
955 if do_err:
956 return (result, err)
957 else:
958 return result