Fix saving lists of arrays with recent versions of numpy
[qpms.git] / qpms / cycommon.pyx
blob88c557d12ff4b170aca27d5ce2fca689b404e6da
1 import numpy as np
2 from .qpms_cdefs cimport *
3 from libc.stdlib cimport malloc
4 cimport cython
5 import enum
7 # Here will be enum and dtype definitions; maybe move these to a separate file
8 class VSWFType(enum.IntEnum):
9 ELECTRIC = QPMS_VSWF_ELECTRIC
10 MAGNETIC = QPMS_VSWF_MAGNETIC
11 LONGITUDINAL = QPMS_VSWF_LONGITUDINAL
12 M = QPMS_VSWF_MAGNETIC
13 N = QPMS_VSWF_ELECTRIC
14 L = QPMS_VSWF_LONGITUDINAL
16 class BesselType(enum.IntEnum):
17 UNDEF = QPMS_BESSEL_UNDEF
18 REGULAR = QPMS_BESSEL_REGULAR
19 SINGULAR = QPMS_BESSEL_SINGULAR
20 HANKEL_PLUS = QPMS_HANKEL_PLUS
21 HANKEL_MINUS = QPMS_HANKEL_MINUS
23 class PointGroupClass(enum.IntEnum):
24 CN = QPMS_PGS_CN
25 S2N = QPMS_PGS_S2N
26 CNH = QPMS_PGS_CNH
27 CNV = QPMS_PGS_CNV
28 DN = QPMS_PGS_DN
29 DND = QPMS_PGS_DND
30 DNH = QPMS_PGS_DNH
31 T = QPMS_PGS_T
32 TD = QPMS_PGS_TD
33 TH = QPMS_PGS_TH
34 O = QPMS_PGS_O
35 OH = QPMS_PGS_OH
36 I = QPMS_PGS_I
37 IH = QPMS_PGS_IH
38 CINF = QPMS_PGS_CINF
39 CINFH = QPMS_PGS_CINFH
40 CINFV = QPMS_PGS_CINFV
41 DINF = QPMS_PGS_DINF
42 DINFH = QPMS_PGS_DINFH
43 SO3 = QPMS_PGS_SO3
44 O3 = QPMS_PGS_O3
46 try:
47 class DebugFlags(enum.IntFlag): # Should be IntFlag if python version >= 3.6
48 MISC = QPMS_DBGMSG_MISC
49 THREADS = QPMS_DBGMSG_THREADS
50 INTEGRATION = QPMS_DBGMSG_INTEGRATION
51 has_IntFlag = True
52 except AttributeError: # For old versions of enum, use IntEnum instead
53 class DebugFlags(enum.IntEnum):
54 MISC = QPMS_DBGMSG_MISC
55 THREADS = QPMS_DBGMSG_THREADS
56 INTEGRATION = QPMS_DBGMSG_INTEGRATION
57 has_IntFlag = False
59 def dbgmsg_enable(qpms_dbgmsg_flags types):
60 flags = qpms_dbgmsg_enable(types)
61 return DebugFlags(flags) if has_IntFlag else flags
62 def dbgmsg_disable(qpms_dbgmsg_flags types):
63 flags = qpms_dbgmsg_disable(types)
64 return DebugFlags(flags) if has_IntFlag else flags
65 def dbgmsg_active():
66 flags = qpms_dbgmsg_enable(<qpms_dbgmsg_flags>0)
67 return DebugFlags(flags) if has_IntFlag else flags
69 #import re # TODO for crep methods?
71 #cimport openmp
72 #openmp.omp_set_dynamic(1)
74 ## Auxillary function for retrieving the "meshgrid-like" indices; inc. nmax
75 @cython.boundscheck(False)
76 def get_mn_y(int nmax):
77 """
78 Auxillary function for retreiving the 'meshgrid-like' indices from the flat indexing;
79 inc. nmax.
80 ('y to mn' conversion)
82 Parameters
83 ----------
85 nmax : int
86 The maximum order to which the VSWFs / Legendre functions etc. will be evaluated.
88 Returns
89 -------
91 output : (m, n)
92 Tuple of two arrays of type np.array(shape=(nmax*nmax + 2*nmax), dtype=np.int),
93 where [(m[y],n[y]) for y in range(nmax*nmax + 2*nma)] covers all possible
94 integer pairs n >= 1, -n <= m <= n.
95 """
96 cdef Py_ssize_t nelems = nmax * nmax + 2 * nmax
97 cdef np.ndarray[np.int_t,ndim=1] m_arr = np.empty([nelems], dtype=np.int)
98 cdef np.ndarray[np.int_t,ndim=1] n_arr = np.empty([nelems], dtype=np.int)
99 cdef Py_ssize_t i = 0
100 cdef np.int_t n, m
101 for n in range(1,nmax+1):
102 for m in range(-n,n+1):
103 m_arr[i] = m
104 n_arr[i] = n
105 i = i + 1
106 return (m_arr, n_arr)
108 def get_nelem(unsigned int lMax):
109 return lMax * (lMax + 2)
111 def get_y_mn_unsigned(int nmax):
113 Auxillary function for mapping 'unsigned m', n indices to the flat y-indexing.
114 For use with functions as scipy.special.lpmn, which have to be evaluated separately
115 for positive and negative m.
117 Parameters
118 ----------
120 nmax : int
121 The maximum order to which the VSWFs / Legendre functions etc. will be evaluated.
123 output : (ymn_plus, ymn_minus)
124 Tuple of two arrays of shape (nmax+1,nmax+1), containing the flat y-indices corresponding
125 to the respective (m,n) and (-m,n). The elements for which |m| > n are set to -1.
126 (Therefore, the caller must not use those elements equal to -1.)
128 cdef np.ndarray[np.intp_t, ndim=2] ymn_plus = np.full((nmax+1,nmax+1),-1, dtype=np.intp)
129 cdef np.ndarray[np.intp_t, ndim=2] ymn_minus = np.full((nmax+1,nmax+1),-1, dtype=np.intp)
130 cdef Py_ssize_t i = 0
131 cdef np.int_t n, m
132 for n in range(1,nmax+1):
133 for m in range(-n,0):
134 ymn_minus[-m,n] = i
135 i = i + 1
136 for m in range(0,n+1):
137 ymn_plus[m,n] = i
138 i = i + 1
139 return(ymn_plus, ymn_minus)
142 def tlm2uvswfi(t, l, m):
143 ''' TODO doc
144 And TODO this should rather be an ufunc.
146 # Very low-priority TODO: add some types / cythonize
147 if isinstance(t, int) and isinstance(l, int) and isinstance(m, int):
148 return qpms_tmn2uvswfi(t, m, l)
149 elif len(t) == len(l) and len(t) == len(m):
150 u = list()
151 for i in range(len(t)):
152 if not (t[i] % 1 == 0 and l[i] % 1 == 0 and m[i] % 1 == 0): # maybe not the best check possible, though
153 raise ValueError # TODO error message
154 u.append(qpms_tmn2uvswfi(t[i],m[i],l[i]))
155 return u
156 else:
157 print(len(t), len(l), len(m))
158 raise ValueError("Lengths of the t,l,m arrays must be equal, but they are %d, %d, %d."
159 % (len(t), len(l), len(m)))
162 def uvswfi2tlm(u):
163 ''' TODO doc
164 and TODO this should rather be an ufunc.
166 cdef qpms_vswf_type_t t
167 cdef qpms_l_t l
168 cdef qpms_m_t m
169 cdef size_t i
170 if isinstance(u, (int, np.ulonglong)):
171 if (qpms_uvswfi2tmn(u, &t, &m, &l) != QPMS_SUCCESS):
172 raise ValueError("Invalid uvswf index")
173 return (t, l, m)
174 else:
175 ta = list()
176 la = list()
177 ma = list()
178 for i in range(len(u)):
179 if (qpms_uvswfi2tmn(u[i], &t, &m, &l) != QPMS_SUCCESS):
180 raise ValueError("Invalid uvswf index")
181 ta.append(t)
182 la.append(l)
183 ma.append(m)
184 return (ta, la, ma)
186 cdef char *make_c_string(pythonstring):
188 Copies contents of a python string into a char[]
189 (allocating the memory with malloc())
191 bytestring = pythonstring.encode('UTF-8')
192 cdef Py_ssize_t n = len(bytestring)
193 cdef Py_ssize_t i
194 cdef char *s
195 s = <char *>malloc(n+1)
196 if not s:
197 raise MemoryError
198 #s[:n] = bytestring # This segfaults; why?
199 for i in range(n): s[i] = bytestring[i]
200 s[n] = <char>0
201 return s
203 def string_c2py(const char* cstring):
204 return cstring.decode('UTF-8')