Skip to content

Commit c3bdefa

Browse files
committed
add kelvins tutorial
1 parent 6c6296b commit c3bdefa

3 files changed

Lines changed: 227 additions & 1 deletion

File tree

docs/conf.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,8 @@
9696
"LSTM_training.ipynb",
9797
"basics.ipynb",
9898
"probabilistic_programming_module.ipynb",
99-
"plotting.ipynb"
99+
"plotting.ipynb",
100+
"kelvins_dataset.ipynb"
100101
]
101102

102103
latex_engine = "xelatex"
Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"import kessler\n",
10+
"from kessler import EventDataset\n",
11+
"from kessler.nn import LSTMPredictor\n",
12+
"from kessler.data import kelvins_to_event_dataset\n",
13+
"import pandas as pd\n",
14+
"\n",
15+
"# Set the random number generator seed for reproducibility\n",
16+
"kessler.seed(1)"
17+
]
18+
},
19+
{
20+
"cell_type": "markdown",
21+
"metadata": {},
22+
"source": [
23+
"# Data Loading\n",
24+
"\n",
25+
"Kessler accepts CDMs either in KVN format or as pandas dataframes. We hereby show a pandas dataframe loading example:"
26+
]
27+
},
28+
{
29+
"cell_type": "code",
30+
"execution_count": null,
31+
"metadata": {},
32+
"outputs": [],
33+
"source": [
34+
"#As an example, we first show the case in which the data comes from the Kelvins competition.\n",
35+
"#For this, we built a specific converter that takes care of the conversion from Kelvins format\n",
36+
"#to standard CDM format (the data can be downloaded at https://kelvins.esa.int/collision-avoidance-challenge/data/):\n",
37+
"file_name = '/home/gunes/data/kelvins/train_data/train_data.csv'\n",
38+
"events = kelvins_to_event_dataset(file_name, drop_features=['c_rcs_estimate', 't_rcs_estimate'], num_events=1000) #we use only 200 events"
39+
]
40+
},
41+
{
42+
"cell_type": "code",
43+
"execution_count": null,
44+
"metadata": {},
45+
"outputs": [],
46+
"source": [
47+
"#Instead, this is a generic real CDM data loader that should parse your Pandas (uncomment the following lines if needed):\n",
48+
"#file_name = 'path_to_csv/file.csv'\n",
49+
"\n",
50+
"#df=pd.read_csv(file_name)\n",
51+
"#events = EventDataset.from_pandas(df)"
52+
]
53+
},
54+
{
55+
"cell_type": "markdown",
56+
"metadata": {},
57+
"source": [
58+
"# Descriptive Statistics"
59+
]
60+
},
61+
{
62+
"cell_type": "code",
63+
"execution_count": null,
64+
"metadata": {},
65+
"outputs": [],
66+
"source": [
67+
"#Descriptive statistics of the event:\n",
68+
"kessler_stats = events.to_dataframe().describe()\n",
69+
"print(kessler_stats)\n"
70+
]
71+
},
72+
{
73+
"cell_type": "markdown",
74+
"metadata": {},
75+
"source": [
76+
"# LSTM Training"
77+
]
78+
},
79+
{
80+
"cell_type": "code",
81+
"execution_count": null,
82+
"metadata": {},
83+
"outputs": [],
84+
"source": [
85+
"#We only use features with numeric content for the training\n",
86+
"#nn_features is a list of the feature names taken into account for the training:\n",
87+
"#it can be edited in case more features want to be added or removed\n",
88+
"nn_features = events.common_features(only_numeric=True)\n",
89+
"print(nn_features)"
90+
]
91+
},
92+
{
93+
"cell_type": "code",
94+
"execution_count": null,
95+
"metadata": {},
96+
"outputs": [],
97+
"source": [
98+
"# Split data into a test set (5% of the total number of events)\n",
99+
"len_test_set=int(0.05*len(events))\n",
100+
"print('Test data:', len_test_set)\n",
101+
"events_test=events[-len_test_set:]\n",
102+
"print(events_test)\n",
103+
"\n",
104+
"# The rest of the data will be used for training and validation\n",
105+
"print('Training and validation data:', len(events)-len_test_set)\n",
106+
"events_train_and_val=events[:-len_test_set]\n",
107+
"print(events_train_and_val)"
108+
]
109+
},
110+
{
111+
"cell_type": "code",
112+
"execution_count": null,
113+
"metadata": {},
114+
"outputs": [],
115+
"source": [
116+
"# Create an LSTM predictor, specialized to the nn_features we extracted above\n",
117+
"model = LSTMPredictor(\n",
118+
" lstm_size=256, # Number of hidden units per LSTM layer\n",
119+
" lstm_depth=2, # Number of stacked LSTM layers\n",
120+
" dropout=0.2, # Dropout probability\n",
121+
" features=nn_features) # The list of feature names to use in the LSTM\n",
122+
"\n",
123+
"# Start training\n",
124+
"model.learn(events_train_and_val, \n",
125+
" epochs=10, # Number of epochs (one epoch is one full pass through the training dataset)\n",
126+
" lr=1e-3, # Learning rate, can decrease it if training diverges\n",
127+
" batch_size=16, # Minibatch size, can be decreased if there are issues with memory use\n",
128+
" device='cpu', # Can be 'cuda' if there is a GPU available\n",
129+
" valid_proportion=0.15, # Proportion of the data to use as a validation set internally\n",
130+
" num_workers=4, # Number of multithreaded dataloader workers, 4 is good for performance, but if there are any issues or errors, please try num_workers=1 as this solves issues with PyTorch most of the time\n",
131+
" event_samples_for_stats=1000) # Number of events to use to compute NN normalization factors, have this number as big as possible (and at least a few thousands)"
132+
]
133+
},
134+
{
135+
"cell_type": "code",
136+
"execution_count": null,
137+
"metadata": {},
138+
"outputs": [],
139+
"source": [
140+
"#Save the model to a file after training:\n",
141+
"model.save(file_name=\"LSTM_20epochs_lr10-4_batchsize16\")"
142+
]
143+
},
144+
{
145+
"cell_type": "code",
146+
"execution_count": null,
147+
"metadata": {},
148+
"outputs": [],
149+
"source": [
150+
"#NN loss plotted to a file:\n",
151+
"model.plot_loss(file_name='plot_loss.pdf')"
152+
]
153+
},
154+
{
155+
"cell_type": "code",
156+
"execution_count": null,
157+
"metadata": {},
158+
"outputs": [],
159+
"source": [
160+
"#we show an example CDM from the set:\n",
161+
"events_train_and_val[0][0]"
162+
]
163+
},
164+
{
165+
"cell_type": "code",
166+
"execution_count": null,
167+
"metadata": {},
168+
"outputs": [],
169+
"source": [
170+
"#we take a single event, we remove the last CDM and try to predict it\n",
171+
"event=events_test[3]\n",
172+
"event_len = len(event)\n",
173+
"print(event)\n",
174+
"event_beginning = event[0:event_len-1]\n",
175+
"print(event_beginning)\n",
176+
"event_evolution = model.predict_event(event_beginning, num_samples=100, max_length=14)"
177+
]
178+
},
179+
{
180+
"cell_type": "code",
181+
"execution_count": null,
182+
"metadata": {},
183+
"outputs": [],
184+
"source": [
185+
"#We plot the prediction in red:\n",
186+
"axs = event_evolution.plot_features(['RELATIVE_SPEED', 'MISS_DISTANCE', 'OBJECT1_CT_T'], return_axs=True, linewidth=0.1, color='red', alpha=0.33, label='Prediction')\n",
187+
"#and the ground truth value in blue:\n",
188+
"event.plot_features(['RELATIVE_SPEED', 'MISS_DISTANCE', 'OBJECT1_CT_T'], axs=axs, label='Real', legend=True)"
189+
]
190+
},
191+
{
192+
"cell_type": "code",
193+
"execution_count": null,
194+
"metadata": {},
195+
"outputs": [],
196+
"source": [
197+
"#we now plot the uncertainty prediction for all the covariance matrix elements of both OBJECT1 and OBJECT2:\n",
198+
"axs = event_evolution.plot_uncertainty(return_axs=True, linewidth=0.5, label='Prediction', alpha=0.5, color='red', legend=True, diagonal=False)\n",
199+
"event.plot_uncertainty(axs=axs, label='Real', diagonal=False)"
200+
]
201+
}
202+
],
203+
"metadata": {
204+
"kernelspec": {
205+
"display_name": "Python 3",
206+
"language": "python",
207+
"name": "python3"
208+
},
209+
"language_info": {
210+
"codemirror_mode": {
211+
"name": "ipython",
212+
"version": 3
213+
},
214+
"file_extension": ".py",
215+
"mimetype": "text/x-python",
216+
"name": "python",
217+
"nbconvert_exporter": "python",
218+
"pygments_lexer": "ipython3",
219+
"version": "3.8.5"
220+
}
221+
},
222+
"nbformat": 4,
223+
"nbformat_minor": 5
224+
}

docs/tutorials.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ These tutorials include some basic examples on how to use kessler
1313
notebooks/basics.ipynb
1414
notebooks/cdms_analysis_and_plotting.ipynb
1515
notebooks/plotting.ipynb
16+
notebooks/kelvins_dataset.ipynb
1617

1718

1819
Advanced

0 commit comments

Comments
 (0)