|
1 | | -import json |
2 | 1 | import logging |
3 | | -import time |
4 | 2 |
|
| 3 | +import asyncio |
5 | 4 | import aiohttp |
| 5 | +import aiohttp.web |
6 | 6 |
|
7 | | -from .protocol import DataBuffer |
| 7 | +from .messages.envelope import FramedBuffer |
8 | 8 |
|
9 | 9 | logger = logging.getLogger(__name__) |
10 | 10 |
|
11 | 11 |
|
12 | | -async def watch_queue(sock, buf): |
13 | | - while sock.open: |
| 12 | +async def watch_queue(ws, buf): |
| 13 | + while not ws.closed: |
14 | 14 | try: |
15 | 15 | msg = await buf.get() |
16 | 16 | except Exception: |
17 | | - logger.exception("Error getting data from buffer") |
18 | | - |
| 17 | + logger.exception("watch_queue: error getting data from buffer") |
| 18 | + continue |
| 19 | + |
19 | 20 | try: |
20 | | - sock.send(msg) |
| 21 | + await ws.send_bytes(msg) |
21 | 22 | except Exception: |
22 | | - logger.exception("Error received trying to write") |
| 23 | + logger.exception("watch_queue: error received trying to write") |
23 | 24 | await buf.put(msg) |
24 | | - return await sock.close() |
| 25 | + return await ws.close() |
| 26 | + logger.debug("watch_queue: ws is now closed") |
25 | 27 |
|
26 | 28 |
|
27 | | -class WSClient: |
| 29 | +class WSBase: |
28 | 30 | def __init__(self, receptor, loop): |
29 | 31 | self.receptor = receptor |
30 | 32 | self.loop = loop |
| 33 | + self.buf = FramedBuffer(loop=self.loop) |
| 34 | + self.remote_id = None |
31 | 35 |
|
32 | | - async def connect(self, uri): |
33 | | - async with aiohttp.ClientSession().ws_connect(uri) as sock: |
34 | | - # handshake |
35 | | - node_id = await self.handshake(sock) |
36 | | - incoming_buffer = DataBuffer() |
37 | | - self.loop.create_task(self.receive(sock, incoming_buffer)) # reader |
38 | | - |
39 | | - buf = self.receptor.buffer_mgr.get_buffer_for_node(node_id, self.receptor) |
40 | | - self.loop.create_task(watch_queue(sock, buf)) # writer |
41 | | - |
42 | | - self.loop.create_task(self.connect(uri)) |
| 36 | + async def receive(self, ws): |
| 37 | + try: |
| 38 | + async for msg in ws: |
| 39 | + await self.buf.put(msg.data) |
| 40 | + except Exception: |
| 41 | + logger.exception("receive") |
43 | 42 |
|
| 43 | + def register(self, ws): |
| 44 | + self.receptor.update_connections(ws, id_=self.remote_id) |
44 | 45 |
|
45 | | - async def handshake(self, sock): |
46 | | - msg = json.dumps({ |
47 | | - "cmd": "HI", |
48 | | - "id": self.receptor.node_id, |
49 | | - "expire_time": time.time() + 10, |
50 | | - "meta": { |
51 | | - "capabilities": self.receptor.work_manager.get_capabilities(), |
52 | | - "groups": self.receptor.config.node_groups, |
53 | | - "work": self.receptor.work_manager.get_work(), |
54 | | - } |
55 | | - }).encode("utf-8") |
56 | | - await sock.send_bytes(msg) |
57 | | - response = await sock.receive().json() |
58 | | - return response["id"] |
| 46 | + async def hello(self, ws): |
| 47 | + msg = self.receptor._say_hi().serialize() |
| 48 | + await ws.send_bytes(msg) |
59 | 49 |
|
60 | | - async def receive(self, sock, buf): |
61 | | - self.loop.create_task(self.receptor.message_handler(buf)) |
62 | | - async for msg in sock.receive(): |
63 | | - buf.add(msg.data) |
| 50 | + async def start_processing(self, ws): |
| 51 | + self.loop.create_task(self.receptor.message_handler(self.buf)) |
| 52 | + out = self.receptor.buffer_mgr.get_buffer_for_node( |
| 53 | + self.remote_id, self.receptor |
| 54 | + ) |
| 55 | + return await watch_queue(ws, out) |
64 | 56 |
|
65 | 57 |
|
66 | | -class WSServer: |
| 58 | +class WSClient(WSBase): |
| 59 | + async def connect(self, uri): |
| 60 | + try: |
| 61 | + async with aiohttp.ClientSession().ws_connect(uri) as ws: |
| 62 | + |
| 63 | + logger.debug("connect: starting recv") |
| 64 | + recv_loop = self.loop.create_task(self.receive(ws)) # reader |
| 65 | + logger.debug("connect: sending HI") |
| 66 | + await self.hello(ws) |
| 67 | + logger.debug("connect: waiting for HI") |
| 68 | + response = await self.buf.get() # TODO: deal with timeout |
| 69 | + self.remote_id = response.header["id"] |
| 70 | + self.register(ws) |
| 71 | + logger.debug("connect: sending routes") |
| 72 | + await self.receptor.send_route_advertisement() |
| 73 | + logger.debug("connect: starting normal loop") |
| 74 | + await self.start_processing(ws) |
| 75 | + logger.debug("connect: normal exit") |
| 76 | + except Exception: |
| 77 | + logger.exception("connect") |
| 78 | + logger.debug("connect: reconnecting") |
| 79 | + self.loop.create_task(self.connect(uri)) |
67 | 80 |
|
68 | | - def __init__(self, receptor, loop): |
69 | | - self.receptor = receptor |
70 | | - self.loop = loop |
71 | 81 |
|
| 82 | +class WSServer(WSBase): |
72 | 83 | async def serve(self, request): |
73 | 84 |
|
74 | 85 | ws = aiohttp.web.WebSocketResponse() |
75 | 86 | await ws.prepare(request) |
76 | 87 |
|
77 | | - handshake = await ws.receive().json() |
78 | | - await ws.send_json({ |
79 | | - "cmd": "HI", |
80 | | - "id": self.receptor.node_id, |
81 | | - "expire_time": time.time() + 10, |
82 | | - "meta": { |
83 | | - "capabilities": self.receptor.work_manager.get_capabilities(), |
84 | | - "groups": self.receptor.config.node_groups, |
85 | | - "work": self.receptor.work_manager.get_work(), |
86 | | - } |
87 | | - }) |
88 | | - |
89 | | - buf = self.receptor.buffer_mgr.get_buffer_for_node(handshake["id"], self.receptor) |
90 | | - self.loop.create_task(watch_queue(ws, buf)) # writer |
91 | | - |
92 | | - incoming_buffer = DataBuffer() |
93 | | - self.loop.create_task(self.receptor.message_handler(incoming_buffer)) |
94 | | - async for msg in ws: |
95 | | - incoming_buffer.add(msg.data) |
| 88 | + logger.debug("serve: starting recv") |
| 89 | + self.loop.create_task(self.receive(ws)) # reader |
| 90 | + logger.debug("serve: waiting for HI") |
| 91 | + response = await self.buf.get() # TODO: deal with timeout |
| 92 | + self.remote_id = response.header["id"] |
| 93 | + self.register(ws) |
| 94 | + logger.debug("serve: sending HI") |
| 95 | + await self.hello(ws) |
| 96 | + logger.debug("serve: sending routes") |
| 97 | + await self.receptor.send_route_advertisement() |
| 98 | + logger.debug("serve: starting normal recv loop") |
| 99 | + await self.start_processing(ws) |
| 100 | + |
| 101 | + def app(self): |
| 102 | + app = aiohttp.web.Application() |
| 103 | + app.add_routes([aiohttp.web.get("/", self.serve)]) |
| 104 | + return app |
0 commit comments