Skip to content

Commit 80e6d95

Browse files
authored
Add backpressure to send buffer (#97)
Why === Currently the send buffer has no backpressure, so if a user of the river client sends too many messages too quickly, the send buffer will overflow and the client will error out. Rather than relying on users to rate limit themselves, we introduce backpressure so client naturally will only send messages as fast as the river client/server is able to consume them. What changed ============ - Add a condvar to the message buffer to allow waiting for buffer space to come available when the buffer is full - Add a closed bit to the message buffer so when a session is shut down, we have a way of unblocking all the futures that are stuck waiting for space in the buffer - Close the message buffer when the session is closed Test plan ========= - Added a test which sends `2 * MAX_MESSAGE_BUFFER_SIZE` messages to a river server. Previously this resulted in an error. - Added some unit tests for the `MessageBuffer` to ensure backpressure works and closing works
1 parent 41647dc commit 80e6d95

5 files changed

Lines changed: 118 additions & 11 deletions

File tree

replit_river/message_buffer.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,25 +8,37 @@
88
logger = logging.getLogger(__name__)
99

1010

11+
class MessageBufferClosedError(BaseException):
12+
"""Raised when a message buffer is closed and is not accepting new messages."""
13+
14+
1115
class MessageBuffer:
1216
"""A buffer to store messages and support current updates"""
1317

1418
def __init__(self, max_num_messages: int = MAX_MESSAGE_BUFFER_SIZE):
1519
self.max_size = max_num_messages
1620
self.buffer: list[TransportMessage] = []
1721
self._lock = asyncio.Lock()
22+
self._space_available_cond = asyncio.Condition(lock=self._lock)
23+
self._closed = False
1824

1925
async def empty(self) -> bool:
2026
"""Check if the buffer is empty"""
2127
async with self._lock:
2228
return len(self.buffer) == 0
2329

2430
async def put(self, message: TransportMessage) -> None:
25-
"""Add a message to the buffer"""
26-
async with self._lock:
27-
if len(self.buffer) >= self.max_size:
28-
logger.error("Buffer is full, dropping message")
29-
raise ValueError("Buffer is full")
31+
"""Add a message to the buffer. Blocks until there is space in the buffer.
32+
33+
Raises:
34+
MessageBufferClosedError: if the buffer is closed.
35+
"""
36+
async with self._space_available_cond:
37+
await self._space_available_cond.wait_for(
38+
lambda: len(self.buffer) < self.max_size or self._closed
39+
)
40+
if self._closed:
41+
raise MessageBufferClosedError("message buffer is closed")
3042
self.buffer.append(message)
3143

3244
async def peek(self) -> Optional[TransportMessage]:
@@ -40,3 +52,12 @@ async def remove_old_messages(self, min_seq: int) -> None:
4052
"""Remove messages in the buffer with a seq number less than min_seq."""
4153
async with self._lock:
4254
self.buffer = [msg for msg in self.buffer if msg.seq >= min_seq]
55+
self._space_available_cond.notify_all()
56+
57+
async def close(self) -> None:
58+
"""
59+
Closes the message buffer and rejects any pending put operations.
60+
"""
61+
async with self._lock:
62+
self._closed = True
63+
self._space_available_cond.notify_all()

replit_river/session.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from aiochannel import Channel, ChannelClosed
99
from websockets.exceptions import ConnectionClosed
1010

11-
from replit_river.message_buffer import MessageBuffer
11+
from replit_river.message_buffer import MessageBuffer, MessageBufferClosedError
1212
from replit_river.messages import (
1313
FailedSendingMessageException,
1414
WebsocketClosedException,
@@ -386,10 +386,8 @@ async def send_message(
386386
async with self._msg_lock:
387387
try:
388388
await self._buffer.put(msg)
389-
except Exception:
390-
# We should close the session when there are too many messages in
391-
# buffer
392-
await self.close()
389+
except MessageBufferClosedError:
390+
# The session is closed and is no longer accepting new messages.
393391
return
394392
async with self._ws_lock:
395393
if not await self._ws_wrapper.is_open():
@@ -542,6 +540,8 @@ async def close(self) -> None:
542540

543541
await self.close_websocket(self._ws_wrapper, should_retry=False)
544542

543+
await self._buffer.close()
544+
545545
# Clear the session in transports
546546
await self._close_session_callback(self)
547547

replit_river/transport_options.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from pydantic import BaseModel
55

6-
MAX_MESSAGE_BUFFER_SIZE = 1024
6+
MAX_MESSAGE_BUFFER_SIZE = 128
77

88

99
class ConnectionRetryOptions(BaseModel):

tests/test_communication.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from replit_river.client import Client
77
from replit_river.error_schema import RiverError
8+
from replit_river.transport_options import MAX_MESSAGE_BUFFER_SIZE
89
from tests.conftest import deserialize_error, deserialize_response, serialize_request
910

1011

@@ -41,6 +42,27 @@ async def upload_data() -> AsyncGenerator[str, None]:
4142
assert response == "Uploaded: Initial Data, Data 1, Data 2, Data 3"
4243

4344

45+
@pytest.mark.asyncio
46+
async def test_upload_more_than_send_buffer_max(client: Client) -> None:
47+
iterations = MAX_MESSAGE_BUFFER_SIZE * 2
48+
49+
async def upload_data() -> AsyncGenerator[str, None]:
50+
for _ in range(0, iterations):
51+
yield "Data"
52+
53+
response = await client.send_upload(
54+
"test_service",
55+
"upload_method",
56+
"Initial Data",
57+
upload_data(),
58+
serialize_request,
59+
serialize_request,
60+
deserialize_response,
61+
deserialize_response,
62+
) # type: ignore
63+
assert response == "Uploaded: Initial Data" + (", Data" * iterations)
64+
65+
4466
@pytest.mark.asyncio
4567
async def test_upload_empty(client: Client) -> None:
4668
async def upload_data(enabled: bool = False) -> AsyncGenerator[str, None]:

tests/test_message_buffer.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import asyncio
2+
3+
import pytest
4+
5+
from replit_river.message_buffer import MessageBuffer, MessageBufferClosedError
6+
from replit_river.rpc import TransportMessage
7+
8+
9+
def mock_transport_message(seq: int) -> TransportMessage:
10+
return TransportMessage(
11+
seq=seq,
12+
id="test",
13+
ack=0,
14+
from_="test",
15+
to="test",
16+
streamId="test",
17+
controlFlags=0,
18+
payload=0,
19+
model_config={},
20+
)
21+
22+
23+
async def test_message_buffer_backpressure() -> None:
24+
"""
25+
Tests that MessageBuffer.put blocks until there is space in the buffer,
26+
creating back pressure in the client.
27+
"""
28+
buffer = MessageBuffer(max_num_messages=1)
29+
30+
iterations = 100
31+
32+
# We use a queue as a way to sync our test logic with the background
33+
# task with the testing logic.
34+
sync_events: asyncio.Queue[None] = asyncio.Queue()
35+
36+
async def put_messages() -> None:
37+
for i in range(0, iterations):
38+
await buffer.put(mock_transport_message(seq=i))
39+
await sync_events.put(None)
40+
41+
background_puts = asyncio.create_task(put_messages())
42+
43+
for i in range(1, iterations):
44+
# Wait for the put call to return.
45+
await sync_events.get()
46+
assert len(buffer.buffer) == 1
47+
await buffer.remove_old_messages(i)
48+
49+
await background_puts
50+
51+
52+
async def test_message_buffer_close() -> None:
53+
"""
54+
Tests that MessageBuffer.put raises an exception when the buffer
55+
is closed while the put operation is waiting for space in the buffer.
56+
"""
57+
buffer = MessageBuffer(max_num_messages=1)
58+
await buffer.put(mock_transport_message(seq=1))
59+
background_put = asyncio.create_task(buffer.put(mock_transport_message(seq=1)))
60+
await buffer.close()
61+
with pytest.raises(MessageBufferClosedError):
62+
await background_put
63+
with pytest.raises(MessageBufferClosedError):
64+
await buffer.put(mock_transport_message(seq=1))

0 commit comments

Comments
 (0)