Skip to content

Commit e03e6e6

Browse files
committed
Slightly improve threadsafety by adding locks
This doesn't fix all potential race conditions to global variables, but it's something, and it should have a negligible performance impact.
1 parent 209612d commit e03e6e6

3 files changed

Lines changed: 48 additions & 31 deletions

File tree

src/DistributedNext.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,17 @@ function _require_callback(mod::Base.PkgId)
106106
end
107107
end
108108

109+
# This is a minimal copy of Base.Lockable we use for backwards compatibility with 1.10
110+
struct Lockable{T, L <: Base.AbstractLock}
111+
value::T
112+
lock::L
113+
end
114+
Lockable(value) = Lockable(value, ReentrantLock())
115+
Base.getindex(l::Lockable) = (Base.assert_havelock(l.lock); l.value)
116+
Base.lock(l::Lockable) = lock(l.lock)
117+
Base.trylock(l::Lockable) = trylock(l.lock)
118+
Base.unlock(l::Lockable) = unlock(l.lock)
119+
109120
const REF_ID = Threads.Atomic{Int}(1)
110121
next_ref_id() = Threads.atomic_add!(REF_ID, 1)
111122

src/cluster.jl

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,8 @@ mutable struct Worker
139139
Worker(id::Int) = Worker(id, nothing)
140140
function Worker(id::Int, conn_func)
141141
@assert id > 0
142-
if haskey(map_pid_wrkr, id)
143-
return map_pid_wrkr[id]
142+
@lock map_pid_wrkr if haskey(map_pid_wrkr[], id)
143+
return map_pid_wrkr[][id]
144144
end
145145
w=new(id, Threads.ReentrantLock(), [], [], false, W_CREATED, Threads.Condition(), time(), conn_func)
146146
w.initialized = Event()
@@ -407,7 +407,7 @@ function init_worker(cookie::AbstractString, manager::ClusterManager=DefaultClus
407407

408408
# System is started in head node mode, cleanup related entries
409409
empty!(PGRP.workers)
410-
empty!(map_pid_wrkr)
410+
@lock map_pid_wrkr empty!(map_pid_wrkr[])
411411

412412
cluster_cookie(cookie)
413413
nothing
@@ -793,7 +793,7 @@ function check_master_connect()
793793
errormonitor(
794794
Threads.@spawn begin
795795
timeout = worker_timeout()
796-
if timedwait(() -> haskey(map_pid_wrkr, 1), timeout) === :timed_out
796+
if timedwait(() -> @lock(map_pid_wrkr, haskey(map_pid_wrkr[], 1)), timeout) === :timed_out
797797
print(stderr, "Master process (id 1) could not connect within $(timeout) seconds.\nexiting.\n")
798798
exit(1)
799799
end
@@ -868,11 +868,11 @@ end
868868
# globals
869869
const LPROC = LocalProcess()
870870
const LPROCROLE = Ref{Symbol}(:master)
871-
const HDR_VERSION_LEN=16
872-
const HDR_COOKIE_LEN=16
873-
const map_pid_wrkr = Dict{Int, Union{Worker, LocalProcess}}()
874-
const map_sock_wrkr = IdDict()
875-
const map_del_wrkr = Set{Int}()
871+
const HDR_VERSION_LEN = 16
872+
const HDR_COOKIE_LEN = 16
873+
const map_pid_wrkr = Lockable(Dict{Int, Union{Worker, LocalProcess}}())
874+
const map_sock_wrkr = Lockable(IdDict())
875+
const map_del_wrkr = Lockable(Set{Int}())
876876

877877
# whether process is a master or worker in a distributed setup
878878
myrole() = LPROCROLE[]
@@ -1013,7 +1013,7 @@ See also [`other_procs()`](@ref).
10131013
function procs(pid::Integer)
10141014
if myid() == 1
10151015
all_workers = [x for x in PGRP.workers if isa(x, LocalProcess) || ((@atomic x.state) === W_CONNECTED)]
1016-
if (pid == 1) || (isa(map_pid_wrkr[pid].manager, LocalManager))
1016+
if (pid == 1) || (isa(@lock(map_pid_wrkr, map_pid_wrkr[][pid].manager), LocalManager))
10171017
Int[x.id for x in filter(w -> (w.id==1) || (isa(w.manager, LocalManager)), all_workers)]
10181018
else
10191019
ipatpid = get_bind_addr(pid)
@@ -1119,8 +1119,8 @@ function _rmprocs(pids, waitfor)
11191119
if p == 1
11201120
@warn "rmprocs: process 1 not removed"
11211121
else
1122-
if haskey(map_pid_wrkr, p)
1123-
w = map_pid_wrkr[p]
1122+
w = @lock map_pid_wrkr get(map_pid_wrkr[], p, nothing)
1123+
if !isnothing(w)
11241124
set_worker_state(w, W_TERMINATING)
11251125
kill(w.manager, p, w.config)
11261126
push!(rmprocset, w)
@@ -1160,16 +1160,17 @@ ProcessExitedException() = ProcessExitedException(-1)
11601160

11611161
worker_from_id(i) = worker_from_id(PGRP, i)
11621162
function worker_from_id(pg::ProcessGroup, i)
1163-
if !isempty(map_del_wrkr) && in(i, map_del_wrkr)
1163+
@lock map_del_wrkr if !isempty(map_del_wrkr[]) && in(i, map_del_wrkr[])
11641164
throw(ProcessExitedException(i))
11651165
end
1166-
w = get(map_pid_wrkr, i, nothing)
1166+
1167+
w = @lock map_pid_wrkr get(map_pid_wrkr[], i, nothing)
11671168
if w === nothing
11681169
if myid() == 1
11691170
error("no process with id $i exists")
11701171
end
11711172
w = Worker(i)
1172-
map_pid_wrkr[i] = w
1173+
@lock map_pid_wrkr map_pid_wrkr[][i] = w
11731174
else
11741175
w = w::Union{Worker, LocalProcess}
11751176
end
@@ -1185,7 +1186,7 @@ This is useful when writing custom [`serialize`](@ref) methods for a type,
11851186
which optimizes the data written out depending on the receiving process id.
11861187
"""
11871188
function worker_id_from_socket(s)
1188-
w = get(map_sock_wrkr, s, nothing)
1189+
w = @lock map_sock_wrkr get(map_sock_wrkr[], s, nothing)
11891190
if isa(w,Worker)
11901191
if s === w.r_stream || s === w.w_stream
11911192
return w.id
@@ -1202,23 +1203,28 @@ end
12021203
register_worker(w) = register_worker(PGRP, w)
12031204
function register_worker(pg, w)
12041205
push!(pg.workers, w)
1205-
map_pid_wrkr[w.id] = w
1206+
@lock map_pid_wrkr map_pid_wrkr[][w.id] = w
12061207
end
12071208

12081209
function register_worker_streams(w)
1209-
map_sock_wrkr[w.r_stream] = w
1210-
map_sock_wrkr[w.w_stream] = w
1210+
@lock map_sock_wrkr begin
1211+
map_sock_wrkr[][w.r_stream] = w
1212+
map_sock_wrkr[][w.w_stream] = w
1213+
end
12111214
end
12121215

12131216
deregister_worker(pid) = deregister_worker(PGRP, pid)
12141217
function deregister_worker(pg, pid)
12151218
pg.workers = filter(x -> !(x.id == pid), pg.workers)
1216-
w = pop!(map_pid_wrkr, pid, nothing)
1219+
1220+
w = @lock map_pid_wrkr pop!(map_pid_wrkr[], pid, nothing)
12171221
if isa(w, Worker)
12181222
if isdefined(w, :r_stream)
1219-
pop!(map_sock_wrkr, w.r_stream, nothing)
1220-
if w.r_stream != w.w_stream
1221-
pop!(map_sock_wrkr, w.w_stream, nothing)
1223+
@lock map_sock_wrkr begin
1224+
pop!(map_sock_wrkr[], w.r_stream, nothing)
1225+
if w.r_stream != w.w_stream
1226+
pop!(map_sock_wrkr[], w.w_stream, nothing)
1227+
end
12221228
end
12231229
end
12241230

@@ -1235,7 +1241,7 @@ function deregister_worker(pg, pid)
12351241
end
12361242
end
12371243
end
1238-
push!(map_del_wrkr, pid)
1244+
@lock map_del_wrkr push!(map_del_wrkr[], pid)
12391245

12401246
# delete this worker from our remote reference client sets
12411247
ids = []
@@ -1265,7 +1271,7 @@ end
12651271

12661272
function interrupt(pid::Integer)
12671273
@assert myid() == 1
1268-
w = map_pid_wrkr[pid]
1274+
w = @lock map_pid_wrkr map_pid_wrkr[][pid]
12691275
if isa(w, Worker)
12701276
manage(w.manager, w.id, w.config, :interrupt)
12711277
end
@@ -1305,11 +1311,11 @@ function check_same_host(pids)
13051311
# We checkfirst if all test pids have been started using the local manager,
13061312
# else we check for the same bind_to addr. This handles the special case
13071313
# where the local ip address may change - as during a system sleep/awake
1308-
if all(p -> (p==1) || (isa(map_pid_wrkr[p].manager, LocalManager)), pids)
1314+
@lock map_pid_wrkr if all(p -> (p==1) || (isa(map_pid_wrkr[][p].manager, LocalManager)), pids)
13091315
return true
13101316
else
1311-
first_bind_addr = notnothing(wp_bind_addr(map_pid_wrkr[pids[1]]))
1312-
return all(p -> notnothing(wp_bind_addr(map_pid_wrkr[p])) == first_bind_addr, pids[2:end])
1317+
first_bind_addr = notnothing(wp_bind_addr(map_pid_wrkr[][pids[1]]))
1318+
return all(p -> notnothing(wp_bind_addr(map_pid_wrkr[][p])) == first_bind_addr, pids[2:end])
13131319
end
13141320
end
13151321
end

src/process_messages.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ function message_handler_loop(r_stream::IO, w_stream::IO, incoming::Bool)
220220
if wpid < 1
221221
println(stderr, e, CapturedException(e, catch_backtrace()))
222222
println(stderr, "Process($(myid())) - Unknown remote, closing connection.")
223-
elseif !(wpid in map_del_wrkr)
223+
elseif @lock(map_del_wrkr, !(wpid in map_del_wrkr[]))
224224
werr = worker_from_id(wpid)
225225
oldstate = @atomic werr.state
226226
set_worker_state(werr, W_TERMINATED)
@@ -325,7 +325,7 @@ function handle_msg(msg::IdentifySocketMsg, header, r_stream, w_stream, version)
325325
end
326326

327327
function handle_msg(msg::IdentifySocketAckMsg, header, r_stream, w_stream, version)
328-
w = map_sock_wrkr[r_stream]
328+
w = @lock map_sock_wrkr map_sock_wrkr[][r_stream]
329329
w.version = version
330330
end
331331

@@ -378,7 +378,7 @@ function connect_to_peer(manager::ClusterManager, rpid::Int, wconfig::WorkerConf
378378
end
379379

380380
function handle_msg(msg::JoinCompleteMsg, header, r_stream, w_stream, version)
381-
w = map_sock_wrkr[r_stream]
381+
w = @lock map_sock_wrkr map_sock_wrkr[][r_stream]
382382
environ = something(w.config.environ, Dict())
383383
environ[:cpu_threads] = msg.cpu_threads
384384
w.config.environ = environ

0 commit comments

Comments
 (0)