Skip to content

Commit 36b313e

Browse files
committed
tests
1 parent 2445062 commit 36b313e

2 files changed

Lines changed: 13 additions & 13 deletions

File tree

kessler/model.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -308,14 +308,14 @@ def make_chaser(self):
308308
mean_anomaly = pyro.sample('c_mean_anomaly', self._prior_dict['mean_anomaly_prior'])
309309
tle = self._c_tle.copy()
310310
tle.update({'mean_anomaly': mean_anomaly})
311-
pyro.deterministic('c_mean_motion',tle.mean_motion)
312-
pyro.deterministic('c_eccentricity',tle.eccentricity)
313-
pyro.deterministic('c_inclination',tle.inclination)
314-
pyro.deterministic('c_argument_of_perigee',tle.argument_of_perigee)
315-
pyro.deterministic('c_raan',tle.raan)
316-
pyro.deterministic('c_mean_motion_first_derivative',tle.mean_motion_first_derivative)
317-
pyro.deterministic('c_mean_motion_second_derivative',tle.mean_motion_second_derivative)
318-
pyro.deterministic('c_b_star',tle.b_star)
311+
pyro.deterministic('c_mean_motion',torch.tensor(tle.mean_motion))
312+
pyro.deterministic('c_eccentricity',torch.tensor(tle.eccentricity))
313+
pyro.deterministic('c_inclination',torch.tensor(tle.inclination))
314+
pyro.deterministic('c_argument_of_perigee',torch.tensor(tle.argument_of_perigee))
315+
pyro.deterministic('c_raan',torch.tensor(tle.raan))
316+
pyro.deterministic('c_mean_motion_first_derivative',torch.tensor(tle.mean_motion_first_derivative))
317+
pyro.deterministic('c_mean_motion_second_derivative',torch.tensor(tle.mean_motion_second_derivative))
318+
pyro.deterministic('c_b_star',torch.tensor(tle.b_star))
319319
return tle
320320

321321
def generate_cdm(self,

tests/test_util.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -139,11 +139,11 @@ def test_TruncatedNormal(self):
139139
max=0.68639
140140

141141
categorical=Categorical(probs=probs)
142-
batched_truncated_normal = kessler.util.TruncatedNormal(loc=locs, scale=scales, min=min, max=max)
142+
batched_truncated_normal = kessler.util.TruncatedNormal(loc=locs, scale=scales, low=min, high=max)
143143
mix_truncated=MixtureSameFamily(categorical, batched_truncated_normal)
144-
self.assertAlmostEqual(mix_truncated.log_prob(torch.tensor(0.0001)).item(), 7.382209300994873, places=8)
145-
self.assertAlmostEqual(mix_truncated.log_prob(torch.tensor(0.001)).item(), 5.485926151275635, places=8)
146-
self.assertAlmostEqual(mix_truncated.log_prob(torch.tensor(0.01)).item(), 1.863307237625122, places=8)
147-
self.assertAlmostEqual(mix_truncated.log_prob(torch.tensor(0.1)).item(), -2.458112955093384, places=8)
144+
self.assertAlmostEqual(mix_truncated.log_prob(torch.tensor(0.0001)).item(), 7.382209300994873, places=6)
145+
self.assertAlmostEqual(mix_truncated.log_prob(torch.tensor(0.001)).item(), 5.485926151275635, places=6)
146+
self.assertAlmostEqual(mix_truncated.log_prob(torch.tensor(0.01)).item(), 1.863307237625122, places=6)
147+
self.assertAlmostEqual(mix_truncated.log_prob(torch.tensor(0.1)).item(), -2.458112955093384, places=6)
148148

149149

0 commit comments

Comments
 (0)