Skip to content

Commit 02a0cac

Browse files
committed
New Retrieval Method: Extending Baseline with RM3
1 parent 139e4c2 commit 02a0cac

6 files changed

Lines changed: 260 additions & 52 deletions

File tree

py_css/main.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def setup() -> None:
1616
Set up the necessary configurations.
1717
"""
1818
if not pt.started():
19-
pt.init()
19+
pt.init(boot_packages=["com.github.terrierteam:terrier-prf:-SNAPSHOT"])
2020

2121

2222
def main():
@@ -44,7 +44,7 @@ def main():
4444
global_args.add_argument(
4545
"--method",
4646
type=str,
47-
choices=["baseline"],
47+
choices=["baseline", "baseline-prf"],
4848
default="baseline",
4949
help="Set the retrieval method",
5050
)
@@ -56,6 +56,13 @@ def main():
5656
help="Parameters for baseline method as tuple (bm25_docs, mono_t5_docs, duo_t5_docs)",
5757
)
5858

59+
global_args.add_argument(
60+
"--baseline-prf-params",
61+
type=lambda s: tuple(map(int, s.split(","))),
62+
default=(1000, 17, 26, 100, 10),
63+
help="Parameters for baseline method as tuple (bm25_docs, rm3_fb_docs, rm3_fb_terms, mono_t5_docs, duo_t5_docs)",
64+
)
65+
5966
# Command argument
6067
parser.add_argument(
6168
"command",
@@ -109,6 +116,10 @@ def main():
109116
model_parameters = model_parameters_module.BaselineParameters.from_tuple(
110117
args.baseline_params
111118
)
119+
case "baseline-prf":
120+
model_parameters = model_parameters_module.BaselinePRFParameters.from_tuple(
121+
args.baseline_prf_params
122+
)
112123
case _:
113124
raise NotImplementedError
114125

py_css/models/T5Rewriter.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import string
22
import logging
3-
from typing import List, Any, Callable
3+
from typing import List, Any, Callable, Optional
44

55
import pyterrier as pt
66
import pandas as pd
@@ -12,6 +12,8 @@
1212
NUM_BEAMS: int = 10
1313
EARLY_STOPPING: bool = True
1414

15+
COPY_REWRITTEN_QUERY_COLUMN: str = "rewritten_query"
16+
1517

1618
class T5Rewriter(pt.Transformer):
1719
"""
@@ -31,7 +33,10 @@ class T5Rewriter(pt.Transformer):
3133
tokenizer: T5Tokenizer
3234
model: T5ForConditionalGeneration
3335

34-
def __init__(self, index):
36+
def __init__(self):
37+
"""
38+
Constructs all the necessary attributes for the T5 Query Rewriter.
39+
"""
3540
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
3641
self.tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME)
3742
self.model = (
@@ -99,7 +104,7 @@ def __remove_punctuation(self, s):
99104

100105
def transform(self, topics_or_res: pd.DataFrame) -> pd.DataFrame:
101106
# save qid and query columns as dict (qid -> query) query is same for same qid, so sufficient to select first
102-
qid_query_dict = dict(zip(topics_or_res["qid"], topics_or_res["query"]))
107+
rewritten_queries_df = topics_or_res[["qid", "query"]].drop_duplicates()
103108

104109
pipeline: List[Callable] = [
105110
self.__split_query_tokenize_join,
@@ -109,18 +114,21 @@ def transform(self, topics_or_res: pd.DataFrame) -> pd.DataFrame:
109114
self.__remove_punctuation,
110115
]
111116

112-
rewritten_queries = {
113-
qid: _call_list_of_functions(q, pipeline)
114-
for qid, q in qid_query_dict.items()
115-
}
117+
rewritten_queries_df["query"] = rewritten_queries_df["query"].apply(
118+
lambda q: _call_list_of_functions(q, pipeline)
119+
)
120+
116121
# overwrite the query column with the decoded output token ids
117-
topics_or_res["query"] = topics_or_res["qid"].map(
118-
lambda qid: rewritten_queries[qid]
122+
rewritten_queries_df.merge(
123+
pt.model.push_queries(topics_or_res, "query"), on="qid"
119124
)
125+
rewritten_queries_df[COPY_REWRITTEN_QUERY_COLUMN] = rewritten_queries_df[
126+
"query"
127+
]
120128

121-
logging.info(f"Rewritten queries: {topics_or_res['query'].unique()}")
129+
logging.info(f"Rewritten queries: {rewritten_queries_df['query'].unique()}")
122130

123-
return topics_or_res
131+
return rewritten_queries_df
124132

125133

126134
def _call_list_of_functions(x: Any, pipeline: List[Callable]) -> Any:

py_css/models/base.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
import logging
44
from typing import Optional, List, Tuple, TypeAlias
55

6+
import models.T5Rewriter as t5_rewriter_module
7+
68
import pandas as pd
7-
import pyterrier as pt
89

910

1011
@dataclass
@@ -187,7 +188,12 @@ def search(self, query: Query, context: Context) -> Tuple[Context, pd.DataFrame]
187188
query_df = pd.DataFrame([{"qid": query.query_id, "query": query_str}])
188189
result = self.transform(query_df)
189190

190-
query.query = result["query"].iloc[0]
191+
if t5_rewriter_module.COPY_REWRITTEN_QUERY_COLUMN in result.columns:
192+
query.query = result[t5_rewriter_module.COPY_REWRITTEN_QUERY_COLUMN].iloc[0]
193+
print("QUERY COMES FROM REWRITER")
194+
else:
195+
query.query = result["query"].iloc[0]
196+
print("QUERY COMES FROM ORIGINAL?!?!?!?!")
191197

192198
doc_list: List[Document] = []
193199
for _, entry in result.iterrows():
@@ -221,8 +227,14 @@ def batch_search(
221227
)
222228
result = self.transform(query_df)
223229

224-
for query, _ in inputs:
225-
query.query = result[result["qid"] == query.query_id]["query"].iloc[0]
230+
if t5_rewriter_module.COPY_REWRITTEN_QUERY_COLUMN in result.columns:
231+
for query, _ in inputs:
232+
query.query = result[result["qid"] == query.query_id][
233+
t5_rewriter_module.COPY_REWRITTEN_QUERY_COLUMN
234+
].iloc[0]
235+
else:
236+
for query, _ in inputs:
237+
query.query = result[result["qid"] == query.query_id]["query"].iloc[0]
226238

227239
contexts: List[Context] = []
228240
for query, context in inputs:

py_css/models/baseline.py

Lines changed: 33 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -11,22 +11,19 @@
1111
class Baseline(base_module.Pipeline):
1212
"""
1313
A class to represent the baseline retrieval method.
14-
15-
Attributes
16-
----------
17-
stages : List[Tuple[pt.Transformer, int]]
18-
The stages of the pipeline.
1914
"""
2015

21-
stages: List[Tuple[pt.Transformer, int]]
16+
top_docs: Tuple[pt.Transformer, int]
17+
mono_t5: Tuple[MonoT5ReRanker, int]
18+
duo_t5: Tuple[DuoT5ReRanker, int]
2219

2320
def __init__(
2421
self,
2522
index,
2623
*,
27-
bm25_docs,
28-
mono_t5_docs,
29-
duo_t5_docs,
24+
bm25_docs: int,
25+
mono_t5_docs: int,
26+
duo_t5_docs: int,
3027
):
3128
"""
3229
Constructs all the necessary attributes for the baseline retrieval method.
@@ -42,18 +39,11 @@ def __init__(
4239
duo_t5_docs : int
4340
The number of documents to retrieve with DuoT5.
4441
"""
45-
t5_qr = t5_rewriter.T5Rewriter(index)
42+
t5_qr = t5_rewriter.T5Rewriter()
4643
bm25 = pt.BatchRetrieve(index, wmodel="BM25", metadata=["docno", "text"])
47-
mono_t5 = MonoT5ReRanker()
48-
duo_t5 = DuoT5ReRanker()
49-
50-
top_docs = t5_qr >> bm25
51-
52-
self.stages = [
53-
(top_docs, bm25_docs),
54-
(mono_t5, mono_t5_docs),
55-
(duo_t5, duo_t5_docs),
56-
]
44+
self.top_docs = (t5_qr >> bm25, bm25_docs)
45+
self.mono_t5 = (MonoT5ReRanker(), mono_t5_docs)
46+
self.duo_t5 = (DuoT5ReRanker(), duo_t5_docs)
5747

5848
def transform_input(
5949
self, query: base_module.Query, context: base_module.Context
@@ -70,21 +60,29 @@ def transform_input(
7060
return new_query
7161

7262
def transform(self, query_df: pd.DataFrame) -> pd.DataFrame:
73-
# We basically do the pyterrier Concatenate transformer operator here, but more efficiently, since we dont have to do the entire pipeline for each component of the operator.
74-
results = []
75-
current_df = query_df
63+
top_docs_df = self.top_docs[0].transform(query_df)
64+
top_docs_df = (
65+
top_docs_df.sort_values(["qid", "score"], ascending=False)
66+
.groupby("qid")
67+
.head(self.top_docs[1])
68+
)
7669

77-
is_first: bool = True
70+
mono_t5_df = self.mono_t5[0].transform(
71+
top_docs_df.groupby("qid").head(self.mono_t5[1])
72+
)
73+
mono_t5_df = (
74+
mono_t5_df.sort_values(["qid", "score"], ascending=False)
75+
.groupby("qid")
76+
.head(self.mono_t5[1])
77+
)
7878

79-
for stage, num_docs in self.stages:
80-
df = current_df
81-
if not is_first:
82-
df = df.groupby("qid").head(num_docs)
83-
else:
84-
is_first = False
85-
transformed_df = stage.transform(df)
86-
transformed_df = transformed_df.groupby("qid").head(num_docs)
87-
results.append(transformed_df)
88-
current_df = transformed_df
79+
duo_t5_df = self.duo_t5[0].transform(
80+
mono_t5_df.groupby("qid").head(self.duo_t5[1])
81+
)
82+
duo_t5_df = (
83+
duo_t5_df.sort_values(["qid", "score"], ascending=False)
84+
.groupby("qid")
85+
.head(self.duo_t5[1])
86+
)
8987

90-
return self.combine_result_stages(results)
88+
return self.combine_result_stages([top_docs_df, mono_t5_df, duo_t5_df])

py_css/models/baseline_prf.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import models.base as base_module
2+
import models.T5Rewriter as t5_rewriter
3+
4+
from typing import List, Tuple
5+
6+
import pandas as pd
7+
import pyterrier as pt
8+
from pyterrier_t5 import MonoT5ReRanker, DuoT5ReRanker
9+
10+
11+
class BaselinePRF(base_module.Pipeline):
12+
"""
13+
A class to represent the method that extends the baseline retrieval method with pseudo relevance feedback.
14+
"""
15+
16+
t5_qr: t5_rewriter.T5Rewriter
17+
top_docs: Tuple[pt.Transformer, int]
18+
mono_t5: Tuple[MonoT5ReRanker, int]
19+
duo_t5: Tuple[DuoT5ReRanker, int]
20+
21+
def __init__(
22+
self,
23+
index,
24+
*,
25+
bm25_docs: int,
26+
rm3_fb_docs: int,
27+
rm3_fb_terms: int,
28+
mono_t5_docs: int,
29+
duo_t5_docs: int,
30+
):
31+
"""
32+
Constructs all the necessary attributes for the baseline retrieval method.
33+
34+
Parameters
35+
----------
36+
index : pt.Index
37+
The PyTerrier index.
38+
bm25_docs : int
39+
The number of documents to retrieve with BM25.
40+
rm3_fb_docs : int
41+
The number of documents to use for RM3.
42+
rm3_fb_terms : int
43+
The number of terms to use for RM3.
44+
mono_t5_docs : int
45+
The number of documents to retrieve with MonoT5.
46+
duo_t5_docs : int
47+
The number of documents to retrieve with DuoT5.
48+
"""
49+
self.t5_qr = t5_rewriter.T5Rewriter()
50+
bm25 = pt.BatchRetrieve(index, wmodel="BM25", metadata=["docno", "text"])
51+
rm3 = pt.rewrite.RM3(index, fb_docs=rm3_fb_docs, fb_terms=rm3_fb_terms)
52+
self.top_docs = ((bm25 % rm3_fb_docs) >> rm3 >> bm25, bm25_docs)
53+
self.mono_t5 = (MonoT5ReRanker(), mono_t5_docs)
54+
self.duo_t5 = (DuoT5ReRanker(), duo_t5_docs)
55+
56+
def transform_input(
57+
self, query: base_module.Query, context: base_module.Context
58+
) -> str:
59+
history = []
60+
for q, _ in context:
61+
history.append(q.query)
62+
if len(context) > 0:
63+
last_docs = context[-1][1]
64+
if last_docs is not None:
65+
history.append(last_docs[0].content)
66+
history.append(query.query)
67+
new_query = " <sep> ".join(history)
68+
return new_query
69+
70+
def transform(self, query_df: pd.DataFrame) -> pd.DataFrame:
71+
rewritten_queries_df = self.t5_qr.transform(query_df)
72+
73+
top_docs_df = self.top_docs[0].transform(rewritten_queries_df.copy())
74+
top_docs_df = (
75+
top_docs_df.sort_values(["qid", "score"], ascending=False)
76+
.groupby("qid")
77+
.head(self.top_docs[1])
78+
)
79+
80+
# Now add in the rewritten queries to the top docs
81+
top_docs_df = top_docs_df.merge(rewritten_queries_df, on="qid", how="left")
82+
# And overwrite the "query" column again
83+
top_docs_df["query"] = top_docs_df[t5_rewriter.COPY_REWRITTEN_QUERY_COLUMN]
84+
85+
mono_t5_df = self.mono_t5[0].transform(
86+
top_docs_df.groupby("qid").head(self.mono_t5[1])
87+
)
88+
mono_t5_df = (
89+
mono_t5_df.sort_values(["qid", "score"], ascending=False)
90+
.groupby("qid")
91+
.head(self.mono_t5[1])
92+
)
93+
94+
duo_t5_df = self.duo_t5[0].transform(
95+
mono_t5_df.groupby("qid").head(self.duo_t5[1])
96+
)
97+
duo_t5_df = (
98+
duo_t5_df.sort_values(["qid", "score"], ascending=False)
99+
.groupby("qid")
100+
.head(self.duo_t5[1])
101+
)
102+
103+
return self.combine_result_stages([top_docs_df, mono_t5_df, duo_t5_df])

0 commit comments

Comments
 (0)