22import logging
33from collections .abc import AsyncIterable
44from datetime import timedelta
5- from typing import Any , AsyncGenerator , Callable , Coroutine
5+ from typing import Any , AsyncGenerator , Callable , Coroutine , assert_never
66
7- import nanoid # type: ignore
7+ import nanoid
88import websockets
99from aiochannel import Channel
1010from aiochannel .errors import ChannelClosed
1111from opentelemetry .trace import Span
1212from websockets .exceptions import ConnectionClosed
1313
14- from replit_river .common_session import add_msg_to_stream
1514from replit_river .error_schema import (
1615 ERROR_CODE_CANCEL ,
1716 ERROR_CODE_STREAM_CLOSED ,
2524 parse_transport_msg ,
2625)
2726from replit_river .seq_manager import (
28- IgnoreMessageException ,
27+ IgnoreMessage ,
2928 InvalidMessageException ,
3029 OutOfOrderMessageException ,
3130)
3433
3534from .rpc import (
3635 ACK_BIT ,
37- STREAM_CLOSED_BIT ,
3836 STREAM_OPEN_BIT ,
3937 ErrorType ,
4038 InitType ,
4543logger = logging .getLogger (__name__ )
4644
4745
46+ STREAM_CLOSED_BIT = 0x0004 # Synonymous with the cancel bit in v2
47+
48+
4849class ClientSession (Session ):
4950 def __init__ (
5051 self ,
@@ -81,6 +82,13 @@ async def do_close_websocket() -> None:
8182
8283 self ._setup_heartbeats_task (do_close_websocket )
8384
85+ async def replace_with_new_websocket (
86+ self , new_ws : websockets .WebSocketCommonProtocol
87+ ) -> None :
88+ await super ().replace_with_new_websocket (new_ws )
89+ # serve() terminates itself when the ws dies, so we need to start it again
90+ await self .start_serve_responses ()
91+
8492 async def start_serve_responses (self ) -> None :
8593 self ._task_manager .create_task (self .serve ())
8694
@@ -120,31 +128,54 @@ async def _handle_messages_from_ws(self) -> None:
120128 ws_wrapper = self ._ws_wrapper
121129 async for message in ws_wrapper .ws :
122130 try :
123- if not await ws_wrapper .is_open ():
131+ if not ws_wrapper .is_open ():
124132 # We should not process messages if the websocket is closed.
125133 break
126- msg = parse_transport_msg (message , self ._transport_options )
134+ msg = parse_transport_msg (message )
135+ if isinstance (msg , str ):
136+ logger .debug ("Ignoring transport message" , exc_info = True )
137+ continue
127138
128139 logger .debug (f"{ self ._transport_id } got a message %r" , msg )
129140
130141 # Update bookkeeping
131- await self ._seq_manager .check_seq_and_update (msg )
142+ match self ._seq_manager .check_seq_and_update (msg ):
143+ case IgnoreMessage ():
144+ continue
145+ case None :
146+ pass
147+ case other :
148+ assert_never (other )
149+
132150 await self ._buffer .remove_old_messages (
133151 self ._seq_manager .receiver_ack ,
134152 )
135153 self ._reset_session_close_countdown ()
136154
137155 if msg .controlFlags & ACK_BIT != 0 :
138156 continue
139- async with self ._stream_lock :
140- stream = self ._streams .get (msg .streamId , None )
157+ stream = self ._streams .get (msg .streamId , None )
141158 if msg .controlFlags & STREAM_OPEN_BIT == 0 :
142159 if not stream :
143160 logger .warning ("no stream for %s" , msg .streamId )
144- raise IgnoreMessageException (
145- "no stream for message, ignoring"
146- )
147- await add_msg_to_stream (msg , stream )
161+ continue
162+
163+ if (
164+ msg .controlFlags & STREAM_CLOSED_BIT != 0
165+ and msg .payload .get ("type" , None ) == "CLOSE"
166+ ):
167+ # close message is not sent to the stream
168+ pass
169+ else :
170+ try :
171+ await stream .put (msg .payload )
172+ except ChannelClosed :
173+ # The client is no longer interested in this stream,
174+ # just drop the message.
175+ pass
176+ except RuntimeError as e :
177+ raise InvalidMessageException (e ) from e
178+
148179 else :
149180 raise InvalidMessageException (
150181 "Client should not receive stream open bit"
@@ -153,11 +184,7 @@ async def _handle_messages_from_ws(self) -> None:
153184 if msg .controlFlags & STREAM_CLOSED_BIT != 0 :
154185 if stream :
155186 stream .close ()
156- async with self ._stream_lock :
157- del self ._streams [msg .streamId ]
158- except IgnoreMessageException :
159- logger .debug ("Ignoring transport message" , exc_info = True )
160- continue
187+ del self ._streams [msg .streamId ]
161188 except OutOfOrderMessageException :
162189 logger .exception ("Out of order message, closing connection" )
163190 await ws_wrapper .close ()
0 commit comments