|
1 | 1 | from abc import ABC, abstractmethod |
2 | 2 | from dataclasses import dataclass |
3 | 3 | import logging |
4 | | -from typing import Optional, List, Tuple, TypeAlias |
| 4 | +from typing import Optional, List, Tuple, TypeAlias, Set, Any, Dict, Generator |
5 | 5 | import warnings |
6 | 6 |
|
7 | 7 | import models.T5Rewriter as t5_rewriter_module |
@@ -57,6 +57,9 @@ def __str__(self) -> str: |
57 | 57 |
|
58 | 58 | Context: TypeAlias = List[Tuple[Query, Optional[List[Document]]]] |
59 | 59 |
|
| 60 | +# If the retrieval model did not find any suitable or all N-required documents, this document shall be used as a placeholder |
| 61 | +EMPTY_PLACEHOLDER_DOC: Document = Document("-1", "") |
| 62 | + |
60 | 63 |
|
61 | 64 | class Pipeline(ABC): |
62 | 65 | """ |
@@ -99,6 +102,148 @@ def transform(self, query_df: pd.DataFrame) -> pd.DataFrame: |
99 | 102 | """ |
100 | 103 | ... |
101 | 104 |
|
| 105 | + def pad_empty_documents( |
| 106 | + self, df: pd.DataFrame, qids: Set[str], N: int, queries_df: pd.DataFrame |
| 107 | + ) -> pd.DataFrame: |
| 108 | + """ |
| 109 | + Pad the dataframe with empty documents. |
| 110 | +
|
| 111 | + Parameters |
| 112 | + ---------- |
| 113 | + df : pd.DataFrame |
| 114 | + The dataframe to be padded. |
| 115 | + qids : Set[str] |
| 116 | + The query ids. |
| 117 | + N : int |
| 118 | + The number of documents each qid should have. |
| 119 | + queries_df : pd.DataFrame |
| 120 | + The queries dataframe. |
| 121 | +
|
| 122 | + Returns |
| 123 | + ------- |
| 124 | + pd.DataFrame |
| 125 | + The padded dataframe. |
| 126 | + """ |
| 127 | + df_has_rewritten_queries: bool = ( |
| 128 | + t5_rewriter_module.COPY_REWRITTEN_QUERY_COLUMN in df.columns |
| 129 | + ) |
| 130 | + rows_to_add: List[Dict[str, Any]] = [] |
| 131 | + for qid in qids: |
| 132 | + # Check if qid is in top_doc_df (if not --> no documents at all were found) |
| 133 | + if qid not in df["qid"].unique(): |
| 134 | + for i in range(1, N + 1): |
| 135 | + row = { |
| 136 | + "qid": qid, |
| 137 | + "docid": EMPTY_PLACEHOLDER_DOC.docno, |
| 138 | + "docno": EMPTY_PLACEHOLDER_DOC.docno, |
| 139 | + "text": EMPTY_PLACEHOLDER_DOC.content, |
| 140 | + "score": -i, |
| 141 | + "rank": i, |
| 142 | + "query_0": queries_df[queries_df["qid"] == qid]["query"].iloc[ |
| 143 | + 0 |
| 144 | + ], |
| 145 | + "query": queries_df[queries_df["qid"] == qid]["query"].iloc[0], |
| 146 | + } |
| 147 | + if df_has_rewritten_queries: |
| 148 | + row[t5_rewriter_module.COPY_REWRITTEN_QUERY_COLUMN] = row[ |
| 149 | + "query" |
| 150 | + ] |
| 151 | + rows_to_add.append(row) |
| 152 | + else: |
| 153 | + # if there are less than N occurrences of a qid, add base_module.EMPTY_PLACEHOLDER_DOC to fill up (need to adjust score) |
| 154 | + if df.groupby("qid").size()[qid] < N: |
| 155 | + lowest_score = df[df["qid"] == qid]["score"].min() |
| 156 | + rank = int(round(df[df["qid"] == qid]["rank"].max())) |
| 157 | + for i in range( |
| 158 | + rank + 1, |
| 159 | + N + 1, |
| 160 | + ): |
| 161 | + row = { |
| 162 | + "qid": qid, |
| 163 | + "docid": EMPTY_PLACEHOLDER_DOC.docno, |
| 164 | + "docno": EMPTY_PLACEHOLDER_DOC.docno, |
| 165 | + "text": EMPTY_PLACEHOLDER_DOC.content, |
| 166 | + "score": lowest_score - i, |
| 167 | + "rank": i, |
| 168 | + "query_0": queries_df[queries_df["qid"] == qid][ |
| 169 | + "query" |
| 170 | + ].iloc[0], |
| 171 | + "query": queries_df[queries_df["qid"] == qid]["query"].iloc[ |
| 172 | + 0 |
| 173 | + ], |
| 174 | + } |
| 175 | + if df_has_rewritten_queries: |
| 176 | + row[t5_rewriter_module.COPY_REWRITTEN_QUERY_COLUMN] = row[ |
| 177 | + "query" |
| 178 | + ] |
| 179 | + rows_to_add.append(row) |
| 180 | + if len(rows_to_add) > 0: |
| 181 | + df = pd.concat([df, pd.DataFrame(rows_to_add)]) |
| 182 | + df = df.sort_values(["qid", "rank"], ascending=[True, True]) |
| 183 | + |
| 184 | + return df |
| 185 | + |
| 186 | + def replace_empty_placeholder_docs( |
| 187 | + self, df: pd.DataFrame, context_list: List[Tuple[Query, Context]] |
| 188 | + ) -> pd.DataFrame: |
| 189 | + """ |
| 190 | + Replace any empty placeholder documents with documents from the context, if possible. |
| 191 | +
|
| 192 | + Parameters |
| 193 | + ---------- |
| 194 | + df : pd.DataFrame |
| 195 | + The dataframe to be replaced. |
| 196 | + context_list : List[Tuple[Query, Context]] |
| 197 | + The context of the queries. |
| 198 | +
|
| 199 | + Returns |
| 200 | + ------- |
| 201 | + pd.DataFrame |
| 202 | + The dataframe with the replaced documents. |
| 203 | + """ |
| 204 | + |
| 205 | + def gen_context_docs(context: Context) -> Generator[Document, None, None]: |
| 206 | + for _, docs in reversed(context): |
| 207 | + if docs is not None: |
| 208 | + for doc in docs: |
| 209 | + if doc.docno != EMPTY_PLACEHOLDER_DOC.docno: |
| 210 | + yield doc |
| 211 | + |
| 212 | + for query, context in context_list: |
| 213 | + # check if there is a row in the df with "qid" == query.query_id, where "docno" == EMPTY_PLACEHOLDER_DOC.docno |
| 214 | + # if yes, replace it with the top document from the context |
| 215 | + while True: |
| 216 | + if not df[ |
| 217 | + (df["qid"] == query.query_id) |
| 218 | + & (df["docno"] == EMPTY_PLACEHOLDER_DOC.docno) |
| 219 | + ].empty: |
| 220 | + # Check if gen_docs has next element |
| 221 | + doc: Document |
| 222 | + doc_gen = gen_context_docs(context) |
| 223 | + try: |
| 224 | + doc = next(doc_gen) |
| 225 | + while ( |
| 226 | + doc.docno |
| 227 | + in df[(df["qid"] == query.query_id)]["docno"].unique() |
| 228 | + ): |
| 229 | + doc = next(doc_gen) |
| 230 | + except StopIteration: |
| 231 | + break |
| 232 | + # Get the row index of the row to be replaced (of all of the rows satisfying the condition, take the one with min "rank" value) |
| 233 | + row_index = df[ |
| 234 | + (df["qid"] == query.query_id) |
| 235 | + & (df["docno"] == EMPTY_PLACEHOLDER_DOC.docno) |
| 236 | + ]["rank"].idxmin() |
| 237 | + |
| 238 | + # Of that row, set docno and docid to doc.no, and text to doc.content |
| 239 | + df.loc[row_index, "docno"] = doc.docno |
| 240 | + df.loc[row_index, "docid"] = doc.docno |
| 241 | + df.loc[row_index, "text"] = doc.content |
| 242 | + else: |
| 243 | + break |
| 244 | + |
| 245 | + return df |
| 246 | + |
102 | 247 | def combine_result_stages(self, results: List[pd.DataFrame]) -> pd.DataFrame: |
103 | 248 | """ |
104 | 249 | Combine the results of the stages. |
@@ -188,6 +333,7 @@ def search(self, query: Query, context: Context) -> Tuple[Context, pd.DataFrame] |
188 | 333 | query_str = self.transform_input(query, context) |
189 | 334 | query_df = pd.DataFrame([{"qid": query.query_id, "query": query_str}]) |
190 | 335 | result = self.transform(query_df) |
| 336 | + result = self.replace_empty_placeholder_docs(result, [(query, context)]) |
191 | 337 |
|
192 | 338 | if t5_rewriter_module.COPY_REWRITTEN_QUERY_COLUMN in result.columns: |
193 | 339 | temp_result = result[result["qid"] == query.query_id] |
@@ -239,6 +385,7 @@ def batch_search( |
239 | 385 | ] |
240 | 386 | ) |
241 | 387 | result = self.transform(query_df) |
| 388 | + result = self.replace_empty_placeholder_docs(result, inputs) |
242 | 389 |
|
243 | 390 | if t5_rewriter_module.COPY_REWRITTEN_QUERY_COLUMN in result.columns: |
244 | 391 | for query, _ in inputs: |
|
0 commit comments