Skip to content

Commit e100cd7

Browse files
committed
fix batch size
1 parent 94c5830 commit e100cd7

1 file changed

Lines changed: 9 additions & 7 deletions

File tree

pina/_src/data/creator.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -79,13 +79,14 @@ def _compute_batch_sizes(self, datasets):
7979
"""
8080
batch_sizes = {}
8181
if self.batching_mode == "common_batch_size":
82+
83+
if self.batch_size is None:
84+
batch_size = max(dataset.length for dataset in datasets.values())
85+
else:
86+
batch_size = self.batch_size
87+
8288
for name in datasets.keys():
83-
if self.batch_size is None:
84-
batch_sizes[name] = len(datasets[name])
85-
else:
86-
batch_sizes[name] = min(
87-
self.batch_size, len(datasets[name])
88-
)
89+
batch_sizes[name] = min(batch_size, len(datasets[name]))
8990
return batch_sizes
9091
if self.batching_mode == "proportional":
9192
return self._compute_proportional_batch_sizes(datasets)
@@ -168,8 +169,9 @@ def __call__(self, datasets):
168169
dataloaders = {}
169170
if self.batching_mode == "common_batch_size":
170171
max_len = max(len(dataset) for dataset in datasets.values())
172+
171173
for name, dataset in datasets.items():
172-
if self.batching_mode == "common_batch_size":
174+
if self.batching_mode == "common_batch_size" and dataset.length != batch_sizes[name]:
173175
dataset.max_len = max_len
174176
dataloaders[name] = self.conditions[name].create_dataloader(
175177
dataset=dataset,

0 commit comments

Comments
 (0)