|
3 | 3 | { |
4 | 4 | "cell_type": "code", |
5 | 5 | "execution_count": null, |
6 | | - "id": "micro-minneapolis", |
7 | 6 | "metadata": {}, |
8 | 7 | "outputs": [], |
9 | 8 | "source": [ |
10 | 9 | "import kessler\n", |
11 | 10 | "from kessler import EventDataset\n", |
12 | 11 | "from kessler.nn import LSTMPredictor\n", |
13 | 12 | "from kessler.data import kelvins_to_event_dataset\n", |
| 13 | + "import pandas as pd\n", |
14 | 14 | "\n", |
15 | | - "import pandas as pd" |
| 15 | + "# Set the random number generator seed for reproducibility\n", |
| 16 | + "kessler.seed(1)" |
16 | 17 | ] |
17 | 18 | }, |
18 | 19 | { |
19 | 20 | "cell_type": "markdown", |
20 | | - "id": "flexible-algorithm", |
21 | 21 | "metadata": {}, |
22 | 22 | "source": [ |
23 | 23 | "# Data Loading\n", |
|
28 | 28 | { |
29 | 29 | "cell_type": "code", |
30 | 30 | "execution_count": null, |
31 | | - "id": "ahead-beach", |
32 | 31 | "metadata": {}, |
33 | 32 | "outputs": [], |
34 | 33 | "source": [ |
35 | 34 | "#As an example, we first show the case in which the data comes from the Kelvins competition.\n", |
36 | 35 | "#For this, we built a specific converter that takes care of the conversion from Kelvins format\n", |
37 | 36 | "#to standard CDM format (the data can be downloaded at https://kelvins.esa.int/collision-avoidance-challenge/data/):\n", |
38 | | - "file_name = 'path_to_csv/train_data.csv'\n", |
39 | | - "events = kelvins_to_event_dataset(file_name, drop_features=['c_rcs_estimate', 't_rcs_estimate'], num_events=200) #we use only 200 events" |
| 37 | + "file_name = '/home/gunes/data/kelvins/train_data/train_data.csv'\n", |
| 38 | + "events = kelvins_to_event_dataset(file_name, drop_features=['c_rcs_estimate', 't_rcs_estimate'], num_events=1000) #we use only 200 events" |
40 | 39 | ] |
41 | 40 | }, |
42 | 41 | { |
43 | 42 | "cell_type": "code", |
44 | 43 | "execution_count": null, |
45 | | - "id": "formed-recognition", |
46 | 44 | "metadata": {}, |
47 | 45 | "outputs": [], |
48 | 46 | "source": [ |
|
55 | 53 | }, |
56 | 54 | { |
57 | 55 | "cell_type": "markdown", |
58 | | - "id": "weekly-baltimore", |
59 | 56 | "metadata": {}, |
60 | 57 | "source": [ |
61 | 58 | "# Descriptive Statistics" |
|
64 | 61 | { |
65 | 62 | "cell_type": "code", |
66 | 63 | "execution_count": null, |
67 | | - "id": "demonstrated-clothing", |
68 | 64 | "metadata": {}, |
69 | 65 | "outputs": [], |
70 | 66 | "source": [ |
|
75 | 71 | }, |
76 | 72 | { |
77 | 73 | "cell_type": "markdown", |
78 | | - "id": "upper-columbus", |
79 | 74 | "metadata": {}, |
80 | 75 | "source": [ |
81 | 76 | "# LSTM Training" |
|
84 | 79 | { |
85 | 80 | "cell_type": "code", |
86 | 81 | "execution_count": null, |
87 | | - "id": "intense-massage", |
88 | 82 | "metadata": {}, |
89 | 83 | "outputs": [], |
90 | 84 | "source": [ |
|
98 | 92 | { |
99 | 93 | "cell_type": "code", |
100 | 94 | "execution_count": null, |
101 | | - "id": "norman-value", |
102 | 95 | "metadata": {}, |
103 | 96 | "outputs": [], |
104 | 97 | "source": [ |
|
117 | 110 | { |
118 | 111 | "cell_type": "code", |
119 | 112 | "execution_count": null, |
120 | | - "id": "corporate-gardening", |
121 | 113 | "metadata": {}, |
122 | 114 | "outputs": [], |
123 | 115 | "source": [ |
124 | 116 | "# Create an LSTM predictor, specialized to the nn_features we extracted above\n", |
125 | | - "model = LSTMPredictor(features=nn_features)\n", |
| 117 | + "model = LSTMPredictor(\n", |
| 118 | + " lstm_size=256, # Number of hidden units per LSTM layer\n", |
| 119 | + " lstm_depth=2, # Number of stacked LSTM layers\n", |
| 120 | + " dropout=0.2, # Dropout probability\n", |
| 121 | + " features=nn_features) # The list of feature names to use in the LSTM\n", |
126 | 122 | "\n", |
127 | 123 | "# Start training\n", |
128 | 124 | "model.learn(events_train_and_val, \n", |
129 | | - " epochs=3, # Number of epochs (one epoch is one full pass through the training dataset)\n", |
130 | | - " lr=1e-4, # Learning rate, can decrease it if training diverges\n", |
| 125 | + " epochs=10, # Number of epochs (one epoch is one full pass through the training dataset)\n", |
| 126 | + " lr=1e-3, # Learning rate, can decrease it if training diverges\n", |
131 | 127 | " batch_size=16, # Minibatch size, can be decreased if there are issues with memory use\n", |
132 | 128 | " device='cpu', # Can be 'cuda' if there is a GPU available\n", |
133 | 129 | " valid_proportion=0.15, # Proportion of the data to use as a validation set internally\n", |
|
138 | 134 | { |
139 | 135 | "cell_type": "code", |
140 | 136 | "execution_count": null, |
141 | | - "id": "egyptian-yemen", |
142 | 137 | "metadata": {}, |
143 | 138 | "outputs": [], |
144 | 139 | "source": [ |
|
149 | 144 | { |
150 | 145 | "cell_type": "code", |
151 | 146 | "execution_count": null, |
152 | | - "id": "alert-furniture", |
153 | 147 | "metadata": {}, |
154 | 148 | "outputs": [], |
155 | 149 | "source": [ |
|
160 | 154 | { |
161 | 155 | "cell_type": "code", |
162 | 156 | "execution_count": null, |
163 | | - "id": "compressed-democracy", |
164 | 157 | "metadata": {}, |
165 | 158 | "outputs": [], |
166 | 159 | "source": [ |
|
171 | 164 | { |
172 | 165 | "cell_type": "code", |
173 | 166 | "execution_count": null, |
174 | | - "id": "contemporary-professional", |
175 | 167 | "metadata": {}, |
176 | 168 | "outputs": [], |
177 | 169 | "source": [ |
|
187 | 179 | { |
188 | 180 | "cell_type": "code", |
189 | 181 | "execution_count": null, |
190 | | - "id": "collected-chaos", |
191 | 182 | "metadata": {}, |
192 | 183 | "outputs": [], |
193 | 184 | "source": [ |
|
200 | 191 | { |
201 | 192 | "cell_type": "code", |
202 | 193 | "execution_count": null, |
203 | | - "id": "grateful-billion", |
204 | 194 | "metadata": {}, |
205 | 195 | "outputs": [], |
206 | 196 | "source": [ |
207 | 197 | "#we now plot the uncertainty prediction for all the covariance matrix elements of both OBJECT1 and OBJECT2:\n", |
208 | 198 | "axs = event_evolution.plot_uncertainty(return_axs=True, linewidth=0.5, label='Prediction', alpha=0.5, color='red', legend=True, diagonal=False)\n", |
209 | 199 | "event.plot_uncertainty(axs=axs, label='Real', diagonal=False)" |
210 | 200 | ] |
211 | | - }, |
212 | | - { |
213 | | - "cell_type": "markdown", |
214 | | - "id": "graphic-impression", |
215 | | - "metadata": {}, |
216 | | - "source": [ |
217 | | - "# Plotting loop over all the events & CDMs\n", |
218 | | - "You can here customize the features to be plotted: we use relative speed, miss distance, and a covariance value:" |
219 | | - ] |
220 | | - }, |
221 | | - { |
222 | | - "cell_type": "code", |
223 | | - "execution_count": null, |
224 | | - "id": "going-memory", |
225 | | - "metadata": {}, |
226 | | - "outputs": [], |
227 | | - "source": [ |
228 | | - "#we loop over the test set events:\n", |
229 | | - "predict_full_event=False\n", |
230 | | - "for i in range(0,len(events_test)):\n", |
231 | | - " event=events_test[i]\n", |
232 | | - " len_ev=len(event)\n", |
233 | | - " for j in range(1,len_ev):\n", |
234 | | - " #print(j)\n", |
235 | | - " if predict_full_event:\n", |
236 | | - " event_evolution = model.predict_event(event[0:j],num_samples=10)\n", |
237 | | - " else:\n", |
238 | | - " event_evolution = model.predict_event_step(event[0:j],num_samples=10)\n", |
239 | | - "\n", |
240 | | - " #we plot the features (ground truth & prediction)\n", |
241 | | - " axs_1 = event_evolution.plot_features(['RELATIVE_SPEED', 'MISS_DISTANCE', 'OBJECT1_CT_T'], return_axs=True, linewidth=0.1, color='red', alpha=0.33, label='Prediction')\n", |
242 | | - " event.plot_features(['RELATIVE_SPEED', 'MISS_DISTANCE', 'OBJECT1_CT_T'], axs=axs_1, label='Real', legend=True,file_name=f'features_event_{i}_cdm_{j}.pdf')\n", |
243 | | - " #we plot the uncertainties (ground truth & prediction)\n", |
244 | | - " axs_2 = event_evolution.plot_uncertainty(return_axs=True, linewidth=0.5, label='Prediction', alpha=0.5, color='red', legend=True, diagonal=False)\n", |
245 | | - " event.plot_uncertainty(axs=axs_2, label='Real', diagonal=False, file_name=f'uncertainties_event_{i}_cdm_{j}.pdf')" |
246 | | - ] |
247 | | - }, |
248 | | - { |
249 | | - "cell_type": "markdown", |
250 | | - "id": "actual-effectiveness", |
251 | | - "metadata": {}, |
252 | | - "source": [ |
253 | | - "# Training set test\n", |
254 | | - "We check if the model is able to predict the CDMs on the training set" |
255 | | - ] |
256 | | - }, |
257 | | - { |
258 | | - "cell_type": "code", |
259 | | - "execution_count": null, |
260 | | - "id": "enclosed-europe", |
261 | | - "metadata": {}, |
262 | | - "outputs": [], |
263 | | - "source": [ |
264 | | - "\n", |
265 | | - "#we loop over some training set events, to check the NN performances:\n", |
266 | | - "num_events=10\n", |
267 | | - "for i in range(0,num_events):\n", |
268 | | - " event=events_train_and_val[i]\n", |
269 | | - " len_ev=len(event)\n", |
270 | | - " for j in range(1,len_ev):\n", |
271 | | - " print(j)\n", |
272 | | - " event_evolution = model.predict_event(event[0:j],num_samples=10)\n", |
273 | | - " #we plot the features (ground truth & prediction)\n", |
274 | | - " axs_1 = event_evolution.plot_features(['RELATIVE_SPEED', 'MISS_DISTANCE', 'OBJECT1_CT_T'], return_axs=True, linewidth=0.1, color='red', alpha=0.33, label='Prediction')\n", |
275 | | - " event.plot_features(['RELATIVE_SPEED', 'MISS_DISTANCE', 'OBJECT1_CT_T'], axs=axs_1, label='Real', legend=True,file_name=f'training_set_features_event_{i}_cdm_{j}.pdf')\n", |
276 | | - " #we plot the uncertainties (ground truth & prediction)\n", |
277 | | - " axs_2 = event_evolution.plot_uncertainty(return_axs=True, linewidth=0.5, label='Prediction', alpha=0.5, color='red', legend=True, diagonal=False)\n", |
278 | | - " event.plot_uncertainty(axs=axs_2, label='Real', diagonal=False, file_name=f'training_set_uncertainties_event_{i}_cdm_{j}.pdf')" |
279 | | - ] |
280 | 201 | } |
281 | 202 | ], |
282 | 203 | "metadata": { |
|
295 | 216 | "name": "python", |
296 | 217 | "nbconvert_exporter": "python", |
297 | 218 | "pygments_lexer": "ipython3", |
298 | | - "version": "3.7.9" |
| 219 | + "version": "3.8.5" |
299 | 220 | } |
300 | 221 | }, |
301 | 222 | "nbformat": 4, |
|
0 commit comments