Skip to content

Commit 2d5fb02

Browse files
authored
[replit_river] return cleanup task from client.disconnect() (#99)
Why === * The task created by the websocket wrapper was orphaned. Tasks need to be awaited somewhere, or you get errors like ``` RuntimeError: no running event loop Task was destroyed but it is pending! ``` * Since it takes a while to finish, we don't want to wait for it in certain cases, so instead we'll return it as a cleanup task that the caller can await as appropriate. What changed === * When the websocket close task is made, return it. * At every level, return the task and combine it with other cleanup tasks as appropriate Test plan === * The behavior shouldn't be different unless you await the cleanup function. If you do await it, you won't get a pending task exception when closing the event loop.
1 parent 80e6d95 commit 2d5fb02

5 files changed

Lines changed: 33 additions & 13 deletions

File tree

replit_river/client.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import logging
23
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable
34
from typing import Any, Generic, Optional, Union
@@ -38,10 +39,11 @@ def __init__(
3839
transport_options=transport_options,
3940
)
4041

41-
async def close(self) -> None:
42+
async def close(self) -> asyncio.Task | None:
4243
logger.info(f"river client {self._client_id} start closing")
43-
await self._transport.close()
44+
cleanup_task = await self._transport.close()
4445
logger.info(f"river client {self._client_id} closed")
46+
return cleanup_task
4547

4648
async def ensure_connected(self) -> None:
4749
await self._transport.get_or_create_session()

replit_river/client_transport.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,9 @@ def __init__(
6868
# We want to make sure there's only one session creation at a time
6969
self._create_session_lock = asyncio.Lock()
7070

71-
async def close(self) -> None:
71+
async def close(self) -> asyncio.Task:
7272
self._rate_limiter.close()
73-
await self._close_all_sessions()
73+
return await self._close_all_sessions()
7474

7575
async def get_or_create_session(self) -> ClientSession:
7676
async with self._create_session_lock:

replit_river/session.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -434,15 +434,18 @@ async def _send_responses_from_output_stream(
434434

435435
async def close_websocket(
436436
self, ws_wrapper: WebsocketWrapper, should_retry: bool
437-
) -> None:
437+
) -> asyncio.Task | None:
438438
"""Mark the websocket as closed, close the websocket, and retry if needed."""
439+
cleanup_websocket_task: asyncio.Task | None = None
439440
async with self._ws_lock:
440441
# Already closed.
441442
if not await ws_wrapper.is_open():
442-
return
443-
await ws_wrapper.close()
443+
logger.info("websocket wrapper already closed")
444+
return cleanup_websocket_task
445+
cleanup_websocket_task = await ws_wrapper.close()
444446
if should_retry and self._retry_connection_callback:
445447
self._task_manager.create_task(self._retry_connection_callback())
448+
return cleanup_websocket_task
446449

447450
async def _open_stream_and_call_handler(
448451
self,
@@ -523,8 +526,9 @@ async def _remove_acked_messages_in_buffer(self) -> None:
523526
async def start_serve_responses(self) -> None:
524527
self._task_manager.create_task(self.serve())
525528

526-
async def close(self) -> None:
529+
async def close(self) -> asyncio.Task | None:
527530
"""Close the session and all associated streams."""
531+
cleanup_websocket_task: asyncio.Task | None = None
528532
logger.info(
529533
f"{self._transport_id} closing session "
530534
f"to {self._to_id}, ws: {self._ws_wrapper.id}, "
@@ -533,12 +537,14 @@ async def close(self) -> None:
533537
async with self._state_lock:
534538
if self._state != SessionState.ACTIVE:
535539
# already closing
536-
return
540+
return cleanup_websocket_task
537541
self._state = SessionState.CLOSING
538542
self._reset_session_close_countdown()
539543
await self._task_manager.cancel_all_tasks()
540544

541-
await self.close_websocket(self._ws_wrapper, should_retry=False)
545+
cleanup_websocket_task = await self.close_websocket(
546+
self._ws_wrapper, should_retry=False
547+
)
542548

543549
await self._buffer.close()
544550

@@ -553,3 +559,4 @@ async def close(self) -> None:
553559
self._streams.clear()
554560

555561
self._state = SessionState.CLOSED
562+
return cleanup_websocket_task

replit_river/transport.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ def __init__(
2727
self._handlers: Dict[Tuple[str, str], Tuple[str, GenericRpcHandler]] = {}
2828
self._session_lock = asyncio.Lock()
2929

30-
async def _close_all_sessions(self) -> None:
30+
async def _close_all_sessions(self) -> asyncio.Task:
31+
cleanup_tasks: list[asyncio.Task] = []
3132
sessions = self._sessions.values()
3233
logger.info(
3334
f"start closing sessions {self._transport_id}, number sessions : "
@@ -38,10 +39,18 @@ async def _close_all_sessions(self) -> None:
3839
# closing sessions requires access to the session lock, so we need to close
3940
# them one by one to be safe
4041
for session in sessions_to_close:
41-
await session.close()
42+
cleanup_task = await session.close()
43+
if cleanup_task:
44+
cleanup_tasks.append(cleanup_task)
4245

4346
logger.info(f"Transport closed {self._transport_id}")
4447

48+
async def cleanup() -> None:
49+
for cleanup_task in cleanup_tasks:
50+
await cleanup_task
51+
52+
return asyncio.create_task(cleanup())
53+
4554
async def _delete_session(self, session: Session) -> None:
4655
async with self._session_lock:
4756
if session._to_id in self._sessions:

replit_river/websocket_wrapper.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ async def is_open(self) -> bool:
2424
async with self.ws_lock:
2525
return self.ws_state == WsState.OPEN
2626

27-
async def close(self) -> None:
27+
async def close(self) -> asyncio.Task | None:
2828
async with self.ws_lock:
2929
if self.ws_state == WsState.OPEN:
3030
self.ws_state = WsState.CLOSING
@@ -33,3 +33,5 @@ async def close(self) -> None:
3333
lambda _: logger.debug("old websocket %s closed.", self.ws.id)
3434
)
3535
self.ws_state = WsState.CLOSED
36+
return task
37+
return None

0 commit comments

Comments
 (0)