We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
2 parents 33661e7 + 07af00f commit 971344bCopy full SHA for 971344b
1 file changed
pufferlib/pytorch.py
@@ -191,6 +191,9 @@ def sample_logits(logits, action=None):
191
if isinstance(logits, torch.distributions.Normal):
192
batch = logits.loc.shape[0]
193
if action is None:
194
+ mean = torch.nan_to_num(logits.loc, 0.0, 0.0, 0.0)
195
+ std = torch.nan_to_num(logits.scale, 1.0, 1.0, 1.0)
196
+ logits = torch.distributions.Normal(mean, std)
197
action = logits.sample().view(batch, -1)
198
199
log_probs = logits.log_prob(action.view(batch, -1)).sum(1)
0 commit comments