|
2 | 2 | from dataclasses import dataclass |
3 | 3 | import logging |
4 | 4 | from typing import Optional, List, Tuple, TypeAlias |
| 5 | +import warnings |
5 | 6 |
|
6 | 7 | import models.T5Rewriter as t5_rewriter_module |
7 | 8 |
|
@@ -189,11 +190,17 @@ def search(self, query: Query, context: Context) -> Tuple[Context, pd.DataFrame] |
189 | 190 | result = self.transform(query_df) |
190 | 191 |
|
191 | 192 | if t5_rewriter_module.COPY_REWRITTEN_QUERY_COLUMN in result.columns: |
192 | | - query.query = result[t5_rewriter_module.COPY_REWRITTEN_QUERY_COLUMN].iloc[0] |
193 | | - print("QUERY COMES FROM REWRITER") |
| 193 | + temp_result = result[result["qid"] == query.query_id] |
| 194 | + if not temp_result.empty: |
| 195 | + query.query = temp_result.at[temp_result.index[0], t5_rewriter_module.COPY_REWRITTEN_QUERY_COLUMN] |
| 196 | + else: |
| 197 | + warnings.warn(f"Query {query.query_id} not found in result. This should not happen. All query-ids: {result['qid'].unique()}") |
194 | 198 | else: |
195 | | - query.query = result["query"].iloc[0] |
196 | | - print("QUERY COMES FROM ORIGINAL?!?!?!?!") |
| 199 | + temp_result = result[result["qid"] == query.query_id] |
| 200 | + if not temp_result.empty: |
| 201 | + query.query = temp_result.at[temp_result.index[0], "query"] |
| 202 | + else: |
| 203 | + warnings.warn(f"Query {query.query_id} not found in result. This should not happen. All query-ids: {result['qid'].unique()}") |
197 | 204 |
|
198 | 205 | doc_list: List[Document] = [] |
199 | 206 | for _, entry in result.iterrows(): |
@@ -229,12 +236,18 @@ def batch_search( |
229 | 236 |
|
230 | 237 | if t5_rewriter_module.COPY_REWRITTEN_QUERY_COLUMN in result.columns: |
231 | 238 | for query, _ in inputs: |
232 | | - query.query = result[result["qid"] == query.query_id][ |
233 | | - t5_rewriter_module.COPY_REWRITTEN_QUERY_COLUMN |
234 | | - ].iloc[0] |
| 239 | + temp_result = result[result["qid"] == query.query_id] |
| 240 | + if not temp_result.empty: |
| 241 | + query.query = temp_result.at[temp_result.index[0], t5_rewriter_module.COPY_REWRITTEN_QUERY_COLUMN] |
| 242 | + else: |
| 243 | + warnings.warn(f"Query {query.query_id} not found in result. This should not happen. All query-ids: {result['qid'].unique()}") |
235 | 244 | else: |
236 | 245 | for query, _ in inputs: |
237 | | - query.query = result[result["qid"] == query.query_id]["query"].iloc[0] |
| 246 | + temp_result = result[result["qid"] == query.query_id] |
| 247 | + if not temp_result.empty: |
| 248 | + query.query = temp_result.at[temp_result.index[0], "query"] |
| 249 | + else: |
| 250 | + warnings.warn(f"Query {query.query_id} not found in result. This should not happen. All query-ids: {result['qid'].unique()}") |
238 | 251 |
|
239 | 252 | contexts: List[Context] = [] |
240 | 253 | for query, context in inputs: |
|
0 commit comments