Skip to content

Commit 0d75ed5

Browse files
author
Sceki
committed
first notebook: get started
1 parent 2b543fb commit 0d75ed5

1 file changed

Lines changed: 303 additions & 0 deletions

File tree

notebooks/001_get_started.ipynb

Lines changed: 303 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,303 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"id": "micro-minneapolis",
7+
"metadata": {},
8+
"outputs": [],
9+
"source": [
10+
"import kessler\n",
11+
"from kessler import EventDataset\n",
12+
"from kessler.nn import LSTMPredictor\n",
13+
"from kessler.data import kelvins_to_event_dataset\n",
14+
"\n",
15+
"import pandas as pd"
16+
]
17+
},
18+
{
19+
"cell_type": "markdown",
20+
"id": "flexible-algorithm",
21+
"metadata": {},
22+
"source": [
23+
"# Data Loading\n",
24+
"\n",
25+
"Kessler accepts CDMs either in KVN format or as pandas dataframes. We hereby show a pandas dataframe loading example:"
26+
]
27+
},
28+
{
29+
"cell_type": "code",
30+
"execution_count": null,
31+
"id": "ahead-beach",
32+
"metadata": {},
33+
"outputs": [],
34+
"source": [
35+
"#As an example, we first show the case in which the data comes from the Kelvins competition.\n",
36+
"#For this, we built a specific converter that takes care of the conversion from Kelvins format\n",
37+
"#to standard CDM format (the data can be downloaded at https://kelvins.esa.int/collision-avoidance-challenge/data/):\n",
38+
"file_name = 'path_to_csv/train_data.csv'\n",
39+
"events = kelvins_to_event_dataset(file_name, drop_features=['c_rcs_estimate', 't_rcs_estimate'], num_events=200) #we use only 200 events"
40+
]
41+
},
42+
{
43+
"cell_type": "code",
44+
"execution_count": null,
45+
"id": "formed-recognition",
46+
"metadata": {},
47+
"outputs": [],
48+
"source": [
49+
"#Instead, this is a generic real CDM data loader that should parse your Pandas (uncomment the following lines if needed):\n",
50+
"#file_name = 'path_to_csv/file.csv'\n",
51+
"\n",
52+
"#df=pd.read_csv(file_name)\n",
53+
"#events = EventDataset.from_pandas(df)"
54+
]
55+
},
56+
{
57+
"cell_type": "markdown",
58+
"id": "weekly-baltimore",
59+
"metadata": {},
60+
"source": [
61+
"# Descriptive Statistics"
62+
]
63+
},
64+
{
65+
"cell_type": "code",
66+
"execution_count": null,
67+
"id": "demonstrated-clothing",
68+
"metadata": {},
69+
"outputs": [],
70+
"source": [
71+
"#Descriptive statistics of the event:\n",
72+
"kessler_stats = events.to_dataframe().describe()\n",
73+
"print(kessler_stats)\n"
74+
]
75+
},
76+
{
77+
"cell_type": "markdown",
78+
"id": "upper-columbus",
79+
"metadata": {},
80+
"source": [
81+
"# LSTM Training"
82+
]
83+
},
84+
{
85+
"cell_type": "code",
86+
"execution_count": null,
87+
"id": "intense-massage",
88+
"metadata": {},
89+
"outputs": [],
90+
"source": [
91+
"#We only use features with numeric content for the training\n",
92+
"#nn_features is a list of the feature names taken into account for the training:\n",
93+
"#it can be edited in case more features want to be added or removed\n",
94+
"nn_features = events.common_features(only_numeric=True)\n",
95+
"print(nn_features)"
96+
]
97+
},
98+
{
99+
"cell_type": "code",
100+
"execution_count": null,
101+
"id": "norman-value",
102+
"metadata": {},
103+
"outputs": [],
104+
"source": [
105+
"# Split data into a test set (5% of the total number of events)\n",
106+
"len_test_set=int(0.05*len(events))\n",
107+
"print('Test data:', len_test_set)\n",
108+
"events_test=events[-len_test_set:]\n",
109+
"print(events_test)\n",
110+
"\n",
111+
"# The rest of the data will be used for training and validation\n",
112+
"print('Training and validation data:', len(events)-len_test_set)\n",
113+
"events_train_and_val=events[:-len_test_set]\n",
114+
"print(events_train_and_val)"
115+
]
116+
},
117+
{
118+
"cell_type": "code",
119+
"execution_count": null,
120+
"id": "corporate-gardening",
121+
"metadata": {},
122+
"outputs": [],
123+
"source": [
124+
"# Create an LSTM predictor, specialized to the nn_features we extracted above\n",
125+
"model = LSTMPredictor(features=nn_features)\n",
126+
"\n",
127+
"# Start training\n",
128+
"model.learn(events_train_and_val, \n",
129+
" epochs=3, # Number of epochs (one epoch is one full pass through the training dataset)\n",
130+
" lr=1e-4, # Learning rate, can decrease it if training diverges\n",
131+
" batch_size=16, # Minibatch size, can be decreased if there are issues with memory use\n",
132+
" device='cpu', # Can be 'cuda' if there is a GPU available\n",
133+
" valid_proportion=0.15, # Proportion of the data to use as a validation set internally\n",
134+
" num_workers=4, # Number of multithreaded dataloader workers, 4 is good for performance, but if there are any issues or errors, please try num_workers=1 as this solves issues with PyTorch most of the time\n",
135+
" event_samples_for_stats=1000) # Number of events to use to compute NN normalization factors, have this number as big as possible (and at least a few thousands)"
136+
]
137+
},
138+
{
139+
"cell_type": "code",
140+
"execution_count": null,
141+
"id": "egyptian-yemen",
142+
"metadata": {},
143+
"outputs": [],
144+
"source": [
145+
"#Save the model to a file after training:\n",
146+
"model.save(file_name=\"LSTM_20epochs_lr10-4_batchsize16\")"
147+
]
148+
},
149+
{
150+
"cell_type": "code",
151+
"execution_count": null,
152+
"id": "alert-furniture",
153+
"metadata": {},
154+
"outputs": [],
155+
"source": [
156+
"#NN loss plotted to a file:\n",
157+
"model.plot_loss(file_name='plot_loss.pdf')"
158+
]
159+
},
160+
{
161+
"cell_type": "code",
162+
"execution_count": null,
163+
"id": "compressed-democracy",
164+
"metadata": {},
165+
"outputs": [],
166+
"source": [
167+
"#we show an example CDM from the set:\n",
168+
"events_train_and_val[0][0]"
169+
]
170+
},
171+
{
172+
"cell_type": "code",
173+
"execution_count": null,
174+
"id": "contemporary-professional",
175+
"metadata": {},
176+
"outputs": [],
177+
"source": [
178+
"#we take a single event, we remove the last CDM and try to predict it\n",
179+
"event=events_test[3]\n",
180+
"event_len = len(event)\n",
181+
"print(event)\n",
182+
"event_beginning = event[0:event_len-1]\n",
183+
"print(event_beginning)\n",
184+
"event_evolution = model.predict_event(event_beginning, num_samples=100, max_length=14)"
185+
]
186+
},
187+
{
188+
"cell_type": "code",
189+
"execution_count": null,
190+
"id": "collected-chaos",
191+
"metadata": {},
192+
"outputs": [],
193+
"source": [
194+
"#We plot the prediction in red:\n",
195+
"axs = event_evolution.plot_features(['RELATIVE_SPEED', 'MISS_DISTANCE', 'OBJECT1_CT_T'], return_axs=True, linewidth=0.1, color='red', alpha=0.33, label='Prediction')\n",
196+
"#and the ground truth value in blue:\n",
197+
"event.plot_features(['RELATIVE_SPEED', 'MISS_DISTANCE', 'OBJECT1_CT_T'], axs=axs, label='Real', legend=True)"
198+
]
199+
},
200+
{
201+
"cell_type": "code",
202+
"execution_count": null,
203+
"id": "grateful-billion",
204+
"metadata": {},
205+
"outputs": [],
206+
"source": [
207+
"#we now plot the uncertainty prediction for all the covariance matrix elements of both OBJECT1 and OBJECT2:\n",
208+
"axs = event_evolution.plot_uncertainty(return_axs=True, linewidth=0.5, label='Prediction', alpha=0.5, color='red', legend=True, diagonal=False)\n",
209+
"event.plot_uncertainty(axs=axs, label='Real', diagonal=False)"
210+
]
211+
},
212+
{
213+
"cell_type": "markdown",
214+
"id": "graphic-impression",
215+
"metadata": {},
216+
"source": [
217+
"# Plotting loop over all the events & CDMs\n",
218+
"You can here customize the features to be plotted: we use relative speed, miss distance, and a covariance value:"
219+
]
220+
},
221+
{
222+
"cell_type": "code",
223+
"execution_count": null,
224+
"id": "going-memory",
225+
"metadata": {},
226+
"outputs": [],
227+
"source": [
228+
"#we loop over the test set events:\n",
229+
"predict_full_event=False\n",
230+
"for i in range(0,len(events_test)):\n",
231+
" event=events_test[i]\n",
232+
" len_ev=len(event)\n",
233+
" for j in range(1,len_ev):\n",
234+
" #print(j)\n",
235+
" if predict_full_event:\n",
236+
" event_evolution = model.predict_event(event[0:j],num_samples=10)\n",
237+
" else:\n",
238+
" event_evolution = model.predict_event_step(event[0:j],num_samples=10)\n",
239+
"\n",
240+
" #we plot the features (ground truth & prediction)\n",
241+
" axs_1 = event_evolution.plot_features(['RELATIVE_SPEED', 'MISS_DISTANCE', 'OBJECT1_CT_T'], return_axs=True, linewidth=0.1, color='red', alpha=0.33, label='Prediction')\n",
242+
" event.plot_features(['RELATIVE_SPEED', 'MISS_DISTANCE', 'OBJECT1_CT_T'], axs=axs_1, label='Real', legend=True,file_name=f'features_event_{i}_cdm_{j}.pdf')\n",
243+
" #we plot the uncertainties (ground truth & prediction)\n",
244+
" axs_2 = event_evolution.plot_uncertainty(return_axs=True, linewidth=0.5, label='Prediction', alpha=0.5, color='red', legend=True, diagonal=False)\n",
245+
" event.plot_uncertainty(axs=axs_2, label='Real', diagonal=False, file_name=f'uncertainties_event_{i}_cdm_{j}.pdf')"
246+
]
247+
},
248+
{
249+
"cell_type": "markdown",
250+
"id": "actual-effectiveness",
251+
"metadata": {},
252+
"source": [
253+
"# Training set test\n",
254+
"We check if the model is able to predict the CDMs on the training set"
255+
]
256+
},
257+
{
258+
"cell_type": "code",
259+
"execution_count": null,
260+
"id": "enclosed-europe",
261+
"metadata": {},
262+
"outputs": [],
263+
"source": [
264+
"\n",
265+
"#we loop over some training set events, to check the NN performances:\n",
266+
"num_events=10\n",
267+
"for i in range(0,num_events):\n",
268+
" event=events_train_and_val[i]\n",
269+
" len_ev=len(event)\n",
270+
" for j in range(1,len_ev):\n",
271+
" print(j)\n",
272+
" event_evolution = model.predict_event(event[0:j],num_samples=10)\n",
273+
" #we plot the features (ground truth & prediction)\n",
274+
" axs_1 = event_evolution.plot_features(['RELATIVE_SPEED', 'MISS_DISTANCE', 'OBJECT1_CT_T'], return_axs=True, linewidth=0.1, color='red', alpha=0.33, label='Prediction')\n",
275+
" event.plot_features(['RELATIVE_SPEED', 'MISS_DISTANCE', 'OBJECT1_CT_T'], axs=axs_1, label='Real', legend=True,file_name=f'training_set_features_event_{i}_cdm_{j}.pdf')\n",
276+
" #we plot the uncertainties (ground truth & prediction)\n",
277+
" axs_2 = event_evolution.plot_uncertainty(return_axs=True, linewidth=0.5, label='Prediction', alpha=0.5, color='red', legend=True, diagonal=False)\n",
278+
" event.plot_uncertainty(axs=axs_2, label='Real', diagonal=False, file_name=f'training_set_uncertainties_event_{i}_cdm_{j}.pdf')"
279+
]
280+
}
281+
],
282+
"metadata": {
283+
"kernelspec": {
284+
"display_name": "Python 3",
285+
"language": "python",
286+
"name": "python3"
287+
},
288+
"language_info": {
289+
"codemirror_mode": {
290+
"name": "ipython",
291+
"version": 3
292+
},
293+
"file_extension": ".py",
294+
"mimetype": "text/x-python",
295+
"name": "python",
296+
"nbconvert_exporter": "python",
297+
"pygments_lexer": "ipython3",
298+
"version": "3.7.9"
299+
}
300+
},
301+
"nbformat": 4,
302+
"nbformat_minor": 5
303+
}

0 commit comments

Comments
 (0)