Skip to content

Commit 72cd320

Browse files
committed
Adding framed buffer
FramedBuffer encapsulates handling of the new message format. Essentially there are 1 or 2 frames that constitute a message. [Header Frame][Header Body][Payload Frame][Payload] or [Command Frame][Command Body] Signed-off-by: Jesse Jaggars <jjaggars@redhat.com>
1 parent b1f1ae0 commit 72cd320

2 files changed

Lines changed: 161 additions & 1 deletion

File tree

receptor/messages/envelope.py

Lines changed: 118 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,129 @@
1+
import asyncio
12
import base64
23
import datetime
4+
import itertools
35
import json
46
import logging
5-
import uuid
67
import time
8+
import uuid
9+
from struct import pack, unpack
710

811
logger = logging.getLogger(__name__)
912

13+
MAX_INT64 = (2 ** 64 - 1)
14+
15+
16+
class FramedBuffer:
17+
"""
18+
A buffer that accumulates frames and bytes to produce a header and a
19+
payload.
20+
21+
This buffer assumes that an entire message (denoted by msg_id) will be
22+
sent before another message is sent.
23+
"""
24+
def __init__(self, loop=None):
25+
self.q = asyncio.Queue(loop=loop)
26+
self.header = None
27+
self.bb = bytearray()
28+
self.current_frame = None
29+
self.to_read = 0
30+
31+
async def put(self, data):
32+
if not self.to_read:
33+
return await self.handle_frame(data)
34+
await self.consume(data)
35+
36+
async def handle_frame(self, data):
37+
self.current_frame, rest = Frame.from_data(data)
38+
if self.current_frame.type in (Frame.START_MSG, Frame.PAYLOAD):
39+
self.to_read = self.current_frame.length
40+
await self.consume(rest)
41+
else:
42+
raise Exception("Unknown Frame Type")
43+
44+
async def consume(self, data):
45+
self.to_read -= len(data)
46+
self.bb += data
47+
if self.to_read == 0:
48+
await self.finish()
49+
50+
async def finish(self):
51+
if self.current_frame.type == Frame.START_MSG:
52+
self.header = Header(**json.loads(self.bb))
53+
elif self.current_frame.type == Frame.PAYLOAD:
54+
await self.q.put((self.header, self.bb))
55+
self.header = None
56+
self.to_read = 0
57+
self.bb = bytearray()
58+
59+
async def get(self):
60+
return await self.q.get()
61+
62+
63+
class Frame:
64+
START_MSG = 0
65+
PAYLOAD = 1
66+
FINISH = 2
67+
68+
def __init__(self, type_, version, length, msg_id, id_):
69+
self.type = type_
70+
self.version = version
71+
self.length = length
72+
self.msg_id = msg_id
73+
self.id = id_
74+
75+
def serialize(self):
76+
high, low = ((self.msg_id >> 64) & MAX_INT64, self.msg_id & MAX_INT64)
77+
return b''.join([
78+
pack("ccIi", chr(self.type).encode("ascii"), chr(self.version).encode("ascii"), self.id, self.length),
79+
pack(">QQ", high, low),
80+
])
81+
82+
@classmethod
83+
def deserialize(cls, buf):
84+
t, v, i, length = unpack("ccIi", buf[0:12])
85+
hi, lo = unpack(">QQ", buf[12:])
86+
msg_id = (hi << 64) | lo
87+
return cls(ord(t), ord(v), length, msg_id, i)
88+
89+
@classmethod
90+
def from_data(cls, data):
91+
return cls.deserialize(data[:28]), data[28:]
92+
93+
94+
class Header:
95+
def __init__(self, sender, recipient, route_list):
96+
self.sender = sender
97+
self.recipient = recipient
98+
self.route_list = route_list
99+
100+
def serialize(self):
101+
return json.dumps({"sender": self.sender, "recipient": self.recipient, "route_list": self.route_list}).encode("utf-8")
102+
103+
def __repr__(self):
104+
return f"Header: {self.sender}, {self.recipient}, {self.route_list}"
105+
106+
def __eq__(self, other):
107+
return (self.sender, self.recipient, self.route_list) == (other.sender, other.recipient, other.route_list)
108+
109+
110+
def gen_chunks(buffer, header, msg_id=None, chunksize=2 ** 8):
111+
if msg_id is None:
112+
msg_id = uuid.uuid4().int
113+
seq = itertools.count()
114+
buf = bytearray(chunksize)
115+
bv = memoryview(buf)
116+
header = header.serialize()
117+
yield Frame(Frame.START_MSG, 1, len(header), msg_id, next(seq)).serialize() + header
118+
bytes_read = buffer.readinto(buf)
119+
while bytes_read:
120+
f = Frame(Frame.PAYLOAD, 1, bytes_read, msg_id, next(seq)).serialize()
121+
if bytes_read == chunksize:
122+
yield f + bv.tobytes()
123+
else:
124+
yield f + bv[:bytes_read].tobytes()
125+
bytes_read = buffer.readinto(buf)
126+
10127

11128
class OuterEnvelope:
12129
def __init__(self, frame_id, sender, recipient, route_list, inner):

receptor/tests/test_protocol.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1+
import io
2+
import uuid
3+
14
import pytest
25

36
from receptor import protocol
7+
from receptor.messages.envelope import Frame, FramedBuffer, Header, gen_chunks
48

59

610
def deser(x):
@@ -32,3 +36,42 @@ async def test_databuffer_many_msgs():
3236
assert msg[0] == await b.get()
3337
assert msg[1] == await b.get()
3438
assert b.q.empty()
39+
40+
41+
@pytest.mark.asyncio
42+
async def test_framedbuffer():
43+
b = FramedBuffer()
44+
45+
msg_id = uuid.uuid4().int
46+
header = Header("node1", "node2", [])
47+
header_bytes = header.serialize()
48+
f1 = Frame(Frame.START_MSG, 1, len(header_bytes), msg_id, 1)
49+
50+
await b.put(f1.serialize() + header_bytes)
51+
52+
payload = b"tina loves butts"
53+
payload2 = b"yep yep yep"
54+
f2 = Frame(Frame.PAYLOAD, 1, len(payload) + len(payload2), msg_id, 2)
55+
56+
await b.put(f2.serialize() + payload)
57+
await b.put(payload2)
58+
59+
h, p = await b.get()
60+
61+
assert h == header
62+
assert p == payload + payload2
63+
64+
65+
@pytest.mark.asyncio
66+
async def test_gen_chunks():
67+
68+
b = FramedBuffer()
69+
70+
header = Header("node1", "node2", [])
71+
payload = b"this is a test with a buffer"
72+
for chunk in gen_chunks(io.BytesIO(payload), header):
73+
await b.put(chunk)
74+
75+
h, p = await b.get()
76+
assert h == header
77+
assert p == payload

0 commit comments

Comments
 (0)