Skip to content

Commit f556964

Browse files
committed
Fixed the issue bug that occurred when a query does not have N-relevant documents (or even none)
1 parent 2bf879a commit f556964

4 files changed

Lines changed: 192 additions & 60 deletions

File tree

py_css/models/T5Rewriter.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
EARLY_STOPPING: bool = True
1414

1515
COPY_REWRITTEN_QUERY_COLUMN: str = "rewritten_query"
16+
SEPERATOR_TOKEN: str = " ||| "
1617

1718

1819
class T5Rewriter(pt.Transformer):
@@ -46,26 +47,26 @@ def __init__(self):
4647
)
4748
super().__init__()
4849

49-
# the query has multiply " <sep> " in it. Create a list of the split with a maximum of 3 elements (last element is last, second last is middle, and the first n are joined)
50+
# the query has multiply SEPERATOR_TOKEN in it. Create a list of the split with a maximum of 3 elements (last element is last, second last is middle, and the first n are joined)
5051
def __split_query_tokenize_join(self, q):
5152
"""
5253
Split the query, tokenize the parts, and join them back together.
5354
"""
54-
l = q.split(" <sep> ")
55+
l = q.split(SEPERATOR_TOKEN)
5556
if len(l) < 3:
5657
tokens = []
5758
for ll in l:
5859
tokens.extend(self.tokenizer.tokenize(ll))
59-
tokens.append(" <sep> ")
60+
tokens.append(SEPERATOR_TOKEN)
6061
if len(tokens) > 0:
6162
tokens.pop()
6263
return tokens
6364
else:
6465
tokens = []
65-
tokens.extend(self.tokenizer.tokenize(" <sep> ".join(l[:-2])))
66-
tokens.append(" <sep> ")
66+
tokens.extend(self.tokenizer.tokenize(SEPERATOR_TOKEN.join(l[:-2])))
67+
tokens.append(SEPERATOR_TOKEN)
6768
tokens.extend(self.tokenizer.tokenize(l[-2]))
68-
tokens.append(" <sep> ")
69+
tokens.append(SEPERATOR_TOKEN)
6970
tokens.extend(self.tokenizer.tokenize(l[-1]))
7071
return tokens
7172

py_css/models/base.py

Lines changed: 148 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from abc import ABC, abstractmethod
22
from dataclasses import dataclass
33
import logging
4-
from typing import Optional, List, Tuple, TypeAlias
4+
from typing import Optional, List, Tuple, TypeAlias, Set, Any, Dict, Generator
55
import warnings
66

77
import models.T5Rewriter as t5_rewriter_module
@@ -57,6 +57,9 @@ def __str__(self) -> str:
5757

5858
Context: TypeAlias = List[Tuple[Query, Optional[List[Document]]]]
5959

60+
# If the retrieval model did not find any suitable or all N-required documents, this document shall be used as a placeholder
61+
EMPTY_PLACEHOLDER_DOC: Document = Document("-1", "")
62+
6063

6164
class Pipeline(ABC):
6265
"""
@@ -99,6 +102,148 @@ def transform(self, query_df: pd.DataFrame) -> pd.DataFrame:
99102
"""
100103
...
101104

105+
def pad_empty_documents(
106+
self, df: pd.DataFrame, qids: Set[str], N: int, queries_df: pd.DataFrame
107+
) -> pd.DataFrame:
108+
"""
109+
Pad the dataframe with empty documents.
110+
111+
Parameters
112+
----------
113+
df : pd.DataFrame
114+
The dataframe to be padded.
115+
qids : Set[str]
116+
The query ids.
117+
N : int
118+
The number of documents each qid should have.
119+
queries_df : pd.DataFrame
120+
The queries dataframe.
121+
122+
Returns
123+
-------
124+
pd.DataFrame
125+
The padded dataframe.
126+
"""
127+
df_has_rewritten_queries: bool = (
128+
t5_rewriter_module.COPY_REWRITTEN_QUERY_COLUMN in df.columns
129+
)
130+
rows_to_add: List[Dict[str, Any]] = []
131+
for qid in qids:
132+
# Check if qid is in top_doc_df (if not --> no documents at all were found)
133+
if qid not in df["qid"].unique():
134+
for i in range(1, N + 1):
135+
row = {
136+
"qid": qid,
137+
"docid": EMPTY_PLACEHOLDER_DOC.docno,
138+
"docno": EMPTY_PLACEHOLDER_DOC.docno,
139+
"text": EMPTY_PLACEHOLDER_DOC.content,
140+
"score": -i,
141+
"rank": i,
142+
"query_0": queries_df[queries_df["qid"] == qid]["query"].iloc[
143+
0
144+
],
145+
"query": queries_df[queries_df["qid"] == qid]["query"].iloc[0],
146+
}
147+
if df_has_rewritten_queries:
148+
row[t5_rewriter_module.COPY_REWRITTEN_QUERY_COLUMN] = row[
149+
"query"
150+
]
151+
rows_to_add.append(row)
152+
else:
153+
# if there are less than N occurrences of a qid, add base_module.EMPTY_PLACEHOLDER_DOC to fill up (need to adjust score)
154+
if df.groupby("qid").size()[qid] < N:
155+
lowest_score = df[df["qid"] == qid]["score"].min()
156+
rank = int(round(df[df["qid"] == qid]["rank"].max()))
157+
for i in range(
158+
rank + 1,
159+
N + 1,
160+
):
161+
row = {
162+
"qid": qid,
163+
"docid": EMPTY_PLACEHOLDER_DOC.docno,
164+
"docno": EMPTY_PLACEHOLDER_DOC.docno,
165+
"text": EMPTY_PLACEHOLDER_DOC.content,
166+
"score": lowest_score - i,
167+
"rank": i,
168+
"query_0": queries_df[queries_df["qid"] == qid][
169+
"query"
170+
].iloc[0],
171+
"query": queries_df[queries_df["qid"] == qid]["query"].iloc[
172+
0
173+
],
174+
}
175+
if df_has_rewritten_queries:
176+
row[t5_rewriter_module.COPY_REWRITTEN_QUERY_COLUMN] = row[
177+
"query"
178+
]
179+
rows_to_add.append(row)
180+
if len(rows_to_add) > 0:
181+
df = pd.concat([df, pd.DataFrame(rows_to_add)])
182+
df = df.sort_values(["qid", "rank"], ascending=[True, True])
183+
184+
return df
185+
186+
def replace_empty_placeholder_docs(
187+
self, df: pd.DataFrame, context_list: List[Tuple[Query, Context]]
188+
) -> pd.DataFrame:
189+
"""
190+
Replace any empty placeholder documents with documents from the context, if possible.
191+
192+
Parameters
193+
----------
194+
df : pd.DataFrame
195+
The dataframe to be replaced.
196+
context_list : List[Tuple[Query, Context]]
197+
The context of the queries.
198+
199+
Returns
200+
-------
201+
pd.DataFrame
202+
The dataframe with the replaced documents.
203+
"""
204+
205+
def gen_context_docs(context: Context) -> Generator[Document, None, None]:
206+
for _, docs in reversed(context):
207+
if docs is not None:
208+
for doc in docs:
209+
if doc.docno != EMPTY_PLACEHOLDER_DOC.docno:
210+
yield doc
211+
212+
for query, context in context_list:
213+
# check if there is a row in the df with "qid" == query.query_id, where "docno" == EMPTY_PLACEHOLDER_DOC.docno
214+
# if yes, replace it with the top document from the context
215+
while True:
216+
if not df[
217+
(df["qid"] == query.query_id)
218+
& (df["docno"] == EMPTY_PLACEHOLDER_DOC.docno)
219+
].empty:
220+
# Check if gen_docs has next element
221+
doc: Document
222+
doc_gen = gen_context_docs(context)
223+
try:
224+
doc = next(doc_gen)
225+
while (
226+
doc.docno
227+
in df[(df["qid"] == query.query_id)]["docno"].unique()
228+
):
229+
doc = next(doc_gen)
230+
except StopIteration:
231+
break
232+
# Get the row index of the row to be replaced (of all of the rows satisfying the condition, take the one with min "rank" value)
233+
row_index = df[
234+
(df["qid"] == query.query_id)
235+
& (df["docno"] == EMPTY_PLACEHOLDER_DOC.docno)
236+
]["rank"].idxmin()
237+
238+
# Of that row, set docno and docid to doc.no, and text to doc.content
239+
df.loc[row_index, "docno"] = doc.docno
240+
df.loc[row_index, "docid"] = doc.docno
241+
df.loc[row_index, "text"] = doc.content
242+
else:
243+
break
244+
245+
return df
246+
102247
def combine_result_stages(self, results: List[pd.DataFrame]) -> pd.DataFrame:
103248
"""
104249
Combine the results of the stages.
@@ -188,6 +333,7 @@ def search(self, query: Query, context: Context) -> Tuple[Context, pd.DataFrame]
188333
query_str = self.transform_input(query, context)
189334
query_df = pd.DataFrame([{"qid": query.query_id, "query": query_str}])
190335
result = self.transform(query_df)
336+
result = self.replace_empty_placeholder_docs(result, [(query, context)])
191337

192338
if t5_rewriter_module.COPY_REWRITTEN_QUERY_COLUMN in result.columns:
193339
temp_result = result[result["qid"] == query.query_id]
@@ -239,6 +385,7 @@ def batch_search(
239385
]
240386
)
241387
result = self.transform(query_df)
388+
result = self.replace_empty_placeholder_docs(result, inputs)
242389

243390
if t5_rewriter_module.COPY_REWRITTEN_QUERY_COLUMN in result.columns:
244391
for query, _ in inputs:

py_css/models/baseline.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ class Baseline(base_module.Pipeline):
1717
A class to represent the baseline retrieval method.
1818
"""
1919

20+
t5_qr: t5_rewriter.T5Rewriter
2021
top_docs: Tuple[pt.Transformer, int]
2122
mono_t5: Tuple[MonoT5ReRanker, int]
2223
duo_t5: Tuple[DuoT5ReRanker, int]
@@ -43,9 +44,9 @@ def __init__(
4344
duo_t5_docs : int
4445
The number of documents to retrieve with DuoT5.
4546
"""
46-
t5_qr = t5_rewriter.T5Rewriter()
47+
self.t5_qr = t5_rewriter.T5Rewriter()
4748
bm25 = pt.BatchRetrieve(index, wmodel="BM25", metadata=["docno", "text"])
48-
self.top_docs = ((t5_qr >> bm25) % bm25_docs, bm25_docs)
49+
self.top_docs = ((bm25 % bm25_docs).compile(), bm25_docs)
4950
self.mono_t5 = (MonoT5ReRanker(batch_size=BATCH_SIZE), mono_t5_docs)
5051
self.duo_t5 = (DuoT5ReRanker(batch_size=BATCH_SIZE), duo_t5_docs)
5152

@@ -58,7 +59,11 @@ def transform_input(
5859
doc_was_added = False
5960
if len(context) > 0:
6061
last_docs = context[-1][1]
61-
if last_docs is not None and len(last_docs) > 0:
62+
if (
63+
last_docs is not None
64+
and len(last_docs) > 0
65+
and last_docs[0].docno != base_module.EMPTY_PLACEHOLDER_DOC.docno
66+
):
6267
history.append(last_docs[0].content)
6368
doc_was_added = True
6469
sum_of_lengths = sum([len(q) for q in history]) + len(query.query)
@@ -89,11 +94,15 @@ def transform_input(
8994
remaining = 0
9095

9196
history.append(query.query)
92-
new_query = " <sep> ".join(history)
97+
new_query = t5_rewriter.SEPERATOR_TOKEN.join(history)
9398
return new_query
9499

95100
def transform(self, query_df: pd.DataFrame) -> pd.DataFrame:
96-
top_docs_df = self.top_docs[0].transform(query_df)
101+
unique_qids = set(query_df["qid"].unique())
102+
103+
rewritten_queries_df = self.t5_qr.transform(query_df)
104+
105+
top_docs_df = self.top_docs[0].transform(rewritten_queries_df.copy())
97106
top_docs_df = (
98107
top_docs_df.sort_values(["qid", "score"], ascending=False)
99108
.groupby("qid")
@@ -118,4 +127,9 @@ def transform(self, query_df: pd.DataFrame) -> pd.DataFrame:
118127
.head(self.duo_t5[1])
119128
)
120129

121-
return self.combine_result_stages([top_docs_df, mono_t5_df, duo_t5_df])
130+
result = self.combine_result_stages([top_docs_df, mono_t5_df, duo_t5_df])
131+
result = self.pad_empty_documents(
132+
result, unique_qids, self.top_docs[1], rewritten_queries_df
133+
)
134+
135+
return result

0 commit comments

Comments
 (0)