Skip to content

Commit 971344b

Browse files
authored
Merge pull request #444 from FinlaySanders/3.0
continuous action sampling sanitisation
2 parents 33661e7 + 07af00f commit 971344b

1 file changed

Lines changed: 3 additions & 0 deletions

File tree

pufferlib/pytorch.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,9 @@ def sample_logits(logits, action=None):
191191
if isinstance(logits, torch.distributions.Normal):
192192
batch = logits.loc.shape[0]
193193
if action is None:
194+
mean = torch.nan_to_num(logits.loc, 0.0, 0.0, 0.0)
195+
std = torch.nan_to_num(logits.scale, 1.0, 1.0, 1.0)
196+
logits = torch.distributions.Normal(mean, std)
194197
action = logits.sample().view(batch, -1)
195198

196199
log_probs = logits.log_prob(action.view(batch, -1)).sum(1)

0 commit comments

Comments
 (0)