Skip to content

Commit 9ee3b8d

Browse files
committed
Minor Fix
1 parent 02a0cac commit 9ee3b8d

1 file changed

Lines changed: 21 additions & 8 deletions

File tree

py_css/models/base.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from dataclasses import dataclass
33
import logging
44
from typing import Optional, List, Tuple, TypeAlias
5+
import warnings
56

67
import models.T5Rewriter as t5_rewriter_module
78

@@ -189,11 +190,17 @@ def search(self, query: Query, context: Context) -> Tuple[Context, pd.DataFrame]
189190
result = self.transform(query_df)
190191

191192
if t5_rewriter_module.COPY_REWRITTEN_QUERY_COLUMN in result.columns:
192-
query.query = result[t5_rewriter_module.COPY_REWRITTEN_QUERY_COLUMN].iloc[0]
193-
print("QUERY COMES FROM REWRITER")
193+
temp_result = result[result["qid"] == query.query_id]
194+
if not temp_result.empty:
195+
query.query = temp_result.at[temp_result.index[0], t5_rewriter_module.COPY_REWRITTEN_QUERY_COLUMN]
196+
else:
197+
warnings.warn(f"Query {query.query_id} not found in result. This should not happen. All query-ids: {result['qid'].unique()}")
194198
else:
195-
query.query = result["query"].iloc[0]
196-
print("QUERY COMES FROM ORIGINAL?!?!?!?!")
199+
temp_result = result[result["qid"] == query.query_id]
200+
if not temp_result.empty:
201+
query.query = temp_result.at[temp_result.index[0], "query"]
202+
else:
203+
warnings.warn(f"Query {query.query_id} not found in result. This should not happen. All query-ids: {result['qid'].unique()}")
197204

198205
doc_list: List[Document] = []
199206
for _, entry in result.iterrows():
@@ -229,12 +236,18 @@ def batch_search(
229236

230237
if t5_rewriter_module.COPY_REWRITTEN_QUERY_COLUMN in result.columns:
231238
for query, _ in inputs:
232-
query.query = result[result["qid"] == query.query_id][
233-
t5_rewriter_module.COPY_REWRITTEN_QUERY_COLUMN
234-
].iloc[0]
239+
temp_result = result[result["qid"] == query.query_id]
240+
if not temp_result.empty:
241+
query.query = temp_result.at[temp_result.index[0], t5_rewriter_module.COPY_REWRITTEN_QUERY_COLUMN]
242+
else:
243+
warnings.warn(f"Query {query.query_id} not found in result. This should not happen. All query-ids: {result['qid'].unique()}")
235244
else:
236245
for query, _ in inputs:
237-
query.query = result[result["qid"] == query.query_id]["query"].iloc[0]
246+
temp_result = result[result["qid"] == query.query_id]
247+
if not temp_result.empty:
248+
query.query = temp_result.at[temp_result.index[0], "query"]
249+
else:
250+
warnings.warn(f"Query {query.query_id} not found in result. This should not happen. All query-ids: {result['qid'].unique()}")
238251

239252
contexts: List[Context] = []
240253
for query, context in inputs:

0 commit comments

Comments
 (0)