1+ # This code is part of Kessler, a machine learning library for spacecraft collision avoidance.
2+ #
3+ # Copyright (c) 2020-
4+ # Trillium Technologies
5+ # University of Oxford
6+ # Giacomo Acciarini (giacomo.acciarini@gmail.com)
7+ # and other contributors, see README in root of repository.
8+ #
9+ # GNU General Public License version 3. See LICENSE in root of repository.
10+
111import matplotlib as mpl
212import matplotlib .pyplot as plt
313import matplotlib .image as mpimg
414import os
515import uuid
616import tempfile
7- import pyprob
17+ import pyro
18+ import dsgp4
819from pyprob .distributions import Empirical
920import numpy as np
1021import torch
@@ -24,7 +35,7 @@ def plot_mix(mix, min_val=-10, max_val=10, resolution=1000, figsize=(10, 5), xla
2435 fig , ax = plt .subplots (figsize = figsize )
2536 fig .tight_layout ()
2637 ax .grid ()
27- xvals = np .linspace (min_val , max_val , resolution )
38+ xvals = torch .linspace (min_val , max_val , resolution )
2839 ax .plot (xvals , [torch .exp (mix .log_prob (x )) for x in xvals ], * args , ** kwargs )
2940 if log_xscale :
3041 ax .set_xscale ('log' )
@@ -188,8 +199,6 @@ def plot_dist(dists, file_name=None, n_bins=30, num_resample=None, trace=None, f
188199 if len (marginal_dists [i ]['dist_conj' ]) > 0 :
189200 marginal_dists [i ]['dist_time_cdm' ] = marginal_dists [i ]['dist_events_with_conjunction' ].map (lambda t :t ['time_cdm' ])
190201
191- pyprob .set_verbosity (2 )
192-
193202 fig , axs = plt .subplots (8 , 4 , figsize = figsize )
194203
195204 t_color = 'green'
@@ -580,41 +589,44 @@ def plot_dist(dists, file_name=None, n_bins=30, num_resample=None, trace=None, f
580589def plot_trace_orbit (trace , time_upsample_factor = 100 , figsize = (10 , 8 ), file_name = None ):
581590 t_color , c_color = 'red' , 'forestgreen'
582591
583- time0 = float (trace ['time0' ])
584- max_duration_days = float (trace ['max_duration_days' ])
585- delta_time = float (trace ['delta_time' ])
592+ time0 = float (trace . nodes ['time0' ][ 'value ' ])
593+ max_duration_days = float (trace . nodes ['max_duration_days' ][ 'value ' ])
594+ delta_time = float (trace . nodes ['delta_time' ][ 'value ' ])
586595 times = np .arange (time0 , time0 + max_duration_days , delta_time )
587596
588- t_mean_motion = float (trace ['t_mean_motion' ])
589- t_mean_motion_first_derivative = float (trace ['t_mean_motion_first_derivative' ])
590- t_mean_motion_second_derivative = float (trace ['t_mean_motion_second_derivative' ])
591- t_eccentricity = float (trace ['t_eccentricity' ])
592- t_inclination = float (trace ['t_inclination' ])
593- t_argument_of_perigee = float (trace ['t_argument_of_perigee' ])
594- t_raan = float (trace ['t_raan' ])
595- t_mean_anomaly = float (trace ['t_mean_anomaly' ])
596- t_b_star = float (trace ['t_b_star' ])
597-
598- util . lpop_init ( trace [ 't_tle0' ])
597+ t_mean_motion = float (trace . nodes ['t_mean_motion' ][ 'value ' ])
598+ t_mean_motion_first_derivative = float (trace . nodes ['t_mean_motion_first_derivative' ][ 'value ' ])
599+ t_mean_motion_second_derivative = float (trace . nodes ['t_mean_motion_second_derivative' ][ 'value ' ])
600+ t_eccentricity = float (trace . nodes ['t_eccentricity' ][ 'value ' ])
601+ t_inclination = float (trace . nodes ['t_inclination' ][ 'value ' ])
602+ t_argument_of_perigee = float (trace . nodes ['t_argument_of_perigee' ][ 'value ' ])
603+ t_raan = float (trace . nodes ['t_raan' ][ 'value ' ])
604+ t_mean_anomaly = float (trace . nodes ['t_mean_anomaly' ][ 'value ' ])
605+ t_b_star = float (trace . nodes ['t_b_star' ][ 'value ' ])
606+
607+ t_tle = trace . nodes [ 't_tle' ][ 'infer' ][ 't_tle' ]
599608 try :
600- t_states = util .lpop_sequence_upsample (times , time_upsample_factor )
609+ dsgp4 .initialize_tle (t_tle )
610+ t_states = util .propagate_upsample (tle = t_tle , times_mjd = times , upsample_factor = time_upsample_factor )
601611 t_prop_error = False
602612 except RuntimeError as e :
613+ print (f'Error during target propagation: { e } ' )
603614 t_prop_error = True
604615
605- c_mean_motion = float (trace ['c_mean_motion' ])
606- c_mean_motion_first_derivative = float (trace ['c_mean_motion_first_derivative' ])
607- c_mean_motion_second_derivative = float (trace ['c_mean_motion_second_derivative' ])
608- c_eccentricity = float (trace ['c_eccentricity' ])
609- c_inclination = float (trace ['c_inclination' ])
610- c_argument_of_perigee = float (trace ['c_argument_of_perigee' ])
611- c_raan = float (trace ['c_raan' ])
612- c_mean_anomaly = float (trace ['c_mean_anomaly' ])
613- c_b_star = float (trace ['c_b_star' ])
614-
615- util . lpop_init ( trace [ 'c_tle0' ])
616+ c_mean_motion = float (trace . nodes ['c_mean_motion' ][ 'value ' ])
617+ c_mean_motion_first_derivative = float (trace . nodes ['c_mean_motion_first_derivative' ][ 'value ' ])
618+ c_mean_motion_second_derivative = float (trace . nodes ['c_mean_motion_second_derivative' ][ 'value ' ])
619+ c_eccentricity = float (trace . nodes ['c_eccentricity' ][ 'value ' ])
620+ c_inclination = float (trace . nodes ['c_inclination' ][ 'value ' ])
621+ c_argument_of_perigee = float (trace . nodes ['c_argument_of_perigee' ][ 'value ' ])
622+ c_raan = float (trace . nodes ['c_raan' ][ 'value ' ])
623+ c_mean_anomaly = float (trace . nodes ['c_mean_anomaly' ][ 'value ' ])
624+ c_b_star = float (trace . nodes ['c_b_star' ][ 'value ' ])
625+
626+ c_tle = trace . nodes [ 'c_tle' ][ 'infer' ][ 'c_tle' ]
616627 try :
617- c_states = util .lpop_sequence_upsample (times , time_upsample_factor )
628+ dsgp4 .initialize_tle (c_tle )
629+ c_states = util .propagate_upsample (tle = c_tle , times_mjd = times , upsample_factor = time_upsample_factor )
618630 c_prop_error = False
619631 except RuntimeError as e :
620632 c_prop_error = True
@@ -635,8 +647,8 @@ def plot_trace_orbit(trace, time_upsample_factor=100, figsize=(10, 8), file_name
635647 if not c_prop_error :
636648 ax .plot (c_states [:,0 ,0 ], c_states [:,0 ,1 ], c_states [:,0 ,2 ], alpha = 0.75 , color = c_color )
637649# set_axes_equal(ax)
638- if trace ['conj' ]:
639- i_conj = int (trace ['i_conj' ])
650+ if trace . nodes ['conj' ][ 'value ' ]:
651+ i_conj = int (trace . nodes ['i_conj' ][ 'value ' ])
640652 if not t_prop_error :
641653 t_pos_conj = t_states [i_conj , 0 ]
642654 ax .scatter (t_pos_conj [0 ], t_pos_conj [1 ], t_pos_conj [2 ], s = 1e3 , marker = '*' , color = 'green' )
0 commit comments