77import warnings
88from lightning .pytorch import LightningDataModule
99import torch
10- from torch_geometric .data import Data
11- from torch .utils .data import DataLoader , SequentialSampler
12- from torch .utils .data .distributed import DistributedSampler
13- from pina ._src .core .label_tensor import LabelTensor
14- from pina ._src .data .dataset import PinaDatasetFactory , PinaTensorDataset
10+ from torch_geometric .data import Batch
1511from pina ._src .data .creator import _Creator
12+ from pina ._src .core .graph import LabelBatch , Graph
1613from pina ._src .data .aggregator import _Aggregator
1714
1815
@@ -131,6 +128,7 @@ def __init__(
131128 self .shuffle = shuffle
132129 self .batching_mode = batching_mode
133130 self .automatic_batching = automatic_batching
131+ self .batching_mode = batching_mode
134132
135133 # If batch size is None, num_workers has no effect
136134 if batch_size is None and num_workers != 0 :
@@ -244,7 +242,6 @@ def setup(self, stage=None):
244242 :raises ValueError: If the stage is neither "fit" nor "test".
245243 """
246244 if stage == "fit" or stage is None :
247- print ("Sono qui" )
248245 self .train_datasets = {
249246 name : _ConditionSubset (
250247 condition ,
@@ -254,7 +251,7 @@ def setup(self, stage=None):
254251 for name , condition in self .problem .conditions .items ()
255252 if len (self .split_idxs [name ]["train" ]) > 0
256253 }
257- print ( self . train_datasets )
254+
258255 self .val_datasets = {
259256 name : _ConditionSubset (
260257 condition ,
@@ -265,6 +262,7 @@ def setup(self, stage=None):
265262 if len (self .split_idxs [name ]["val" ]) > 0
266263 }
267264 return
265+
268266 if stage == "test" or stage is None :
269267 self .test_datasets = {
270268 name : _ConditionSubset (
@@ -281,22 +279,20 @@ def setup(self, stage=None):
281279 )
282280
283281 def train_dataloader (self ):
284- print (self .train_datasets )
285282 return _Aggregator (
286283 self .creator (self .train_datasets ),
287- batching_mode = "separate_conditions" ,
284+ batching_mode = self . batching_mode ,
288285 )
289286
290287 def val_dataloader (self ):
291- print (self .val_datasets )
292288 return _Aggregator (
293- self .creator (self .val_datasets ), batching_mode = "separate_conditions"
289+ self .creator (self .val_datasets ), batching_mode = self . batching_mode
294290 )
295291
296292 def test_dataloader (self ):
297293 return _Aggregator (
298294 self .creator (self .test_datasets ),
299- batching_mode = "separate_conditions" ,
295+ batching_mode = self . batching_mode ,
300296 )
301297
302298 @staticmethod
0 commit comments