Skip to content

Commit 4c693d4

Browse files
committed
fix common_batch_size iteration bug
1 parent fea450b commit 4c693d4

2 files changed

Lines changed: 9 additions & 1 deletion

File tree

pina/_src/data/creator.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,11 @@ def __call__(self, datasets):
166166
# Compute batch sizes per condition based on batching_mode
167167
batch_sizes = self._compute_batch_sizes(datasets)
168168
dataloaders = {}
169+
if self.batching_mode == "common_batch_size":
170+
max_len = max(len(dataset) for dataset in datasets.values())
169171
for name, dataset in datasets.items():
172+
if self.batching_mode == "common_batch_size":
173+
dataset.max_len = max_len
170174
dataloaders[name] = self.conditions[name].create_dataloader(
171175
dataset=dataset,
172176
batch_size=batch_sizes[name],

pina/_src/data/data_module.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,11 @@ def __init__(self, condition, indices, automatic_batching):
2424
self.condition = condition
2525
self.indices = indices
2626
self.automatic_batching = automatic_batching
27+
self.length = len(self.indices)
28+
self.max_len = self.length
2729

2830
def __len__(self):
29-
return len(self.indices)
31+
return self.max_len
3032

3133
def __getitem__(self, idx):
3234
"""
@@ -36,6 +38,8 @@ def __getitem__(self, idx):
3638
:return: The data corresponding to the given index.
3739
:rtype: dict
3840
"""
41+
if idx >= self.length:
42+
idx = idx % self.length
3943
idx = self.indices[idx]
4044
if not self.automatic_batching:
4145
return idx

0 commit comments

Comments
 (0)