66from delphi .config import RunConfig
77
88# Import the function to be tested
9- from delphi .sparse_coders import load_hooks_sparse_coders
9+ from delphi .sparse_coders import load_hooks_sparse_coders , load_sparse_coders
1010
1111
1212# A simple dummy run configuration for testing.
1313class DummyRunConfig :
14- def __init__ (self , sparse_model , hookpoints ):
14+ def __init__ (self , sparse_model , hookpoints , random = False ):
1515 self .sparse_model = sparse_model
1616 self .hookpoints = hookpoints
17+ self .random = random
1718 # Additional required fields can be added here if needed.
1819 self .model = "dummy_model"
1920 self .hf_token = ""
@@ -62,6 +63,7 @@ def run_cfg_sparsify():
6263 return DummyRunConfig (
6364 sparse_model = "EleutherAI/sae-pythia-70m-32k" ,
6465 hookpoints = ["layers.4.mlp" , "layers.0" ],
66+ random = False ,
6567 )
6668
6769
@@ -75,6 +77,7 @@ def run_cfg_gemma():
7577 "layer_12/width_131k/average_l0_67" ,
7678 "layer_12/width_16k/average_l0_22" ,
7779 ],
80+ random = False ,
7881 )
7982
8083
@@ -127,3 +130,53 @@ def test_retrieve_autoencoders_from_gemma(dummy_model, run_cfg_gemma):
127130 f"Autoencoder '{ key } ' from the Gemma branch failed when called:"
128131 f"\n { repr (e )} "
129132 )
133+
134+
135+ def test_load_sparse_coders_forwards_random_and_compile (monkeypatch ):
136+ """Ensure random and compile flags are forwarded for the sparsify path."""
137+ captured : dict [str , object ] = {}
138+
139+ def fake_loader (name , hookpoints , device , random = False , compile = False ):
140+ captured ["name" ] = name
141+ captured ["hookpoints" ] = hookpoints
142+ captured ["device" ] = device
143+ captured ["random" ] = random
144+ captured ["compile" ] = compile
145+ return {"layers.0" : object ()}
146+
147+ monkeypatch .setattr (
148+ "delphi.sparse_coders.sparse_model.load_sparsify_sparse_coders" ,
149+ fake_loader ,
150+ )
151+
152+ cfg = DummyRunConfig (
153+ sparse_model = "EleutherAI/sae-pythia-70m-32k" ,
154+ hookpoints = ["layers.0" ],
155+ random = True ,
156+ )
157+
158+ result = load_sparse_coders (cfg , device = "cpu" , compile = True )
159+
160+ assert isinstance (result , dict )
161+ assert captured == {
162+ "name" : "EleutherAI/sae-pythia-70m-32k" ,
163+ "hookpoints" : ["layers.0" ],
164+ "device" : "cpu" ,
165+ "random" : True ,
166+ "compile" : True ,
167+ }
168+
169+
170+ def test_load_sparse_coders_requires_random_field ():
171+ """The run config must explicitly provide a random field."""
172+
173+ class MissingRandomRunConfig :
174+ sparse_model = "EleutherAI/sae-pythia-70m-32k"
175+ hookpoints = ["layers.0" ]
176+
177+ @property
178+ def __class__ (self ) -> type : # type: ignore
179+ return RunConfig
180+
181+ with pytest .raises (AttributeError ):
182+ load_sparse_coders (MissingRandomRunConfig (), device = "cpu" )
0 commit comments