|
| 1 | +import asyncio |
1 | 2 | import base64 |
2 | 3 | import datetime |
| 4 | +import itertools |
3 | 5 | import json |
4 | 6 | import logging |
5 | | -import uuid |
6 | 7 | import time |
| 8 | +import uuid |
| 9 | +from struct import pack, unpack |
7 | 10 |
|
8 | 11 | logger = logging.getLogger(__name__) |
9 | 12 |
|
| 13 | +MAX_INT64 = (2 ** 64 - 1) |
| 14 | + |
| 15 | + |
| 16 | +class FramedBuffer: |
| 17 | + """ |
| 18 | + A buffer that accumulates frames and bytes to produce a header and a |
| 19 | + payload. |
| 20 | +
|
| 21 | + This buffer assumes that an entire message (denoted by msg_id) will be |
| 22 | + sent before another message is sent. |
| 23 | + """ |
| 24 | + def __init__(self, loop=None): |
| 25 | + self.q = asyncio.Queue(loop=loop) |
| 26 | + self.header = None |
| 27 | + self.bb = bytearray() |
| 28 | + self.current_frame = None |
| 29 | + self.to_read = 0 |
| 30 | + |
| 31 | + async def put(self, data): |
| 32 | + if not self.to_read: |
| 33 | + return await self.handle_frame(data) |
| 34 | + await self.consume(data) |
| 35 | + |
| 36 | + async def handle_frame(self, data): |
| 37 | + self.current_frame, rest = Frame.from_data(data) |
| 38 | + if self.current_frame.type in (Frame.START_MSG, Frame.PAYLOAD): |
| 39 | + self.to_read = self.current_frame.length |
| 40 | + await self.consume(rest) |
| 41 | + else: |
| 42 | + raise Exception("Unknown Frame Type") |
| 43 | + |
| 44 | + async def consume(self, data): |
| 45 | + self.to_read -= len(data) |
| 46 | + self.bb += data |
| 47 | + if self.to_read == 0: |
| 48 | + await self.finish() |
| 49 | + |
| 50 | + async def finish(self): |
| 51 | + if self.current_frame.type == Frame.START_MSG: |
| 52 | + self.header = Header(**json.loads(self.bb)) |
| 53 | + elif self.current_frame.type == Frame.PAYLOAD: |
| 54 | + await self.q.put((self.header, self.bb)) |
| 55 | + self.header = None |
| 56 | + self.to_read = 0 |
| 57 | + self.bb = bytearray() |
| 58 | + |
| 59 | + async def get(self): |
| 60 | + return await self.q.get() |
| 61 | + |
| 62 | + |
| 63 | +class Frame: |
| 64 | + START_MSG = 0 |
| 65 | + PAYLOAD = 1 |
| 66 | + FINISH = 2 |
| 67 | + |
| 68 | + def __init__(self, type_, version, length, msg_id, id_): |
| 69 | + self.type = type_ |
| 70 | + self.version = version |
| 71 | + self.length = length |
| 72 | + self.msg_id = msg_id |
| 73 | + self.id = id_ |
| 74 | + |
| 75 | + def serialize(self): |
| 76 | + high, low = ((self.msg_id >> 64) & MAX_INT64, self.msg_id & MAX_INT64) |
| 77 | + return b''.join([ |
| 78 | + pack("ccIi", chr(self.type).encode("ascii"), chr(self.version).encode("ascii"), self.id, self.length), |
| 79 | + pack(">QQ", high, low), |
| 80 | + ]) |
| 81 | + |
| 82 | + @classmethod |
| 83 | + def deserialize(cls, buf): |
| 84 | + t, v, i, length = unpack("ccIi", buf[0:12]) |
| 85 | + hi, lo = unpack(">QQ", buf[12:]) |
| 86 | + msg_id = (hi << 64) | lo |
| 87 | + return cls(ord(t), ord(v), length, msg_id, i) |
| 88 | + |
| 89 | + @classmethod |
| 90 | + def from_data(cls, data): |
| 91 | + return cls.deserialize(data[:28]), data[28:] |
| 92 | + |
| 93 | + |
| 94 | +class Header: |
| 95 | + def __init__(self, sender, recipient, route_list): |
| 96 | + self.sender = sender |
| 97 | + self.recipient = recipient |
| 98 | + self.route_list = route_list |
| 99 | + |
| 100 | + def serialize(self): |
| 101 | + return json.dumps({"sender": self.sender, "recipient": self.recipient, "route_list": self.route_list}).encode("utf-8") |
| 102 | + |
| 103 | + def __repr__(self): |
| 104 | + return f"Header: {self.sender}, {self.recipient}, {self.route_list}" |
| 105 | + |
| 106 | + def __eq__(self, other): |
| 107 | + return (self.sender, self.recipient, self.route_list) == (other.sender, other.recipient, other.route_list) |
| 108 | + |
| 109 | + |
| 110 | +def gen_chunks(buffer, header, msg_id=None, chunksize=2 ** 8): |
| 111 | + if msg_id is None: |
| 112 | + msg_id = uuid.uuid4().int |
| 113 | + seq = itertools.count() |
| 114 | + buf = bytearray(chunksize) |
| 115 | + bv = memoryview(buf) |
| 116 | + header = header.serialize() |
| 117 | + yield Frame(Frame.START_MSG, 1, len(header), msg_id, next(seq)).serialize() + header |
| 118 | + bytes_read = buffer.readinto(buf) |
| 119 | + while bytes_read: |
| 120 | + f = Frame(Frame.PAYLOAD, 1, bytes_read, msg_id, next(seq)).serialize() |
| 121 | + if bytes_read == chunksize: |
| 122 | + yield f + bv.tobytes() |
| 123 | + else: |
| 124 | + yield f + bv[:bytes_read].tobytes() |
| 125 | + bytes_read = buffer.readinto(buf) |
| 126 | + |
10 | 127 |
|
11 | 128 | class OuterEnvelope: |
12 | 129 | def __init__(self, frame_id, sender, recipient, route_list, inner): |
|
0 commit comments