Skip to content

Commit fea450b

Browse files
committed
bug fix and add tests
1 parent f36efe2 commit fea450b

6 files changed

Lines changed: 324 additions & 557 deletions

File tree

pina/_src/condition/condition_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def switch_dataloader_fn(self, create_dataloader_fn):
142142
:return: The decorated function with the new dataloader function.
143143
:rtype: function
144144
"""
145-
# Replace the create_dataloader method of the ConditionBase class with
145+
# Replace the create_dataloader method of the ConditionBase class with
146146
# the new function
147147
self.has_custom_dataloader_fn = True
148148
self.create_dataloader = create_dataloader_fn

pina/_src/core/trainer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,11 @@ def __init__(
140140
)
141141
batching_mode = "common_batch_size"
142142

143-
if batch_size == 1 and batching_mode == "proportional":
143+
if (
144+
batch_size is not None
145+
and batch_size <= len(solver.problem.conditions)
146+
and batching_mode == "proportional"
147+
):
144148
warnings.warn(
145149
"Batching mode is set to proportional but batch_size is 1. "
146150
"Batching mode will be set to common_batch_size.",

tests/test_data/test_data_module.py

Lines changed: 0 additions & 331 deletions
This file was deleted.

0 commit comments

Comments
 (0)