File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -79,13 +79,14 @@ def _compute_batch_sizes(self, datasets):
7979 """
8080 batch_sizes = {}
8181 if self .batching_mode == "common_batch_size" :
82+
83+ if self .batch_size is None :
84+ batch_size = max (dataset .length for dataset in datasets .values ())
85+ else :
86+ batch_size = self .batch_size
87+
8288 for name in datasets .keys ():
83- if self .batch_size is None :
84- batch_sizes [name ] = len (datasets [name ])
85- else :
86- batch_sizes [name ] = min (
87- self .batch_size , len (datasets [name ])
88- )
89+ batch_sizes [name ] = min (batch_size , len (datasets [name ]))
8990 return batch_sizes
9091 if self .batching_mode == "proportional" :
9192 return self ._compute_proportional_batch_sizes (datasets )
@@ -168,8 +169,9 @@ def __call__(self, datasets):
168169 dataloaders = {}
169170 if self .batching_mode == "common_batch_size" :
170171 max_len = max (len (dataset ) for dataset in datasets .values ())
172+
171173 for name , dataset in datasets .items ():
172- if self .batching_mode == "common_batch_size" :
174+ if self .batching_mode == "common_batch_size" and dataset . length != batch_sizes [ name ] :
173175 dataset .max_len = max_len
174176 dataloaders [name ] = self .conditions [name ].create_dataloader (
175177 dataset = dataset ,
You can’t perform that action at this time.
0 commit comments