Fix saving lists of arrays with recent versions of numpy
[qpms.git] / misc / hexdisplot.py
blob0c08173de6c83bd6b2784f11c3d003e8d88393d4
1 #!/usr/bin/env python3
3 import argparse, re, random, string
4 from scipy.constants import hbar, e as eV, pi, c
6 parser = argparse.ArgumentParser()
7 #TODO? použít type=argparse.FileType('r') ?
8 parser.add_argument('--output', '-o', action='store', help='Path to output PDF')
9 parser.add_argument('--nSV', action='store', metavar='N', type=int, default=1, help='Draw N minimum singular values')
10 parser.add_argument('--bitmap', action='store_true', help='Create an interpolated bitmap rather than a scatter plot.')
11 #parser.add_argument('--eVfreq', action='store', required=True, type=float, help='Frequency in eV')
12 parser.add_argument('inputfile', nargs='+', help='Npz file(s) generated by dispersion_chunks.py or other script')
13 pargs=parser.parse_args()
14 print(pargs)
16 #freq = eVfreq*eV/hbar
18 pdfout = pargs.output if pargs.output else '%s.pdf' % pargs.inputfile[-1]
19 print(pdfout)
21 svn = pargs.nSV
23 # -----------------finished basic CLI parsing (except for op arguments) ------------------
24 import time
25 begtime=time.time()
27 import qpms
28 from matplotlib.path import Path
29 import matplotlib.patches as patches
30 import matplotlib.pyplot as plt
31 import numpy as np
32 import os, sys, warnings, math
33 from matplotlib import pyplot as plt
34 from matplotlib.backends.backend_pdf import PdfPages
35 from scipy import interpolate
37 # We do not want to import whole qpms, so copy and modify this only fun needed
38 def nelem2lMax(nelem):
39 lMax = round(math.sqrt(1+nelem) - 1)
40 if ((lMax < 1) or ((lMax + 2) * lMax != nelem)):
41 raise
42 else:
43 return lMax
47 nx = None
48 s3 = math.sqrt(3)
51 # read data from files
52 lMax = None
53 epsilon_b = None
54 hexside = None
55 karrlist = list()
56 svTElist = list()
57 svTMlist = list()
58 omegalist = list()
59 records = 0
60 for filename in pargs.inputfile:
61 npz = np.load(filename)
62 lMaxRead = npz['metadata'][()]['lMax'] if 'lMax' in npz['metadata'][()] else nelem2lMax(npz['sTE'].shape[1] / 2)
63 if lMax is None: lMax = lMaxRead
64 elif lMax != lMaxRead: raise
65 if epsilon_b is None: epsilon_b = npz['metadata'][()]['epsilon_b']
66 elif epsilon_b != npz['metadata'][()]['epsilon_b'] : raise
67 if hexside is None: hexside = npz['metadata'][()]['hexside']
68 elif hexside != npz['metadata'][()]['hexside'] : raise
69 omegalist.append(npz['omega'][()])
70 karrlist.append(np.array(npz['klist']))
71 svTElist.append(np.array(npz['sTE'][:,-svn:]))
72 svTMlist.append(np.array(npz['sTM'][:,-svn:]))
73 records += 1
74 npz.close()
76 # sort by frequencies
77 omegas = set(omegalist)
78 print(omegas)
79 k = dict()
80 svTE = dict()
81 svTM = dict()
82 for omega in omegas:
83 k[omega] = list()
84 svTE[omega] = list()
85 svTM[omega] = list()
86 for i in range(records):
87 omega = omegalist[i]
88 k[omega].append(karrlist[i])
89 svTE[omega].append(svTElist[i])
90 svTM[omega].append(svTMlist[i])
91 # concatenate arrays for each frequency
92 for omega in omegas:
93 k[omega] = np.concatenate(k[omega])
94 svTE[omega] = np.concatenate(svTE[omega])
95 svTM[omega] = np.concatenate(svTM[omega])
97 # ... that was for the slices. TODO fill also the righternmost plot with the calculated (which?) modes.
99 pdf = PdfPages(pdfout)
101 # In[3]:
103 cdn = c/ math.sqrt(epsilon_b)
104 #my, ny = qpms.get_mn_y(lMax)
105 #nelem = len(my)
106 nelem = lMax * (lMax + 2)
109 ''' The new pretty diffracted order drawing '''
110 maxlayer_reciprocal=4
111 cdn = c/ math.sqrt(epsilon_b)
112 bz_0 = np.array((0,0,))
113 bz_K1 = np.array((1.,0))*4*np.pi/3/hexside/s3
114 bz_K2 = np.array((1./2.,s3/2))*4*np.pi/3/hexside/s3
115 bz_M = np.array((3./4, s3/4))*4*np.pi/3/hexside/s3
117 # reciprocal lattice basis
118 B1 = 2* bz_K1 - bz_K2
119 B2 = 2* bz_K2 - bz_K1
121 k2density = 100
122 k0Mlist = bz_0 + (bz_M-bz_0) * np.linspace(0,1,k2density)[:,nx]
123 kMK1list = bz_M + (bz_K1-bz_M) * np.linspace(0,1,k2density)[:,nx]
124 kK10list = bz_K1 + (bz_0-bz_K1) * np.linspace(0,1,k2density)[:,nx]
125 k0K2list = bz_0 + (bz_K2-bz_0) * np.linspace(0,1,k2density)[:,nx]
126 kK2Mlist = bz_K2 + (bz_M-bz_K2) * np.linspace(0,1,k2density)[:,nx]
127 k2list = np.concatenate((k0Mlist,kMK1list,kK10list,k0K2list,kK2Mlist), axis=0)
128 kxmaplist = np.concatenate((np.array([0]),np.cumsum(np.linalg.norm(np.diff(k2list, axis=0), axis=-1))))
130 centers2=qpms.generate_trianglepoints(maxlayer_reciprocal, v3d = False, include_origin= True)*4*np.pi/3/hexside
131 rot90 = np.array([[0,-1],[1,0]])
132 centers2=np.dot(centers2,rot90)
134 import matplotlib.pyplot as plt
135 import matplotlib
136 from matplotlib.path import Path
137 import matplotlib.patches as patches
138 cmap = matplotlib.cm.prism
139 colormax = np.amax(np.linalg.norm(centers2,axis=0))
142 for omega in sorted(omegas):
143 klist = k[omega]
144 if pargs.bitmap:
145 minx = np.amin(klist[:,0])
146 maxx = np.amax(klist[:,0])
147 miny = np.amin(klist[:,1])
148 maxy = np.amax(klist[:,1])
149 l = klist.shape[0]
150 meshstep = math.sqrt((maxy - miny) * (maxx - minx) / l) / 9
151 x = np.linspace(minx, maxx, (maxx-minx) / meshstep)
152 y = np.linspace(miny, maxy, (maxy-miny) / meshstep)
153 fullshape = np.broadcast(x[:,nx],y[nx,:]).shape
154 flatx = np.broadcast_to(x[:,nx], fullshape).flatten
155 flaty = np.broadcast_to(y[nx,:], fullshape).flatten
156 minsvTElist = svTE[omega]
157 minsvTMlist = svTM[omega]
158 for minN in reversed(range(svn)):
159 f, axes = plt.subplots(1,3, figsize=(20,4.8))
160 ax = axes[0]
161 if pargs.bitmap:
162 interpolator = interpolate.interp2d(klist[:,0], klist[:,1], np.abs(minsvTElist[:,minN]))
163 z = interpolator(flatx, flaty)
164 z.reshape(fullshape)
165 sc = ax.pcolormesh(x[:,nx],y[nx,:],z)
166 else:
167 sc = ax.scatter(klist[:,0], klist[:,1], c = np.clip(np.abs(minsvTElist[:,minN]),0,1), lw=0)
168 for center in centers2:
169 circle=plt.Circle((center[0],center[1]),omega/cdn, facecolor='none', edgecolor=cmap(np.linalg.norm(center)/colormax),lw=0.5)
170 ax.add_artist(circle)
171 verts = [(math.cos(math.pi*i/3)*4*np.pi/3/hexside/s3,math.sin(math.pi*i/3)*4*np.pi/3/hexside/s3) for i in range(6 +1)]
172 codes = [Path.MOVETO,Path.LINETO,Path.LINETO,Path.LINETO,Path.LINETO,Path.LINETO,Path.CLOSEPOLY,]
173 path = Path(verts, codes)
174 patch = patches.PathPatch(path, facecolor='none', edgecolor='black', lw=1)
175 ax.add_patch(patch)
176 ax.set_xticks([])
177 ax.set_yticks([])
178 ax.title.set_text('E in-plane ("TE"), %d. lowest SV' % minN)
179 f.colorbar(sc,ax=ax)
182 ax = axes[1]
183 if pargs.bitmap:
184 interpolator = interpolate.interp2d(klist[:,0], klist[:,1], np.abs(minsvTMlist[:,minN]))
185 meshstep = math.sqrt((maxy - miny) * (maxx - minx) / l) / 9
186 z = interpolator(flatx, flaty)
187 z.reshape(fullshape)
188 sc = ax.pcolormesh(x[:,nx],y[nx,:],z)
189 else:
190 sc = ax.scatter(klist[:,0], klist[:,1], c = np.clip(np.abs(minsvTMlist[:,minN]),0,1), lw=0)
191 for center in centers2:
192 circle=plt.Circle((center[0],center[1]),omega/cdn, facecolor='none', edgecolor=cmap(np.linalg.norm(center)/colormax),lw=0.5)
193 ax.add_artist(circle)
194 verts = [(math.cos(math.pi*i/3)*4*np.pi/3/hexside/s3,math.sin(math.pi*i/3)*4*np.pi/3/hexside/s3) for i in range(6 +1)]
195 codes = [Path.MOVETO,Path.LINETO,Path.LINETO,Path.LINETO,Path.LINETO,Path.LINETO,Path.CLOSEPOLY,]
196 path = Path(verts, codes)
197 patch = patches.PathPatch(path, facecolor='none', edgecolor='black', lw=1)
198 ax.add_patch(patch)
199 ax.set_xticks([])
200 ax.set_yticks([])
201 ax.title.set_text('E perpendicular ("TM"), %d. lowest SV' % minN)
202 f.colorbar(sc,ax=ax)
204 ax = axes[2]
205 for center in centers2:
206 ax.plot(kxmaplist, np.linalg.norm(k2list-center,axis=-1)*cdn, '-', color=cmap(np.linalg.norm(center)/colormax))
208 #ax.set_xlim([np.min(kxmlarr),np.max(kxmlarr)])
209 #ax.set_ylim([np.min(omegalist),np.max(omegalist)])
210 xticklist = [0, kxmaplist[len(k0Mlist)-1], kxmaplist[len(k0Mlist)+len(kMK1list)-1], kxmaplist[len(k0Mlist)+len(kMK1list)+len(kK10list)-1], kxmaplist[len(k0Mlist)+len(kMK1list)+len(kK10list)+len(k0K2list)-1], kxmaplist[len(k0Mlist)+len(kMK1list)+len(kK10list)+len(k0K2list)+len(kK2Mlist)-1]]
211 ax.set_xticks(xticklist)
212 for xt in xticklist:
213 ax.axvline(xt, ls='dotted', lw=0.3,c='k')
214 ax.set_xticklabels(['Γ', 'M', 'K', 'Γ', 'K\'','M'])
215 ax.axhline(omega, c='black')
216 ax.set_ylim([0,5e15])
217 ax2 = ax.twinx()
218 ax2.set_ylim([ax.get_ylim()[0]/eV*hbar,ax.get_ylim()[1]/eV*hbar])
220 pdf.savefig(f)
222 pdf.close()