Skip to content

Add SLURM cluster support to distributed_segmentation#1443

Open
karimi-ali wants to merge 2 commits intoMouseLand:mainfrom
karimi-ali:slurm-distributed
Open

Add SLURM cluster support to distributed_segmentation#1443
karimi-ali wants to merge 2 commits intoMouseLand:mainfrom
karimi-ali:slurm-distributed

Conversation

@karimi-ali
Copy link
Copy Markdown

Add SLURM cluster support to distributed_segmentation (4.x)

Closes/refs #1111.

Summary

cellpose/contrib/distributed_segmentation.py ships single-machine
(myLocalCluster) and Janelia-LSF (janeliaLSFCluster) cluster
backends, with the docs noting that SLURM "is an easy addition" —
this PR is that addition, plus the tooling around it that bigger
runs need (resume after walltime kill, multi-GPU per SLURM job,
robust memory-format handling, faster final merge for volumes with
many millions of labels).

Tested end-to-end on the MPCDF Raven cluster (NVIDIA A100 GPU
nodes) against a 4518 × 5008 × 4560 uint16 X-ray nuclei volume with
both a custom cellpose 3.x model (CP_20250324_Nuc6, via the v3
backport branch) and the built-in cyto3/cpsam models in 4.x.

What changes

All edits in two files: cellpose/contrib/distributed_segmentation.py
and docs/distributed.rst. Backward-compatible — existing local and
LSF code paths are untouched.

slurmCluster (new class)

Thin wrapper on dask_jobqueue.SLURMCluster, mirroring the shape of
janeliaLSFCluster. Defaults are tuned for GPU jobs (one A100 per
worker by default). Cluster-specific extras pass through as
job_extra_directives, and partition / account are first-class
kwargs.

cluster_kwargs = {
    'cluster_type': 'slurm',
    'ncpus': 18,
    'min_workers': 1,
    'max_workers': 8,
    'walltime': '24:00:00',
    'memory': '125GB',
    'job_extra_directives': ['--constraint=gpu', '--gres=gpu:a100:1'],
}
segments, boxes = distributed_eval(
    input_zarr=arr, blocksize=(256,256,256), write_path='out.zarr',
    model_kwargs={'gpu': True, 'pretrained_model': '/path/to/model'},
    eval_kwargs={'diameter': 7.5, 'do_3D': True, 'z_axis': 0},
    cluster_kwargs=cluster_kwargs,
)

Multi-GPU per SLURM job (gpus_per_job=N)

Some clusters cap concurrent jobs per user (Raven default is 8). A
gpus_per_job=N parameter on slurmCluster lets a single SLURM
allocation hold N GPUs and run N dask worker processes per job (one
per GPU). When gpus_per_job > 1 the wrapper switches the worker
command to dask-cuda-worker for per-process CUDA_VISIBLE_DEVICES
isolation. Optional dependency on dask-cuda (only needed when N > 1).

cluster_type dispatch in the cluster decorator

The decorator now picks a constructor by cluster_kwargs['cluster_type']
("local", "lsf", "slurm"). When the key is missing it falls back
to the original LSF detection (LSF if ncpus+min_workers+max_workers
are all present, otherwise local), so existing user code keeps working.

distributed_eval resume mechanism (resume_dir)

A new resume_dir argument turns the temp dir into a caller-provided
persistent path. On a second call with the same resume_dir, blocks
whose chunk file already exists in the unstitched temp zarr are
skipped, and their (faces, boxes, remap) is recomputed from the
saved segmentation
(helpers compute_trimmed_crop and
recompute_block_results). Stitching then proceeds on the union of
recomputed and freshly-computed results.

This makes a 24-hour SLURM walltime kill recoverable: re-submit with
the same resume_dir, only un-done blocks actually run.

Memory-format split for dask vs SLURM

_parse_memory_mb accepts memory as int (MB) or string ("125GB" /
"125000" / "8 GiB") and emits two distinct representations:

  • memory=f"{n}MB" to dask (string with unit, parsed correctly),
  • job_mem=str(n) to SLURM --mem (plain MB integer).

Avoids two failure modes I hit in testing:

  1. dask interprets a bare number as bytes, setting a microscopic
    per-worker memory budget that makes the nanny kill the worker
    immediately on every restart.
  2. SLURM treats G as GiB on some sites, silently exceeding
    per-share memory caps.

SLURM-aware "release GPUs for stitching"

The post-segmentation step that drops GPUs and shrinks workers for
the cheap stitching pass branches on cluster type. LSF keeps the
existing change_worker_attributes(...) call with LSF-flavored
kwargs; SLURM gets a parallel call with SLURM-flavored kwargs and
empty job_extra_directives (drops the GPU constraint and gres so
the stitcher runs on cheap CPU jobs).

Reliable change_worker_attributes on SLURM (follow-up commit)

The original LSF-shape implementation patched
self.new_spec['options'][k] = v and called adapt(). On SLURM this
was unreliable: an in-flight production run finished segmentation,
then change_worker_attributes(cores=1, memory='15GB', job_extra_directives=[], …) was called to drop the GPU constraint
for stitching — scontrol show job showed the new SLURM jobs still
had the original cpu=18, mem=125000M, gres/gpu:a100=1 directives.

The fix:

  • Block until existing workers actually leave (self.scale(0) then
    self.sync(self._correct_state)); otherwise adapt() can find a
    worker still in the spec and skip the respawn.
  • Update self._job_kwargs (the canonical store dask-jobqueue uses
    to render the job script) directly, with an assertion that
    self.new_spec['options'] is the same dict — so a future
    dask-jobqueue API change fails loudly rather than silently.
  • Print the freshly-rendered SBATCH header for the next worker so the
    directives are visible in the driver log; future runs can verify
    propagation at a glance.

Vectorized merge_all_boxes (follow-up commit)

The previous loop was for iii in np.unique(box_ids): np.argwhere(box_ids == iii) — O(N) per group. On the production
volume the stitching tail wedged for 2 + hours when the global
relabeled label count hit ~10^7. Replaced with one argsort plus
np.minimum/maximum.reduceat over (N, ndim) start/stop arrays —
O(N log N · ndim). Verified bit-for-bit against the legacy
implementation on synthetic inputs of (N=5e3, M=8e2) and (N=5e4,
M=5e3); 0 mismatches.

End-to-end run on the 8-block, 512³ test subvolume completed
successfully
: 51 356 final merged cells, no Traceback. The
follow-up validation captured the freshly-rendered SBATCH header
in the driver log (--cpus-per-task=1 --mem=15GB, no
--gres=gpu / --constraint=gpu) — confirming the
change_worker_attributes fix above also propagated correctly
on this run.

Worth flagging for reviewers: on this subvolume the
dask.array.map_blocks(np.load(new_labeling)[block], …) /
to_zarr step (which runs before merge_all_boxes and is
unchanged by this PR) was very slow — workers spent extended
time in a memory-pressure pause/resume loop before producing
output. That is independent of the merge_all_boxes change but
something to investigate as a separate follow-up.

Bug fixes in the segmentation path

  • overlap = int(eval_kwargs['diameter'] * 2) — latent: float
    diameter produced float zarr slice indices, which recent zarr
    versions reject with TypeError: slice indices must be integers....
  • block_face_adjacency_graph: scipy.ndimage.generate_binary_structure(face.ndim, 1)
    instead of the hardcoded (3, 1). Robustness fix for 2D
    segmentations.

Docs

docs/distributed.rst: adds a SLURM Raven example block; updates
the supported-clusters intro to remove the "SLURM is an easy addition,
please file an issue" line that referenced #1111.

Cluster-portability notes

The patch was developed against MPCDF Raven, but Raven-specific
values appear only in docstring examples — not as hardcoded defaults
users can't override:

Reference Where Override
/ptmp/<user> scratch slurmCluster.__init__, local_directory falls back to /tmp/<user>/ automatically when /ptmp doesn't exist; pass local_directory= to override explicitly
--gres=gpu:a100:N, --constraint=gpu example only callers pass their own job_extra_directives
dask_cuda_worker_entry shim only used when gpus_per_job > 1 importable name; users with multi-GPU mode point it at dask_cuda.cli.worker (3-line shim)
18 CPUs / 125000 MB per A100 share docstring example callers pass their own ncpus and memory

Out of scope

  • LSF code path: unchanged (existing logic is preserved verbatim).
  • dask-cuda installation: optional (only needed for gpus_per_job > 1).
  • 3.x backport: filed as a companion PR.

karimi-ali and others added 2 commits May 1, 2026 06:40
Adds a new `slurmCluster` class (mirrors the existing `janeliaLSFCluster`),
a `cluster_type` dispatch in the `cluster` decorator, and a `resume_dir`
mechanism so a SLURM walltime kill can be recovered without re-segmenting
already-completed blocks.

Headline additions in `cellpose/contrib/distributed_segmentation.py`:

- `slurmCluster` (new): thin wrapper on `dask_jobqueue.SLURMCluster`,
  defaults tuned for GPU jobs (one A100 per worker by default), tested
  on the MPCDF Raven cluster. Supports `gpus_per_job=N (>1)` for
  multi-GPU jobs that bypass the per-user concurrent-job cap by packing
  N dask workers per SLURM job (one bound to each GPU via dask-cuda).

- `_parse_memory_mb` helper that splits a single user-facing memory
  string (e.g. `"125GB"` or the MB-int `"125000"`) into:
  - `memory=f"{n}MB"` for dask (string with unit, parsed correctly),
  - `job_mem=str(n)` for SLURM `--mem` (plain MB integer; avoids the
    GiB-vs-GB ambiguity that bites Raven shared-GPU memory caps).
  Both representations are needed because dask reads a bare number
  as bytes and SLURM treats `G` as GiB on some sites.

- `cluster` decorator: now picks among `myLocalCluster`,
  `janeliaLSFCluster`, and `slurmCluster` based on a new `cluster_type`
  key in `cluster_kwargs`. Backward-compatible: when `cluster_type` is
  absent, the original LSF detection (when `ncpus`+`min_workers`+
  `max_workers` are all present) still applies.

- `distributed_eval`:
  - SLURM-aware stitching release: branches between LSF and SLURM for
    the post-segmentation worker shrink (drops GPU constraints/gres so
    the cheap stitching pass runs on CPU-only workers).
  - `resume_dir` parameter: when set, the tempdir becomes
    caller-provided and persistent (no random suffix, no auto-delete).
    On a second invocation pointed at the same path, blocks whose
    chunk file already exists in the unstitched temp zarr are skipped
    and their `(faces, boxes, remap)` is re-derived from the saved
    segmentation. Lets a walltime-killed run resume cleanly.
  - `overlap = int(eval_kwargs['diameter'] * 2)`: latent bug fix.
    Float diameter produced float zarr slice indices, which recent
    zarr versions reject with `TypeError`.

- `compute_trimmed_crop`, `recompute_block_results`, `block_chunk_path`:
  helpers backing the resume code. `recompute_block_results`
  searchsort-relabels to local IDs before calling
  `scipy.ndimage.find_objects` to avoid a multi-GB `None` array on
  the globally-remapped label space (driver OOMs without this).

- `block_face_adjacency_graph`: `scipy.ndimage.generate_binary_structure(face.ndim, 1)`
  instead of the hardcoded `(3, 1)`, for 2D-segmentation safety.

Docs (`docs/distributed.rst`): adds a SLURM example block (Raven
1×A100 per worker) and updates the "supported cluster" intro.

Tested end-to-end on Raven against the GLC-07391_2 X-ray nuclei
volume (4518x5008x4560 uint16) with the cellpose 3.x `CP_20250324_Nuc6`
custom model: see github.com/MouseLand/issues/1111 for context.
change_worker_attributes
========================
The previous implementation patched ``self.new_spec['options'][k] = v``
and called adapt(). On the GLC-07391_2 production run we observed via
scontrol that newly spawned stitching jobs still carried the original
GPU directives (cpu=18, mem=125000M, gres=gpu:a100:1) — the kwargs
never made it onto the queued jobs.

Two changes to make this reliable:

* ``self.scale(0)`` is followed by ``self.sync(self._correct_state)``
  so the cluster blocks until the existing GPU workers have actually
  left. Without this, adapt() can find a worker still in the spec and
  skip the respawn, leaving the run stuck against the original SLURM
  directives.
* The kwargs are written into ``self._job_kwargs`` (the canonical
  store dask-jobqueue uses to render the job script) rather than just
  ``self.new_spec['options']``. We assert the two are still the same
  dict; if a future dask-jobqueue version breaks that invariant, the
  failure is loud rather than silent.

The function now also prints the freshly-rendered SBATCH header for
the next worker so the directives are visible in the driver log.

merge_all_boxes
===============
Was O(N) per unique id (per-id ``argwhere(==iii)``). For volumes with
~10^7 unique labels after stitching the quadratic blow-up wedged the
final box-merge for hours. Replaced with a single ``argsort`` plus
``np.minimum/maximum.reduceat`` over (N, ndim) start/stop arrays —
O(N log N * ndim).

Verified bit-for-bit against the legacy implementation on synthetic
inputs of (N=5e3, M=8e2) and (N=5e4, M=5e3); 0 mismatches in both
regimes.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant