Skip to content

Commit 8d30b83

Browse files
authored
Merge pull request #56 from JuliaParallel/threadsafety
Slightly improve threadsafety by adding locks
2 parents 34de931 + e03e6e6 commit 8d30b83

5 files changed

Lines changed: 61 additions & 60 deletions

File tree

src/DistributedNext.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,17 @@ function _require_callback(mod::Base.PkgId)
106106
end
107107
end
108108

109+
# This is a minimal copy of Base.Lockable we use for backwards compatibility with 1.10
110+
struct Lockable{T, L <: Base.AbstractLock}
111+
value::T
112+
lock::L
113+
end
114+
Lockable(value) = Lockable(value, ReentrantLock())
115+
Base.getindex(l::Lockable) = (Base.assert_havelock(l.lock); l.value)
116+
Base.lock(l::Lockable) = lock(l.lock)
117+
Base.trylock(l::Lockable) = trylock(l.lock)
118+
Base.unlock(l::Lockable) = unlock(l.lock)
119+
109120
const REF_ID = Threads.Atomic{Int}(1)
110121
next_ref_id() = Threads.atomic_add!(REF_ID, 1)
111122

src/cluster.jl

Lines changed: 40 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,8 @@ mutable struct Worker
139139
Worker(id::Int) = Worker(id, nothing)
140140
function Worker(id::Int, conn_func)
141141
@assert id > 0
142-
if haskey(map_pid_wrkr, id)
143-
return map_pid_wrkr[id]
142+
@lock map_pid_wrkr if haskey(map_pid_wrkr[], id)
143+
return map_pid_wrkr[][id]
144144
end
145145
w=new(id, Threads.ReentrantLock(), [], [], false, W_CREATED, Threads.Condition(), time(), conn_func)
146146
w.initialized = Event()
@@ -407,7 +407,7 @@ function init_worker(cookie::AbstractString, manager::ClusterManager=DefaultClus
407407

408408
# System is started in head node mode, cleanup related entries
409409
empty!(PGRP.workers)
410-
empty!(map_pid_wrkr)
410+
@lock map_pid_wrkr empty!(map_pid_wrkr[])
411411

412412
cluster_cookie(cookie)
413413
nothing
@@ -793,7 +793,7 @@ function check_master_connect()
793793
errormonitor(
794794
Threads.@spawn begin
795795
timeout = worker_timeout()
796-
if timedwait(() -> haskey(map_pid_wrkr, 1), timeout) === :timed_out
796+
if timedwait(() -> @lock(map_pid_wrkr, haskey(map_pid_wrkr[], 1)), timeout) === :timed_out
797797
print(stderr, "Master process (id 1) could not connect within $(timeout) seconds.\nexiting.\n")
798798
exit(1)
799799
end
@@ -826,24 +826,19 @@ function cluster_cookie(cookie)
826826
cookie
827827
end
828828

829-
830-
let next_pid = 2 # 1 is reserved for the client (always)
831-
global get_next_pid
832-
function get_next_pid()
833-
retval = next_pid
834-
next_pid += 1
835-
retval
836-
end
837-
end
829+
# 1 is reserved for the client (always)
830+
const next_pid = Threads.Atomic{Int}(2)
831+
# Note that atomic_add!() returns the old value, which is what we want
832+
get_next_pid() = Threads.atomic_add!(next_pid, 1)
838833

839834
mutable struct ProcessGroup
840835
name::String
841-
workers::Array{Any,1}
836+
workers::Vector{Union{Worker, LocalProcess}}
842837
refs::Dict{RRID,Any} # global references
843838
topology::Symbol
844839
lazy::Union{Bool, Nothing}
845840

846-
ProcessGroup(w::Array{Any,1}) = new("pg-default", w, Dict(), :all_to_all, nothing)
841+
ProcessGroup(w::Vector) = new("pg-default", w, Dict(), :all_to_all, nothing)
847842
end
848843
const PGRP = ProcessGroup([])
849844

@@ -873,11 +868,11 @@ end
873868
# globals
874869
const LPROC = LocalProcess()
875870
const LPROCROLE = Ref{Symbol}(:master)
876-
const HDR_VERSION_LEN=16
877-
const HDR_COOKIE_LEN=16
878-
const map_pid_wrkr = Dict{Int, Union{Worker, LocalProcess}}()
879-
const map_sock_wrkr = IdDict()
880-
const map_del_wrkr = Set{Int}()
871+
const HDR_VERSION_LEN = 16
872+
const HDR_COOKIE_LEN = 16
873+
const map_pid_wrkr = Lockable(Dict{Int, Union{Worker, LocalProcess}}())
874+
const map_sock_wrkr = Lockable(IdDict())
875+
const map_del_wrkr = Lockable(Set{Int}())
881876

882877
# whether process is a master or worker in a distributed setup
883878
myrole() = LPROCROLE[]
@@ -1018,7 +1013,7 @@ See also [`other_procs()`](@ref).
10181013
function procs(pid::Integer)
10191014
if myid() == 1
10201015
all_workers = [x for x in PGRP.workers if isa(x, LocalProcess) || ((@atomic x.state) === W_CONNECTED)]
1021-
if (pid == 1) || (isa(map_pid_wrkr[pid].manager, LocalManager))
1016+
if (pid == 1) || (isa(@lock(map_pid_wrkr, map_pid_wrkr[][pid].manager), LocalManager))
10221017
Int[x.id for x in filter(w -> (w.id==1) || (isa(w.manager, LocalManager)), all_workers)]
10231018
else
10241019
ipatpid = get_bind_addr(pid)
@@ -1124,8 +1119,8 @@ function _rmprocs(pids, waitfor)
11241119
if p == 1
11251120
@warn "rmprocs: process 1 not removed"
11261121
else
1127-
if haskey(map_pid_wrkr, p)
1128-
w = map_pid_wrkr[p]
1122+
w = @lock map_pid_wrkr get(map_pid_wrkr[], p, nothing)
1123+
if !isnothing(w)
11291124
set_worker_state(w, W_TERMINATING)
11301125
kill(w.manager, p, w.config)
11311126
push!(rmprocset, w)
@@ -1165,16 +1160,17 @@ ProcessExitedException() = ProcessExitedException(-1)
11651160

11661161
worker_from_id(i) = worker_from_id(PGRP, i)
11671162
function worker_from_id(pg::ProcessGroup, i)
1168-
if !isempty(map_del_wrkr) && in(i, map_del_wrkr)
1163+
@lock map_del_wrkr if !isempty(map_del_wrkr[]) && in(i, map_del_wrkr[])
11691164
throw(ProcessExitedException(i))
11701165
end
1171-
w = get(map_pid_wrkr, i, nothing)
1166+
1167+
w = @lock map_pid_wrkr get(map_pid_wrkr[], i, nothing)
11721168
if w === nothing
11731169
if myid() == 1
11741170
error("no process with id $i exists")
11751171
end
11761172
w = Worker(i)
1177-
map_pid_wrkr[i] = w
1173+
@lock map_pid_wrkr map_pid_wrkr[][i] = w
11781174
else
11791175
w = w::Union{Worker, LocalProcess}
11801176
end
@@ -1190,7 +1186,7 @@ This is useful when writing custom [`serialize`](@ref) methods for a type,
11901186
which optimizes the data written out depending on the receiving process id.
11911187
"""
11921188
function worker_id_from_socket(s)
1193-
w = get(map_sock_wrkr, s, nothing)
1189+
w = @lock map_sock_wrkr get(map_sock_wrkr[], s, nothing)
11941190
if isa(w,Worker)
11951191
if s === w.r_stream || s === w.w_stream
11961192
return w.id
@@ -1207,23 +1203,28 @@ end
12071203
register_worker(w) = register_worker(PGRP, w)
12081204
function register_worker(pg, w)
12091205
push!(pg.workers, w)
1210-
map_pid_wrkr[w.id] = w
1206+
@lock map_pid_wrkr map_pid_wrkr[][w.id] = w
12111207
end
12121208

12131209
function register_worker_streams(w)
1214-
map_sock_wrkr[w.r_stream] = w
1215-
map_sock_wrkr[w.w_stream] = w
1210+
@lock map_sock_wrkr begin
1211+
map_sock_wrkr[][w.r_stream] = w
1212+
map_sock_wrkr[][w.w_stream] = w
1213+
end
12161214
end
12171215

12181216
deregister_worker(pid) = deregister_worker(PGRP, pid)
12191217
function deregister_worker(pg, pid)
12201218
pg.workers = filter(x -> !(x.id == pid), pg.workers)
1221-
w = pop!(map_pid_wrkr, pid, nothing)
1219+
1220+
w = @lock map_pid_wrkr pop!(map_pid_wrkr[], pid, nothing)
12221221
if isa(w, Worker)
12231222
if isdefined(w, :r_stream)
1224-
pop!(map_sock_wrkr, w.r_stream, nothing)
1225-
if w.r_stream != w.w_stream
1226-
pop!(map_sock_wrkr, w.w_stream, nothing)
1223+
@lock map_sock_wrkr begin
1224+
pop!(map_sock_wrkr[], w.r_stream, nothing)
1225+
if w.r_stream != w.w_stream
1226+
pop!(map_sock_wrkr[], w.w_stream, nothing)
1227+
end
12271228
end
12281229
end
12291230

@@ -1240,7 +1241,7 @@ function deregister_worker(pg, pid)
12401241
end
12411242
end
12421243
end
1243-
push!(map_del_wrkr, pid)
1244+
@lock map_del_wrkr push!(map_del_wrkr[], pid)
12441245

12451246
# delete this worker from our remote reference client sets
12461247
ids = []
@@ -1270,7 +1271,7 @@ end
12701271

12711272
function interrupt(pid::Integer)
12721273
@assert myid() == 1
1273-
w = map_pid_wrkr[pid]
1274+
w = @lock map_pid_wrkr map_pid_wrkr[][pid]
12741275
if isa(w, Worker)
12751276
manage(w.manager, w.id, w.config, :interrupt)
12761277
end
@@ -1310,11 +1311,11 @@ function check_same_host(pids)
13101311
# We checkfirst if all test pids have been started using the local manager,
13111312
# else we check for the same bind_to addr. This handles the special case
13121313
# where the local ip address may change - as during a system sleep/awake
1313-
if all(p -> (p==1) || (isa(map_pid_wrkr[p].manager, LocalManager)), pids)
1314+
@lock map_pid_wrkr if all(p -> (p==1) || (isa(map_pid_wrkr[][p].manager, LocalManager)), pids)
13141315
return true
13151316
else
1316-
first_bind_addr = notnothing(wp_bind_addr(map_pid_wrkr[pids[1]]))
1317-
return all(p -> notnothing(wp_bind_addr(map_pid_wrkr[p])) == first_bind_addr, pids[2:end])
1317+
first_bind_addr = notnothing(wp_bind_addr(map_pid_wrkr[][pids[1]]))
1318+
return all(p -> notnothing(wp_bind_addr(map_pid_wrkr[][p])) == first_bind_addr, pids[2:end])
13181319
end
13191320
end
13201321
end

src/macros.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
# This file is a part of Julia. License is MIT: https://julialang.org/license
22

3-
let nextidx = Threads.Atomic{Int}(0)
4-
global nextproc
5-
function nextproc()
6-
idx = Threads.atomic_add!(nextidx, 1)
7-
return workers()[(idx % nworkers()) + 1]
8-
end
3+
const nextidx = Threads.Atomic{Int}(0)
4+
function nextproc()
5+
idx = Threads.atomic_add!(nextidx, 1)
6+
return workers()[(idx % nworkers()) + 1]
97
end
108

119
spawnat(p, thunk) = remotecall(thunk, p)

src/managers.jl

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -414,18 +414,9 @@ function manage(manager::SSHManager, id::Integer, config::WorkerConfig, op::Symb
414414
end
415415
end
416416

417-
let tunnel_port = 9201
418-
global next_tunnel_port
419-
function next_tunnel_port()
420-
retval = tunnel_port
421-
if tunnel_port > 32000
422-
tunnel_port = 9201
423-
else
424-
tunnel_port += 1
425-
end
426-
retval
427-
end
428-
end
417+
const tunnel_counter = Threads.Atomic{Int}(1)
418+
# This is defined such that the port numbers start at 9201 and wrap around at 32,000
419+
next_tunnel_port() = (Threads.atomic_add!(tunnel_counter, 1) % 22_800) + 9200
429420

430421

431422
"""

src/process_messages.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ function message_handler_loop(r_stream::IO, w_stream::IO, incoming::Bool)
220220
if wpid < 1
221221
println(stderr, e, CapturedException(e, catch_backtrace()))
222222
println(stderr, "Process($(myid())) - Unknown remote, closing connection.")
223-
elseif !(wpid in map_del_wrkr)
223+
elseif @lock(map_del_wrkr, !(wpid in map_del_wrkr[]))
224224
werr = worker_from_id(wpid)
225225
oldstate = @atomic werr.state
226226
set_worker_state(werr, W_TERMINATED)
@@ -325,7 +325,7 @@ function handle_msg(msg::IdentifySocketMsg, header, r_stream, w_stream, version)
325325
end
326326

327327
function handle_msg(msg::IdentifySocketAckMsg, header, r_stream, w_stream, version)
328-
w = map_sock_wrkr[r_stream]
328+
w = @lock map_sock_wrkr map_sock_wrkr[][r_stream]
329329
w.version = version
330330
end
331331

@@ -378,7 +378,7 @@ function connect_to_peer(manager::ClusterManager, rpid::Int, wconfig::WorkerConf
378378
end
379379

380380
function handle_msg(msg::JoinCompleteMsg, header, r_stream, w_stream, version)
381-
w = map_sock_wrkr[r_stream]
381+
w = @lock map_sock_wrkr map_sock_wrkr[][r_stream]
382382
environ = something(w.config.environ, Dict())
383383
environ[:cpu_threads] = msg.cpu_threads
384384
w.config.environ = environ

0 commit comments

Comments
 (0)