Skip to content

Commit 0ae97cf

Browse files
authored
Merge pull request #54 from jhjaggars/websockets
Adding Websockets
2 parents 53a803f + 184bf48 commit 0ae97cf

12 files changed

Lines changed: 313 additions & 30 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,3 +111,4 @@ venv.bak/
111111

112112
# mypy
113113
.mypy_cache/
114+
graph_*.dot

Pipfile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ flake8 = "*"
1313
[packages]
1414
python-dateutil = "*"
1515
prometheus-client = "*"
16+
aiohttp = "*"
1617

1718
[requires]
1819
python_version = "3.6"

Pipfile.lock

Lines changed: 113 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

receptor/__main__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ def main(args=None):
4242

4343
try:
4444
config.go()
45-
except Exception as e:
46-
logger.error("An error occured while running receptor:\n%s" % (str(e),))
45+
except Exception:
46+
logger.exception("main: an error occured while running receptor")
4747
sys.exit(1)
4848

4949

receptor/controller.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import sys
66

77
from . import protocol
8+
from .ws import WSServer
89

910
logger = logging.getLogger(__name__)
1011

@@ -38,6 +39,13 @@ def mainloop(receptor, socket_path, loop=asyncio.get_event_loop()):
3839
config.controller_listen_address, config.controller_listen_port, ssl=config.get_server_ssl_context())
3940
logger.info("Serving on %s:%s", config.controller_listen_address, config.controller_listen_port)
4041
loop.create_task(listener)
42+
43+
ws_server = WSServer(receptor, loop)
44+
ws_listener = loop.create_server(ws_server.app().make_handler(),
45+
config.node_listen_address, config.node_listen_port + 1, ssl=config.get_server_ssl_context())
46+
loop.create_task(ws_listener)
47+
logger.info("Serving ws on %s:%s", config.node_listen_address, config.node_listen_port + 1)
48+
4149
control_listener = loop.create_unix_server(
4250
lambda: protocol.BasicControllerProtocol(receptor, loop),
4351
path=socket_path

receptor/messages/envelope.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ def __init__(self, msg_id=None, header=None, payload=None):
2727
self.header = header
2828
self.payload = payload
2929

30+
def __repr__(self):
31+
return f"FramedMessage(msg_id={self.msg_id}, header={self.header}, payload={self.payload})"
32+
3033
def serialize(self):
3134
h = json.dumps(self.header).encode("utf-8")
3235
return b''.join([
@@ -89,7 +92,6 @@ async def handle_frame(self, data):
8992
await self.consume(rest)
9093

9194
async def consume(self, data):
92-
logger.debug("consuming %d bytes; to_read = %d bytes", len(data), self.to_read)
9395
data, rest = data[:self.to_read], data[self.to_read:]
9496
self.to_read -= len(data)
9597
self.bb += data
@@ -117,6 +119,9 @@ async def finish(self):
117119
async def get(self):
118120
return await self.q.get()
119121

122+
def get_nowait(self):
123+
return self.q.get_nowait()
124+
120125

121126
class Frame:
122127
"""

receptor/node.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
11
import asyncio
22
import logging
3+
from urllib.parse import urlparse
34

45
from .protocol import BasicProtocol, create_peer
6+
from .ws import WSClient, WSServer
57

68
logger = logging.getLogger(__name__)
79

10+
def parse_peer(peer):
11+
if "://" not in peer:
12+
peer = f"receptor://{peer}"
13+
return urlparse(peer)
814

915
# FIXME: ping_interval is in the config, it shouldn't need to be passed as an arg here
1016
def mainloop(receptor, ping_interval=None, loop=asyncio.get_event_loop(), skip_run=False):
@@ -15,8 +21,23 @@ def mainloop(receptor, ping_interval=None, loop=asyncio.get_event_loop(), skip_r
1521
config.node_listen_address, config.node_listen_port, ssl=config.get_server_ssl_context())
1622
loop.create_task(listener)
1723
logger.info("Serving on %s:%s", config.node_listen_address, config.node_listen_port)
24+
25+
ws_server = WSServer(receptor, loop)
26+
ws_listener = loop.create_server(ws_server.app().make_handler(),
27+
config.node_listen_address, config.node_listen_port + 1, ssl=config.get_server_ssl_context())
28+
loop.create_task(ws_listener)
29+
logger.info("Serving ws on %s:%s", config.node_listen_address, config.node_listen_port + 1)
30+
1831
for peer in config.node_peers:
19-
loop.create_task(create_peer(receptor, loop, *peer.strip().split(":", 1)))
32+
parsed = parse_peer(peer)
33+
if parsed.scheme == "receptor":
34+
loop.create_task(create_peer(receptor, loop, parsed.hostname, parsed.port))
35+
elif parsed.scheme in ("ws", "wss"):
36+
c = WSClient(receptor, loop)
37+
loop.create_task(c.connect(peer))
38+
else:
39+
logger.warn(f"invalid peer: %s -> %s", peer, parsed)
40+
2041
if ping_interval > 0:
2142
ping_time = (((int(loop.time()) + 1) // ping_interval) + 1) * ping_interval
2243
loop.call_at(ping_time, loop.create_task, send_pings_and_reschedule(receptor, loop, ping_time, ping_interval))
@@ -34,6 +55,6 @@ async def send_pings_and_reschedule(receptor, loop, ping_time, ping_interval):
3455
logger.debug(f'Scheduling mesh ping.')
3556
for node_id in receptor.router.get_nodes():
3657
await receptor.router.ping_node(node_id)
37-
loop.call_at(ping_time + ping_interval,
58+
loop.call_at(ping_time + ping_interval,
3859
loop.create_task, send_pings_and_reschedule(
3960
receptor, loop, ping_time + ping_interval, ping_interval))

receptor/protocol.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -84,15 +84,7 @@ def handle_handshake(self, data):
8484
self.loop.create_task(self.receptor.message_handler(self.incoming_buffer))
8585

8686
def send_handshake(self):
87-
msg = envelope.CommandMessage(header={
88-
"cmd": "HI",
89-
"id": self.receptor.node_id,
90-
"expire_time": time.time() + 10,
91-
"meta": dict(capabilities=self.receptor.work_manager.get_capabilities(),
92-
groups=self.receptor.config.node_groups,
93-
work=self.receptor.work_manager.get_work())
94-
})
95-
self.transport.write(msg.serialize())
87+
self.transport.write(self.receptor._say_hi().serialize())
9688

9789

9890
class BasicProtocol(BaseProtocol):

receptor/receptor.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -84,14 +84,6 @@ def update_connection_manifest(self, connection):
8484
last=time.time()))
8585
self.write_connection_manifest(manifest)
8686

87-
def update_connections(self, protocol_obj):
88-
self.router.register_edge(protocol_obj.id, self.node_id, 1)
89-
if protocol_obj.id in self.connections:
90-
self.connections[protocol_obj.id].append(protocol_obj)
91-
else:
92-
self.connections[protocol_obj.id] = [protocol_obj]
93-
self.update_connection_manifest(protocol_obj.id)
94-
9587
async def message_handler(self, buf):
9688
logger.debug("spawning message_handler")
9789
while True:
@@ -106,10 +98,27 @@ async def message_handler(self, buf):
10698
else:
10799
await self.handle_message(data)
108100

101+
def update_connections(self, protocol_obj, id_=None):
102+
if id_ is None:
103+
id_ = protocol_obj.id
104+
105+
self.router.register_edge(id_, self.node_id, 1)
106+
if id_ in self.connections:
107+
self.connections[id_].append(protocol_obj)
108+
else:
109+
self.connections[id_] = [protocol_obj]
110+
self.update_connection_manifest(id_)
111+
109112
def add_connection(self, protocol_obj):
110113
self.update_connections(protocol_obj)
111114

112-
def remove_connection(self, protocol_obj):
115+
def remove_connection(self, protocol_obj, id_=None, loop=None):
116+
if id_ is None:
117+
id_ = protocol_obj.id
118+
119+
if loop is None:
120+
loop = protocol_obj.loop
121+
113122
notify_connections = []
114123
for connection_node in self.connections:
115124
if protocol_obj in self.connections[connection_node]:
@@ -120,14 +129,24 @@ def remove_connection(self, protocol_obj):
120129
self.router.debug_router()
121130
self.update_connection_manifest(connection_node)
122131
notify_connections += self.connections[connection_node]
123-
protocol_obj.loop.create_task(self.send_route_advertisement(self.router.get_edges()))
132+
loop.create_task(self.send_route_advertisement(self.router.get_edges()))
124133

125134
async def shutdown_handler(self):
126135
while True:
127136
if self.stop:
128137
return
129138
await asyncio.sleep(1)
130139

140+
def _say_hi(self):
141+
return envelope.CommandMessage(header={
142+
"cmd": "HI",
143+
"id": self.node_id,
144+
"expire_time": time.time() + 10,
145+
"meta": dict(capabilities=self.work_manager.get_capabilities(),
146+
groups=self.config.node_groups,
147+
work=self.work_manager.get_work())
148+
})
149+
131150
async def handle_route_advertisement(self, data):
132151
self.router.add_edges(data["edges"])
133152
await self.send_route_advertisement(data["edges"], data["seen"])

receptor/tests/test_framedbuffer.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import json
23
import uuid
34

@@ -119,10 +120,6 @@ async def test_malformed_frame(framed_buffer, msg_id):
119120
)
120121

121122

122-
@pytest.mark.skip(
123-
reason="""
124-
This test illustrates that sending an incomplete stream corrupts the transport"""
125-
)
126123
@pytest.mark.asyncio
127124
async def test_too_short(framed_buffer, msg_id):
128125
f1 = Frame(Frame.Types.HEADER, 1, 100, 1, 1)
@@ -133,4 +130,5 @@ async def test_too_short(framed_buffer, msg_id):
133130
await framed_buffer.put(f1.serialize() + too_short_header)
134131
await framed_buffer.put(f2.serialize() + too_short_payload)
135132

136-
await framed_buffer.get()
133+
with pytest.raises(asyncio.QueueEmpty):
134+
framed_buffer.get_nowait()

0 commit comments

Comments
 (0)