11import asyncio
22import logging
33from collections .abc import Awaitable , Callable
4- from typing import Generic , Optional , Tuple , TypeVar
4+ from typing import Generic , Optional , Tuple
55
66import websockets
77from pydantic import ValidationError
3737 InvalidMessageException ,
3838)
3939from replit_river .transport import Transport
40- from replit_river .transport_options import TransportOptions
40+ from replit_river .transport_options import (
41+ HandshakeMetadataType ,
42+ TransportOptions ,
43+ UriAndMetadata ,
44+ )
4145
4246logger = logging .getLogger (__name__ )
4347
4448
45- HandshakeType = TypeVar ("HandshakeType" )
46-
47-
48- class ClientTransport (Transport , Generic [HandshakeType ]):
49+ class ClientTransport (Transport , Generic [HandshakeMetadataType ]):
4950 def __init__ (
5051 self ,
51- websocket_uri_factory : Callable [[], Awaitable [str ]],
52+ uri_and_metadata_factory : Callable [[], Awaitable [UriAndMetadata ]],
5253 client_id : str ,
5354 server_id : str ,
5455 transport_options : TransportOptions ,
55- handshake_metadata_factory : Callable [[], Awaitable [HandshakeType ]],
5656 ):
5757 super ().__init__ (
5858 transport_id = client_id ,
5959 transport_options = transport_options ,
6060 is_server = False ,
6161 )
62- self ._websocket_uri_factory = websocket_uri_factory
62+ self ._uri_and_metadata_factory = uri_and_metadata_factory
6363 self ._client_id = client_id
6464 self ._server_id = server_id
6565 self ._rate_limiter = LeakyBucketRateLimit (
6666 transport_options .connection_retry_options
6767 )
68- self ._handshake_metadata_factory = handshake_metadata_factory
6968 # We want to make sure there's only one session creation at a time
7069 self ._create_session_lock = asyncio .Lock ()
7170
@@ -121,7 +120,7 @@ async def _establish_new_connection(
121120 old_session : Optional [ClientSession ] = None ,
122121 ) -> Tuple [
123122 WebSocketCommonProtocol ,
124- ControlMessageHandshakeRequest [HandshakeType ],
123+ ControlMessageHandshakeRequest [HandshakeMetadataType ],
125124 ControlMessageHandshakeResponse ,
126125 ]:
127126 """Build a new websocket connection with retry logic."""
@@ -147,9 +146,8 @@ async def _establish_new_connection(
147146 old_session = None
148147
149148 try :
150- websocket_uri = await self ._websocket_uri_factory ()
151- handshake_metadata = await self ._handshake_metadata_factory ()
152- ws = await websockets .connect (websocket_uri )
149+ uri_and_metadata = await self ._uri_and_metadata_factory ()
150+ ws = await websockets .connect (uri_and_metadata ["uri" ])
153151 session_id = (
154152 self .generate_nanoid ()
155153 if not old_session
@@ -164,7 +162,7 @@ async def _establish_new_connection(
164162 self ._transport_id ,
165163 self ._server_id ,
166164 session_id ,
167- handshake_metadata ,
165+ uri_and_metadata [ "metadata" ] ,
168166 ws ,
169167 old_session ,
170168 )
@@ -223,11 +221,11 @@ async def _send_handshake_request(
223221 transport_id : str ,
224222 to_id : str ,
225223 session_id : str ,
226- handshake_metadata : Optional [HandshakeType ],
224+ handshake_metadata : Optional [HandshakeMetadataType ],
227225 websocket : WebSocketCommonProtocol ,
228226 expected_session_state : ExpectedSessionState ,
229- ) -> ControlMessageHandshakeRequest [HandshakeType ]:
230- handshake_request = ControlMessageHandshakeRequest [HandshakeType ](
227+ ) -> ControlMessageHandshakeRequest [HandshakeMetadataType ]:
228+ handshake_request = ControlMessageHandshakeRequest [HandshakeMetadataType ](
231229 type = "HANDSHAKE_REQ" ,
232230 protocolVersion = PROTOCOL_VERSION ,
233231 sessionId = session_id ,
@@ -292,11 +290,12 @@ async def _establish_handshake(
292290 transport_id : str ,
293291 to_id : str ,
294292 session_id : str ,
295- handshake_metadata : HandshakeType ,
293+ handshake_metadata : HandshakeMetadataType ,
296294 websocket : WebSocketCommonProtocol ,
297295 old_session : Optional [ClientSession ],
298296 ) -> Tuple [
299- ControlMessageHandshakeRequest [HandshakeType ], ControlMessageHandshakeResponse
297+ ControlMessageHandshakeRequest [HandshakeMetadataType ],
298+ ControlMessageHandshakeResponse ,
300299 ]:
301300 try :
302301 handshake_request = await self ._send_handshake_request (
0 commit comments