Skip to content

Commit a8418ec

Browse files
authored
[websocket] use background tasks instead of cleanup function (#100)
Why === * The cleanup function added a burden on the caller to make sure the tasks were awaited. * Instead, follow the advice of https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task and store them in the background properly What changed === * Revert cleanup commit * Add background websocket task storage Test plan === * Ran a pytest, and didn't see "Pending task destroyed" exceptions
1 parent 2d5fb02 commit a8418ec

5 files changed

Lines changed: 16 additions & 36 deletions

File tree

replit_river/client.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import asyncio
21
import logging
32
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable
43
from typing import Any, Generic, Optional, Union
@@ -39,11 +38,10 @@ def __init__(
3938
transport_options=transport_options,
4039
)
4140

42-
async def close(self) -> asyncio.Task | None:
41+
async def close(self) -> None:
4342
logger.info(f"river client {self._client_id} start closing")
44-
cleanup_task = await self._transport.close()
43+
await self._transport.close()
4544
logger.info(f"river client {self._client_id} closed")
46-
return cleanup_task
4745

4846
async def ensure_connected(self) -> None:
4947
await self._transport.get_or_create_session()

replit_river/client_transport.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,9 @@ def __init__(
6868
# We want to make sure there's only one session creation at a time
6969
self._create_session_lock = asyncio.Lock()
7070

71-
async def close(self) -> asyncio.Task:
71+
async def close(self) -> None:
7272
self._rate_limiter.close()
73-
return await self._close_all_sessions()
73+
await self._close_all_sessions()
7474

7575
async def get_or_create_session(self) -> ClientSession:
7676
async with self._create_session_lock:

replit_river/session.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -434,18 +434,15 @@ async def _send_responses_from_output_stream(
434434

435435
async def close_websocket(
436436
self, ws_wrapper: WebsocketWrapper, should_retry: bool
437-
) -> asyncio.Task | None:
437+
) -> None:
438438
"""Mark the websocket as closed, close the websocket, and retry if needed."""
439-
cleanup_websocket_task: asyncio.Task | None = None
440439
async with self._ws_lock:
441440
# Already closed.
442441
if not await ws_wrapper.is_open():
443-
logger.info("websocket wrapper already closed")
444-
return cleanup_websocket_task
445-
cleanup_websocket_task = await ws_wrapper.close()
442+
return
443+
await ws_wrapper.close()
446444
if should_retry and self._retry_connection_callback:
447445
self._task_manager.create_task(self._retry_connection_callback())
448-
return cleanup_websocket_task
449446

450447
async def _open_stream_and_call_handler(
451448
self,
@@ -526,9 +523,8 @@ async def _remove_acked_messages_in_buffer(self) -> None:
526523
async def start_serve_responses(self) -> None:
527524
self._task_manager.create_task(self.serve())
528525

529-
async def close(self) -> asyncio.Task | None:
526+
async def close(self) -> None:
530527
"""Close the session and all associated streams."""
531-
cleanup_websocket_task: asyncio.Task | None = None
532528
logger.info(
533529
f"{self._transport_id} closing session "
534530
f"to {self._to_id}, ws: {self._ws_wrapper.id}, "
@@ -537,14 +533,12 @@ async def close(self) -> asyncio.Task | None:
537533
async with self._state_lock:
538534
if self._state != SessionState.ACTIVE:
539535
# already closing
540-
return cleanup_websocket_task
536+
return
541537
self._state = SessionState.CLOSING
542538
self._reset_session_close_countdown()
543539
await self._task_manager.cancel_all_tasks()
544540

545-
cleanup_websocket_task = await self.close_websocket(
546-
self._ws_wrapper, should_retry=False
547-
)
541+
await self.close_websocket(self._ws_wrapper, should_retry=False)
548542

549543
await self._buffer.close()
550544

@@ -559,4 +553,3 @@ async def close(self) -> asyncio.Task | None:
559553
self._streams.clear()
560554

561555
self._state = SessionState.CLOSED
562-
return cleanup_websocket_task

replit_river/transport.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@ def __init__(
2727
self._handlers: Dict[Tuple[str, str], Tuple[str, GenericRpcHandler]] = {}
2828
self._session_lock = asyncio.Lock()
2929

30-
async def _close_all_sessions(self) -> asyncio.Task:
31-
cleanup_tasks: list[asyncio.Task] = []
30+
async def _close_all_sessions(self) -> None:
3231
sessions = self._sessions.values()
3332
logger.info(
3433
f"start closing sessions {self._transport_id}, number sessions : "
@@ -39,18 +38,10 @@ async def _close_all_sessions(self) -> asyncio.Task:
3938
# closing sessions requires access to the session lock, so we need to close
4039
# them one by one to be safe
4140
for session in sessions_to_close:
42-
cleanup_task = await session.close()
43-
if cleanup_task:
44-
cleanup_tasks.append(cleanup_task)
41+
await session.close()
4542

4643
logger.info(f"Transport closed {self._transport_id}")
4744

48-
async def cleanup() -> None:
49-
for cleanup_task in cleanup_tasks:
50-
await cleanup_task
51-
52-
return asyncio.create_task(cleanup())
53-
5445
async def _delete_session(self, session: Session) -> None:
5546
async with self._session_lock:
5647
if session._to_id in self._sessions:

replit_river/websocket_wrapper.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from websockets import WebSocketCommonProtocol
66

77
logger = logging.getLogger(__name__)
8+
_background_tasks: set[asyncio.Task] = set()
89

910

1011
class WsState(enum.Enum):
@@ -24,14 +25,11 @@ async def is_open(self) -> bool:
2425
async with self.ws_lock:
2526
return self.ws_state == WsState.OPEN
2627

27-
async def close(self) -> asyncio.Task | None:
28+
async def close(self) -> None:
2829
async with self.ws_lock:
2930
if self.ws_state == WsState.OPEN:
3031
self.ws_state = WsState.CLOSING
3132
task = asyncio.create_task(self.ws.close())
32-
task.add_done_callback(
33-
lambda _: logger.debug("old websocket %s closed.", self.ws.id)
34-
)
33+
_background_tasks.add(task)
34+
task.add_done_callback(_background_tasks.discard)
3535
self.ws_state = WsState.CLOSED
36-
return task
37-
return None

0 commit comments

Comments
 (0)