File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 1212
1313
1414MODEL_NAME : str = "castorini/doc2query-t5-large-msmarco"
15- MAX_LENGTH : int = 512
16- BATCH_SIZE : int = 64
15+ MAX_LENGTH : int = 256
16+ BATCH_SIZE : int = 32
1717DEVICE : str = "cuda" if torch .cuda .is_available () else "cpu"
1818NUM_SAMPLES : int = 3
1919
@@ -31,8 +31,12 @@ class Doc2Query:
3131 The T5 model to use
3232 tokenizer : T5TokenizerFast
3333 The T5 tokenizer to use
34+ device : torch.device
35+ The device to use
3436 max_length : int
3537 The maximum length of the input
38+ num_samples : int
39+ The number of samples to generate
3640 batch_size : int
3741 The batch size to use
3842 input_file : Path
@@ -41,6 +45,8 @@ class Doc2Query:
4145 The output file
4246 output_df : pd.DataFrame
4347 The output dataframe
48+ pattern : re.Pattern
49+ The pattern to remove URLs from the input
4450 """
4551
4652 model : T5ForConditionalGeneration
You can’t perform that action at this time.
0 commit comments