Skip to content

Commit ddb166a

Browse files
committed
initial updates to plotting module to suppor pyro, docs update, test
1 parent b803593 commit ddb166a

4 files changed

Lines changed: 303 additions & 123 deletions

File tree

docs/notebooks/plotting.ipynb

Lines changed: 46 additions & 34 deletions
Large diffs are not rendered by default.

kessler/plot.py

Lines changed: 45 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,21 @@
1+
# This code is part of Kessler, a machine learning library for spacecraft collision avoidance.
2+
#
3+
# Copyright (c) 2020-
4+
# Trillium Technologies
5+
# University of Oxford
6+
# Giacomo Acciarini (giacomo.acciarini@gmail.com)
7+
# and other contributors, see README in root of repository.
8+
#
9+
# GNU General Public License version 3. See LICENSE in root of repository.
10+
111
import matplotlib as mpl
212
import matplotlib.pyplot as plt
313
import matplotlib.image as mpimg
414
import os
515
import uuid
616
import tempfile
7-
import pyprob
17+
import pyro
18+
import dsgp4
819
from pyprob.distributions import Empirical
920
import numpy as np
1021
import torch
@@ -24,7 +35,7 @@ def plot_mix(mix, min_val=-10, max_val=10, resolution=1000, figsize=(10, 5), xla
2435
fig, ax = plt.subplots(figsize=figsize)
2536
fig.tight_layout()
2637
ax.grid()
27-
xvals = np.linspace(min_val, max_val, resolution)
38+
xvals = torch.linspace(min_val, max_val, resolution)
2839
ax.plot(xvals, [torch.exp(mix.log_prob(x)) for x in xvals], *args, **kwargs)
2940
if log_xscale:
3041
ax.set_xscale('log')
@@ -188,8 +199,6 @@ def plot_dist(dists, file_name=None, n_bins=30, num_resample=None, trace=None, f
188199
if len(marginal_dists[i]['dist_conj']) > 0:
189200
marginal_dists[i]['dist_time_cdm'] = marginal_dists[i]['dist_events_with_conjunction'].map(lambda t:t['time_cdm'])
190201

191-
pyprob.set_verbosity(2)
192-
193202
fig, axs = plt.subplots(8, 4, figsize=figsize)
194203

195204
t_color = 'green'
@@ -580,41 +589,44 @@ def plot_dist(dists, file_name=None, n_bins=30, num_resample=None, trace=None, f
580589
def plot_trace_orbit(trace, time_upsample_factor=100, figsize=(10, 8), file_name=None):
581590
t_color, c_color = 'red', 'forestgreen'
582591

583-
time0 = float(trace['time0'])
584-
max_duration_days = float(trace['max_duration_days'])
585-
delta_time = float(trace['delta_time'])
592+
time0 = float(trace.nodes['time0']['value'])
593+
max_duration_days = float(trace.nodes['max_duration_days']['value'])
594+
delta_time = float(trace.nodes['delta_time']['value'])
586595
times = np.arange(time0, time0 + max_duration_days, delta_time)
587596

588-
t_mean_motion = float(trace['t_mean_motion'])
589-
t_mean_motion_first_derivative = float(trace['t_mean_motion_first_derivative'])
590-
t_mean_motion_second_derivative= float(trace['t_mean_motion_second_derivative'])
591-
t_eccentricity = float(trace['t_eccentricity'])
592-
t_inclination = float(trace['t_inclination'])
593-
t_argument_of_perigee = float(trace['t_argument_of_perigee'])
594-
t_raan = float(trace['t_raan'])
595-
t_mean_anomaly = float(trace['t_mean_anomaly'])
596-
t_b_star = float(trace['t_b_star'])
597-
598-
util.lpop_init(trace['t_tle0'])
597+
t_mean_motion = float(trace.nodes['t_mean_motion']['value'])
598+
t_mean_motion_first_derivative = float(trace.nodes['t_mean_motion_first_derivative']['value'])
599+
t_mean_motion_second_derivative= float(trace.nodes['t_mean_motion_second_derivative']['value'])
600+
t_eccentricity = float(trace.nodes['t_eccentricity']['value'])
601+
t_inclination = float(trace.nodes['t_inclination']['value'])
602+
t_argument_of_perigee = float(trace.nodes['t_argument_of_perigee']['value'])
603+
t_raan = float(trace.nodes['t_raan']['value'])
604+
t_mean_anomaly = float(trace.nodes['t_mean_anomaly']['value'])
605+
t_b_star = float(trace.nodes['t_b_star']['value'])
606+
607+
t_tle=trace.nodes['t_tle']['infer']['t_tle']
599608
try:
600-
t_states = util.lpop_sequence_upsample(times, time_upsample_factor)
609+
dsgp4.initialize_tle(t_tle)
610+
t_states = util.propagate_upsample(tle=t_tle, times_mjd=times, upsample_factor=time_upsample_factor)
601611
t_prop_error = False
602612
except RuntimeError as e:
613+
print(f'Error during target propagation: {e}')
603614
t_prop_error = True
604615

605-
c_mean_motion = float(trace['c_mean_motion'])
606-
c_mean_motion_first_derivative = float(trace['c_mean_motion_first_derivative'])
607-
c_mean_motion_second_derivative= float(trace['c_mean_motion_second_derivative'])
608-
c_eccentricity = float(trace['c_eccentricity'])
609-
c_inclination = float(trace['c_inclination'])
610-
c_argument_of_perigee = float(trace['c_argument_of_perigee'])
611-
c_raan = float(trace['c_raan'])
612-
c_mean_anomaly = float(trace['c_mean_anomaly'])
613-
c_b_star = float(trace['c_b_star'])
614-
615-
util.lpop_init(trace['c_tle0'])
616+
c_mean_motion = float(trace.nodes['c_mean_motion']['value'])
617+
c_mean_motion_first_derivative = float(trace.nodes['c_mean_motion_first_derivative']['value'])
618+
c_mean_motion_second_derivative= float(trace.nodes['c_mean_motion_second_derivative']['value'])
619+
c_eccentricity = float(trace.nodes['c_eccentricity']['value'])
620+
c_inclination = float(trace.nodes['c_inclination']['value'])
621+
c_argument_of_perigee = float(trace.nodes['c_argument_of_perigee']['value'])
622+
c_raan = float(trace.nodes['c_raan']['value'])
623+
c_mean_anomaly = float(trace.nodes['c_mean_anomaly']['value'])
624+
c_b_star = float(trace.nodes['c_b_star']['value'])
625+
626+
c_tle=trace.nodes['c_tle']['infer']['c_tle']
616627
try:
617-
c_states = util.lpop_sequence_upsample(times, time_upsample_factor)
628+
dsgp4.initialize_tle(c_tle)
629+
c_states = util.propagate_upsample(tle=c_tle, times_mjd=times, upsample_factor=time_upsample_factor)
618630
c_prop_error = False
619631
except RuntimeError as e:
620632
c_prop_error = True
@@ -635,8 +647,8 @@ def plot_trace_orbit(trace, time_upsample_factor=100, figsize=(10, 8), file_name
635647
if not c_prop_error:
636648
ax.plot(c_states[:,0,0], c_states[:,0,1], c_states[:,0,2], alpha=0.75, color=c_color)
637649
# set_axes_equal(ax)
638-
if trace['conj']:
639-
i_conj = int(trace['i_conj'])
650+
if trace.nodes['conj']['value']:
651+
i_conj = int(trace.nodes['i_conj']['value'])
640652
if not t_prop_error:
641653
t_pos_conj = t_states[i_conj, 0]
642654
ax.scatter(t_pos_conj[0], t_pos_conj[1], t_pos_conj[2], s=1e3, marker='*', color='green')

0 commit comments

Comments
 (0)