Skip to content

Commit 2bf879a

Browse files
committed
Narrowing down the bug
1 parent 8b6818f commit 2bf879a

4 files changed

Lines changed: 17 additions & 3 deletions

File tree

py_css/interface/kaggle.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,11 @@ def to_kaggle_format(df: pd.DataFrame) -> str:
2828
The dataframe in the Kaggle submission format.
2929
"""
3030
# for each query, only keep the best 3 docnos ranked by asceding rank
31-
df = df.groupby("qid").sort_values("rank").head(3)
31+
df = (
32+
df.sort_values(by=["qid", "rank"], ascending=[True, True])
33+
.groupby("qid")
34+
.head(3)
35+
)
3236

3337
output = "qid,docid\n"
3438
for _, row in df.iterrows():

py_css/interface/run_queries.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def to_trec_runfile_format(df: pd.DataFrame, model_name: str) -> str:
2929
str
3030
The dataframe in the TREC runfile format.
3131
"""
32+
df = df.sort_values(by=["qid", "rank"], ascending=[True, True])
3233
return "\n".join(
3334
[
3435
f"{row['qid']} Q0 {row['docno']} {int(row['rank']) + 1} {row['score']} {model_name}"

py_css/models/baseline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def __init__(
4545
"""
4646
t5_qr = t5_rewriter.T5Rewriter()
4747
bm25 = pt.BatchRetrieve(index, wmodel="BM25", metadata=["docno", "text"])
48-
self.top_docs = (t5_qr >> bm25, bm25_docs)
48+
self.top_docs = ((t5_qr >> bm25) % bm25_docs, bm25_docs)
4949
self.mono_t5 = (MonoT5ReRanker(batch_size=BATCH_SIZE), mono_t5_docs)
5050
self.duo_t5 = (DuoT5ReRanker(batch_size=BATCH_SIZE), duo_t5_docs)
5151

py_css/models/baseline_prf.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def __init__(
5353
self.t5_qr = t5_rewriter.T5Rewriter()
5454
bm25 = pt.BatchRetrieve(index, wmodel="BM25", metadata=["docno", "text"])
5555
rm3 = pt.rewrite.RM3(index, fb_docs=rm3_fb_docs, fb_terms=rm3_fb_terms)
56-
self.top_docs = ((bm25 % rm3_fb_docs) >> rm3 >> bm25, bm25_docs)
56+
self.top_docs = (((bm25 % rm3_fb_docs) >> rm3 >> bm25) % bm25_docs, bm25_docs)
5757
self.mono_t5 = (MonoT5ReRanker(batch_size=BATCH_SIZE), mono_t5_docs)
5858
self.duo_t5 = (DuoT5ReRanker(batch_size=BATCH_SIZE), duo_t5_docs)
5959

@@ -115,6 +115,11 @@ def transform(self, query_df: pd.DataFrame) -> pd.DataFrame:
115115
top_docs_df["qid"].unique()
116116
), f"{unique_qids} != {set(top_docs_df['qid'].unique())}"
117117

118+
# assert that each qid is present 1000 times
119+
assert (
120+
top_docs_df.groupby("qid").size() == 1000
121+
).all(), f"{top_docs_df.groupby('qid').size().unique()}"
122+
118123
top_docs_df = (
119124
top_docs_df.sort_values(["qid", "score"], ascending=False)
120125
.groupby("qid")
@@ -125,6 +130,10 @@ def transform(self, query_df: pd.DataFrame) -> pd.DataFrame:
125130
top_docs_df["qid"].unique()
126131
), f"{unique_qids} != {set(top_docs_df['qid'].unique())}"
127132

133+
assert (
134+
top_docs_df.groupby("qid").size() == 1000
135+
).all(), f"{top_docs_df.groupby('qid').size().unique()}"
136+
128137
# Now add in the rewritten queries to the top docs
129138
top_docs_df = pt.model.push_queries(top_docs_df, inplace=True)
130139
top_docs_df = pd.merge(

0 commit comments

Comments
 (0)