Skip to content

Commit 42996b3

Browse files
committed
Adjusting parameters to best Google Cloud GPU (Nvidia T4) settings
1 parent 80f0980 commit 42996b3

1 file changed

Lines changed: 8 additions & 2 deletions

File tree

scripts/doc2query-t5.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212

1313

1414
MODEL_NAME: str = "castorini/doc2query-t5-large-msmarco"
15-
MAX_LENGTH: int = 512
16-
BATCH_SIZE: int = 64
15+
MAX_LENGTH: int = 256
16+
BATCH_SIZE: int = 32
1717
DEVICE: str = "cuda" if torch.cuda.is_available() else "cpu"
1818
NUM_SAMPLES: int = 3
1919

@@ -31,8 +31,12 @@ class Doc2Query:
3131
The T5 model to use
3232
tokenizer : T5TokenizerFast
3333
The T5 tokenizer to use
34+
device : torch.device
35+
The device to use
3436
max_length : int
3537
The maximum length of the input
38+
num_samples : int
39+
The number of samples to generate
3640
batch_size : int
3741
The batch size to use
3842
input_file : Path
@@ -41,6 +45,8 @@ class Doc2Query:
4145
The output file
4246
output_df : pd.DataFrame
4347
The output dataframe
48+
pattern : re.Pattern
49+
The pattern to remove URLs from the input
4450
"""
4551

4652
model: T5ForConditionalGeneration

0 commit comments

Comments
 (0)