@@ -72,7 +72,7 @@ def __init__(
7272 val_size = 0.1 ,
7373 batch_size = None ,
7474 shuffle = True ,
75- batching_mode = "separate_conditions " ,
75+ batching_mode = "common_batch_size " ,
7676 automatic_batching = None ,
7777 num_workers = 0 ,
7878 pin_memory = False ,
@@ -95,7 +95,7 @@ def __init__(
9595 Default ``True``.
9696 :param str batching_mode: The batching mode to use. Options are
9797 ``"common_batch_size"``, ``"proportional"``, and
98- ``"separate_conditions"``. Default is ``"separate_conditions "``.
98+ ``"separate_conditions"``. Default is ``"common_batch_size "``.
9999 :param automatic_batching: If ``True``, automatic PyTorch batching
100100 is performed, which consists of extracting one element at a time
101101 from the dataset and collating them into a batch. This is useful
@@ -241,7 +241,7 @@ def setup(self, stage=None):
241241
242242 :raises ValueError: If the stage is neither "fit" nor "test".
243243 """
244- if stage == "fit" or stage is None :
244+ if stage in ( "fit" , None ) :
245245 self .train_datasets = {
246246 name : _ConditionSubset (
247247 condition ,
@@ -261,9 +261,8 @@ def setup(self, stage=None):
261261 for name , condition in self .problem .conditions .items ()
262262 if len (self .split_idxs [name ]["val" ]) > 0
263263 }
264- return
265264
266- if stage == "test" or stage is None :
265+ if stage in ( "test" , None ) :
267266 self .test_datasets = {
268267 name : _ConditionSubset (
269268 condition ,
@@ -273,7 +272,7 @@ def setup(self, stage=None):
273272 for name , condition in self .problem .conditions .items ()
274273 if len (self .split_idxs [name ]["test" ]) > 0
275274 }
276- else :
275+ if stage not in ( "fit" , "test" , None ) :
277276 raise ValueError (
278277 f"Invalid stage { stage } . Stage must be either 'fit' or 'test'."
279278 )
0 commit comments