Skip to content

Commit 6eed64a

Browse files
committed
fix bugs
1 parent 975f0ef commit 6eed64a

2 files changed

Lines changed: 15 additions & 16 deletions

File tree

pina/_src/data/aggregator.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ def __len__(self):
3232
:return: The length of the aggregated dataloader.
3333
:rtype: int
3434
"""
35+
if self.batching_mode == "separate_conditions":
36+
return sum(len(dl) for dl in self.dataloaders.values())
3537
return max(len(dl) for dl in self.dataloaders.values())
3638

3739
def __iter__(self):
@@ -42,10 +44,11 @@ def __iter__(self):
4244
:rtype: iterator
4345
"""
4446
if self.batching_mode == "separate_conditions":
45-
for name, dl in self.dataloaders.items():
46-
for batch in dl:
47-
yield {name: batch}
48-
return
47+
# TODO: implement separate_conditions batching mode
48+
raise NotImplementedError(
49+
"Batching mode 'separate_conditions' is not implemented yet."
50+
)
51+
4952
iterators = {name: iter(dl) for name, dl in self.dataloaders.items()}
5053
for _ in range(len(self)):
5154
batch = {}

pina/_src/data/data_module.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,9 @@
77
import warnings
88
from lightning.pytorch import LightningDataModule
99
import torch
10-
from torch_geometric.data import Data
11-
from torch.utils.data import DataLoader, SequentialSampler
12-
from torch.utils.data.distributed import DistributedSampler
13-
from pina._src.core.label_tensor import LabelTensor
14-
from pina._src.data.dataset import PinaDatasetFactory, PinaTensorDataset
10+
from torch_geometric.data import Batch
1511
from pina._src.data.creator import _Creator
12+
from pina._src.core.graph import LabelBatch, Graph
1613
from pina._src.data.aggregator import _Aggregator
1714

1815

@@ -131,6 +128,7 @@ def __init__(
131128
self.shuffle = shuffle
132129
self.batching_mode = batching_mode
133130
self.automatic_batching = automatic_batching
131+
self.batching_mode = batching_mode
134132

135133
# If batch size is None, num_workers has no effect
136134
if batch_size is None and num_workers != 0:
@@ -244,7 +242,6 @@ def setup(self, stage=None):
244242
:raises ValueError: If the stage is neither "fit" nor "test".
245243
"""
246244
if stage == "fit" or stage is None:
247-
print("Sono qui")
248245
self.train_datasets = {
249246
name: _ConditionSubset(
250247
condition,
@@ -254,7 +251,7 @@ def setup(self, stage=None):
254251
for name, condition in self.problem.conditions.items()
255252
if len(self.split_idxs[name]["train"]) > 0
256253
}
257-
print(self.train_datasets)
254+
258255
self.val_datasets = {
259256
name: _ConditionSubset(
260257
condition,
@@ -265,6 +262,7 @@ def setup(self, stage=None):
265262
if len(self.split_idxs[name]["val"]) > 0
266263
}
267264
return
265+
268266
if stage == "test" or stage is None:
269267
self.test_datasets = {
270268
name: _ConditionSubset(
@@ -281,22 +279,20 @@ def setup(self, stage=None):
281279
)
282280

283281
def train_dataloader(self):
284-
print(self.train_datasets)
285282
return _Aggregator(
286283
self.creator(self.train_datasets),
287-
batching_mode="separate_conditions",
284+
batching_mode=self.batching_mode,
288285
)
289286

290287
def val_dataloader(self):
291-
print(self.val_datasets)
292288
return _Aggregator(
293-
self.creator(self.val_datasets), batching_mode="separate_conditions"
289+
self.creator(self.val_datasets), batching_mode=self.batching_mode
294290
)
295291

296292
def test_dataloader(self):
297293
return _Aggregator(
298294
self.creator(self.test_datasets),
299-
batching_mode="separate_conditions",
295+
batching_mode=self.batching_mode,
300296
)
301297

302298
@staticmethod

0 commit comments

Comments
 (0)