Skip to content

Commit f36efe2

Browse files
committed
fix bugs
1 parent 47ec090 commit f36efe2

2 files changed

Lines changed: 23 additions & 6 deletions

File tree

pina/_src/core/trainer.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,12 +131,30 @@ def __init__(
131131
automatic_batching if automatic_batching is not None else False
132132
)
133133

134+
if batch_size is None and batching_mode != "common_batch_size":
135+
warnings.warn(
136+
"Batching mode is set to "
137+
f"{batching_mode} but batch_size is None. "
138+
"Batching mode will be set to common_batch_size.",
139+
UserWarning,
140+
)
141+
batching_mode = "common_batch_size"
142+
143+
if batch_size == 1 and batching_mode == "proportional":
144+
warnings.warn(
145+
"Batching mode is set to proportional but batch_size is 1. "
146+
"Batching mode will be set to common_batch_size.",
147+
UserWarning,
148+
)
149+
batching_mode = "common_batch_size"
150+
134151
# set attributes
135152
self.compile = compile
136153
self.solver = solver
137154
self.batch_size = batch_size
138155
self._move_to_device()
139156
self.data_module = None
157+
140158
self._create_datamodule(
141159
train_size=train_size,
142160
test_size=test_size,

pina/_src/data/data_module.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def __init__(
7272
val_size=0.1,
7373
batch_size=None,
7474
shuffle=True,
75-
batching_mode="separate_conditions",
75+
batching_mode="common_batch_size",
7676
automatic_batching=None,
7777
num_workers=0,
7878
pin_memory=False,
@@ -95,7 +95,7 @@ def __init__(
9595
Default ``True``.
9696
:param str batching_mode: The batching mode to use. Options are
9797
``"common_batch_size"``, ``"proportional"``, and
98-
``"separate_conditions"``. Default is ``"separate_conditions"``.
98+
``"separate_conditions"``. Default is ``"common_batch_size"``.
9999
:param automatic_batching: If ``True``, automatic PyTorch batching
100100
is performed, which consists of extracting one element at a time
101101
from the dataset and collating them into a batch. This is useful
@@ -241,7 +241,7 @@ def setup(self, stage=None):
241241
242242
:raises ValueError: If the stage is neither "fit" nor "test".
243243
"""
244-
if stage == "fit" or stage is None:
244+
if stage in ("fit", None):
245245
self.train_datasets = {
246246
name: _ConditionSubset(
247247
condition,
@@ -261,9 +261,8 @@ def setup(self, stage=None):
261261
for name, condition in self.problem.conditions.items()
262262
if len(self.split_idxs[name]["val"]) > 0
263263
}
264-
return
265264

266-
if stage == "test" or stage is None:
265+
if stage in ("test", None):
267266
self.test_datasets = {
268267
name: _ConditionSubset(
269268
condition,
@@ -273,7 +272,7 @@ def setup(self, stage=None):
273272
for name, condition in self.problem.conditions.items()
274273
if len(self.split_idxs[name]["test"]) > 0
275274
}
276-
else:
275+
if stage not in ("fit", "test", None):
277276
raise ValueError(
278277
f"Invalid stage {stage}. Stage must be either 'fit' or 'test'."
279278
)

0 commit comments

Comments
 (0)