Skip to content

Commit b500a0d

Browse files
authored
Combine uri and metadata factory (#87)
Why === It's unnecessary to have two functions. With the way our metadata and websocket uri are setup, it's just added complexity What changed ============ - Make `handshake_metadata_factory` and `websocket_uri_factory` a single `uri_and_metadata_factory` that grabs both. - Returns a new TypeDict `UriAndMetadata` - Rename `HandshakeType` to `HandshakeMetadataType` - `HandshakeMetadataType` `TypeVar` is reused (not sure if i'm doing some python no-no here). Also moved it to the options module to avoid circular imports
1 parent b850388 commit b500a0d

5 files changed

Lines changed: 54 additions & 45 deletions

File tree

replit_river/client.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
import logging
22
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable
3-
from typing import Any, Generic, Optional, TypeVar, Union
3+
from typing import Any, Generic, Optional, Union
44

55
from replit_river.client_transport import ClientTransport
6-
from replit_river.transport_options import TransportOptions
6+
from replit_river.transport_options import (
7+
HandshakeMetadataType,
8+
TransportOptions,
9+
UriAndMetadata,
10+
)
711

812
from .rpc import (
913
ErrorType,
@@ -15,26 +19,23 @@
1519
logger = logging.getLogger(__name__)
1620

1721

18-
HandshakeType = TypeVar("HandshakeType")
19-
20-
21-
class Client(Generic[HandshakeType]):
22+
class Client(Generic[HandshakeMetadataType]):
2223
def __init__(
2324
self,
24-
websocket_uri_factory: Callable[[], Awaitable[str]],
25+
uri_and_metadata_factory: Callable[
26+
[], Awaitable[UriAndMetadata[HandshakeMetadataType]]
27+
],
2528
client_id: str,
2629
server_id: str,
2730
transport_options: TransportOptions,
28-
handshake_metadata_factory: Callable[[], Awaitable[HandshakeType]],
2931
) -> None:
3032
self._client_id = client_id
3133
self._server_id = server_id
32-
self._transport = ClientTransport[HandshakeType](
33-
websocket_uri_factory=websocket_uri_factory,
34+
self._transport = ClientTransport[HandshakeMetadataType](
35+
uri_and_metadata_factory=uri_and_metadata_factory,
3436
client_id=client_id,
3537
server_id=server_id,
3638
transport_options=transport_options,
37-
handshake_metadata_factory=handshake_metadata_factory,
3839
)
3940

4041
async def close(self) -> None:

replit_river/client_transport.py

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import asyncio
22
import logging
33
from collections.abc import Awaitable, Callable
4-
from typing import Generic, Optional, Tuple, TypeVar
4+
from typing import Generic, Optional, Tuple
55

66
import websockets
77
from pydantic import ValidationError
@@ -37,35 +37,34 @@
3737
InvalidMessageException,
3838
)
3939
from replit_river.transport import Transport
40-
from replit_river.transport_options import TransportOptions
40+
from replit_river.transport_options import (
41+
HandshakeMetadataType,
42+
TransportOptions,
43+
UriAndMetadata,
44+
)
4145

4246
logger = logging.getLogger(__name__)
4347

4448

45-
HandshakeType = TypeVar("HandshakeType")
46-
47-
48-
class ClientTransport(Transport, Generic[HandshakeType]):
49+
class ClientTransport(Transport, Generic[HandshakeMetadataType]):
4950
def __init__(
5051
self,
51-
websocket_uri_factory: Callable[[], Awaitable[str]],
52+
uri_and_metadata_factory: Callable[[], Awaitable[UriAndMetadata]],
5253
client_id: str,
5354
server_id: str,
5455
transport_options: TransportOptions,
55-
handshake_metadata_factory: Callable[[], Awaitable[HandshakeType]],
5656
):
5757
super().__init__(
5858
transport_id=client_id,
5959
transport_options=transport_options,
6060
is_server=False,
6161
)
62-
self._websocket_uri_factory = websocket_uri_factory
62+
self._uri_and_metadata_factory = uri_and_metadata_factory
6363
self._client_id = client_id
6464
self._server_id = server_id
6565
self._rate_limiter = LeakyBucketRateLimit(
6666
transport_options.connection_retry_options
6767
)
68-
self._handshake_metadata_factory = handshake_metadata_factory
6968
# We want to make sure there's only one session creation at a time
7069
self._create_session_lock = asyncio.Lock()
7170

@@ -121,7 +120,7 @@ async def _establish_new_connection(
121120
old_session: Optional[ClientSession] = None,
122121
) -> Tuple[
123122
WebSocketCommonProtocol,
124-
ControlMessageHandshakeRequest[HandshakeType],
123+
ControlMessageHandshakeRequest[HandshakeMetadataType],
125124
ControlMessageHandshakeResponse,
126125
]:
127126
"""Build a new websocket connection with retry logic."""
@@ -147,9 +146,8 @@ async def _establish_new_connection(
147146
old_session = None
148147

149148
try:
150-
websocket_uri = await self._websocket_uri_factory()
151-
handshake_metadata = await self._handshake_metadata_factory()
152-
ws = await websockets.connect(websocket_uri)
149+
uri_and_metadata = await self._uri_and_metadata_factory()
150+
ws = await websockets.connect(uri_and_metadata["uri"])
153151
session_id = (
154152
self.generate_nanoid()
155153
if not old_session
@@ -164,7 +162,7 @@ async def _establish_new_connection(
164162
self._transport_id,
165163
self._server_id,
166164
session_id,
167-
handshake_metadata,
165+
uri_and_metadata["metadata"],
168166
ws,
169167
old_session,
170168
)
@@ -223,11 +221,11 @@ async def _send_handshake_request(
223221
transport_id: str,
224222
to_id: str,
225223
session_id: str,
226-
handshake_metadata: Optional[HandshakeType],
224+
handshake_metadata: Optional[HandshakeMetadataType],
227225
websocket: WebSocketCommonProtocol,
228226
expected_session_state: ExpectedSessionState,
229-
) -> ControlMessageHandshakeRequest[HandshakeType]:
230-
handshake_request = ControlMessageHandshakeRequest[HandshakeType](
227+
) -> ControlMessageHandshakeRequest[HandshakeMetadataType]:
228+
handshake_request = ControlMessageHandshakeRequest[HandshakeMetadataType](
231229
type="HANDSHAKE_REQ",
232230
protocolVersion=PROTOCOL_VERSION,
233231
sessionId=session_id,
@@ -292,11 +290,12 @@ async def _establish_handshake(
292290
transport_id: str,
293291
to_id: str,
294292
session_id: str,
295-
handshake_metadata: HandshakeType,
293+
handshake_metadata: HandshakeMetadataType,
296294
websocket: WebSocketCommonProtocol,
297295
old_session: Optional[ClientSession],
298296
) -> Tuple[
299-
ControlMessageHandshakeRequest[HandshakeType], ControlMessageHandshakeResponse
297+
ControlMessageHandshakeRequest[HandshakeMetadataType],
298+
ControlMessageHandshakeResponse,
300299
]:
301300
try:
302301
handshake_request = await self._send_handshake_request(

replit_river/rpc.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,10 @@
3030
stringify_exception,
3131
)
3232
from replit_river.task_manager import BackgroundTaskManager
33-
from replit_river.transport_options import MAX_MESSAGE_BUFFER_SIZE
33+
from replit_river.transport_options import (
34+
MAX_MESSAGE_BUFFER_SIZE,
35+
HandshakeMetadataType,
36+
)
3437

3538
logger = logging.getLogger(__name__)
3639

@@ -62,15 +65,12 @@ class ExpectedSessionState(BaseModel):
6265
nextSentSeq: Optional[int] = None
6366

6467

65-
HandshakeType = TypeVar("HandshakeType")
66-
67-
68-
class ControlMessageHandshakeRequest(BaseModel, Generic[HandshakeType]):
68+
class ControlMessageHandshakeRequest(BaseModel, Generic[HandshakeMetadataType]):
6969
type: Literal["HANDSHAKE_REQ"] = "HANDSHAKE_REQ"
7070
protocolVersion: str
7171
sessionId: str
7272
expectedSessionState: ExpectedSessionState
73-
metadata: Optional[HandshakeType] = None
73+
metadata: Optional[HandshakeMetadataType] = None
7474

7575

7676
class HandShakeStatus(BaseModel):

replit_river/transport_options.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
from typing import Generic, TypedDict, TypeVar
23

34
from pydantic import BaseModel
45

@@ -50,3 +51,11 @@ def create_from_env(cls) -> "TransportOptions":
5051
heartbeat_ms=heartbeat_ms,
5152
heartbeats_until_dead=heartbeats_to_dead,
5253
)
54+
55+
56+
HandshakeMetadataType = TypeVar("HandshakeMetadataType")
57+
58+
59+
class UriAndMetadata(TypedDict, Generic[HandshakeMetadataType]):
60+
uri: str
61+
metadata: HandshakeMetadataType

tests/conftest.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from websockets.server import serve
99

1010
from replit_river.client import Client
11+
from replit_river.client_transport import UriAndMetadata
1112
from replit_river.error_schema import RiverError
1213
from replit_river.rpc import (
1314
GrpcContext,
@@ -138,20 +139,19 @@ async def client(
138139
no_logging_error: NoErrors,
139140
) -> AsyncGenerator[Client, None]:
140141

141-
async def websocket_uri_factory() -> str:
142-
return "ws://localhost:8765"
143-
144-
async def handshake_metadata_factory() -> None:
145-
return None
142+
async def websocket_uri_factory() -> UriAndMetadata[None]:
143+
return {
144+
"uri": "ws://localhost:8765",
145+
"metadata": None,
146+
}
146147

147148
try:
148149
async with serve(server.serve, "localhost", 8765):
149-
client: Client[Literal[None]] = Client(
150-
websocket_uri_factory,
150+
client: Client[Literal[None]] = Client[None](
151+
uri_and_metadata_factory=websocket_uri_factory,
151152
client_id="test_client",
152153
server_id="test_server",
153154
transport_options=transport_options,
154-
handshake_metadata_factory=handshake_metadata_factory,
155155
)
156156
try:
157157
yield client

0 commit comments

Comments
 (0)