Skip to content

Commit a05c202

Browse files
authored
5593 Raise error when GPU number is even greater than the dataset length (#5680)
Fixes #5593 . ### Description This PR added error check in `DistributedSampler` for the case that GPU number is greater than the dataset length and `unevenly divisible`. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Nic Ma <nma@nvidia.com>
1 parent 3873d23 commit a05c202

2 files changed

Lines changed: 8 additions & 0 deletions

File tree

monai/data/samplers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ def __init__(
5151

5252
if not even_divisible:
5353
data_len = len(dataset) # type: ignore
54+
if data_len < self.num_replicas:
55+
raise ValueError("the dataset length is less than the number of participating ranks.")
5456
extra_size = self.total_size - data_len
5557
if self.rank + extra_size >= self.num_replicas:
5658
self.num_samples -= 1

tests/test_sampler_dist.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,12 @@ def test_uneven(self):
4646
if dist.get_rank() == 1:
4747
np.testing.assert_allclose(samples, np.array([2, 4]))
4848

49+
@DistCall(nnodes=1, nproc_per_node=2)
50+
def test_uneven_less_data(self):
51+
data = [1]
52+
with self.assertRaises(ValueError):
53+
DistributedSampler(dataset=data, shuffle=False, even_divisible=False)
54+
4955
@DistCall(nnodes=1, nproc_per_node=2, timeout=120)
5056
def test_cachedataset(self):
5157
data = [1, 2, 3, 4, 5]

0 commit comments

Comments
 (0)