Skip to content

Commit a316e5d

Browse files
committed
Preperation for advanced retrieval methods
1 parent 5aa7857 commit a316e5d

10 files changed

Lines changed: 171 additions & 42 deletions

File tree

py_css/interface/cli.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
import logging
2-
from typing import Tuple
32

43
from rich.prompt import Prompt
54
from rich.style import Style
65
from rich.console import Console
76

87
import indexer.index as index_module
98
import models.base as base_module
10-
import models.baseline as baseline_module
9+
import models.model_parameters as model_parameters_module
1110

1211
index = None
1312
pipeline: base_module.Pipeline
@@ -63,7 +62,12 @@ def process_input(input_str: str, *, top_n: int) -> str:
6362
return "\n".join(contents)
6463

6564

66-
def main(*, recreate: bool, top_n: int, baseline_params: Tuple[int, int, int]) -> None:
65+
def main(
66+
*,
67+
recreate: bool,
68+
top_n: int,
69+
model_parameters: model_parameters_module.ParametersBase,
70+
) -> None:
6771
"""
6872
The main function of the CLI interface.
6973
@@ -73,16 +77,14 @@ def main(*, recreate: bool, top_n: int, baseline_params: Tuple[int, int, int]) -
7377
Whether to recreate the index.
7478
top_n : int
7579
The number of top-ranked documents to return.
76-
baseline_params : Tuple[int, int, int]
77-
The parameters for the baseline model.
80+
model_parameters : model_parameters_module.ParametersBase
81+
The model parameters to use.
7882
"""
7983
global index
8084
global pipeline
8185

8286
index = index_module.get_index(recreate=recreate)
83-
pipeline = baseline_module.Baseline(
84-
index, baseline_params[0], baseline_params[1], baseline_params[2]
85-
)
87+
pipeline = model_parameters.create_Pipeline(index=index)
8688

8789
# Initialize the rich console
8890
console = Console()

py_css/interface/eval.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,9 @@
44
import tempfile
55
import subprocess
66

7-
import pyterrier as pt
8-
97
import indexer.index as index_module
108
import models.base as base_model
11-
import models.baseline as baseline_module
9+
import models.model_parameters as model_parameters_module
1210
import interface.run_queries as run_queries_module
1311

1412
index = None
@@ -20,7 +18,7 @@ def main(
2018
recreate: bool,
2119
queries_file_path: str,
2220
qrels_file_path: str,
23-
baseline_params: Tuple[int, int, int],
21+
model_parameters: model_parameters_module.ParametersBase,
2422
) -> None:
2523
"""
2624
The main function of the eval interface.
@@ -33,16 +31,14 @@ def main(
3331
The path to the queries file.
3432
qrels_file_path : str
3533
The path to the qrels file.
36-
baseline_params : Tuple[int, int, int]
37-
The parameters for the baseline model.
34+
model_parameters : model_parameters_module.ParametersBase
35+
The model parameters.
3836
"""
3937
global index
4038
global pipeline
4139

4240
index = index_module.get_index(recreate=recreate)
43-
pipeline = baseline_module.Baseline(
44-
index, baseline_params[0], baseline_params[1], baseline_params[2]
45-
)
41+
pipeline = model_parameters.create_Pipeline(index=index)
4642

4743
logging.info("Loading queries...")
4844
queries: Dict[int, Dict[int, base_model.Query]] = {} # topic_id -> (turn_id, query)

py_css/interface/kaggle.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import indexer.index as index_module
99
import models.base as base_model
10-
import models.baseline as baseline_module
10+
import models.model_parameters as model_parameters_module
1111

1212
index = None
1313
pipeline: base_model.Pipeline
@@ -41,7 +41,7 @@ def main(
4141
recreate: bool,
4242
queries_file_path: str,
4343
output_file_path: str,
44-
baseline_params: Tuple[int, int, int],
44+
model_parameters: model_parameters_module.ParametersBase,
4545
) -> None:
4646
"""
4747
The main function of the eval interface.
@@ -54,16 +54,14 @@ def main(
5454
The path to the queries file.
5555
qrels_file_path : str
5656
The path to the qrels file.
57-
baseline_params : Tuple[int, int, int]
58-
The parameters for the baseline model.
57+
model_parameters : model_parameters_module.ParametersBase
58+
The model parameters.
5959
"""
6060
global index
6161
global pipeline
6262

6363
index = index_module.get_index(recreate=recreate)
64-
pipeline = baseline_module.Baseline(
65-
index, baseline_params[0], baseline_params[1], baseline_params[2]
66-
)
64+
pipeline = model_parameters.create_Pipeline(index=index)
6765

6866
logging.info("Loading queries...")
6967
queries: Dict[int, Dict[int, base_model.Query]] = {} # topic_id -> (turn_id, query)

py_css/interface/run_queries.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import indexer.index as index_module
99
import models.base as base_model
10-
import models.baseline as baseline_module
10+
import models.model_parameters as model_parameters_module
1111

1212
index = None
1313
pipeline: base_model.Pipeline
@@ -42,7 +42,7 @@ def main(
4242
recreate: bool,
4343
queries_file_path: str,
4444
output_file_path: str,
45-
baseline_params: Tuple[int, int, int],
45+
model_parameters: model_parameters_module.ParametersBase,
4646
) -> None:
4747
"""
4848
The main function of the eval interface.
@@ -55,16 +55,14 @@ def main(
5555
The path to the queries file.
5656
qrels_file_path : str
5757
The path to the qrels file.
58-
baseline_params : Tuple[int, int, int]
59-
The parameters for the baseline model.
58+
model_parameters : model_parameters_module.ParametersBase
59+
The model parameters.
6060
"""
6161
global index
6262
global pipeline
6363

6464
index = index_module.get_index(recreate=recreate)
65-
pipeline = baseline_module.Baseline(
66-
index, baseline_params[0], baseline_params[1], baseline_params[2]
67-
)
65+
pipeline = model_parameters.create_Pipeline(index=index)
6866

6967
logging.info("Loading queries...")
7068
queries: Dict[int, Dict[int, base_model.Query]] = {} # topic_id -> (turn_id, query)

py_css/main.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@
88
import interface.eval as eval_module
99
import interface.kaggle as kaggle_module
1010

11-
import models.base as base_module
12-
import models.baseline as baseline_module
11+
import models.model_parameters as model_parameters_module
1312

1413

1514
def setup() -> None:
@@ -45,7 +44,7 @@ def main():
4544
global_args.add_argument(
4645
"--method",
4746
type=str,
48-
choices=["baseline", "advanced"],
47+
choices=["baseline"],
4948
default="baseline",
5049
help="Set the retrieval method",
5150
)
@@ -104,6 +103,15 @@ def main():
104103
# Log Level
105104
logging.basicConfig(level=args.log)
106105

106+
model_parameters: model_parameters_module.ParametersBase
107+
match args.method:
108+
case "baseline":
109+
model_parameters = model_parameters_module.BaselineParameters.from_tuple(
110+
args.baseline_params
111+
)
112+
case _:
113+
raise NotImplementedError
114+
107115
# Call the setup function
108116
setup()
109117

@@ -112,28 +120,28 @@ def main():
112120
cli_module.main(
113121
recreate=args.recreate,
114122
top_n=args.top_n,
115-
baseline_params=args.baseline_params,
123+
model_parameters=model_parameters,
116124
)
117125
elif args.command == "run_file":
118126
run_queries_module.main(
119127
recreate=args.recreate,
120128
queries_file_path=args.queries,
121129
output_file_path=args.output,
122-
baseline_params=args.baseline_params,
130+
model_parameters=model_parameters,
123131
)
124132
elif args.command == "eval":
125133
eval_module.main(
126134
recreate=args.recreate,
127135
queries_file_path=args.queries,
128136
qrels_file_path=args.qrels,
129-
baseline_params=args.baseline_params,
137+
model_parameters=model_parameters,
130138
)
131139
elif args.command == "kaggle":
132140
kaggle_module.main(
133141
recreate=args.recreate,
134142
queries_file_path=args.queries,
135143
output_file_path=args.output,
136-
baseline_params=args.baseline_params,
144+
model_parameters=model_parameters,
137145
)
138146

139147

py_css/models/T5DocumentExpander.py

Whitespace-only changes.

py_css/models/baseline.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,10 @@ class Baseline(base_module.Pipeline):
2323
def __init__(
2424
self,
2525
index,
26-
bm25_docs: int = 1000,
27-
mono_t5_docs: int = 100,
28-
duo_t5_docs: int = 10,
26+
*,
27+
bm25_docs,
28+
mono_t5_docs,
29+
duo_t5_docs,
2930
):
3031
"""
3132
Constructs all the necessary attributes for the baseline retrieval method.

py_css/models/model_parameters.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
from __future__ import annotations
2+
3+
from abc import ABC, abstractmethod
4+
from dataclasses import dataclass
5+
from typing import Tuple
6+
7+
import models.base as base_model
8+
import models.baseline as baseline_model
9+
10+
11+
class ParametersBase(ABC):
12+
"""
13+
An abstract class to represent the parameters of a model.
14+
"""
15+
16+
@abstractmethod
17+
def create_Pipeline(self, index) -> base_model.Pipeline:
18+
"""
19+
Creates a pipeline with the given index.
20+
21+
Parameters
22+
----------
23+
index : pt.Index
24+
The PyTerrier index.
25+
26+
Returns
27+
-------
28+
base_model.Pipeline
29+
The pipeline.
30+
"""
31+
...
32+
33+
@staticmethod
34+
@abstractmethod
35+
def from_tuple(tup: Tuple) -> ParametersBase:
36+
"""
37+
Creates a ParametersBase object from a tuple.
38+
39+
Parameters
40+
----------
41+
tup : Tuple
42+
The tuple.
43+
44+
Returns
45+
-------
46+
ParametersBase
47+
The ParametersBase object.
48+
"""
49+
...
50+
51+
52+
@dataclass
53+
class BaselineParameters(ParametersBase):
54+
"""
55+
A class to represent the parameters of the baseline retrieval method.
56+
57+
Attributes
58+
----------
59+
bm25_docs : int
60+
The number of documents to retrieve with BM25.
61+
mono_t5_docs : int
62+
The number of documents to rerank with MonoT5.
63+
duo_t5_docs : int
64+
The number of documents to rerank with DuoT5.
65+
"""
66+
67+
bm_25_docs: int
68+
mono_t5_docs: int
69+
duo_t5_docs: int
70+
71+
def create_Pipeline(self, index) -> base_model.Pipeline:
72+
"""
73+
Creates the baseline pipeline with the given index.
74+
75+
Parameters
76+
----------
77+
index : pt.Index
78+
The PyTerrier index.
79+
80+
Returns
81+
-------
82+
base_model.Pipeline (baseline_model.Baseline)
83+
The baseline pipeline.
84+
"""
85+
return baseline_model.Baseline(
86+
index,
87+
bm25_docs=self.bm_25_docs,
88+
mono_t5_docs=self.mono_t5_docs,
89+
duo_t5_docs=self.duo_t5_docs,
90+
)
91+
92+
@staticmethod
93+
def from_tuple(tup: Tuple) -> ParametersBase:
94+
"""
95+
Creates a BaselineParameters object from a tuple.
96+
97+
Parameters
98+
----------
99+
tup : Tuple[int, int, int]
100+
The tuple (bm25_docs, mono_t5_docs, duo_t5_docs)
101+
102+
Returns
103+
-------
104+
ParametersBase (BaselineParameters)
105+
The ParametersBase object.
106+
107+
Raises
108+
------
109+
AssertionError
110+
If the tuple does not have 3 elements or if any of the elements is not a positive integer.
111+
"""
112+
assert len(tup) == 3, "The tuple must have 3 elements."
113+
for i in tup:
114+
assert isinstance(i, int), "All parameters must be integers."
115+
assert i > 0, "All parameters must be positive integers."
116+
return BaselineParameters(tup[0], tup[1], tup[2])

report/main.pdf

747 Bytes
Binary file not shown.

report/main.tex

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,17 @@ \section{Problem Statement}\label{sec:problem}
7171

7272

7373
\section{Related Work}\label{sec:related}
74-
Identify and discuss related work (5-10 relevant papers).
74+
\subsection*{T5}
75+
76+
\subsection*{doc2query}
77+
78+
\subsection*{doc2query-T5}
79+
80+
\subsection*{monoT5 \& duoT5}
81+
82+
\subsection*{T5 Query Rewriting}
83+
84+
7585

7686

7787
\section{Baseline Method}\label{sec:baseline}

0 commit comments

Comments
 (0)