@@ -479,20 +479,28 @@ end
479479```
480480"""
481481function addprocs (manager:: ClusterManager ; kwargs... )
482+ params = merge (default_addprocs_params (manager), Dict {Symbol, Any} (kwargs))
483+
482484 init_multi ()
483485
484486 cluster_mgmt_from_master_check ()
485487
486- lock (worker_lock)
487- try
488- addprocs_locked (manager:: ClusterManager ; kwargs... )
489- finally
490- unlock (worker_lock)
491- end
488+ # Call worker-starting callbacks
489+ warning_interval = params[:callback_warning_interval ]
490+ _run_callbacks_concurrently (" worker-starting" , worker_starting_callbacks,
491+ warning_interval, [(manager, params)])
492+
493+ # Add new workers
494+ new_workers = @lock worker_lock addprocs_locked (manager:: ClusterManager , params)
495+
496+ # Call worker-started callbacks
497+ _run_callbacks_concurrently (" worker-started" , worker_started_callbacks,
498+ warning_interval, new_workers)
499+
500+ return new_workers
492501end
493502
494- function addprocs_locked (manager:: ClusterManager ; kwargs... )
495- params = merge (default_addprocs_params (manager), Dict {Symbol,Any} (kwargs))
503+ function addprocs_locked (manager:: ClusterManager , params)
496504 topology (Symbol (params[:topology ]))
497505
498506 if PGRP. topology != = :all_to_all
@@ -579,7 +587,8 @@ default_addprocs_params() = Dict{Symbol,Any}(
579587 :exeflags => ` ` ,
580588 :env => [],
581589 :enable_threaded_blas => false ,
582- :lazy => true )
590+ :lazy => true ,
591+ :callback_warning_interval => 10 )
583592
584593
585594function setup_launched_worker (manager, wconfig, launched_q)
@@ -888,13 +897,174 @@ const HDR_COOKIE_LEN = 16
888897const map_pid_wrkr = Lockable (Dict {Int, Union{Worker, LocalProcess}} ())
889898const map_sock_wrkr = Lockable (IdDict ())
890899const map_del_wrkr = Lockable (Set {Int} ())
900+ const worker_starting_callbacks = Dict {Any, Base.Callable} ()
901+ const worker_started_callbacks = Dict {Any, Base.Callable} ()
902+ const worker_exiting_callbacks = Dict {Any, Base.Callable} ()
903+ const worker_exited_callbacks = Dict {Any, Base.Callable} ()
891904
892905# whether process is a master or worker in a distributed setup
893906myrole () = LPROCROLE[]
894907function myrole! (proctype:: Symbol )
895908 LPROCROLE[] = proctype
896909end
897910
911+ # Callbacks
912+
913+ function _run_callbacks_concurrently (callbacks_name, callbacks_dict, warning_interval, arglist; catch_exceptions= false )
914+ callback_tasks = Tuple{Any, Task}[]
915+ for args in arglist
916+ for (name, callback) in callbacks_dict
917+ push! (callback_tasks, (name, Threads. @spawn callback (args... )))
918+ end
919+ end
920+
921+ running_callbacks = () -> [" '$(key) '" for (key, task) in callback_tasks if ! istaskdone (task)]
922+ while timedwait (() -> isempty (running_callbacks ()), warning_interval) === :timed_out
923+ callbacks_str = join (running_callbacks (), " , " )
924+ @warn " Waiting for these $(callbacks_name) callbacks to finish: $(callbacks_str) "
925+ end
926+
927+ if catch_exceptions
928+ for (key, task) in callback_tasks
929+ try
930+ wait (task)
931+ catch ex
932+ @error " Error when running $(callbacks_name) callback '$(key) '" exception= (ex, catch_backtrace ())
933+ end
934+ end
935+ else
936+ # Wait on the tasks so that exceptions bubble up
937+ foreach (wait, [x[2 ] for x in callback_tasks])
938+ end
939+ end
940+
941+ function _add_callback (f, key, dict; arg_types= Tuple{Int})
942+ if isnothing (key)
943+ key = Symbol (gensym (), nameof (f))
944+ end
945+
946+ desired_signature = " f(" * join ([" ::$(t) " for t in arg_types. types], " , " ) * " )"
947+
948+ if ! hasmethod (f, arg_types)
949+ throw (ArgumentError (" Callback function is invalid, it must be able to be called with these argument types: $(desired_signature) " ))
950+ elseif haskey (dict, key)
951+ throw (ArgumentError (" A callback function with key '$(key) ' already exists" ))
952+ end
953+
954+ dict[key] = f
955+ return key
956+ end
957+
958+ _remove_callback (key, dict) = delete! (dict, key)
959+
960+ """
961+ add_worker_starting_callback(f::Base.Callable; key=nothing) -> key
962+
963+ Register a callback to be called on the master worker immediately before new
964+ workers are started. Chooses and returns a unique key for the callback if `key`
965+ is not specified. The callback `f` will be called with the `ClusterManager`
966+ instance that is being used and a dictionary of parameters related to adding
967+ workers, i.e. `f(manager, params)`. The `params` dictionary is specific to the
968+ `manager` type. Note that the `LocalManager` and `SSHManager` cluster managers
969+ in DistributedNext are not fully documented yet, see the
970+ [managers.jl](https://github.com/JuliaParallel/DistributedNext.jl/blob/master/src/managers.jl)
971+ file for their definitions.
972+
973+ !!! warning
974+ Adding workers can fail so it is not guaranteed that the workers requested
975+ in `manager` will exist in the future. e.g. if a worker is requested on a
976+ node that is unreachable then the worker-starting callbacks will be called
977+ but the worker will never be added.
978+
979+ The worker-starting callbacks will be executed concurrently. If one throws an
980+ exception it will not be caught and will be rethrown by [`addprocs`](@ref).
981+
982+ Keep in mind that the callbacks will add to the time taken to launch workers; so
983+ try to either keep the callbacks fast to execute, or do the actual work
984+ asynchronously by spawning a task in the callback (beware of race conditions if
985+ you do this).
986+ """
987+ add_worker_starting_callback (f:: Base.Callable ; key= nothing ) = _add_callback (f, key, worker_starting_callbacks;
988+ arg_types= Tuple{ClusterManager, Dict})
989+ """
990+ remove_worker_starting_callback(key)
991+
992+ Remove the callback for `key` that was added with [`add_worker_starting_callback()`](@ref).
993+ """
994+ remove_worker_starting_callback (key) = _remove_callback (key, worker_starting_callbacks)
995+
996+ """
997+ add_worker_started_callback(f::Base.Callable; key=nothing) -> key
998+
999+ Register a callback to be called on the master worker whenever a worker has
1000+ been added. The callback will be called with the added worker ID,
1001+ e.g. `f(w::Int)`. Chooses and returns a unique key for the callback if `key` is
1002+ not specified.
1003+
1004+ The worker-started callbacks will be executed concurrently. If one throws an
1005+ exception it will not be caught and will be rethrown by [`addprocs()`](@ref).
1006+
1007+ Keep in mind that the callbacks will add to the time taken to launch workers; so
1008+ try to either keep the callbacks fast to execute, or do the actual
1009+ initialization asynchronously by spawning a task in the callback (beware of race
1010+ conditions if you do this).
1011+ """
1012+ add_worker_started_callback (f:: Base.Callable ; key= nothing ) = _add_callback (f, key, worker_started_callbacks)
1013+
1014+ """
1015+ remove_worker_started_callback(key)
1016+
1017+ Remove the callback for `key` that was added with [`add_worker_started_callback()`](@ref).
1018+ """
1019+ remove_worker_started_callback (key) = _remove_callback (key, worker_started_callbacks)
1020+
1021+ """
1022+ add_worker_exiting_callback(f::Base.Callable; key=nothing) -> key
1023+
1024+ Register a callback to be called on the master worker immediately before a
1025+ worker is removed with [`rmprocs()`](@ref). The callback will be called with the
1026+ worker ID, e.g. `f(w::Int)`. Chooses and returns a unique key for the callback
1027+ if `key` is not specified.
1028+
1029+ All worker-exiting callbacks will be executed concurrently and if they don't
1030+ all finish before the `callback_timeout` passed to `rmprocs()` then the worker
1031+ will be removed anyway.
1032+ """
1033+ add_worker_exiting_callback (f:: Base.Callable ; key= nothing ) = _add_callback (f, key, worker_exiting_callbacks)
1034+
1035+ """
1036+ remove_worker_exiting_callback(key)
1037+
1038+ Remove the callback for `key` that was added with [`add_worker_exiting_callback()`](@ref).
1039+ """
1040+ remove_worker_exiting_callback (key) = _remove_callback (key, worker_exiting_callbacks)
1041+
1042+ """
1043+ add_worker_exited_callback(f::Base.Callable; key=nothing) -> key
1044+
1045+ Register a callback to be called on the master worker when a worker has exited
1046+ for any reason (i.e. not only because of [`rmprocs()`](@ref) but also the worker
1047+ segfaulting etc). Chooses and returns a unique key for the callback if `key` is
1048+ not specified.
1049+
1050+ The callback will be called with the worker ID and the final
1051+ `Distributed.WorkerState` of the worker, e.g. `f(w::Int, state)`. `state` is an
1052+ enum, a value of `WorkerState_terminated` means a graceful exit and a value of
1053+ `WorkerState_exterminated` means the worker died unexpectedly.
1054+
1055+ All worker-exited callbacks will be executed concurrently. If a callback throws
1056+ an exception it will be caught and printed.
1057+ """
1058+ add_worker_exited_callback (f:: Base.Callable ; key= nothing ) = _add_callback (f, key, worker_exited_callbacks;
1059+ arg_types= Tuple{Int, WorkerState})
1060+
1061+ """
1062+ remove_worker_exited_callback(key)
1063+
1064+ Remove the callback for `key` that was added with [`add_worker_exited_callback()`](@ref).
1065+ """
1066+ remove_worker_exited_callback (key) = _remove_callback (key, worker_exited_callbacks)
1067+
8981068# cluster management related API
8991069"""
9001070 myid()
@@ -1081,7 +1251,7 @@ function cluster_mgmt_from_master_check()
10811251end
10821252
10831253"""
1084- rmprocs(pids...; waitfor=typemax(Int))
1254+ rmprocs(pids...; waitfor=typemax(Int), callback_timeout=10 )
10851255
10861256Remove the specified workers. Note that only process 1 can add or remove
10871257workers.
@@ -1095,6 +1265,10 @@ Argument `waitfor` specifies how long to wait for the workers to shut down:
10951265 returned. The user should call [`wait`](@ref) on the task before invoking any other
10961266 parallel calls.
10971267
1268+ The `callback_timeout` specifies how long to wait for any callbacks to execute
1269+ before continuing to remove the workers (see
1270+ [`add_worker_exiting_callback()`](@ref)).
1271+
10981272# Examples
10991273```julia-repl
11001274\$ julia -p 5
@@ -1111,24 +1285,38 @@ julia> workers()
11111285 6
11121286```
11131287"""
1114- function rmprocs (pids... ; waitfor= typemax (Int))
1288+ function rmprocs (pids... ; waitfor= typemax (Int), callback_timeout = 10 )
11151289 cluster_mgmt_from_master_check ()
11161290
11171291 pids = vcat (pids... )
11181292 if waitfor == 0
1119- t = @async _rmprocs (pids, typemax (Int))
1293+ t = @async _rmprocs (pids, typemax (Int), callback_timeout )
11201294 yield ()
11211295 return t
11221296 else
1123- _rmprocs (pids, waitfor)
1297+ _rmprocs (pids, waitfor, callback_timeout )
11241298 # return a dummy task object that user code can wait on.
11251299 return @async nothing
11261300 end
11271301end
11281302
1129- function _rmprocs (pids, waitfor)
1303+ function _rmprocs (pids, waitfor, callback_timeout )
11301304 lock (worker_lock)
11311305 try
1306+ # Run the callbacks
1307+ callback_tasks = Tuple{Any, Task}[]
1308+ for pid in pids
1309+ for (name, callback) in worker_exiting_callbacks
1310+ push! (callback_tasks, (name, Threads. @spawn callback (pid)))
1311+ end
1312+ end
1313+
1314+ if timedwait (() -> all (istaskdone, [x[2 ] for x in callback_tasks]), callback_timeout) === :timed_out
1315+ timedout_callbacks = [" '$(key) '" for (key, task) in callback_tasks if ! istaskdone (task)]
1316+ callbacks_str = join (timedout_callbacks, " , " )
1317+ @warn " Some worker-exiting callbacks have not yet finished, continuing to remove workers anyway. These are the callbacks still running: $(callbacks_str) "
1318+ end
1319+
11321320 rmprocset = Union{LocalProcess, Worker}[]
11331321 for p in pids
11341322 if p == 1
@@ -1280,6 +1468,18 @@ function deregister_worker(pg, pid)
12801468 delete! (pg. refs, id)
12811469 end
12821470 end
1471+
1472+ # Call callbacks on the master
1473+ if myid () == 1
1474+ for (name, callback) in worker_exited_callbacks
1475+ try
1476+ callback (pid, w. state)
1477+ catch ex
1478+ @error " Error when running worker-exited callback '$(name) '" exception= (ex, catch_backtrace ())
1479+ end
1480+ end
1481+ end
1482+
12831483 return
12841484end
12851485
0 commit comments