Skip to content

Commit 8e2ac6e

Browse files
committed
adding websockets
Signed-off-by: Jesse Jaggars <jjaggars@redhat.com>
1 parent 5e4db0b commit 8e2ac6e

7 files changed

Lines changed: 125 additions & 81 deletions

File tree

receptor/__main__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ def main(args=None):
4242

4343
try:
4444
config.go()
45-
except Exception as e:
46-
logger.error("An error occured while running receptor:\n%s" % (str(e),))
45+
except Exception:
46+
logger.exception("main: an error occured while running receptor")
4747
sys.exit(1)
4848

4949

receptor/controller.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import sys
66

77
from . import protocol
8+
from .ws import WSServer
89

910
logger = logging.getLogger(__name__)
1011

@@ -38,6 +39,13 @@ def mainloop(receptor, socket_path, loop=asyncio.get_event_loop()):
3839
config.controller_listen_address, config.controller_listen_port, ssl=config.get_server_ssl_context())
3940
logger.info("Serving on %s:%s", config.controller_listen_address, config.controller_listen_port)
4041
loop.create_task(listener)
42+
43+
ws_server = WSServer(receptor, loop)
44+
ws_listener = loop.create_server(ws_server.app().make_handler(),
45+
config.node_listen_address, config.node_listen_port + 1, ssl=config.get_server_ssl_context())
46+
loop.create_task(ws_listener)
47+
logger.info("Serving ws on %s:%s", config.node_listen_address, config.node_listen_port + 1)
48+
4149
control_listener = loop.create_unix_server(
4250
lambda: protocol.BasicControllerProtocol(receptor, loop),
4351
path=socket_path

receptor/messages/envelope.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ def __init__(self, msg_id=None, header=None, payload=None):
2727
self.header = header
2828
self.payload = payload
2929

30+
def __repr__(self):
31+
return f"FramedMessage(msg_id={self.msg_id}, header={self.header}, payload={self.payload})"
32+
3033
def serialize(self):
3134
h = json.dumps(self.header).encode("utf-8")
3235
return b''.join([
@@ -99,6 +102,7 @@ async def consume(self, data):
99102
await self.handle_frame(rest)
100103

101104
async def finish(self):
105+
logger.debug("finish: %s", self.current_frame)
102106
if self.current_frame.type == Frame.Types.HEADER:
103107
self.header = json.loads(self.bb)
104108
elif self.current_frame.type == Frame.Types.PAYLOAD:

receptor/node.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
11
import asyncio
22
import logging
3+
from urllib.parse import urlparse
34

45
from .protocol import BasicProtocol, create_peer
6+
from .ws import WSClient, WSServer
57

68
logger = logging.getLogger(__name__)
79

10+
def parse_peer(peer):
11+
if "://" not in peer:
12+
peer = f"receptor://{peer}"
13+
return urlparse(peer)
814

915
# FIXME: ping_interval is in the config, it shouldn't need to be passed as an arg here
1016
def mainloop(receptor, ping_interval=None, loop=asyncio.get_event_loop(), skip_run=False):
@@ -15,8 +21,21 @@ def mainloop(receptor, ping_interval=None, loop=asyncio.get_event_loop(), skip_r
1521
config.node_listen_address, config.node_listen_port, ssl=config.get_server_ssl_context())
1622
loop.create_task(listener)
1723
logger.info("Serving on %s:%s", config.node_listen_address, config.node_listen_port)
24+
25+
ws_server = WSServer(receptor, loop)
26+
ws_listener = loop.create_server(ws_server.app().make_handler(),
27+
config.node_listen_address, config.node_listen_port + 1, ssl=config.get_server_ssl_context())
28+
loop.create_task(ws_listener)
29+
logger.info("Serving ws on %s:%s", config.node_listen_address, config.node_listen_port + 1)
30+
1831
for peer in config.node_peers:
19-
loop.create_task(create_peer(receptor, loop, *peer.strip().split(":", 1)))
32+
parsed = parse_peer(peer)
33+
if parsed.scheme == "receptor://":
34+
loop.create_task(create_peer(receptor, loop, parsed.hostname, parsed.port))
35+
elif parsed.scheme in ("ws", "wss"):
36+
c = WSClient(receptor, loop)
37+
loop.create_task(c.connect(peer))
38+
2039
if ping_interval > 0:
2140
ping_time = (((int(loop.time()) + 1) // ping_interval) + 1) * ping_interval
2241
loop.call_at(ping_time, loop.create_task, send_pings_and_reschedule(receptor, loop, ping_time, ping_interval))

receptor/protocol.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -84,15 +84,7 @@ def handle_handshake(self, data):
8484
self.loop.create_task(self.receptor.message_handler(self.incoming_buffer))
8585

8686
def send_handshake(self):
87-
msg = envelope.CommandMessage(header={
88-
"cmd": "HI",
89-
"id": self.receptor.node_id,
90-
"expire_time": time.time() + 10,
91-
"meta": dict(capabilities=self.receptor.work_manager.get_capabilities(),
92-
groups=self.receptor.config.node_groups,
93-
work=self.receptor.work_manager.get_work())
94-
})
95-
self.transport.write(msg.serialize())
87+
self.transport.write(self.receptor._say_hi().serialize())
9688

9789

9890
class BasicProtocol(BaseProtocol):

receptor/receptor.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -84,13 +84,15 @@ def update_connection_manifest(self, connection):
8484
last=time.time()))
8585
self.write_connection_manifest(manifest)
8686

87-
def update_connections(self, protocol_obj):
88-
self.router.register_edge(protocol_obj.id, self.node_id, 1)
89-
if protocol_obj.id in self.connections:
90-
self.connections[protocol_obj.id].append(protocol_obj)
87+
def update_connections(self, protocol_obj, id_=None):
88+
if id_ is None:
89+
id_ = protocol_obj.id
90+
self.router.register_edge(id_, self.node_id, 1)
91+
if id_ in self.connections:
92+
self.connections[id_].append(protocol_obj)
9193
else:
92-
self.connections[protocol_obj.id] = [protocol_obj]
93-
self.update_connection_manifest(protocol_obj.id)
94+
self.connections[id_] = [protocol_obj]
95+
self.update_connection_manifest(id_)
9496

9597
async def message_handler(self, buf):
9698
logger.debug("spawning message_handler")
@@ -128,6 +130,16 @@ async def shutdown_handler(self):
128130
return
129131
await asyncio.sleep(1)
130132

133+
def _say_hi(self):
134+
return envelope.CommandMessage(header={
135+
"cmd": "HI",
136+
"id": self.node_id,
137+
"expire_time": time.time() + 10,
138+
"meta": dict(capabilities=self.work_manager.get_capabilities(),
139+
groups=self.config.node_groups,
140+
work=self.work_manager.get_work())
141+
})
142+
131143
async def handle_route_advertisement(self, data):
132144
self.router.add_edges(data["edges"])
133145
await self.send_route_advertisement(data["edges"], data["seen"])

receptor/ws.py

Lines changed: 72 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,95 +1,104 @@
1-
import json
21
import logging
3-
import time
42

3+
import asyncio
54
import aiohttp
5+
import aiohttp.web
66

7-
from .protocol import DataBuffer
7+
from .messages.envelope import FramedBuffer
88

99
logger = logging.getLogger(__name__)
1010

1111

12-
async def watch_queue(sock, buf):
13-
while sock.open:
12+
async def watch_queue(ws, buf):
13+
while not ws.closed:
1414
try:
1515
msg = await buf.get()
1616
except Exception:
17-
logger.exception("Error getting data from buffer")
18-
17+
logger.exception("watch_queue: error getting data from buffer")
18+
continue
19+
1920
try:
20-
sock.send(msg)
21+
await ws.send_bytes(msg)
2122
except Exception:
22-
logger.exception("Error received trying to write")
23+
logger.exception("watch_queue: error received trying to write")
2324
await buf.put(msg)
24-
return await sock.close()
25+
return await ws.close()
26+
logger.debug("watch_queue: ws is now closed")
2527

2628

27-
class WSClient:
29+
class WSBase:
2830
def __init__(self, receptor, loop):
2931
self.receptor = receptor
3032
self.loop = loop
33+
self.buf = FramedBuffer(loop=self.loop)
34+
self.remote_id = None
3135

32-
async def connect(self, uri):
33-
async with aiohttp.ClientSession().ws_connect(uri) as sock:
34-
# handshake
35-
node_id = await self.handshake(sock)
36-
incoming_buffer = DataBuffer()
37-
self.loop.create_task(self.receive(sock, incoming_buffer)) # reader
38-
39-
buf = self.receptor.buffer_mgr.get_buffer_for_node(node_id, self.receptor)
40-
self.loop.create_task(watch_queue(sock, buf)) # writer
41-
42-
self.loop.create_task(self.connect(uri))
36+
async def receive(self, ws):
37+
try:
38+
async for msg in ws:
39+
await self.buf.put(msg.data)
40+
except Exception:
41+
logger.exception("receive")
4342

43+
def register(self, ws):
44+
self.receptor.update_connections(ws, id_=self.remote_id)
4445

45-
async def handshake(self, sock):
46-
msg = json.dumps({
47-
"cmd": "HI",
48-
"id": self.receptor.node_id,
49-
"expire_time": time.time() + 10,
50-
"meta": {
51-
"capabilities": self.receptor.work_manager.get_capabilities(),
52-
"groups": self.receptor.config.node_groups,
53-
"work": self.receptor.work_manager.get_work(),
54-
}
55-
}).encode("utf-8")
56-
await sock.send_bytes(msg)
57-
response = await sock.receive().json()
58-
return response["id"]
46+
async def hello(self, ws):
47+
msg = self.receptor._say_hi().serialize()
48+
await ws.send_bytes(msg)
5949

60-
async def receive(self, sock, buf):
61-
self.loop.create_task(self.receptor.message_handler(buf))
62-
async for msg in sock.receive():
63-
buf.add(msg.data)
50+
async def start_processing(self, ws):
51+
self.loop.create_task(self.receptor.message_handler(self.buf))
52+
out = self.receptor.buffer_mgr.get_buffer_for_node(
53+
self.remote_id, self.receptor
54+
)
55+
return await watch_queue(ws, out)
6456

6557

66-
class WSServer:
58+
class WSClient(WSBase):
59+
async def connect(self, uri):
60+
try:
61+
async with aiohttp.ClientSession().ws_connect(uri) as ws:
62+
63+
logger.debug("connect: starting recv")
64+
recv_loop = self.loop.create_task(self.receive(ws)) # reader
65+
logger.debug("connect: sending HI")
66+
await self.hello(ws)
67+
logger.debug("connect: waiting for HI")
68+
response = await self.buf.get() # TODO: deal with timeout
69+
self.remote_id = response.header["id"]
70+
self.register(ws)
71+
logger.debug("connect: sending routes")
72+
await self.receptor.send_route_advertisement()
73+
logger.debug("connect: starting normal loop")
74+
await self.start_processing(ws)
75+
logger.debug("connect: normal exit")
76+
except Exception:
77+
logger.exception("connect")
78+
logger.debug("connect: reconnecting")
79+
self.loop.create_task(self.connect(uri))
6780

68-
def __init__(self, receptor, loop):
69-
self.receptor = receptor
70-
self.loop = loop
7181

82+
class WSServer(WSBase):
7283
async def serve(self, request):
7384

7485
ws = aiohttp.web.WebSocketResponse()
7586
await ws.prepare(request)
7687

77-
handshake = await ws.receive().json()
78-
await ws.send_json({
79-
"cmd": "HI",
80-
"id": self.receptor.node_id,
81-
"expire_time": time.time() + 10,
82-
"meta": {
83-
"capabilities": self.receptor.work_manager.get_capabilities(),
84-
"groups": self.receptor.config.node_groups,
85-
"work": self.receptor.work_manager.get_work(),
86-
}
87-
})
88-
89-
buf = self.receptor.buffer_mgr.get_buffer_for_node(handshake["id"], self.receptor)
90-
self.loop.create_task(watch_queue(ws, buf)) # writer
91-
92-
incoming_buffer = DataBuffer()
93-
self.loop.create_task(self.receptor.message_handler(incoming_buffer))
94-
async for msg in ws:
95-
incoming_buffer.add(msg.data)
88+
logger.debug("serve: starting recv")
89+
self.loop.create_task(self.receive(ws)) # reader
90+
logger.debug("serve: waiting for HI")
91+
response = await self.buf.get() # TODO: deal with timeout
92+
self.remote_id = response.header["id"]
93+
self.register(ws)
94+
logger.debug("serve: sending HI")
95+
await self.hello(ws)
96+
logger.debug("serve: sending routes")
97+
await self.receptor.send_route_advertisement()
98+
logger.debug("serve: starting normal recv loop")
99+
await self.start_processing(ws)
100+
101+
def app(self):
102+
app = aiohttp.web.Application()
103+
app.add_routes([aiohttp.web.get("/", self.serve)])
104+
return app

0 commit comments

Comments
 (0)