@@ -53,7 +53,7 @@ def __init__(
5353 self .t5_qr = t5_rewriter .T5Rewriter ()
5454 bm25 = pt .BatchRetrieve (index , wmodel = "BM25" , metadata = ["docno" , "text" ])
5555 rm3 = pt .rewrite .RM3 (index , fb_docs = rm3_fb_docs , fb_terms = rm3_fb_terms )
56- self .top_docs = ((bm25 % rm3_fb_docs ) >> rm3 >> bm25 , bm25_docs )
56+ self .top_docs = ((( bm25 % rm3_fb_docs ) >> rm3 >> bm25 ) % bm25_docs , bm25_docs )
5757 self .mono_t5 = (MonoT5ReRanker (batch_size = BATCH_SIZE ), mono_t5_docs )
5858 self .duo_t5 = (DuoT5ReRanker (batch_size = BATCH_SIZE ), duo_t5_docs )
5959
@@ -115,6 +115,11 @@ def transform(self, query_df: pd.DataFrame) -> pd.DataFrame:
115115 top_docs_df ["qid" ].unique ()
116116 ), f"{ unique_qids } != { set (top_docs_df ['qid' ].unique ())} "
117117
118+ # assert that each qid is present 1000 times
119+ assert (
120+ top_docs_df .groupby ("qid" ).size () == 1000
121+ ).all (), f"{ top_docs_df .groupby ('qid' ).size ().unique ()} "
122+
118123 top_docs_df = (
119124 top_docs_df .sort_values (["qid" , "score" ], ascending = False )
120125 .groupby ("qid" )
@@ -125,6 +130,10 @@ def transform(self, query_df: pd.DataFrame) -> pd.DataFrame:
125130 top_docs_df ["qid" ].unique ()
126131 ), f"{ unique_qids } != { set (top_docs_df ['qid' ].unique ())} "
127132
133+ assert (
134+ top_docs_df .groupby ("qid" ).size () == 1000
135+ ).all (), f"{ top_docs_df .groupby ('qid' ).size ().unique ()} "
136+
128137 # Now add in the rewritten queries to the top docs
129138 top_docs_df = pt .model .push_queries (top_docs_df , inplace = True )
130139 top_docs_df = pd .merge (
0 commit comments