1414 Coroutine ,
1515 Literal ,
1616 NotRequired ,
17+ Protocol ,
1718 TypeAlias ,
1819 TypedDict ,
1920 assert_never ,
@@ -112,7 +113,6 @@ class ResultError(TypedDict):
112113trace_propagator = TraceContextTextMapPropagator ()
113114trace_setter = TransportMessageTracingSetter ()
114115
115- CloseSessionCallback : TypeAlias = Callable [["Session" ], None ]
116116RetryConnectionCallback : TypeAlias = Callable [
117117 [],
118118 Coroutine [Any , Any , Any ],
@@ -131,14 +131,25 @@ class StreamMeta(TypedDict):
131131 output : Channel [ResultType ]
132132
133133
134+ class TriggerCloseCall (Protocol ):
135+ def __call__ (
136+ self ,
137+ signal_closing : Callable [[], None ],
138+ task_manager : BackgroundTaskManager , # .cancel_all_tasks()
139+ terminate_remaining_output_streams : Callable [[], None ],
140+ join_output_streams_with_timeout : Callable [[], Awaitable [None ]],
141+ ws : ClientConnection | None ,
142+ become_closed : Callable [[], None ],
143+ ) -> asyncio .Event : ...
144+
145+
134146class Session [HandshakeMetadata ]:
135147 _server_id : str
136148 session_id : str
137149 _transport_options : TransportOptions
138150
139151 # session state, only modified during closing
140152 _state : SessionState
141- _close_session_callback : CloseSessionCallback
142153 _close_session_after_time_secs : float | None
143154 _connecting_task : asyncio .Task [None ] | None
144155 _wait_for_connected : asyncio .Event
@@ -168,19 +179,19 @@ class Session[HandshakeMetadata]:
168179 seq : int # Last sent sequence number
169180
170181 # Terminating
171- _terminating_task : asyncio . Task [ None ] | None
182+ _trigger_close : TriggerCloseCall
172183
173184 def __init__ (
174185 self ,
175186 server_id : str ,
176187 session_id : str ,
177188 transport_options : TransportOptions ,
178- close_session_callback : CloseSessionCallback ,
179189 client_id : str ,
180190 rate_limiter : RateLimiter ,
181191 uri_and_metadata_factory : Callable [
182192 [], Awaitable [UriAndMetadata [HandshakeMetadata ]]
183193 ],
194+ trigger_close : TriggerCloseCall ,
184195 retry_connection_callback : RetryConnectionCallback | None = None ,
185196 ) -> None :
186197 self ._server_id = server_id
@@ -189,7 +200,6 @@ def __init__(
189200
190201 # session state
191202 self ._state = SessionState .NO_CONNECTION
192- self ._close_session_callback = close_session_callback
193203 self ._close_session_after_time_secs : float | None = None
194204 self ._connecting_task = None
195205 self ._wait_for_connected = asyncio .Event ()
@@ -227,7 +237,7 @@ def __init__(
227237 self .seq = 0
228238
229239 # Terminating
230- self ._terminating_task = None
240+ self ._trigger_close = trigger_close
231241
232242 self ._start_recv_from_ws ()
233243 self ._start_buffered_message_sender ()
@@ -304,7 +314,7 @@ def unbind_connecting_task() -> None:
304314 close_ws_in_background = close_ws_in_background ,
305315 transition_connected = transition_connected ,
306316 unbind_connecting_task = unbind_connecting_task ,
307- close_session = self ._close_internal_nowait ,
317+ close_session = self .close ,
308318 )
309319 )
310320
@@ -313,12 +323,6 @@ def unbind_connecting_task() -> None:
313323 except asyncio .CancelledError :
314324 pass
315325
316- if self ._terminating_task :
317- try :
318- await self ._terminating_task
319- except asyncio .CancelledError :
320- pass
321-
322326 def is_terminal (self ) -> bool :
323327 """
324328 If the session is in a terminal state.
@@ -394,52 +398,16 @@ async def _enqueue_message(
394398 # Wake up buffered_message_sender
395399 self ._process_messages .set ()
396400
397- async def close (
401+ def close (
398402 self ,
399403 reason : Exception | None = None ,
400- ) -> None :
404+ ) -> asyncio . Event :
401405 """Close the session and all associated streams."""
402- if self ._terminating_task :
403- try :
404- logger .debug ("Session already closing, waiting..." )
405- async with asyncio .timeout (SESSION_CLOSE_TIMEOUT_SEC ):
406- await self ._terminating_task
407- except asyncio .TimeoutError :
408- logger .warning (
409- f"Session took longer than { SESSION_CLOSE_TIMEOUT_SEC } "
410- "seconds to close, leaking" ,
411- )
412- return
413- try :
414- await self ._close_internal (reason )
415- except asyncio .CancelledError :
416- pass
417-
418- def _close_internal_nowait (self , reason : Exception | None = None ) -> None :
419- """
420- When calling close() from asyncio Tasks, we must not block.
421-
422- This function does so, deferring to the underlying infrastructure for
423- creating self._terminating_task.
424- """
425- self ._close_internal (reason )
426-
427- def _close_internal (self , reason : Exception | None = None ) -> asyncio .Task [None ]:
428- """
429- Internal close method. Subsequent calls past the first do not block.
430-
431- This is intended to be the primary driver of a session being torn down
432- and returned to its initial state.
433-
434- NB: This function is intended to be the sole lifecycle manager of
435- self._terminating_task. Waiting on the completion of that task is optional,
436- but the population of that property is critical.
437-
438- NB: We must not await the task returned from this function from chained tasks
439- inside this session, otherwise we will create a thread loop.
440- """
441406
442- async def do_close () -> None :
407+ def signal_closing () -> None :
408+ """
409+ Roughly "kill 15"
410+ """
443411 logger .info (
444412 f"{ self .session_id } closing session to { self ._server_id } , "
445413 f"ws: { self ._ws } "
@@ -454,11 +422,10 @@ async def do_close() -> None:
454422 # ... message processor so it can exit cleanly
455423 self ._process_messages .set ()
456424
457- # Wait to permit the waiting tasks to shut down gracefully
458- await asyncio .sleep (0.25 )
459-
460- await self ._task_manager .cancel_all_tasks ()
461-
425+ def terminate_remaining_output_streams () -> None :
426+ """
427+ Roughly "kill 9"
428+ """
462429 for stream_id , stream_meta in self ._streams .items ():
463430 stream_meta ["output" ].close ()
464431 # Wake up backpressured writers
@@ -475,6 +442,11 @@ async def do_close() -> None:
475442 "Unable to tell the caller that the session is going away" ,
476443 )
477444 stream_meta ["release_backpressured_waiter" ]()
445+
446+ async def join_output_streams_with_timeout () -> None :
447+ """
448+ Roughly "wait"
449+ """
478450 # Before we GC the streams, let's wait for all tasks to be closed gracefully
479451 try :
480452 async with asyncio .timeout (
@@ -500,21 +472,21 @@ async def do_close() -> None:
500472 )
501473 self ._streams .clear ()
502474
503- if self ._ws :
504- # The Session isn't guaranteed to live much longer than this close()
505- # invocation, so let's await this close to avoid dropping the socket.
506- await self ._ws .close ()
507-
475+ def become_closed () -> None :
476+ pass
508477 self ._state = SessionState .CLOSED
509478
510479 # Clear the session in transports
511480 # This will get us GC'd, so this should be the last thing.
512- self ._close_session_callback (self )
513481
514- if not self ._terminating_task :
515- self ._terminating_task = asyncio .create_task (do_close ())
516-
517- return self ._terminating_task
482+ return self ._trigger_close (
483+ signal_closing ,
484+ self ._task_manager , # .cancel_all_tasks()
485+ terminate_remaining_output_streams ,
486+ join_output_streams_with_timeout ,
487+ self ._ws ,
488+ become_closed ,
489+ )
518490
519491 def _start_buffered_message_sender (
520492 self ,
@@ -657,7 +629,7 @@ async def block_until_connected() -> None:
657629 get_state = lambda : self ._state ,
658630 get_ws = lambda : self ._ws ,
659631 transition_no_connection = transition_no_connection ,
660- close_session = self ._close_internal_nowait ,
632+ close_session = self .close ,
661633 assert_incoming_seq_bookkeeping = assert_incoming_seq_bookkeeping ,
662634 get_stream = lambda stream_id : self ._streams .get (stream_id ),
663635 enqueue_message = self ._enqueue_message ,
@@ -1105,7 +1077,7 @@ async def _do_ensure_connected[HandshakeMetadata](
11051077 close_ws_in_background : Callable [[ClientConnection ], None ],
11061078 transition_connected : Callable [[ClientConnection ], None ],
11071079 unbind_connecting_task : Callable [[], None ],
1108- close_session : Callable [[Exception | None ], None ],
1080+ close_session : Callable [[Exception | None ], asyncio . Event ],
11091081) -> None :
11101082 logger .info ("Attempting to establish new ws connection" )
11111083
@@ -1273,7 +1245,7 @@ async def _recv_from_ws(
12731245 get_state : Callable [[], SessionState ],
12741246 get_ws : Callable [[], ClientConnection | None ],
12751247 transition_no_connection : Callable [[], Awaitable [None ]],
1276- close_session : Callable [[Exception | None ], None ],
1248+ close_session : Callable [[Exception | None ], asyncio . Event ],
12771249 assert_incoming_seq_bookkeeping : Callable [
12781250 [str , int , int ], Literal [True ] | _IgnoreMessage
12791251 ],
0 commit comments