Skip to content

Commit 849c1dc

Browse files
committed
Fixes
1 parent 9ee3b8d commit 849c1dc

3 files changed

Lines changed: 40 additions & 15 deletions

File tree

py_css/models/base.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -192,15 +192,21 @@ def search(self, query: Query, context: Context) -> Tuple[Context, pd.DataFrame]
192192
if t5_rewriter_module.COPY_REWRITTEN_QUERY_COLUMN in result.columns:
193193
temp_result = result[result["qid"] == query.query_id]
194194
if not temp_result.empty:
195-
query.query = temp_result.at[temp_result.index[0], t5_rewriter_module.COPY_REWRITTEN_QUERY_COLUMN]
195+
query.query = temp_result[
196+
t5_rewriter_module.COPY_REWRITTEN_QUERY_COLUMN
197+
].iloc[0]
196198
else:
197-
warnings.warn(f"Query {query.query_id} not found in result. This should not happen. All query-ids: {result['qid'].unique()}")
199+
warnings.warn(
200+
f"Query {query.query_id} not found in result. This should not happen. All query-ids: {result['qid'].unique()}"
201+
)
198202
else:
199203
temp_result = result[result["qid"] == query.query_id]
200204
if not temp_result.empty:
201-
query.query = temp_result.at[temp_result.index[0], "query"]
205+
query.query = temp_result["query"].iloc[0]
202206
else:
203-
warnings.warn(f"Query {query.query_id} not found in result. This should not happen. All query-ids: {result['qid'].unique()}")
207+
warnings.warn(
208+
f"Query {query.query_id} not found in result. This should not happen. All query-ids: {result['qid'].unique()}"
209+
)
204210

205211
doc_list: List[Document] = []
206212
for _, entry in result.iterrows():
@@ -238,16 +244,22 @@ def batch_search(
238244
for query, _ in inputs:
239245
temp_result = result[result["qid"] == query.query_id]
240246
if not temp_result.empty:
241-
query.query = temp_result.at[temp_result.index[0], t5_rewriter_module.COPY_REWRITTEN_QUERY_COLUMN]
247+
query.query = temp_result[
248+
t5_rewriter_module.COPY_REWRITTEN_QUERY_COLUMN
249+
].iloc[0]
242250
else:
243-
warnings.warn(f"Query {query.query_id} not found in result. This should not happen. All query-ids: {result['qid'].unique()}")
251+
warnings.warn(
252+
f"Query {query.query_id} not found in result. This should not happen. All query-ids: {result['qid'].unique()}"
253+
)
244254
else:
245255
for query, _ in inputs:
246256
temp_result = result[result["qid"] == query.query_id]
247257
if not temp_result.empty:
248-
query.query = temp_result.at[temp_result.index[0], "query"]
258+
query.query = temp_result["query"].iloc[0]
249259
else:
250-
warnings.warn(f"Query {query.query_id} not found in result. This should not happen. All query-ids: {result['qid'].unique()}")
260+
warnings.warn(
261+
f"Query {query.query_id} not found in result. This should not happen. All query-ids: {result['qid'].unique()}"
262+
)
251263

252264
contexts: List[Context] = []
253265
for query, context in inputs:

py_css/models/baseline.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77
import pyterrier as pt
88
from pyterrier_t5 import MonoT5ReRanker, DuoT5ReRanker
99

10+
import torch
11+
12+
BATCH_SIZE = 128 if torch.cuda.is_available() else 8
13+
1014

1115
class Baseline(base_module.Pipeline):
1216
"""
@@ -42,8 +46,8 @@ def __init__(
4246
t5_qr = t5_rewriter.T5Rewriter()
4347
bm25 = pt.BatchRetrieve(index, wmodel="BM25", metadata=["docno", "text"])
4448
self.top_docs = (t5_qr >> bm25, bm25_docs)
45-
self.mono_t5 = (MonoT5ReRanker(), mono_t5_docs)
46-
self.duo_t5 = (DuoT5ReRanker(), duo_t5_docs)
49+
self.mono_t5 = (MonoT5ReRanker(batch_size=BATCH_SIZE), mono_t5_docs)
50+
self.duo_t5 = (DuoT5ReRanker(batch_size=BATCH_SIZE), duo_t5_docs)
4751

4852
def transform_input(
4953
self, query: base_module.Query, context: base_module.Context

py_css/models/baseline_prf.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77
import pyterrier as pt
88
from pyterrier_t5 import MonoT5ReRanker, DuoT5ReRanker
99

10+
import torch
11+
12+
BATCH_SIZE = 128 if torch.cuda.is_available() else 8
13+
1014

1115
class BaselinePRF(base_module.Pipeline):
1216
"""
@@ -50,8 +54,8 @@ def __init__(
5054
bm25 = pt.BatchRetrieve(index, wmodel="BM25", metadata=["docno", "text"])
5155
rm3 = pt.rewrite.RM3(index, fb_docs=rm3_fb_docs, fb_terms=rm3_fb_terms)
5256
self.top_docs = ((bm25 % rm3_fb_docs) >> rm3 >> bm25, bm25_docs)
53-
self.mono_t5 = (MonoT5ReRanker(), mono_t5_docs)
54-
self.duo_t5 = (DuoT5ReRanker(), duo_t5_docs)
57+
self.mono_t5 = (MonoT5ReRanker(batch_size=BATCH_SIZE), mono_t5_docs)
58+
self.duo_t5 = (DuoT5ReRanker(batch_size=BATCH_SIZE), duo_t5_docs)
5559

5660
def transform_input(
5761
self, query: base_module.Query, context: base_module.Context
@@ -78,9 +82,14 @@ def transform(self, query_df: pd.DataFrame) -> pd.DataFrame:
7882
)
7983

8084
# Now add in the rewritten queries to the top docs
81-
top_docs_df = top_docs_df.merge(rewritten_queries_df, on="qid", how="left")
82-
# And overwrite the "query" column again
83-
top_docs_df["query"] = top_docs_df[t5_rewriter.COPY_REWRITTEN_QUERY_COLUMN]
85+
top_docs_df = pt.model.push_queries(top_docs_df, inplace=True)
86+
top_docs_df = pd.merge(
87+
top_docs_df,
88+
rewritten_queries_df[["qid", "rewritten_query"]],
89+
on="qid",
90+
how="left",
91+
)
92+
top_docs_df["query"] = top_docs_df["rewritten_query"]
8493

8594
mono_t5_df = self.mono_t5[0].transform(
8695
top_docs_df.groupby("qid").head(self.mono_t5[1])

0 commit comments

Comments
 (0)