Skip to content

Commit c845cc8

Browse files
authored
Merge pull request #173 from EleutherAI/pr/scorers-contract-fixes
Fix typing in classifier
2 parents 805fb1a + fb3aa86 commit c845cc8

10 files changed

Lines changed: 275 additions & 22 deletions

File tree

delphi/scorers/classifier/classifier.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ def __init__(
2121
client: Client,
2222
verbose: bool,
2323
n_examples_shown: int,
24-
log_prob: bool,
2524
seed: int = 42,
25+
log_prob: bool = False,
2626
**generation_kwargs,
2727
):
2828
"""
@@ -143,7 +143,8 @@ def _parse(
143143
match = re.search(pattern, string)
144144
if match is None:
145145
raise ValueError("No match found in string")
146-
predictions: list[bool | Literal[0, 1]] = json.loads(match.group(0))
146+
raw_predictions: list[bool | Literal[0, 1]] = json.loads(match.group(0))
147+
predictions = [bool(prediction) for prediction in raw_predictions]
147148
assert len(predictions) == self.n_examples_shown
148149
probabilities = (
149150
self._parse_logprobs(logprobs)

delphi/scorers/classifier/detection.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def __init__(
3939
temperature=temperature,
4040
**generation_kwargs,
4141
)
42+
self.log_prob = log_prob
4243

4344
def prompt(self, examples: str, explanation: str) -> list[dict]:
4445
return detection_prompt(examples, explanation)

delphi/scorers/classifier/fuzz.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def __init__(
5151
temperature=temperature,
5252
**generation_kwargs,
5353
)
54-
54+
self.log_prob = log_prob
5555
self.threshold = threshold
5656
self.fuzz_type = fuzz_type
5757

delphi/scorers/classifier/intruder.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -170,13 +170,15 @@ def _prepare_and_batch(self, record: LatentRecord) -> list[IntruderSentence]:
170170
active_examples = self.rng.sample(all_active_examples, num_active_examples)
171171

172172
# highlights the active tokens with <<>> markers
173-
majority_examples = []
173+
formatted_examples = []
174+
chosen_examples = []
174175
num_active_tokens = 0
175176
for example in active_examples:
176177
text, _str_tokens = _prepare_text(
177178
example, n_incorrect=0, threshold=0.3, highlighted=True
178179
)
179-
majority_examples.append(text)
180+
formatted_examples.append(text)
181+
chosen_examples.append(example)
180182
num_active_tokens += (example.activations > 0).sum().item()
181183

182184
avg_active_tokens_per_example = num_active_tokens // len(active_examples)
@@ -193,6 +195,7 @@ def _prepare_and_batch(self, record: LatentRecord) -> list[IntruderSentence]:
193195
threshold=0.3,
194196
highlighted=True,
195197
)
198+
196199
elif self.type == "internal":
197200
# randomly select a quantile to be the intruder, make sure it's not
198201
# the same as the source quantile
@@ -224,18 +227,23 @@ def _prepare_and_batch(self, record: LatentRecord) -> list[IntruderSentence]:
224227

225228
# select a random index to insert the intruder sentence
226229
intruder_index = self.rng.randint(0, num_active_examples)
227-
examples = (
228-
majority_examples[:intruder_index]
230+
formatted_examples = (
231+
formatted_examples[:intruder_index]
229232
+ [intruder_sentence]
230-
+ majority_examples[intruder_index:]
233+
+ formatted_examples[intruder_index:]
234+
)
235+
examples = (
236+
chosen_examples[:intruder_index]
237+
+ [intruder]
238+
+ chosen_examples[intruder_index:]
231239
)
232240

233241
example_activations = [example.activations.tolist() for example in examples]
234242
example_tokens = [example.str_tokens for example in examples]
235243

236244
batches.append(
237245
IntruderSentence(
238-
examples=examples,
246+
examples=formatted_examples,
239247
intruder_index=intruder_index,
240248
chosen_quantile=active_quantile,
241249
activations=example_activations,

delphi/scorers/classifier/sample.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def _prepare_text(
112112
if n_incorrect == 0:
113113

114114
def is_above_activation_threshold(i: int) -> bool:
115-
return example.activations[i] >= abs_threshold
115+
return bool((example.activations[i] >= abs_threshold).item())
116116

117117
return _highlight(str_toks, is_above_activation_threshold), str_toks
118118

@@ -137,6 +137,7 @@ def is_above_activation_threshold(i: int) -> bool:
137137
# The activating token is always ctx_len - ctx_len//4
138138
# so we always highlight this one, and if num_tokens_to_highlight > 1
139139
# we highlight num_tokens_to_highlight - 1 random ones
140+
# TODO: This is wrong
140141
token_pos = len(str_toks) - len(str_toks) // 4
141142
if token_pos in tokens_below_threshold:
142143
random_indices = [token_pos]

delphi/scorers/embedding/example_embedding.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -118,10 +118,12 @@ def compute_batch_deltas(self, batch: Batch) -> tuple[float, float]:
118118
# Split the embeddings back into their components
119119
n_neg = len(batch.negative_examples)
120120
n_pos = len(batch.positive_examples)
121-
negative_examples_embeddings = all_embeddings[:n_neg]
122-
positive_examples_embeddings = all_embeddings[n_neg : n_neg + n_pos]
123-
positive_query_embedding = all_embeddings[-2].unsqueeze(0)
124-
negative_query_embedding = all_embeddings[-1].unsqueeze(0)
121+
negative_examples_embeddings = torch.tensor(all_embeddings[:n_neg])
122+
positive_examples_embeddings = torch.tensor(
123+
all_embeddings[n_neg : n_neg + n_pos]
124+
)
125+
positive_query_embedding = torch.tensor(all_embeddings[-2]).unsqueeze(0)
126+
negative_query_embedding = torch.tensor(all_embeddings[-1]).unsqueeze(0)
125127

126128
# Compute the similarity between the query and the examples
127129
negative_similarities = self.model.similarity(
@@ -165,9 +167,11 @@ def _create_batches(
165167
# which are going to be used as "explanations"
166168
positive_train_examples = record.train
167169

170+
number_samples = min(len(positive_train_examples), len(record.not_active))
171+
168172
# Sample from the not_active examples
169173
not_active_index = self.random.sample(
170-
range(len(record.not_active)), len(positive_train_examples)
174+
range(len(record.not_active)), number_samples
171175
)
172176
negative_train_examples = [record.not_active[i] for i in not_active_index]
173177

@@ -192,6 +196,7 @@ def _create_batches(
192196
positive_query_str, _ = _prepare_text(
193197
positive_query, n_incorrect=0, threshold=0.3, highlighted=True
194198
)
199+
195200
# Prepare the negative query
196201
if self.method == "default":
197202
# In the default method, we just sample a random negative example
@@ -205,6 +210,7 @@ def _create_batches(
205210
threshold=0.3,
206211
highlighted=True,
207212
)
213+
208214
elif self.method == "internal":
209215
# In the internal method, we sample a negative example
210216
# that has a different quantile as the positive query
@@ -216,13 +222,13 @@ def _create_batches(
216222
range(len(positive_test_examples)), 1
217223
)[0]
218224
negative_query_temp = positive_test_examples[negative_query_idx]
219-
negative_query_quantile = negative_query.distance
225+
negative_query_quantile = negative_query_temp.quantile
220226

221227
negative_query = NonActivatingExample(
222228
str_tokens=negative_query_temp.str_tokens,
223229
tokens=negative_query_temp.tokens,
224230
activations=negative_query_temp.activations,
225-
distance=negative_query_temp.quantile,
231+
distance=float(negative_query_temp.quantile),
226232
)
227233
# Because it is a converted activating example, it will highlight
228234
# the activating tokens
@@ -234,15 +240,18 @@ def _create_batches(
234240
# that have the same quantile as the positive_query
235241
positive_examples = [
236242
e
237-
for e in positive_train_examples
243+
for e in positive_test_examples
238244
if e.quantile == positive_query.quantile
239245
]
240246
if len(positive_examples) > 10:
241-
positive_examples = self.random.sample(positive_examples, 10)
247+
positive_examples = self.random.sample(positive_examples, 11)
242248
positive_examples_str = [
243249
_prepare_text(e, n_incorrect=0, threshold=0.3, highlighted=True)[0]
244250
for e in positive_examples
245251
]
252+
# if one example is the same as the positive query, remove it
253+
if positive_query_str in positive_examples_str:
254+
positive_examples_str.remove(positive_query_str)
246255

247256
# negative examples
248257
if self.method == "default":
@@ -259,7 +268,7 @@ def _create_batches(
259268
# that has the same quantile as the negative_query
260269
negative_examples = [
261270
e
262-
for e in positive_train_examples
271+
for e in positive_test_examples
263272
if e.quantile == negative_query.distance
264273
]
265274
if len(negative_examples) > 10:
@@ -275,7 +284,7 @@ def _create_batches(
275284
positive_query=positive_query_str,
276285
negative_query=negative_query_str,
277286
quantile_positive_query=positive_query.quantile,
278-
distance_negative_query=negative_query.distance,
287+
distance_negative_query=float(negative_query.distance),
279288
)
280289
batches.append(batch)
281290
return batches

delphi/scorers/scorer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,5 @@ class ScorerResult(NamedTuple):
1414

1515
class Scorer(ABC):
1616
@abstractmethod
17-
def __call__(self, record: LatentRecord) -> ScorerResult:
17+
async def __call__(self, record: LatentRecord) -> ScorerResult:
1818
pass
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import pytest
2+
import torch
3+
4+
from delphi.clients.client import Client, Response
5+
from delphi.latents import ActivatingExample, Latent, LatentRecord, NonActivatingExample
6+
from delphi.scorers import DetectionScorer, FuzzingScorer
7+
from delphi.scorers.scorer import ScorerResult
8+
9+
10+
class ConstantResponseClient(Client):
11+
def __init__(self, text: str):
12+
super().__init__(model="dummy")
13+
self.text = text
14+
15+
async def generate(self, prompt, **kwargs):
16+
return Response(text=self.text)
17+
18+
19+
def _activating_example() -> ActivatingExample:
20+
return ActivatingExample(
21+
tokens=torch.tensor([1, 2, 3], dtype=torch.int64),
22+
activations=torch.tensor([0.0, 1.0, 0.0], dtype=torch.float32),
23+
str_tokens=["a", "b", "c"],
24+
quantile=1,
25+
)
26+
27+
28+
def _non_activating_example() -> NonActivatingExample:
29+
return NonActivatingExample(
30+
tokens=torch.tensor([1, 2, 3], dtype=torch.int64),
31+
activations=torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32),
32+
str_tokens=["x", "y", "z"],
33+
distance=0.0,
34+
)
35+
36+
37+
def _record() -> LatentRecord:
38+
return LatentRecord(
39+
latent=Latent(module_name="layers.0", latent_index=0),
40+
test=[_activating_example()],
41+
not_active=[_non_activating_example()],
42+
explanation="test explanation",
43+
)
44+
45+
46+
@pytest.mark.asyncio
47+
async def test_detection_scorer_async_contract_returns_scorer_result():
48+
scorer = DetectionScorer(
49+
client=ConstantResponseClient("[1]"),
50+
n_examples_shown=1,
51+
verbose=False,
52+
)
53+
54+
result = await scorer(_record())
55+
56+
assert isinstance(result, ScorerResult)
57+
assert result.record.explanation == "test explanation"
58+
assert len(result.score) > 0
59+
60+
61+
def test_detection_parse_casts_binary_ints_to_bool():
62+
scorer = DetectionScorer(
63+
client=ConstantResponseClient("[0, 1]"),
64+
n_examples_shown=2,
65+
verbose=False,
66+
)
67+
68+
predictions, probabilities = scorer._parse("[0, 1]")
69+
70+
assert predictions == [False, True]
71+
assert probabilities == [None, None]
72+
73+
74+
def test_fuzzing_call_sync_contract_and_log_prob_flag():
75+
scorer = FuzzingScorer(
76+
client=ConstantResponseClient("[1]"),
77+
n_examples_shown=1,
78+
verbose=False,
79+
log_prob=True,
80+
)
81+
82+
result = scorer.call_sync(_record())
83+
84+
assert scorer.log_prob is True
85+
assert isinstance(result, ScorerResult)
86+
assert len(result.score) > 0

0 commit comments

Comments
 (0)