Skip to content

Commit be12fd4

Browse files
committed
Add support for worker state callbacks
1 parent bf86b16 commit be12fd4

4 files changed

Lines changed: 301 additions & 14 deletions

File tree

docs/src/_changelog.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ This documents notable changes in DistributedNext.jl. The format is based on
99

1010
## Unreleased
1111

12+
### Added
13+
- Implemented callback support for workers being added/removed etc ([#17]).
14+
1215
### Fixed
1316
- Modified the default implementations of methods like `take!` and `wait` on
1417
[`AbstractWorkerPool`](@ref) to be threadsafe and behave more consistently

docs/src/index.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,19 @@ DistributedNext.cluster_cookie()
5353
DistributedNext.cluster_cookie(::Any)
5454
```
5555

56+
## Callbacks
57+
58+
```@docs
59+
DistributedNext.add_worker_starting_callback
60+
DistributedNext.remove_worker_starting_callback
61+
DistributedNext.add_worker_started_callback
62+
DistributedNext.remove_worker_started_callback
63+
DistributedNext.add_worker_exiting_callback
64+
DistributedNext.remove_worker_exiting_callback
65+
DistributedNext.add_worker_exited_callback
66+
DistributedNext.remove_worker_exited_callback
67+
```
68+
5669
## Cluster Manager Interface
5770

5871
This interface provides a mechanism to launch and manage Julia workers on different cluster environments.

src/cluster.jl

Lines changed: 214 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -479,20 +479,28 @@ end
479479
```
480480
"""
481481
function addprocs(manager::ClusterManager; kwargs...)
482+
params = merge(default_addprocs_params(manager), Dict{Symbol, Any}(kwargs))
483+
482484
init_multi()
483485

484486
cluster_mgmt_from_master_check()
485487

486-
lock(worker_lock)
487-
try
488-
addprocs_locked(manager::ClusterManager; kwargs...)
489-
finally
490-
unlock(worker_lock)
491-
end
488+
# Call worker-starting callbacks
489+
warning_interval = params[:callback_warning_interval]
490+
_run_callbacks_concurrently("worker-starting", worker_starting_callbacks,
491+
warning_interval, [(manager, params)])
492+
493+
# Add new workers
494+
new_workers = @lock worker_lock addprocs_locked(manager::ClusterManager, params)
495+
496+
# Call worker-started callbacks
497+
_run_callbacks_concurrently("worker-started", worker_started_callbacks,
498+
warning_interval, new_workers)
499+
500+
return new_workers
492501
end
493502

494-
function addprocs_locked(manager::ClusterManager; kwargs...)
495-
params = merge(default_addprocs_params(manager), Dict{Symbol,Any}(kwargs))
503+
function addprocs_locked(manager::ClusterManager, params)
496504
topology(Symbol(params[:topology]))
497505

498506
if PGRP.topology !== :all_to_all
@@ -579,7 +587,8 @@ default_addprocs_params() = Dict{Symbol,Any}(
579587
:exeflags => ``,
580588
:env => [],
581589
:enable_threaded_blas => false,
582-
:lazy => true)
590+
:lazy => true,
591+
:callback_warning_interval => 10)
583592

584593

585594
function setup_launched_worker(manager, wconfig, launched_q)
@@ -888,13 +897,174 @@ const HDR_COOKIE_LEN = 16
888897
const map_pid_wrkr = Lockable(Dict{Int, Union{Worker, LocalProcess}}())
889898
const map_sock_wrkr = Lockable(IdDict())
890899
const map_del_wrkr = Lockable(Set{Int}())
900+
const worker_starting_callbacks = Dict{Any, Base.Callable}()
901+
const worker_started_callbacks = Dict{Any, Base.Callable}()
902+
const worker_exiting_callbacks = Dict{Any, Base.Callable}()
903+
const worker_exited_callbacks = Dict{Any, Base.Callable}()
891904

892905
# whether process is a master or worker in a distributed setup
893906
myrole() = LPROCROLE[]
894907
function myrole!(proctype::Symbol)
895908
LPROCROLE[] = proctype
896909
end
897910

911+
# Callbacks
912+
913+
function _run_callbacks_concurrently(callbacks_name, callbacks_dict, warning_interval, arglist; catch_exceptions=false)
914+
callback_tasks = Tuple{Any, Task}[]
915+
for args in arglist
916+
for (name, callback) in callbacks_dict
917+
push!(callback_tasks, (name, Threads.@spawn callback(args...)))
918+
end
919+
end
920+
921+
running_callbacks = () -> ["'$(key)'" for (key, task) in callback_tasks if !istaskdone(task)]
922+
while timedwait(() -> isempty(running_callbacks()), warning_interval) === :timed_out
923+
callbacks_str = join(running_callbacks(), ", ")
924+
@warn "Waiting for these $(callbacks_name) callbacks to finish: $(callbacks_str)"
925+
end
926+
927+
if catch_exceptions
928+
for (key, task) in callback_tasks
929+
try
930+
wait(task)
931+
catch ex
932+
@error "Error when running $(callbacks_name) callback '$(key)'" exception=(ex, catch_backtrace())
933+
end
934+
end
935+
else
936+
# Wait on the tasks so that exceptions bubble up
937+
foreach(wait, [x[2] for x in callback_tasks])
938+
end
939+
end
940+
941+
function _add_callback(f, key, dict; arg_types=Tuple{Int})
942+
if isnothing(key)
943+
key = Symbol(gensym(), nameof(f))
944+
end
945+
946+
desired_signature = "f(" * join(["::$(t)" for t in arg_types.types], ", ") * ")"
947+
948+
if !hasmethod(f, arg_types)
949+
throw(ArgumentError("Callback function is invalid, it must be able to be called with these argument types: $(desired_signature)"))
950+
elseif haskey(dict, key)
951+
throw(ArgumentError("A callback function with key '$(key)' already exists"))
952+
end
953+
954+
dict[key] = f
955+
return key
956+
end
957+
958+
_remove_callback(key, dict) = delete!(dict, key)
959+
960+
"""
961+
add_worker_starting_callback(f::Base.Callable; key=nothing) -> key
962+
963+
Register a callback to be called on the master worker immediately before new
964+
workers are started. Chooses and returns a unique key for the callback if `key`
965+
is not specified. The callback `f` will be called with the `ClusterManager`
966+
instance that is being used and a dictionary of parameters related to adding
967+
workers, i.e. `f(manager, params)`. The `params` dictionary is specific to the
968+
`manager` type. Note that the `LocalManager` and `SSHManager` cluster managers
969+
in DistributedNext are not fully documented yet, see the
970+
[managers.jl](https://github.com/JuliaParallel/DistributedNext.jl/blob/master/src/managers.jl)
971+
file for their definitions.
972+
973+
!!! warning
974+
Adding workers can fail so it is not guaranteed that the workers requested
975+
in `manager` will exist in the future. e.g. if a worker is requested on a
976+
node that is unreachable then the worker-starting callbacks will be called
977+
but the worker will never be added.
978+
979+
The worker-starting callbacks will be executed concurrently. If one throws an
980+
exception it will not be caught and will be rethrown by [`addprocs`](@ref).
981+
982+
Keep in mind that the callbacks will add to the time taken to launch workers; so
983+
try to either keep the callbacks fast to execute, or do the actual work
984+
asynchronously by spawning a task in the callback (beware of race conditions if
985+
you do this).
986+
"""
987+
add_worker_starting_callback(f::Base.Callable; key=nothing) = _add_callback(f, key, worker_starting_callbacks;
988+
arg_types=Tuple{ClusterManager, Dict})
989+
"""
990+
remove_worker_starting_callback(key)
991+
992+
Remove the callback for `key` that was added with [`add_worker_starting_callback()`](@ref).
993+
"""
994+
remove_worker_starting_callback(key) = _remove_callback(key, worker_starting_callbacks)
995+
996+
"""
997+
add_worker_started_callback(f::Base.Callable; key=nothing) -> key
998+
999+
Register a callback to be called on the master worker whenever a worker has
1000+
been added. The callback will be called with the added worker ID,
1001+
e.g. `f(w::Int)`. Chooses and returns a unique key for the callback if `key` is
1002+
not specified.
1003+
1004+
The worker-started callbacks will be executed concurrently. If one throws an
1005+
exception it will not be caught and will be rethrown by [`addprocs()`](@ref).
1006+
1007+
Keep in mind that the callbacks will add to the time taken to launch workers; so
1008+
try to either keep the callbacks fast to execute, or do the actual
1009+
initialization asynchronously by spawning a task in the callback (beware of race
1010+
conditions if you do this).
1011+
"""
1012+
add_worker_started_callback(f::Base.Callable; key=nothing) = _add_callback(f, key, worker_started_callbacks)
1013+
1014+
"""
1015+
remove_worker_started_callback(key)
1016+
1017+
Remove the callback for `key` that was added with [`add_worker_started_callback()`](@ref).
1018+
"""
1019+
remove_worker_started_callback(key) = _remove_callback(key, worker_started_callbacks)
1020+
1021+
"""
1022+
add_worker_exiting_callback(f::Base.Callable; key=nothing) -> key
1023+
1024+
Register a callback to be called on the master worker immediately before a
1025+
worker is removed with [`rmprocs()`](@ref). The callback will be called with the
1026+
worker ID, e.g. `f(w::Int)`. Chooses and returns a unique key for the callback
1027+
if `key` is not specified.
1028+
1029+
All worker-exiting callbacks will be executed concurrently and if they don't
1030+
all finish before the `callback_timeout` passed to `rmprocs()` then the worker
1031+
will be removed anyway.
1032+
"""
1033+
add_worker_exiting_callback(f::Base.Callable; key=nothing) = _add_callback(f, key, worker_exiting_callbacks)
1034+
1035+
"""
1036+
remove_worker_exiting_callback(key)
1037+
1038+
Remove the callback for `key` that was added with [`add_worker_exiting_callback()`](@ref).
1039+
"""
1040+
remove_worker_exiting_callback(key) = _remove_callback(key, worker_exiting_callbacks)
1041+
1042+
"""
1043+
add_worker_exited_callback(f::Base.Callable; key=nothing) -> key
1044+
1045+
Register a callback to be called on the master worker when a worker has exited
1046+
for any reason (i.e. not only because of [`rmprocs()`](@ref) but also the worker
1047+
segfaulting etc). Chooses and returns a unique key for the callback if `key` is
1048+
not specified.
1049+
1050+
The callback will be called with the worker ID and the final
1051+
`Distributed.WorkerState` of the worker, e.g. `f(w::Int, state)`. `state` is an
1052+
enum, a value of `WorkerState_terminated` means a graceful exit and a value of
1053+
`WorkerState_exterminated` means the worker died unexpectedly.
1054+
1055+
All worker-exited callbacks will be executed concurrently. If a callback throws
1056+
an exception it will be caught and printed.
1057+
"""
1058+
add_worker_exited_callback(f::Base.Callable; key=nothing) = _add_callback(f, key, worker_exited_callbacks;
1059+
arg_types=Tuple{Int, WorkerState})
1060+
1061+
"""
1062+
remove_worker_exited_callback(key)
1063+
1064+
Remove the callback for `key` that was added with [`add_worker_exited_callback()`](@ref).
1065+
"""
1066+
remove_worker_exited_callback(key) = _remove_callback(key, worker_exited_callbacks)
1067+
8981068
# cluster management related API
8991069
"""
9001070
myid()
@@ -1081,7 +1251,7 @@ function cluster_mgmt_from_master_check()
10811251
end
10821252

10831253
"""
1084-
rmprocs(pids...; waitfor=typemax(Int))
1254+
rmprocs(pids...; waitfor=typemax(Int), callback_timeout=10)
10851255
10861256
Remove the specified workers. Note that only process 1 can add or remove
10871257
workers.
@@ -1095,6 +1265,10 @@ Argument `waitfor` specifies how long to wait for the workers to shut down:
10951265
returned. The user should call [`wait`](@ref) on the task before invoking any other
10961266
parallel calls.
10971267
1268+
The `callback_timeout` specifies how long to wait for any callbacks to execute
1269+
before continuing to remove the workers (see
1270+
[`add_worker_exiting_callback()`](@ref)).
1271+
10981272
# Examples
10991273
```julia-repl
11001274
\$ julia -p 5
@@ -1111,24 +1285,38 @@ julia> workers()
11111285
6
11121286
```
11131287
"""
1114-
function rmprocs(pids...; waitfor=typemax(Int))
1288+
function rmprocs(pids...; waitfor=typemax(Int), callback_timeout=10)
11151289
cluster_mgmt_from_master_check()
11161290

11171291
pids = vcat(pids...)
11181292
if waitfor == 0
1119-
t = @async _rmprocs(pids, typemax(Int))
1293+
t = @async _rmprocs(pids, typemax(Int), callback_timeout)
11201294
yield()
11211295
return t
11221296
else
1123-
_rmprocs(pids, waitfor)
1297+
_rmprocs(pids, waitfor, callback_timeout)
11241298
# return a dummy task object that user code can wait on.
11251299
return @async nothing
11261300
end
11271301
end
11281302

1129-
function _rmprocs(pids, waitfor)
1303+
function _rmprocs(pids, waitfor, callback_timeout)
11301304
lock(worker_lock)
11311305
try
1306+
# Run the callbacks
1307+
callback_tasks = Tuple{Any, Task}[]
1308+
for pid in pids
1309+
for (name, callback) in worker_exiting_callbacks
1310+
push!(callback_tasks, (name, Threads.@spawn callback(pid)))
1311+
end
1312+
end
1313+
1314+
if timedwait(() -> all(istaskdone, [x[2] for x in callback_tasks]), callback_timeout) === :timed_out
1315+
timedout_callbacks = ["'$(key)'" for (key, task) in callback_tasks if !istaskdone(task)]
1316+
callbacks_str = join(timedout_callbacks, ", ")
1317+
@warn "Some worker-exiting callbacks have not yet finished, continuing to remove workers anyway. These are the callbacks still running: $(callbacks_str)"
1318+
end
1319+
11321320
rmprocset = Union{LocalProcess, Worker}[]
11331321
for p in pids
11341322
if p == 1
@@ -1280,6 +1468,18 @@ function deregister_worker(pg, pid)
12801468
delete!(pg.refs, id)
12811469
end
12821470
end
1471+
1472+
# Call callbacks on the master
1473+
if myid() == 1
1474+
for (name, callback) in worker_exited_callbacks
1475+
try
1476+
callback(pid, w.state)
1477+
catch ex
1478+
@error "Error when running worker-exited callback '$(name)'" exception=(ex, catch_backtrace())
1479+
end
1480+
end
1481+
end
1482+
12831483
return
12841484
end
12851485

0 commit comments

Comments
 (0)