Skip to content

Commit 805fb1a

Browse files
authored
Merge pull request #174 from EleutherAI/pr/latents-cache-cleanup
Move to cpu when computing firing counts
2 parents 2795871 + 2ddfb21 commit 805fb1a

1 file changed

Lines changed: 3 additions & 7 deletions

File tree

delphi/latents/cache.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -282,18 +282,14 @@ def run(self, n_tokens: int, tokens: token_tensor_type):
282282
latents
283283
)
284284
self.cache.add(sae_latents, batch, batch_number, hookpoint)
285-
firing_counts = (sae_latents > 0).sum((0, 1))
285+
firing_counts = (sae_latents.cpu() > 0).sum((0, 1))
286286
if self.width is None:
287287
self.width = sae_latents.shape[2]
288288

289289
if hookpoint not in self.hookpoint_firing_counts:
290-
self.hookpoint_firing_counts[hookpoint] = (
291-
firing_counts.cpu()
292-
)
290+
self.hookpoint_firing_counts[hookpoint] = firing_counts
293291
else:
294-
self.hookpoint_firing_counts[
295-
hookpoint
296-
] += firing_counts.cpu()
292+
self.hookpoint_firing_counts[hookpoint] += firing_counts
297293

298294
# Update the progress bar
299295
pbar.update(1)

0 commit comments

Comments
 (0)