new forecast incorporation as 3D data points
[JPSSData.git] / synthetic.py
blob5269d4d6f2bc0477c8cdb7c7eceadc5202165ec5
1 import numpy as np
2 import matplotlib.pyplot as plt
3 from mpl_toolkits.mplot3d import axes3d
4 import matplotlib.colors as colors
5 from svm import SVM3
6 from scipy.io import savemat
7 from scipy import interpolate
9 def plot_case(xx,yy,tign_g,X_satellite=None):
10 fig = plt.figure()
11 ax = fig.gca(projection='3d')
12 ax.contour(xx,yy,tign_g,30)
13 if X_satellite is not None:
14 ax.scatter(X_satellite[:,0],X_satellite[:,1],
15 X_satellite[:,2],s=5,color='r')
16 ax.set_xlabel("X")
17 ax.set_ylabel("Y")
18 ax.set_zlabel("T")
19 plt.savefig('syn_case.png')
21 def plot_data(X,y):
22 col = [(0, .5, 0), (.5, 0, 0)]
23 cm_GR = colors.LinearSegmentedColormap.from_list('GrRd',col,N=2)
24 fig = plt.figure()
25 ax = fig.gca(projection='3d')
26 ax.scatter(X[:, 0], X[:, 1], X[:, 2],
27 c=y, cmap=cm_GR, s=1, alpha=.5,
28 vmin=y.min(), vmax=y.max())
29 ax.set_xlabel("X")
30 ax.set_ylabel("Y")
31 ax.set_zlabel("T")
32 plt.savefig('syn_data.png')
34 def cone_point(xx,yy,nx,ny):
35 cx = nx*.5
36 cy = ny*.5
37 tign_g = np.minimum(1e3,10+(2e3/cx)*np.sqrt(((xx-cx)**2+(yy-cy)**2)/2))
38 tsat = (tign_g.max()-tign_g.min())*.5
39 tt1d = np.ravel(tign_g)
40 mask = tt1d < tt1d.max()
41 xx1d = np.ravel(xx)[mask]
42 yy1d = np.ravel(yy)[mask]
43 tt1d = tt1d[mask]
44 X_satellite = np.array([[cx*.7,cy*.7,tsat]])
45 return tign_g,xx1d,yy1d,tt1d,X_satellite
47 def cone_points(xx,yy,nx,ny):
48 cx = nx*.5
49 cy = ny*.5
50 tign_g = np.minimum(1e3,10+(2e3/cx)*np.sqrt(((xx-cx)**2+(yy-cy)**2)/2))
51 tsat = (tign_g.max()-tign_g.min())*.5
52 tt1d = np.ravel(tign_g)
53 mask = tt1d < tt1d.max()
54 xx1d = np.ravel(xx)[mask]
55 yy1d = np.ravel(yy)[mask]
56 tt1d = tt1d[mask]
57 N = 10
58 X_satellite = np.c_[np.linspace(cx*.7,cx,N+1),
59 np.linspace(cy*.7,cy,N+1),
60 np.linspace(tsat,tign_g.min(),N+1)][:-1]
61 return tign_g,xx1d,yy1d,tt1d,X_satellite
63 def slope(xx,yy,nx,ny):
64 ros = (10,30) # rate of spread
65 cx = round(nx*.5)
66 s1 = 10+np.arange(0,cx*ros[0],ros[0])
67 s2 = ros[1]+np.arange(cx*ros[0],cx*ros[0]+(nx-cx)*ros[1],ros[1])
68 s = np.concatenate((s1,s2))
69 tign_g = np.reshape(np.repeat(s,ny),(nx,ny)).T
70 xx1d = np.ravel(xx)
71 yy1d = np.ravel(yy)
72 tt1d = np.ravel(tign_g)
73 X_satellite = None
74 return tign_g,xx1d,yy1d,tt1d,X_satellite
76 def preprocess_svm(xx,yy,tt,epsilon,weights,X_satellite=None):
77 wforecastg,wforecastf,wsatellite = weights
78 for_fire = np.c_[xx.ravel(),yy.ravel(),tt.ravel() + epsilon]
79 for_nofire = np.c_[xx.ravel(),yy.ravel(),tt.ravel() - epsilon]
80 X_forecast = np.concatenate((for_nofire,for_fire))
81 y_forecast = np.concatenate((-np.ones(len(for_nofire)),np.ones(len(for_fire))))
82 c_forecast = np.concatenate((wforecastg*np.ones(len(for_nofire)),wforecastf*np.ones(len(for_fire))))
83 if X_satellite is not None:
84 X = np.concatenate((X_forecast,X_satellite))
85 y = np.concatenate((y_forecast,np.ones(len(X_satellite))))
86 c = np.concatenate((c_forecast,wsatellite*np.ones(len(X_satellite))))
87 else:
88 X = X_forecast
89 y = y_forecast
90 c = c_forecast
91 return X,y,c
93 if __name__ == "__main__":
94 ## SETTINGS
95 # Experiments: 1) Cone with point, 2) Slope, 3) Cone with points
96 exp = 2
97 # hyperparameter settings
98 wforecastg = 50
99 wforecastf = 50
100 wsatellite = 50
101 kgam = 1
102 # epsilon for artificial forecast in seconds
103 epsilon = 1
104 # dimensions
105 nx, ny = 50, 50
106 # plotting data before svm?
107 plot = True
109 ## CASE
110 xx,yy = np.meshgrid(np.arange(0,nx,1),
111 np.arange(0,ny,1))
112 # select experiment
113 experiments = {1: cone_point, 2: slope, 3: cone_points}
114 tign_g,xx1d,yy1d,tt1d,X_satellite = experiments[exp](xx,yy,nx,ny)
115 if plot:
116 plot_case(xx,yy,tign_g,X_satellite)
118 ## PREPROCESS
119 if X_satellite is None:
120 wsatellite = 0
121 X,y,c = preprocess_svm(xx1d,yy1d,tt1d,epsilon,
122 (wforecastg,wforecastf,wsatellite),
123 X_satellite)
124 if plot:
125 plot_data(X,y)
127 ## SVM
128 # options for SVM
129 options = {'downarti': False, 'plot_data': True,
130 'plot_scaled': True, 'plot_supports': True,
131 'plot_result': True, 'plot_decision': True,
132 'artiu': False, 'hartiu': .2,
133 'artil': False, 'hartil': .2,
134 'notnan': True}
135 if (wforecastg == wforecastf and
136 (wsatellite == 0 or wsatellite == wforecastg)):
137 c = wforecastg
138 # running SVM
139 F = SVM3(X, y, C=c, kgam=kgam, **options)
141 ## POSTPROCESS
142 # interpolation to validate
143 points = np.c_[np.ravel(F[0]),np.ravel(F[1])]
144 values = np.ravel(F[2])
145 zz_svm = interpolate.griddata(points,values,(xx,yy))
146 # output dictionary
147 svm = {'xx': xx, 'yy': yy, 'zz': tign_g,
148 'zz_svm': zz_svm, 'X': X, 'y': y, 'c': c,
149 'fxlon': F[0], 'fxlat': F[1], 'Z': F[2],
150 'epsilon': epsilon, 'options': options}
151 # output file
152 if wsatellite:
153 filename = 'syn_fg%d_ff%d_s%d_k%d_e%d.mat' % (wforecastg,wforecastf,
154 wsatellite,kgam,epsilon)
155 else:
156 filename = 'syn_fg%d_ff%d_k%d_e%d.mat' % (wforecastg,wforecastf,
157 kgam,epsilon)
158 savemat(filename, mdict=svm)
159 print 'plot_svm %s' % filename