Skip to content

Commit b850388

Browse files
authored
Add ensure_connected method to client_transport (#86)
Why === There's currently no way of making sure we're connected other than sending an RPC. This poses issues where we have errors propagating much later than they should be. What changed ============ Add method to allow us to check if we're connected and error as early as possible
1 parent 6968a17 commit b850388

2 files changed

Lines changed: 34 additions & 31 deletions

File tree

replit_river/client.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ async def close(self) -> None:
4242
await self._transport.close()
4343
logger.info(f"river client {self._client_id} closed")
4444

45+
async def ensure_connected(self) -> None:
46+
await self._transport.get_or_create_session()
47+
4548
async def send_rpc(
4649
self,
4750
service_name: str,
@@ -51,7 +54,7 @@ async def send_rpc(
5154
response_deserializer: Callable[[Any], ResponseType],
5255
error_deserializer: Callable[[Any], ErrorType],
5356
) -> ResponseType:
54-
session = await self._transport._get_or_create_session()
57+
session = await self._transport.get_or_create_session()
5558
return await session.send_rpc(
5659
service_name,
5760
procedure_name,
@@ -72,7 +75,7 @@ async def send_upload(
7275
response_deserializer: Callable[[Any], ResponseType],
7376
error_deserializer: Callable[[Any], ErrorType],
7477
) -> ResponseType:
75-
session = await self._transport._get_or_create_session()
78+
session = await self._transport.get_or_create_session()
7679
return await session.send_upload(
7780
service_name,
7881
procedure_name,
@@ -93,7 +96,7 @@ async def send_subscription(
9396
response_deserializer: Callable[[Any], ResponseType],
9497
error_deserializer: Callable[[Any], ErrorType],
9598
) -> AsyncIterator[Union[ResponseType, ErrorType]]:
96-
session = await self._transport._get_or_create_session()
99+
session = await self._transport.get_or_create_session()
97100
return session.send_subscription(
98101
service_name,
99102
procedure_name,
@@ -114,7 +117,7 @@ async def send_stream(
114117
response_deserializer: Callable[[Any], ResponseType],
115118
error_deserializer: Callable[[Any], ErrorType],
116119
) -> AsyncIterator[Union[ResponseType, ErrorType]]:
117-
session = await self._transport._get_or_create_session()
120+
session = await self._transport.get_or_create_session()
118121
return session.send_stream(
119122
service_name,
120123
procedure_name,

replit_river/client_transport.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,32 @@ async def close(self) -> None:
7373
self._rate_limiter.close()
7474
await self._close_all_sessions()
7575

76+
async def get_or_create_session(self) -> ClientSession:
77+
async with self._create_session_lock:
78+
existing_session = await self._get_existing_session()
79+
if not existing_session:
80+
return await self._create_new_session()
81+
is_session_open = await existing_session.is_session_open()
82+
if not is_session_open:
83+
return await self._create_new_session()
84+
is_ws_open = await existing_session.is_websocket_open()
85+
if is_ws_open:
86+
return existing_session
87+
new_ws, _, hs_response = await self._establish_new_connection(
88+
existing_session
89+
)
90+
if hs_response.status.sessionId == existing_session.session_id:
91+
logger.info(
92+
"Replacing ws connection in session id %s",
93+
existing_session.session_id,
94+
)
95+
await existing_session.replace_with_new_websocket(new_ws)
96+
return existing_session
97+
else:
98+
logger.info("Closing stale session %s", existing_session.session_id)
99+
await existing_session.close()
100+
return await self._create_new_session()
101+
76102
async def _get_existing_session(self) -> Optional[ClientSession]:
77103
async with self._session_lock:
78104
if not self._sessions:
@@ -190,33 +216,7 @@ async def _create_new_session(
190216
async def _retry_connection(self) -> ClientSession:
191217
if not self._transport_options.transparent_reconnect:
192218
await self._close_all_sessions()
193-
return await self._get_or_create_session()
194-
195-
async def _get_or_create_session(self) -> ClientSession:
196-
async with self._create_session_lock:
197-
existing_session = await self._get_existing_session()
198-
if not existing_session:
199-
return await self._create_new_session()
200-
is_session_open = await existing_session.is_session_open()
201-
if not is_session_open:
202-
return await self._create_new_session()
203-
is_ws_open = await existing_session.is_websocket_open()
204-
if is_ws_open:
205-
return existing_session
206-
new_ws, _, hs_response = await self._establish_new_connection(
207-
existing_session
208-
)
209-
if hs_response.status.sessionId == existing_session.session_id:
210-
logger.info(
211-
"Replacing ws connection in session id %s",
212-
existing_session.session_id,
213-
)
214-
await existing_session.replace_with_new_websocket(new_ws)
215-
return existing_session
216-
else:
217-
logger.info("Closing stale session %s", existing_session.session_id)
218-
await existing_session.close()
219-
return await self._create_new_session()
219+
return await self.get_or_create_session()
220220

221221
async def _send_handshake_request(
222222
self,

0 commit comments

Comments
 (0)