@@ -118,10 +118,12 @@ def compute_batch_deltas(self, batch: Batch) -> tuple[float, float]:
118118 # Split the embeddings back into their components
119119 n_neg = len (batch .negative_examples )
120120 n_pos = len (batch .positive_examples )
121- negative_examples_embeddings = all_embeddings [:n_neg ]
122- positive_examples_embeddings = all_embeddings [n_neg : n_neg + n_pos ]
123- positive_query_embedding = all_embeddings [- 2 ].unsqueeze (0 )
124- negative_query_embedding = all_embeddings [- 1 ].unsqueeze (0 )
121+ negative_examples_embeddings = torch .tensor (all_embeddings [:n_neg ])
122+ positive_examples_embeddings = torch .tensor (
123+ all_embeddings [n_neg : n_neg + n_pos ]
124+ )
125+ positive_query_embedding = torch .tensor (all_embeddings [- 2 ]).unsqueeze (0 )
126+ negative_query_embedding = torch .tensor (all_embeddings [- 1 ]).unsqueeze (0 )
125127
126128 # Compute the similarity between the query and the examples
127129 negative_similarities = self .model .similarity (
@@ -165,9 +167,11 @@ def _create_batches(
165167 # which are going to be used as "explanations"
166168 positive_train_examples = record .train
167169
170+ number_samples = min (len (positive_train_examples ), len (record .not_active ))
171+
168172 # Sample from the not_active examples
169173 not_active_index = self .random .sample (
170- range (len (record .not_active )), len ( positive_train_examples )
174+ range (len (record .not_active )), number_samples
171175 )
172176 negative_train_examples = [record .not_active [i ] for i in not_active_index ]
173177
@@ -192,6 +196,7 @@ def _create_batches(
192196 positive_query_str , _ = _prepare_text (
193197 positive_query , n_incorrect = 0 , threshold = 0.3 , highlighted = True
194198 )
199+
195200 # Prepare the negative query
196201 if self .method == "default" :
197202 # In the default method, we just sample a random negative example
@@ -205,6 +210,7 @@ def _create_batches(
205210 threshold = 0.3 ,
206211 highlighted = True ,
207212 )
213+
208214 elif self .method == "internal" :
209215 # In the internal method, we sample a negative example
210216 # that has a different quantile as the positive query
@@ -216,13 +222,13 @@ def _create_batches(
216222 range (len (positive_test_examples )), 1
217223 )[0 ]
218224 negative_query_temp = positive_test_examples [negative_query_idx ]
219- negative_query_quantile = negative_query . distance
225+ negative_query_quantile = negative_query_temp . quantile
220226
221227 negative_query = NonActivatingExample (
222228 str_tokens = negative_query_temp .str_tokens ,
223229 tokens = negative_query_temp .tokens ,
224230 activations = negative_query_temp .activations ,
225- distance = negative_query_temp .quantile ,
231+ distance = float ( negative_query_temp .quantile ) ,
226232 )
227233 # Because it is a converted activating example, it will highlight
228234 # the activating tokens
@@ -234,15 +240,18 @@ def _create_batches(
234240 # that have the same quantile as the positive_query
235241 positive_examples = [
236242 e
237- for e in positive_train_examples
243+ for e in positive_test_examples
238244 if e .quantile == positive_query .quantile
239245 ]
240246 if len (positive_examples ) > 10 :
241- positive_examples = self .random .sample (positive_examples , 10 )
247+ positive_examples = self .random .sample (positive_examples , 11 )
242248 positive_examples_str = [
243249 _prepare_text (e , n_incorrect = 0 , threshold = 0.3 , highlighted = True )[0 ]
244250 for e in positive_examples
245251 ]
252+ # if one example is the same as the positive query, remove it
253+ if positive_query_str in positive_examples_str :
254+ positive_examples_str .remove (positive_query_str )
246255
247256 # negative examples
248257 if self .method == "default" :
@@ -259,7 +268,7 @@ def _create_batches(
259268 # that has the same quantile as the negative_query
260269 negative_examples = [
261270 e
262- for e in positive_train_examples
271+ for e in positive_test_examples
263272 if e .quantile == negative_query .distance
264273 ]
265274 if len (negative_examples ) > 10 :
@@ -275,7 +284,7 @@ def _create_batches(
275284 positive_query = positive_query_str ,
276285 negative_query = negative_query_str ,
277286 quantile_positive_query = positive_query .quantile ,
278- distance_negative_query = negative_query .distance ,
287+ distance_negative_query = float ( negative_query .distance ) ,
279288 )
280289 batches .append (batch )
281290 return batches
0 commit comments