Skip to content

Commit 07af00f

Browse files
FinlaySandersstmio
andcommitted
continuous sampling sanitisation
Co-authored-by: Sam Turner <98767222+stmio@users.noreply.github.com>
1 parent 33661e7 commit 07af00f

1 file changed

Lines changed: 3 additions & 0 deletions

File tree

pufferlib/pytorch.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,9 @@ def sample_logits(logits, action=None):
191191
if isinstance(logits, torch.distributions.Normal):
192192
batch = logits.loc.shape[0]
193193
if action is None:
194+
mean = torch.nan_to_num(logits.loc, 0.0, 0.0, 0.0)
195+
std = torch.nan_to_num(logits.scale, 1.0, 1.0, 1.0)
196+
logits = torch.distributions.Normal(mean, std)
194197
action = logits.sample().view(batch, -1)
195198

196199
log_probs = logits.log_prob(action.view(batch, -1)).sum(1)

0 commit comments

Comments
 (0)