Add SLURM cluster support to distributed_segmentation#1443
Open
karimi-ali wants to merge 2 commits intoMouseLand:mainfrom
Open
Add SLURM cluster support to distributed_segmentation#1443karimi-ali wants to merge 2 commits intoMouseLand:mainfrom
karimi-ali wants to merge 2 commits intoMouseLand:mainfrom
Conversation
Adds a new `slurmCluster` class (mirrors the existing `janeliaLSFCluster`),
a `cluster_type` dispatch in the `cluster` decorator, and a `resume_dir`
mechanism so a SLURM walltime kill can be recovered without re-segmenting
already-completed blocks.
Headline additions in `cellpose/contrib/distributed_segmentation.py`:
- `slurmCluster` (new): thin wrapper on `dask_jobqueue.SLURMCluster`,
defaults tuned for GPU jobs (one A100 per worker by default), tested
on the MPCDF Raven cluster. Supports `gpus_per_job=N (>1)` for
multi-GPU jobs that bypass the per-user concurrent-job cap by packing
N dask workers per SLURM job (one bound to each GPU via dask-cuda).
- `_parse_memory_mb` helper that splits a single user-facing memory
string (e.g. `"125GB"` or the MB-int `"125000"`) into:
- `memory=f"{n}MB"` for dask (string with unit, parsed correctly),
- `job_mem=str(n)` for SLURM `--mem` (plain MB integer; avoids the
GiB-vs-GB ambiguity that bites Raven shared-GPU memory caps).
Both representations are needed because dask reads a bare number
as bytes and SLURM treats `G` as GiB on some sites.
- `cluster` decorator: now picks among `myLocalCluster`,
`janeliaLSFCluster`, and `slurmCluster` based on a new `cluster_type`
key in `cluster_kwargs`. Backward-compatible: when `cluster_type` is
absent, the original LSF detection (when `ncpus`+`min_workers`+
`max_workers` are all present) still applies.
- `distributed_eval`:
- SLURM-aware stitching release: branches between LSF and SLURM for
the post-segmentation worker shrink (drops GPU constraints/gres so
the cheap stitching pass runs on CPU-only workers).
- `resume_dir` parameter: when set, the tempdir becomes
caller-provided and persistent (no random suffix, no auto-delete).
On a second invocation pointed at the same path, blocks whose
chunk file already exists in the unstitched temp zarr are skipped
and their `(faces, boxes, remap)` is re-derived from the saved
segmentation. Lets a walltime-killed run resume cleanly.
- `overlap = int(eval_kwargs['diameter'] * 2)`: latent bug fix.
Float diameter produced float zarr slice indices, which recent
zarr versions reject with `TypeError`.
- `compute_trimmed_crop`, `recompute_block_results`, `block_chunk_path`:
helpers backing the resume code. `recompute_block_results`
searchsort-relabels to local IDs before calling
`scipy.ndimage.find_objects` to avoid a multi-GB `None` array on
the globally-remapped label space (driver OOMs without this).
- `block_face_adjacency_graph`: `scipy.ndimage.generate_binary_structure(face.ndim, 1)`
instead of the hardcoded `(3, 1)`, for 2D-segmentation safety.
Docs (`docs/distributed.rst`): adds a SLURM example block (Raven
1×A100 per worker) and updates the "supported cluster" intro.
Tested end-to-end on Raven against the GLC-07391_2 X-ray nuclei
volume (4518x5008x4560 uint16) with the cellpose 3.x `CP_20250324_Nuc6`
custom model: see github.com/MouseLand/issues/1111 for context.
change_worker_attributes ======================== The previous implementation patched ``self.new_spec['options'][k] = v`` and called adapt(). On the GLC-07391_2 production run we observed via scontrol that newly spawned stitching jobs still carried the original GPU directives (cpu=18, mem=125000M, gres=gpu:a100:1) — the kwargs never made it onto the queued jobs. Two changes to make this reliable: * ``self.scale(0)`` is followed by ``self.sync(self._correct_state)`` so the cluster blocks until the existing GPU workers have actually left. Without this, adapt() can find a worker still in the spec and skip the respawn, leaving the run stuck against the original SLURM directives. * The kwargs are written into ``self._job_kwargs`` (the canonical store dask-jobqueue uses to render the job script) rather than just ``self.new_spec['options']``. We assert the two are still the same dict; if a future dask-jobqueue version breaks that invariant, the failure is loud rather than silent. The function now also prints the freshly-rendered SBATCH header for the next worker so the directives are visible in the driver log. merge_all_boxes =============== Was O(N) per unique id (per-id ``argwhere(==iii)``). For volumes with ~10^7 unique labels after stitching the quadratic blow-up wedged the final box-merge for hours. Replaced with a single ``argsort`` plus ``np.minimum/maximum.reduceat`` over (N, ndim) start/stop arrays — O(N log N * ndim). Verified bit-for-bit against the legacy implementation on synthetic inputs of (N=5e3, M=8e2) and (N=5e4, M=5e3); 0 mismatches in both regimes. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Add SLURM cluster support to
distributed_segmentation(4.x)Closes/refs #1111.
Summary
cellpose/contrib/distributed_segmentation.pyships single-machine(
myLocalCluster) and Janelia-LSF (janeliaLSFCluster) clusterbackends, with the docs noting that SLURM "is an easy addition" —
this PR is that addition, plus the tooling around it that bigger
runs need (resume after walltime kill, multi-GPU per SLURM job,
robust memory-format handling, faster final merge for volumes with
many millions of labels).
Tested end-to-end on the MPCDF Raven cluster (NVIDIA A100 GPU
nodes) against a 4518 × 5008 × 4560 uint16 X-ray nuclei volume with
both a custom cellpose 3.x model (
CP_20250324_Nuc6, via the v3backport branch) and the built-in
cyto3/cpsammodels in 4.x.What changes
All edits in two files:
cellpose/contrib/distributed_segmentation.pyand
docs/distributed.rst. Backward-compatible — existing local andLSF code paths are untouched.
slurmCluster(new class)Thin wrapper on
dask_jobqueue.SLURMCluster, mirroring the shape ofjaneliaLSFCluster. Defaults are tuned for GPU jobs (one A100 perworker by default). Cluster-specific extras pass through as
job_extra_directives, andpartition/accountare first-classkwargs.
Multi-GPU per SLURM job (
gpus_per_job=N)Some clusters cap concurrent jobs per user (Raven default is 8). A
gpus_per_job=Nparameter onslurmClusterlets a single SLURMallocation hold N GPUs and run N dask worker processes per job (one
per GPU). When
gpus_per_job > 1the wrapper switches the workercommand to
dask-cuda-workerfor per-processCUDA_VISIBLE_DEVICESisolation. Optional dependency on
dask-cuda(only needed when N > 1).cluster_typedispatch in theclusterdecoratorThe decorator now picks a constructor by
cluster_kwargs['cluster_type'](
"local","lsf","slurm"). When the key is missing it falls backto the original LSF detection (LSF if
ncpus+min_workers+max_workersare all present, otherwise local), so existing user code keeps working.
distributed_evalresume mechanism (resume_dir)A new
resume_dirargument turns the temp dir into a caller-providedpersistent path. On a second call with the same
resume_dir, blockswhose chunk file already exists in the unstitched temp zarr are
skipped, and their
(faces, boxes, remap)is recomputed from thesaved segmentation (helpers
compute_trimmed_cropandrecompute_block_results). Stitching then proceeds on the union ofrecomputed and freshly-computed results.
This makes a 24-hour SLURM walltime kill recoverable: re-submit with
the same
resume_dir, only un-done blocks actually run.Memory-format split for dask vs SLURM
_parse_memory_mbaccepts memory as int (MB) or string ("125GB"/"125000"/"8 GiB") and emits two distinct representations:memory=f"{n}MB"to dask (string with unit, parsed correctly),job_mem=str(n)to SLURM--mem(plain MB integer).Avoids two failure modes I hit in testing:
per-worker memory budget that makes the nanny kill the worker
immediately on every restart.
GasGiBon some sites, silently exceedingper-share memory caps.
SLURM-aware "release GPUs for stitching"
The post-segmentation step that drops GPUs and shrinks workers for
the cheap stitching pass branches on cluster type. LSF keeps the
existing
change_worker_attributes(...)call with LSF-flavoredkwargs; SLURM gets a parallel call with SLURM-flavored kwargs and
empty
job_extra_directives(drops the GPU constraint and gres sothe stitcher runs on cheap CPU jobs).
Reliable
change_worker_attributeson SLURM (follow-up commit)The original LSF-shape implementation patched
self.new_spec['options'][k] = vand calledadapt(). On SLURM thiswas unreliable: an in-flight production run finished segmentation,
then
change_worker_attributes(cores=1, memory='15GB', job_extra_directives=[], …)was called to drop the GPU constraintfor stitching —
scontrol show jobshowed the new SLURM jobs stillhad the original
cpu=18, mem=125000M, gres/gpu:a100=1directives.The fix:
self.scale(0)thenself.sync(self._correct_state)); otherwiseadapt()can find aworker still in the spec and skip the respawn.
self._job_kwargs(the canonical store dask-jobqueue usesto render the job script) directly, with an assertion that
self.new_spec['options']is the same dict — so a futuredask-jobqueue API change fails loudly rather than silently.
directives are visible in the driver log; future runs can verify
propagation at a glance.
Vectorized
merge_all_boxes(follow-up commit)The previous loop was
for iii in np.unique(box_ids): np.argwhere(box_ids == iii)— O(N) per group. On the productionvolume the stitching tail wedged for 2 + hours when the global
relabeled label count hit ~10^7. Replaced with one
argsortplusnp.minimum/maximum.reduceatover (N, ndim) start/stop arrays —O(N log N · ndim). Verified bit-for-bit against the legacy
implementation on synthetic inputs of (N=5e3, M=8e2) and (N=5e4,
M=5e3); 0 mismatches.
End-to-end run on the 8-block, 512³ test subvolume completed
successfully: 51 356 final merged cells, no Traceback. The
follow-up validation captured the freshly-rendered SBATCH header
in the driver log (
--cpus-per-task=1 --mem=15GB, no--gres=gpu/--constraint=gpu) — confirming thechange_worker_attributesfix above also propagated correctlyon this run.
Worth flagging for reviewers: on this subvolume the
dask.array.map_blocks(np.load(new_labeling)[block], …)/to_zarrstep (which runs beforemerge_all_boxesand isunchanged by this PR) was very slow — workers spent extended
time in a memory-pressure pause/resume loop before producing
output. That is independent of the
merge_all_boxeschange butsomething to investigate as a separate follow-up.
Bug fixes in the segmentation path
overlap = int(eval_kwargs['diameter'] * 2)— latent: floatdiameterproduced float zarr slice indices, which recent zarrversions reject with
TypeError: slice indices must be integers....block_face_adjacency_graph:scipy.ndimage.generate_binary_structure(face.ndim, 1)instead of the hardcoded
(3, 1). Robustness fix for 2Dsegmentations.
Docs
docs/distributed.rst: adds a SLURM Raven example block; updatesthe supported-clusters intro to remove the "SLURM is an easy addition,
please file an issue" line that referenced #1111.
Cluster-portability notes
The patch was developed against MPCDF Raven, but Raven-specific
values appear only in docstring examples — not as hardcoded defaults
users can't override:
/ptmp/<user>scratchslurmCluster.__init__,local_directory/tmp/<user>/automatically when/ptmpdoesn't exist; passlocal_directory=to override explicitly--gres=gpu:a100:N,--constraint=gpujob_extra_directivesdask_cuda_worker_entryshimgpus_per_job > 1dask_cuda.cli.worker(3-line shim)ncpusandmemoryOut of scope
dask-cudainstallation: optional (only needed forgpus_per_job > 1).