Skip to content

Commit e2fc703

Browse files
myronwyli
andauthored
cachedataset small fix (#5565)
Fixes #5573 When using CacheDataset , and DataLoader num_workers==0, we convert ProxyList to List in disable_share_memory_cache() but if we're already inside DDP (DistributedDataParallel), we should not do such conversion, as it will crash. we should continue using ProxyList() ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: myron <amyronenko@nvidia.com> Co-authored-by: Wenqi Li <831580+wyli@users.noreply.github.com>
1 parent e15da1c commit e2fc703

2 files changed

Lines changed: 18 additions & 8 deletions

File tree

monai/data/dataloader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def __init__(self, dataset: Dataset, num_workers: int = 0, **kwargs) -> None:
8383
# disable unnecessary multiprocessing caching
8484
from monai.data.dataset import CacheDataset # avoid circular import
8585

86-
if isinstance(dataset, CacheDataset) and dataset.runtime_cache:
86+
if isinstance(dataset, CacheDataset):
8787
dataset.disable_share_memory_cache()
8888

8989
_g.manual_seed(init_seed)

monai/data/dataset.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -763,7 +763,7 @@ def __init__(
763763
will take the minimum of (cache_num, data_length x cache_rate, data_length).
764764
num_workers: the number of worker threads if computing cache in the initialization.
765765
If num_workers is None then the number returned by os.cpu_count() is used.
766-
If a value less than 1 is speficied, 1 will be used instead.
766+
If a value less than 1 is specified, 1 will be used instead.
767767
progress: whether to display a progress bar.
768768
copy_cache: whether to `deepcopy` the cache content before applying the random transforms,
769769
default to `True`. if the random transforms don't modify the cached content
@@ -778,15 +778,18 @@ def __init__(
778778
hash_func: if `hash_as_key`, a callable to compute hash from data items to be cached.
779779
defaults to `monai.data.utils.pickle_hashing`.
780780
runtime_cache: whether to compute cache at the runtime, default to `False` to prepare
781-
the cache content at initializaiton, if `True`, it will cache during the first epoch
781+
the cache content at initialization, if `True`, it will cache during the first epoch
782782
of model training, so it can start the first mini-batch earlier. please note that:
783783
1. when using this option in multi-gpu distributed training,
784784
`torch.cuda.set_device()` must be called before initializing this class.
785-
2. to execute `runtime cache` on GPU memory, must co-work with
785+
2. if caching data that is in GPU memory during multi-gpu distributed training, this option
786+
should not be used, since the underlying shared cache only works for CPU shared memory.
787+
3. to execute `runtime cache` on GPU memory, must co-work with
786788
`monai.data.DataLoader`, and can't work with `monai.data.DistributedSampler`
787789
as GPU Tensor usually can't be shared in the multiprocessing context.
788790
(try ``cache_dataset.disable_share_memory_cache()`` in case of GPU caching issues.)
789791
792+
790793
"""
791794
if not isinstance(transform, Compose):
792795
transform = Compose(transform)
@@ -827,7 +830,7 @@ def _compute_cache(indices=None):
827830
cache = Manager().list([None for _ in range(self.cache_num)])
828831
if self._is_dist:
829832
obj_list = [cache]
830-
# broadcast the ProxyList to all the ranks, then share the same cache content at runtime
833+
# broadcast the ListProxy to all the ranks, then share the same cache content at runtime
831834
dist.broadcast_object_list(obj_list, src=0)
832835
cache = obj_list[0]
833836
else:
@@ -848,11 +851,18 @@ def _compute_cache(indices=None):
848851

849852
def disable_share_memory_cache(self):
850853
"""
851-
If the cache content is multiprocessing share memory list, convert it to a regular ptython list.
852-
Because multiprocessing ProxyList is not supported for the GPU caching, may need to explicitly diasble it.
854+
If the cache content is a multiprocessing shared memory ListProxy, convert it to a regular python list.
855+
Because multiprocessing ListProxy is not supported for the GPU caching, explicitly disable it.
853856
854857
"""
855-
self._cache = list(self._cache)
858+
if self.runtime_cache:
859+
if not self._is_dist:
860+
self._cache = list(self._cache)
861+
else:
862+
warnings.warn(
863+
"Unable to disable shared cache in DDP, when runtime_cache==True."
864+
"Please use runtime_cache=False option to explicitly not use the shared cache."
865+
)
856866

857867
def _fill_cache(self, indices=None) -> List:
858868
"""

0 commit comments

Comments
 (0)