Skip to content

Commit f3291e6

Browse files
committed
updates
1 parent 0d75ed5 commit f3291e6

6 files changed

Lines changed: 35 additions & 131 deletions

File tree

kessler/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
# GNU General Public License version 3. See LICENSE in root of repository.
1010

1111

12-
__version__ = '0.1.2.dev2'
12+
__version__ = '0.1.2.dev3'
1313

14+
from .util import seed
1415
from .cdm import ConjunctionDataMessage, CDM
1516
from .event import Event, EventDataset

kessler/data.py

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -8,47 +8,14 @@
88
#
99
# GNU General Public License version 3. See LICENSE in root of repository.
1010

11-
12-
import uuid
13-
import os
14-
import torch
1511
import pandas as pd
1612
from datetime import datetime, timedelta
1713

1814
from . import util
19-
from .models import Conjunction
2015
from .cdm import CDM
2116
from .event import Event, EventDataset
2217

2318

24-
def generate_event_dataset(dataset_dir, num_events, save_traces=False, verbosity=True, *args, **kwargs):
25-
model = Conjunction(*args, **kwargs)
26-
if verbosity:
27-
print('Generating CDM dataset')
28-
print('Directory: {}'.format(dataset_dir))
29-
30-
util.create_path(dataset_dir, directory=True)
31-
for i in range(num_events):
32-
if verbosity:
33-
print('Generating event {} / {}'.format(i+1, num_events))
34-
file_name_event = os.path.join(dataset_dir, 'event_{}'.format(str(uuid.uuid4())))
35-
36-
trace = model.get_conjunction()
37-
if save_traces:
38-
file_name_trace = file_name_event + '.trace'
39-
if verbosity:
40-
print('Saving trace: {}'.format(file_name_trace))
41-
torch.save(trace, file_name_trace)
42-
43-
cdms = trace['cdms']
44-
for j, cdm in enumerate(cdms):
45-
file_name_suffix = '{}'.format(j).rjust(len('{}'.format(len(cdms))), '0')
46-
file_name_cdm = file_name_event + '_{}.cdm.kvn.txt'.format(file_name_suffix)
47-
if verbosity:
48-
print('Saving cdm : {}'.format(file_name_cdm))
49-
cdm.save(file_name_cdm)
50-
51-
5219
def kelvins_to_event_dataset(file_name, num_events=None, date_tca=None, remove_outliers=True, drop_features=['c_rcs_estimate', 't_rcs_estimate']):
5320
print('Loading Kelvins dataset from file name: {}'.format(file_name))
5421
kelvins = pd.read_csv(file_name)

kessler/event.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import re
2020

2121
from . import util
22-
from .cdm import ConjunctionDataMessage
2322
from .cdm import CDM
2423

2524
mpl.rcParams['axes.unicode_minus'] = False
@@ -33,7 +32,7 @@ def __init__(self, cdms=None, cdm_file_names=None):
3332
raise RuntimeError('Expecting only one of cdms, cdm_file_names, not both')
3433
self._cdms = cdms
3534
elif cdm_file_names is not None:
36-
self._cdms = [ConjunctionDataMessage(file_name) for file_name in cdm_file_names]
35+
self._cdms = [CDM(file_name) for file_name in cdm_file_names]
3736
else:
3837
self._cdms = []
3938
self._update_cdm_extra_features()
@@ -48,7 +47,7 @@ def _update_cdm_extra_features(self):
4847
cdm._values_extra['__DAYS_TO_TCA'] = cdm._values_extra['__TCA'] - cdm._values_extra['__CREATION_DATE']
4948

5049
def add(self, cdm, return_result=False):
51-
if isinstance(cdm, ConjunctionDataMessage):
50+
if isinstance(cdm, CDM):
5251
self._cdms.append(cdm)
5352
elif isinstance(cdm, list):
5453
for c in cdm:

kessler/nn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ def learn(self, event_set, epochs=2, lr=1e-3, batch_size=8, device='cpu', valid_
216216
self._hist_train_loss_iters.append(total_iters)
217217
self._hist_train_loss.append(train_loss)
218218

219-
print('iter {} | minibatch {}/{} | epoch {}/{} | train loss {:.4e} | valid loss {:.4e}'.format(total_iters, i_minibatch+1, len(train_loader), epoch+1, epochs, train_loss, valid_loss), end='\r')
219+
print('iter {} | minibatch {}/{} | epoch {}/{} | train loss {:.4e} | valid loss {:.4e} '.format(total_iters, i_minibatch+1, len(train_loader), epoch+1, epochs, train_loss, valid_loss), end='\r')
220220
sys.stdout.flush()
221221

222222
if file_name_prefix is not None:
@@ -233,7 +233,7 @@ def predict(self, event):
233233
self.reset(1)
234234
output = self.forward(input.unsqueeze(0), input_length.unsqueeze(0)).squeeze()
235235
if util.has_nan_or_inf(output):
236-
raise RuntimeError('Network output has nan or inf:\n'.format(output))
236+
raise RuntimeError('Network output has nan or inf: {}\n'.format(output))
237237
if output.ndim == 1:
238238
output_last = output
239239
else:

kessler/util.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,27 @@
2020
import skyfield.sgp4lib
2121
import datetime
2222
import functools
23+
import random
2324

2425

2526
_print_refresh_rate = 0.25 #
2627

2728

29+
def seed(seed=None):
30+
if seed is None:
31+
seed = int((time.time()*1e6) % 1e8)
32+
global _random_seed
33+
_random_seed = seed
34+
random.seed(seed)
35+
np.random.seed(seed)
36+
torch.manual_seed(seed)
37+
if torch.cuda.is_available():
38+
torch.cuda.manual_seed(seed)
39+
40+
41+
seed()
42+
43+
2844
# This function is from python-sgp4 released under MIT License, (c) 2012–2016 Brandon Rhodes
2945
def compute_checksum(line):
3046
return sum((int(c) if c.isdigit() else c == '-') for c in line[0:68]) % 10

notebooks/001_get_started.ipynb

Lines changed: 13 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,21 @@
33
{
44
"cell_type": "code",
55
"execution_count": null,
6-
"id": "micro-minneapolis",
76
"metadata": {},
87
"outputs": [],
98
"source": [
109
"import kessler\n",
1110
"from kessler import EventDataset\n",
1211
"from kessler.nn import LSTMPredictor\n",
1312
"from kessler.data import kelvins_to_event_dataset\n",
13+
"import pandas as pd\n",
1414
"\n",
15-
"import pandas as pd"
15+
"# Set the random number generator seed for reproducibility\n",
16+
"kessler.seed(1)"
1617
]
1718
},
1819
{
1920
"cell_type": "markdown",
20-
"id": "flexible-algorithm",
2121
"metadata": {},
2222
"source": [
2323
"# Data Loading\n",
@@ -28,21 +28,19 @@
2828
{
2929
"cell_type": "code",
3030
"execution_count": null,
31-
"id": "ahead-beach",
3231
"metadata": {},
3332
"outputs": [],
3433
"source": [
3534
"#As an example, we first show the case in which the data comes from the Kelvins competition.\n",
3635
"#For this, we built a specific converter that takes care of the conversion from Kelvins format\n",
3736
"#to standard CDM format (the data can be downloaded at https://kelvins.esa.int/collision-avoidance-challenge/data/):\n",
38-
"file_name = 'path_to_csv/train_data.csv'\n",
39-
"events = kelvins_to_event_dataset(file_name, drop_features=['c_rcs_estimate', 't_rcs_estimate'], num_events=200) #we use only 200 events"
37+
"file_name = '/home/gunes/data/kelvins/train_data/train_data.csv'\n",
38+
"events = kelvins_to_event_dataset(file_name, drop_features=['c_rcs_estimate', 't_rcs_estimate'], num_events=1000) #we use only 200 events"
4039
]
4140
},
4241
{
4342
"cell_type": "code",
4443
"execution_count": null,
45-
"id": "formed-recognition",
4644
"metadata": {},
4745
"outputs": [],
4846
"source": [
@@ -55,7 +53,6 @@
5553
},
5654
{
5755
"cell_type": "markdown",
58-
"id": "weekly-baltimore",
5956
"metadata": {},
6057
"source": [
6158
"# Descriptive Statistics"
@@ -64,7 +61,6 @@
6461
{
6562
"cell_type": "code",
6663
"execution_count": null,
67-
"id": "demonstrated-clothing",
6864
"metadata": {},
6965
"outputs": [],
7066
"source": [
@@ -75,7 +71,6 @@
7571
},
7672
{
7773
"cell_type": "markdown",
78-
"id": "upper-columbus",
7974
"metadata": {},
8075
"source": [
8176
"# LSTM Training"
@@ -84,7 +79,6 @@
8479
{
8580
"cell_type": "code",
8681
"execution_count": null,
87-
"id": "intense-massage",
8882
"metadata": {},
8983
"outputs": [],
9084
"source": [
@@ -98,7 +92,6 @@
9892
{
9993
"cell_type": "code",
10094
"execution_count": null,
101-
"id": "norman-value",
10295
"metadata": {},
10396
"outputs": [],
10497
"source": [
@@ -117,17 +110,20 @@
117110
{
118111
"cell_type": "code",
119112
"execution_count": null,
120-
"id": "corporate-gardening",
121113
"metadata": {},
122114
"outputs": [],
123115
"source": [
124116
"# Create an LSTM predictor, specialized to the nn_features we extracted above\n",
125-
"model = LSTMPredictor(features=nn_features)\n",
117+
"model = LSTMPredictor(\n",
118+
" lstm_size=256, # Number of hidden units per LSTM layer\n",
119+
" lstm_depth=2, # Number of stacked LSTM layers\n",
120+
" dropout=0.2, # Dropout probability\n",
121+
" features=nn_features) # The list of feature names to use in the LSTM\n",
126122
"\n",
127123
"# Start training\n",
128124
"model.learn(events_train_and_val, \n",
129-
" epochs=3, # Number of epochs (one epoch is one full pass through the training dataset)\n",
130-
" lr=1e-4, # Learning rate, can decrease it if training diverges\n",
125+
" epochs=10, # Number of epochs (one epoch is one full pass through the training dataset)\n",
126+
" lr=1e-3, # Learning rate, can decrease it if training diverges\n",
131127
" batch_size=16, # Minibatch size, can be decreased if there are issues with memory use\n",
132128
" device='cpu', # Can be 'cuda' if there is a GPU available\n",
133129
" valid_proportion=0.15, # Proportion of the data to use as a validation set internally\n",
@@ -138,7 +134,6 @@
138134
{
139135
"cell_type": "code",
140136
"execution_count": null,
141-
"id": "egyptian-yemen",
142137
"metadata": {},
143138
"outputs": [],
144139
"source": [
@@ -149,7 +144,6 @@
149144
{
150145
"cell_type": "code",
151146
"execution_count": null,
152-
"id": "alert-furniture",
153147
"metadata": {},
154148
"outputs": [],
155149
"source": [
@@ -160,7 +154,6 @@
160154
{
161155
"cell_type": "code",
162156
"execution_count": null,
163-
"id": "compressed-democracy",
164157
"metadata": {},
165158
"outputs": [],
166159
"source": [
@@ -171,7 +164,6 @@
171164
{
172165
"cell_type": "code",
173166
"execution_count": null,
174-
"id": "contemporary-professional",
175167
"metadata": {},
176168
"outputs": [],
177169
"source": [
@@ -187,7 +179,6 @@
187179
{
188180
"cell_type": "code",
189181
"execution_count": null,
190-
"id": "collected-chaos",
191182
"metadata": {},
192183
"outputs": [],
193184
"source": [
@@ -200,83 +191,13 @@
200191
{
201192
"cell_type": "code",
202193
"execution_count": null,
203-
"id": "grateful-billion",
204194
"metadata": {},
205195
"outputs": [],
206196
"source": [
207197
"#we now plot the uncertainty prediction for all the covariance matrix elements of both OBJECT1 and OBJECT2:\n",
208198
"axs = event_evolution.plot_uncertainty(return_axs=True, linewidth=0.5, label='Prediction', alpha=0.5, color='red', legend=True, diagonal=False)\n",
209199
"event.plot_uncertainty(axs=axs, label='Real', diagonal=False)"
210200
]
211-
},
212-
{
213-
"cell_type": "markdown",
214-
"id": "graphic-impression",
215-
"metadata": {},
216-
"source": [
217-
"# Plotting loop over all the events & CDMs\n",
218-
"You can here customize the features to be plotted: we use relative speed, miss distance, and a covariance value:"
219-
]
220-
},
221-
{
222-
"cell_type": "code",
223-
"execution_count": null,
224-
"id": "going-memory",
225-
"metadata": {},
226-
"outputs": [],
227-
"source": [
228-
"#we loop over the test set events:\n",
229-
"predict_full_event=False\n",
230-
"for i in range(0,len(events_test)):\n",
231-
" event=events_test[i]\n",
232-
" len_ev=len(event)\n",
233-
" for j in range(1,len_ev):\n",
234-
" #print(j)\n",
235-
" if predict_full_event:\n",
236-
" event_evolution = model.predict_event(event[0:j],num_samples=10)\n",
237-
" else:\n",
238-
" event_evolution = model.predict_event_step(event[0:j],num_samples=10)\n",
239-
"\n",
240-
" #we plot the features (ground truth & prediction)\n",
241-
" axs_1 = event_evolution.plot_features(['RELATIVE_SPEED', 'MISS_DISTANCE', 'OBJECT1_CT_T'], return_axs=True, linewidth=0.1, color='red', alpha=0.33, label='Prediction')\n",
242-
" event.plot_features(['RELATIVE_SPEED', 'MISS_DISTANCE', 'OBJECT1_CT_T'], axs=axs_1, label='Real', legend=True,file_name=f'features_event_{i}_cdm_{j}.pdf')\n",
243-
" #we plot the uncertainties (ground truth & prediction)\n",
244-
" axs_2 = event_evolution.plot_uncertainty(return_axs=True, linewidth=0.5, label='Prediction', alpha=0.5, color='red', legend=True, diagonal=False)\n",
245-
" event.plot_uncertainty(axs=axs_2, label='Real', diagonal=False, file_name=f'uncertainties_event_{i}_cdm_{j}.pdf')"
246-
]
247-
},
248-
{
249-
"cell_type": "markdown",
250-
"id": "actual-effectiveness",
251-
"metadata": {},
252-
"source": [
253-
"# Training set test\n",
254-
"We check if the model is able to predict the CDMs on the training set"
255-
]
256-
},
257-
{
258-
"cell_type": "code",
259-
"execution_count": null,
260-
"id": "enclosed-europe",
261-
"metadata": {},
262-
"outputs": [],
263-
"source": [
264-
"\n",
265-
"#we loop over some training set events, to check the NN performances:\n",
266-
"num_events=10\n",
267-
"for i in range(0,num_events):\n",
268-
" event=events_train_and_val[i]\n",
269-
" len_ev=len(event)\n",
270-
" for j in range(1,len_ev):\n",
271-
" print(j)\n",
272-
" event_evolution = model.predict_event(event[0:j],num_samples=10)\n",
273-
" #we plot the features (ground truth & prediction)\n",
274-
" axs_1 = event_evolution.plot_features(['RELATIVE_SPEED', 'MISS_DISTANCE', 'OBJECT1_CT_T'], return_axs=True, linewidth=0.1, color='red', alpha=0.33, label='Prediction')\n",
275-
" event.plot_features(['RELATIVE_SPEED', 'MISS_DISTANCE', 'OBJECT1_CT_T'], axs=axs_1, label='Real', legend=True,file_name=f'training_set_features_event_{i}_cdm_{j}.pdf')\n",
276-
" #we plot the uncertainties (ground truth & prediction)\n",
277-
" axs_2 = event_evolution.plot_uncertainty(return_axs=True, linewidth=0.5, label='Prediction', alpha=0.5, color='red', legend=True, diagonal=False)\n",
278-
" event.plot_uncertainty(axs=axs_2, label='Real', diagonal=False, file_name=f'training_set_uncertainties_event_{i}_cdm_{j}.pdf')"
279-
]
280201
}
281202
],
282203
"metadata": {
@@ -295,7 +216,7 @@
295216
"name": "python",
296217
"nbconvert_exporter": "python",
297218
"pygments_lexer": "ipython3",
298-
"version": "3.7.9"
219+
"version": "3.8.5"
299220
}
300221
},
301222
"nbformat": 4,

0 commit comments

Comments
 (0)