Skip to content

Commit f25672d

Browse files
Pushing ownership of "close()" out of session into transport
1 parent 86c6d27 commit f25672d

3 files changed

Lines changed: 149 additions & 109 deletions

File tree

src/replit_river/v2/client_transport.py

Lines changed: 65 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
1+
import asyncio
12
import logging
3+
from asyncio import Event, shield
24
from collections.abc import Awaitable, Callable
35
from typing import Generic
46

57
import nanoid
8+
from websockets.asyncio.client import ClientConnection
69

710
from replit_river.rate_limiter import LeakyBucketRateLimit
11+
from replit_river.task_manager import BackgroundTaskManager
812
from replit_river.transport_options import (
913
HandshakeMetadataType,
1014
TransportOptions,
@@ -17,6 +21,7 @@
1721

1822
class ClientTransport(Generic[HandshakeMetadataType]):
1923
_session: Session | None
24+
_closing_event: tuple[Session, Event, Awaitable[None]] | None
2025

2126
def __init__(
2227
self,
@@ -28,6 +33,7 @@ def __init__(
2833
self._session = None
2934
self._transport_id = nanoid.generate()
3035
self._transport_options = transport_options
36+
self._closing_event = None
3137

3238
self._uri_and_metadata_factory = uri_and_metadata_factory
3339
self._client_id = client_id
@@ -37,16 +43,57 @@ def __init__(
3743
)
3844

3945
async def close(self) -> None:
46+
"""
47+
A very simple function that only defers to session's close(), which
48+
defers to the parameter we pass in to the Session constructor.
49+
No logic in here.
50+
"""
4051
self._rate_limiter.close()
4152
if self._session:
42-
await self._session.close()
43-
logger.info(
44-
"Transport closed",
45-
extra={
46-
"client_id": self._client_id,
47-
"transport_id": self._transport_id,
48-
},
49-
)
53+
self._session.close()
54+
55+
if self._closing_event:
56+
await self._closing_event[1].wait()
57+
58+
def _trigger_close(
59+
self,
60+
signal_closing: Callable[[], None],
61+
task_manager: BackgroundTaskManager, # .cancel_all_tasks()
62+
terminate_remaining_output_streams: Callable[[], None],
63+
join_output_streams_with_timeout: Callable[[], Awaitable[None]],
64+
ws: ClientConnection | None,
65+
become_closed: Callable[[], None],
66+
) -> Event:
67+
if self._closing_event:
68+
return self._closing_event[1]
69+
if self._session is None:
70+
noop = asyncio.Event()
71+
noop.set()
72+
return noop
73+
74+
closing_event = Event()
75+
76+
async def _do_close() -> None:
77+
session = self._session
78+
signal_closing()
79+
await task_manager.cancel_all_tasks()
80+
terminate_remaining_output_streams()
81+
await join_output_streams_with_timeout()
82+
if ws:
83+
await ws.close()
84+
become_closed()
85+
# Ensure that we've not established a new session in the
86+
# meantime somehow.
87+
if self._session is session:
88+
self._session = None
89+
closing_event.set()
90+
91+
self._closing_event = (
92+
self._session,
93+
closing_event,
94+
shield(asyncio.create_task(_do_close(), name="do_close")),
95+
)
96+
return self._closing_event[1]
5097

5198
async def get_or_create_session(self) -> Session:
5299
"""
@@ -57,16 +104,23 @@ async def get_or_create_session(self) -> Session:
57104
if not existing_session or existing_session.is_terminal():
58105
logger.info("Creating new session")
59106
if existing_session:
60-
await existing_session.close()
107+
await existing_session.close().wait()
108+
if self._closing_event and self._closing_event[0] == existing_session:
109+
await self._closing_event[2]
110+
else:
111+
logger.error(
112+
"This should not be possible, "
113+
"self._closing_event should always refer to existing_session",
114+
)
61115
new_session = Session(
62116
client_id=self._client_id,
63117
server_id=self._server_id,
64118
session_id=nanoid.generate(),
65119
transport_options=self._transport_options,
66-
close_session_callback=self._delete_session,
67120
retry_connection_callback=self._retry_connection,
68121
uri_and_metadata_factory=self._uri_and_metadata_factory,
69122
rate_limiter=self._rate_limiter,
123+
trigger_close=self._trigger_close,
70124
)
71125

72126
self._session = new_session
@@ -78,21 +132,6 @@ async def get_or_create_session(self) -> Session:
78132
async def _retry_connection(self) -> Session:
79133
if self._session and not self._transport_options.transparent_reconnect:
80134
logger.info("transparent_reconnect not set, closing {self._transport_id}")
81-
await self._session.close()
135+
await self._session.close().wait()
82136
logger.debug("Triggering get_or_create_session")
83137
return await self.get_or_create_session()
84-
85-
def _delete_session(self, session: Session) -> None:
86-
if self._session is session:
87-
self._session = None
88-
else:
89-
logger.warning(
90-
"Session attempted to close itself but it was not the "
91-
"active session, doing nothing",
92-
extra={
93-
"client_id": self._client_id,
94-
"transport_id": self._transport_id,
95-
"active_session_id": self._session and self._session.session_id,
96-
"orphan_session_id": session.session_id,
97-
},
98-
)

src/replit_river/v2/session.py

Lines changed: 45 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
Coroutine,
1515
Literal,
1616
NotRequired,
17+
Protocol,
1718
TypeAlias,
1819
TypedDict,
1920
assert_never,
@@ -112,7 +113,6 @@ class ResultError(TypedDict):
112113
trace_propagator = TraceContextTextMapPropagator()
113114
trace_setter = TransportMessageTracingSetter()
114115

115-
CloseSessionCallback: TypeAlias = Callable[["Session"], None]
116116
RetryConnectionCallback: TypeAlias = Callable[
117117
[],
118118
Coroutine[Any, Any, Any],
@@ -131,14 +131,25 @@ class StreamMeta(TypedDict):
131131
output: Channel[ResultType]
132132

133133

134+
class TriggerCloseCall(Protocol):
135+
def __call__(
136+
self,
137+
signal_closing: Callable[[], None],
138+
task_manager: BackgroundTaskManager, # .cancel_all_tasks()
139+
terminate_remaining_output_streams: Callable[[], None],
140+
join_output_streams_with_timeout: Callable[[], Awaitable[None]],
141+
ws: ClientConnection | None,
142+
become_closed: Callable[[], None],
143+
) -> asyncio.Event: ...
144+
145+
134146
class Session[HandshakeMetadata]:
135147
_server_id: str
136148
session_id: str
137149
_transport_options: TransportOptions
138150

139151
# session state, only modified during closing
140152
_state: SessionState
141-
_close_session_callback: CloseSessionCallback
142153
_close_session_after_time_secs: float | None
143154
_connecting_task: asyncio.Task[None] | None
144155
_wait_for_connected: asyncio.Event
@@ -168,19 +179,19 @@ class Session[HandshakeMetadata]:
168179
seq: int # Last sent sequence number
169180

170181
# Terminating
171-
_terminating_task: asyncio.Task[None] | None
182+
_trigger_close: TriggerCloseCall
172183

173184
def __init__(
174185
self,
175186
server_id: str,
176187
session_id: str,
177188
transport_options: TransportOptions,
178-
close_session_callback: CloseSessionCallback,
179189
client_id: str,
180190
rate_limiter: RateLimiter,
181191
uri_and_metadata_factory: Callable[
182192
[], Awaitable[UriAndMetadata[HandshakeMetadata]]
183193
],
194+
trigger_close: TriggerCloseCall,
184195
retry_connection_callback: RetryConnectionCallback | None = None,
185196
) -> None:
186197
self._server_id = server_id
@@ -189,7 +200,6 @@ def __init__(
189200

190201
# session state
191202
self._state = SessionState.NO_CONNECTION
192-
self._close_session_callback = close_session_callback
193203
self._close_session_after_time_secs: float | None = None
194204
self._connecting_task = None
195205
self._wait_for_connected = asyncio.Event()
@@ -227,7 +237,7 @@ def __init__(
227237
self.seq = 0
228238

229239
# Terminating
230-
self._terminating_task = None
240+
self._trigger_close = trigger_close
231241

232242
self._start_recv_from_ws()
233243
self._start_buffered_message_sender()
@@ -304,7 +314,7 @@ def unbind_connecting_task() -> None:
304314
close_ws_in_background=close_ws_in_background,
305315
transition_connected=transition_connected,
306316
unbind_connecting_task=unbind_connecting_task,
307-
close_session=self._close_internal_nowait,
317+
close_session=self.close,
308318
)
309319
)
310320

@@ -313,12 +323,6 @@ def unbind_connecting_task() -> None:
313323
except asyncio.CancelledError:
314324
pass
315325

316-
if self._terminating_task:
317-
try:
318-
await self._terminating_task
319-
except asyncio.CancelledError:
320-
pass
321-
322326
def is_terminal(self) -> bool:
323327
"""
324328
If the session is in a terminal state.
@@ -394,52 +398,16 @@ async def _enqueue_message(
394398
# Wake up buffered_message_sender
395399
self._process_messages.set()
396400

397-
async def close(
401+
def close(
398402
self,
399403
reason: Exception | None = None,
400-
) -> None:
404+
) -> asyncio.Event:
401405
"""Close the session and all associated streams."""
402-
if self._terminating_task:
403-
try:
404-
logger.debug("Session already closing, waiting...")
405-
async with asyncio.timeout(SESSION_CLOSE_TIMEOUT_SEC):
406-
await self._terminating_task
407-
except asyncio.TimeoutError:
408-
logger.warning(
409-
f"Session took longer than {SESSION_CLOSE_TIMEOUT_SEC} "
410-
"seconds to close, leaking",
411-
)
412-
return
413-
try:
414-
await self._close_internal(reason)
415-
except asyncio.CancelledError:
416-
pass
417-
418-
def _close_internal_nowait(self, reason: Exception | None = None) -> None:
419-
"""
420-
When calling close() from asyncio Tasks, we must not block.
421-
422-
This function does so, deferring to the underlying infrastructure for
423-
creating self._terminating_task.
424-
"""
425-
self._close_internal(reason)
426-
427-
def _close_internal(self, reason: Exception | None = None) -> asyncio.Task[None]:
428-
"""
429-
Internal close method. Subsequent calls past the first do not block.
430-
431-
This is intended to be the primary driver of a session being torn down
432-
and returned to its initial state.
433-
434-
NB: This function is intended to be the sole lifecycle manager of
435-
self._terminating_task. Waiting on the completion of that task is optional,
436-
but the population of that property is critical.
437-
438-
NB: We must not await the task returned from this function from chained tasks
439-
inside this session, otherwise we will create a thread loop.
440-
"""
441406

442-
async def do_close() -> None:
407+
def signal_closing() -> None:
408+
"""
409+
Roughly "kill 15"
410+
"""
443411
logger.info(
444412
f"{self.session_id} closing session to {self._server_id}, "
445413
f"ws: {self._ws}"
@@ -454,11 +422,10 @@ async def do_close() -> None:
454422
# ... message processor so it can exit cleanly
455423
self._process_messages.set()
456424

457-
# Wait to permit the waiting tasks to shut down gracefully
458-
await asyncio.sleep(0.25)
459-
460-
await self._task_manager.cancel_all_tasks()
461-
425+
def terminate_remaining_output_streams() -> None:
426+
"""
427+
Roughly "kill 9"
428+
"""
462429
for stream_id, stream_meta in self._streams.items():
463430
stream_meta["output"].close()
464431
# Wake up backpressured writers
@@ -475,6 +442,11 @@ async def do_close() -> None:
475442
"Unable to tell the caller that the session is going away",
476443
)
477444
stream_meta["release_backpressured_waiter"]()
445+
446+
async def join_output_streams_with_timeout() -> None:
447+
"""
448+
Roughly "wait"
449+
"""
478450
# Before we GC the streams, let's wait for all tasks to be closed gracefully
479451
try:
480452
async with asyncio.timeout(
@@ -500,21 +472,21 @@ async def do_close() -> None:
500472
)
501473
self._streams.clear()
502474

503-
if self._ws:
504-
# The Session isn't guaranteed to live much longer than this close()
505-
# invocation, so let's await this close to avoid dropping the socket.
506-
await self._ws.close()
507-
475+
def become_closed() -> None:
476+
pass
508477
self._state = SessionState.CLOSED
509478

510479
# Clear the session in transports
511480
# This will get us GC'd, so this should be the last thing.
512-
self._close_session_callback(self)
513481

514-
if not self._terminating_task:
515-
self._terminating_task = asyncio.create_task(do_close())
516-
517-
return self._terminating_task
482+
return self._trigger_close(
483+
signal_closing,
484+
self._task_manager, # .cancel_all_tasks()
485+
terminate_remaining_output_streams,
486+
join_output_streams_with_timeout,
487+
self._ws,
488+
become_closed,
489+
)
518490

519491
def _start_buffered_message_sender(
520492
self,
@@ -657,7 +629,7 @@ async def block_until_connected() -> None:
657629
get_state=lambda: self._state,
658630
get_ws=lambda: self._ws,
659631
transition_no_connection=transition_no_connection,
660-
close_session=self._close_internal_nowait,
632+
close_session=self.close,
661633
assert_incoming_seq_bookkeeping=assert_incoming_seq_bookkeeping,
662634
get_stream=lambda stream_id: self._streams.get(stream_id),
663635
enqueue_message=self._enqueue_message,
@@ -1105,7 +1077,7 @@ async def _do_ensure_connected[HandshakeMetadata](
11051077
close_ws_in_background: Callable[[ClientConnection], None],
11061078
transition_connected: Callable[[ClientConnection], None],
11071079
unbind_connecting_task: Callable[[], None],
1108-
close_session: Callable[[Exception | None], None],
1080+
close_session: Callable[[Exception | None], asyncio.Event],
11091081
) -> None:
11101082
logger.info("Attempting to establish new ws connection")
11111083

@@ -1273,7 +1245,7 @@ async def _recv_from_ws(
12731245
get_state: Callable[[], SessionState],
12741246
get_ws: Callable[[], ClientConnection | None],
12751247
transition_no_connection: Callable[[], Awaitable[None]],
1276-
close_session: Callable[[Exception | None], None],
1248+
close_session: Callable[[Exception | None], asyncio.Event],
12771249
assert_incoming_seq_bookkeeping: Callable[
12781250
[str, int, int], Literal[True] | _IgnoreMessage
12791251
],

0 commit comments

Comments
 (0)