Skip to content

Commit 2f9bbde

Browse files
authored
Merge pull request #172 from EleutherAI/pr/sparse-random-contract
feat(sparse-coders): allow for sparse-coder random initialization
2 parents 222a984 + 9b9079c commit 2f9bbde

4 files changed

Lines changed: 94 additions & 7 deletions

File tree

delphi/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,9 @@ class RunConfig(Serializable):
120120
directory containing their weights. Models must be loadable with sparsify
121121
or gemmascope."""
122122

123+
random: bool = False
124+
"""Whether to initialize the sparse models with random weights."""
125+
123126
hookpoints: list[str] = list_field()
124127
"""list of model hookpoints to attach sparse models to."""
125128

delphi/sparse_coders/load_sparsify.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def load_sparsify_sparse_coders(
6565
name: str,
6666
hookpoints: list[str],
6767
device: str | torch.device,
68+
random: bool = False,
6869
compile: bool = False,
6970
) -> dict[str, PotentiallyWrappedSparseCoder]:
7071
"""
@@ -88,18 +89,45 @@ def load_sparsify_sparse_coders(
8889
name_path = Path(name)
8990
if name_path.exists():
9091
for hookpoint in hookpoints:
91-
sparse_model_dict[hookpoint] = SparseCoder.load_from_disk(
92-
name_path / hookpoint, device=device
92+
sparse_model = SparseCoder.load_from_disk(
93+
name_path / hookpoint, device="cpu"
9394
)
95+
# if random, initialize a new sparse model with random weights
96+
if random:
97+
config = sparse_model.cfg
98+
d_in = sparse_model.d_in
99+
dtype = sparse_model.dtype
100+
sparse_model = SparseCoder(
101+
d_in,
102+
config,
103+
device=device,
104+
dtype=dtype,
105+
decoder=False,
106+
)
107+
108+
sparse_model_dict[hookpoint] = sparse_model
94109
if compile:
95110
sparse_model_dict[hookpoint] = torch.compile(
96111
sparse_model_dict[hookpoint]
97112
)
98113
else:
99114
# Load on CPU first to not run out of memory
100115
sparse_models = SparseCoder.load_many(name, device="cpu")
116+
101117
for hookpoint in hookpoints:
102-
sparse_model_dict[hookpoint] = sparse_models[hookpoint].to(device)
118+
sparse_model = sparse_models[hookpoint]
119+
if random:
120+
config = sparse_model.cfg
121+
d_in = sparse_model.d_in
122+
dtype = sparse_model.dtype
123+
sparse_model = SparseCoder(
124+
d_in,
125+
config,
126+
device=device,
127+
dtype=dtype,
128+
decoder=False,
129+
)
130+
sparse_model_dict[hookpoint] = sparse_model.to(device)
103131
if compile:
104132
sparse_model_dict[hookpoint] = torch.compile(
105133
sparse_model_dict[hookpoint]
@@ -113,6 +141,7 @@ def load_sparsify_hooks(
113141
model: PreTrainedModel,
114142
name: str,
115143
hookpoints: list[str],
144+
random: bool = False,
116145
device: str | torch.device | None = None,
117146
compile: bool = False,
118147
) -> tuple[dict[str, Callable], bool]:
@@ -136,6 +165,7 @@ def load_sparsify_hooks(
136165
name,
137166
hookpoints,
138167
device,
168+
random,
139169
compile,
140170
)
141171
hookpoint_to_sparse_encode = {}
@@ -145,7 +175,6 @@ def load_sparsify_hooks(
145175
path_segments = resolve_path(model, hookpoint.split("."))
146176
if path_segments is None:
147177
raise ValueError(f"Could not find valid path for hookpoint: {hookpoint}")
148-
149178
hookpoint_to_sparse_encode[".".join(path_segments)] = partial(
150179
sae_dense_latents, sae=sparse_model
151180
)

delphi/sparse_coders/sparse_model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def load_hooks_sparse_coders(
3636
model,
3737
run_cfg.sparse_model,
3838
run_cfg.hookpoints,
39+
random=run_cfg.random,
3940
compile=compile,
4041
)
4142
else:
@@ -96,7 +97,8 @@ def load_sparse_coders(
9697
run_cfg.sparse_model,
9798
run_cfg.hookpoints,
9899
device,
99-
compile,
100+
random=run_cfg.random,
101+
compile=compile,
100102
)
101103
else:
102104
# model path will always be of the form google/gemma-scope-<size>-pt-<type>/

tests/test_autoencoders/test_sparse_coders.py

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,15 @@
66
from delphi.config import RunConfig
77

88
# Import the function to be tested
9-
from delphi.sparse_coders import load_hooks_sparse_coders
9+
from delphi.sparse_coders import load_hooks_sparse_coders, load_sparse_coders
1010

1111

1212
# A simple dummy run configuration for testing.
1313
class DummyRunConfig:
14-
def __init__(self, sparse_model, hookpoints):
14+
def __init__(self, sparse_model, hookpoints, random=False):
1515
self.sparse_model = sparse_model
1616
self.hookpoints = hookpoints
17+
self.random = random
1718
# Additional required fields can be added here if needed.
1819
self.model = "dummy_model"
1920
self.hf_token = ""
@@ -62,6 +63,7 @@ def run_cfg_sparsify():
6263
return DummyRunConfig(
6364
sparse_model="EleutherAI/sae-pythia-70m-32k",
6465
hookpoints=["layers.4.mlp", "layers.0"],
66+
random=False,
6567
)
6668

6769

@@ -75,6 +77,7 @@ def run_cfg_gemma():
7577
"layer_12/width_131k/average_l0_67",
7678
"layer_12/width_16k/average_l0_22",
7779
],
80+
random=False,
7881
)
7982

8083

@@ -127,3 +130,53 @@ def test_retrieve_autoencoders_from_gemma(dummy_model, run_cfg_gemma):
127130
f"Autoencoder '{key}' from the Gemma branch failed when called:"
128131
f"\n{repr(e)}"
129132
)
133+
134+
135+
def test_load_sparse_coders_forwards_random_and_compile(monkeypatch):
136+
"""Ensure random and compile flags are forwarded for the sparsify path."""
137+
captured: dict[str, object] = {}
138+
139+
def fake_loader(name, hookpoints, device, random=False, compile=False):
140+
captured["name"] = name
141+
captured["hookpoints"] = hookpoints
142+
captured["device"] = device
143+
captured["random"] = random
144+
captured["compile"] = compile
145+
return {"layers.0": object()}
146+
147+
monkeypatch.setattr(
148+
"delphi.sparse_coders.sparse_model.load_sparsify_sparse_coders",
149+
fake_loader,
150+
)
151+
152+
cfg = DummyRunConfig(
153+
sparse_model="EleutherAI/sae-pythia-70m-32k",
154+
hookpoints=["layers.0"],
155+
random=True,
156+
)
157+
158+
result = load_sparse_coders(cfg, device="cpu", compile=True)
159+
160+
assert isinstance(result, dict)
161+
assert captured == {
162+
"name": "EleutherAI/sae-pythia-70m-32k",
163+
"hookpoints": ["layers.0"],
164+
"device": "cpu",
165+
"random": True,
166+
"compile": True,
167+
}
168+
169+
170+
def test_load_sparse_coders_requires_random_field():
171+
"""The run config must explicitly provide a random field."""
172+
173+
class MissingRandomRunConfig:
174+
sparse_model = "EleutherAI/sae-pythia-70m-32k"
175+
hookpoints = ["layers.0"]
176+
177+
@property
178+
def __class__(self) -> type: # type: ignore
179+
return RunConfig
180+
181+
with pytest.raises(AttributeError):
182+
load_sparse_coders(MissingRandomRunConfig(), device="cpu")

0 commit comments

Comments
 (0)