Skip to content

Commit 4bdfa3f

Browse files
committed
fixing protocol selection and adding task cancellation
Signed-off-by: Jesse Jaggars <jjaggars@redhat.com>
1 parent f82e500 commit 4bdfa3f

3 files changed

Lines changed: 52 additions & 24 deletions

File tree

receptor/node.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,13 @@ def mainloop(receptor, ping_interval=None, loop=asyncio.get_event_loop(), skip_r
3030

3131
for peer in config.node_peers:
3232
parsed = parse_peer(peer)
33-
if parsed.scheme == "receptor://":
33+
if parsed.scheme == "receptor":
3434
loop.create_task(create_peer(receptor, loop, parsed.hostname, parsed.port))
3535
elif parsed.scheme in ("ws", "wss"):
3636
c = WSClient(receptor, loop)
3737
loop.create_task(c.connect(peer))
38+
else:
39+
print(f"invalid peer: {peer} -> {parsed}")
3840

3941
if ping_interval > 0:
4042
ping_time = (((int(loop.time()) + 1) // ping_interval) + 1) * ping_interval
@@ -53,6 +55,6 @@ async def send_pings_and_reschedule(receptor, loop, ping_time, ping_interval):
5355
logger.debug(f'Scheduling mesh ping.')
5456
for node_id in receptor.router.get_nodes():
5557
await receptor.router.ping_node(node_id)
56-
loop.call_at(ping_time + ping_interval,
58+
loop.call_at(ping_time + ping_interval,
5759
loop.create_task, send_pings_and_reschedule(
5860
receptor, loop, ping_time + ping_interval, ping_interval))

receptor/receptor.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -84,16 +84,6 @@ def update_connection_manifest(self, connection):
8484
last=time.time()))
8585
self.write_connection_manifest(manifest)
8686

87-
def update_connections(self, protocol_obj, id_=None):
88-
if id_ is None:
89-
id_ = protocol_obj.id
90-
self.router.register_edge(id_, self.node_id, 1)
91-
if id_ in self.connections:
92-
self.connections[id_].append(protocol_obj)
93-
else:
94-
self.connections[id_] = [protocol_obj]
95-
self.update_connection_manifest(id_)
96-
9787
async def message_handler(self, buf):
9888
logger.debug("spawning message_handler")
9989
while True:
@@ -108,10 +98,24 @@ async def message_handler(self, buf):
10898
else:
10999
await self.handle_message(data)
110100

101+
def update_connections(self, protocol_obj, id_=None):
102+
if id_ is None:
103+
id_ = protocol_obj.id
104+
105+
self.router.register_edge(id_, self.node_id, 1)
106+
if id_ in self.connections:
107+
self.connections[id_].append(protocol_obj)
108+
else:
109+
self.connections[id_] = [protocol_obj]
110+
self.update_connection_manifest(id_)
111+
111112
def add_connection(self, protocol_obj):
112113
self.update_connections(protocol_obj)
113114

114-
def remove_connection(self, protocol_obj):
115+
def remove_connection(self, protocol_obj, id_=None):
116+
if id_ is None:
117+
id_ = protocol_obj.id
118+
115119
notify_connections = []
116120
for connection_node in self.connections:
117121
if protocol_obj in self.connections[connection_node]:

receptor/ws.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
async def watch_queue(ws, buf):
1313
while not ws.closed:
1414
try:
15-
msg = await buf.get()
15+
msg = await asyncio.wait_for(buf.get(), 5.0)
16+
except asyncio.TimeoutError:
17+
continue
1618
except Exception:
1719
logger.exception("watch_queue: error getting data from buffer")
1820
continue
@@ -32,6 +34,12 @@ def __init__(self, receptor, loop):
3234
self.loop = loop
3335
self.buf = FramedBuffer(loop=self.loop)
3436
self.remote_id = None
37+
self.read_task = None
38+
self.handle_task = None
39+
self.write_task = None
40+
41+
def start_receiving(self, ws):
42+
self.read_task = self.loop.create_task(self.receive(ws))
3543

3644
async def receive(self, ws):
3745
try:
@@ -43,25 +51,35 @@ async def receive(self, ws):
4351
def register(self, ws):
4452
self.receptor.update_connections(ws, id_=self.remote_id)
4553

54+
def unregister(self, ws):
55+
self.receptor.remove_connection(ws, id_=self.remote_id)
56+
self._cancel(self.read_task)
57+
self._cancel(self.handle_task)
58+
self._cancel(self.write_task)
59+
60+
def _cancel(self, task):
61+
if task:
62+
task.cancel()
63+
4664
async def hello(self, ws):
4765
msg = self.receptor._say_hi().serialize()
4866
await ws.send_bytes(msg)
4967

5068
async def start_processing(self, ws):
51-
self.loop.create_task(self.receptor.message_handler(self.buf))
69+
self.handle_task = self.loop.create_task(self.receptor.message_handler(self.buf))
5270
out = self.receptor.buffer_mgr.get_buffer_for_node(
5371
self.remote_id, self.receptor
5472
)
55-
return await watch_queue(ws, out)
73+
self.write_task = self.loop.create_task(watch_queue(ws, out))
74+
return await self.write_task
5675

5776

5877
class WSClient(WSBase):
5978
async def connect(self, uri):
60-
try:
61-
async with aiohttp.ClientSession().ws_connect(uri) as ws:
62-
79+
async with aiohttp.ClientSession().ws_connect(uri) as ws:
80+
try:
6381
logger.debug("connect: starting recv")
64-
recv_loop = self.loop.create_task(self.receive(ws)) # reader
82+
self.start_receiving(ws)
6583
logger.debug("connect: sending HI")
6684
await self.hello(ws)
6785
logger.debug("connect: waiting for HI")
@@ -73,10 +91,13 @@ async def connect(self, uri):
7391
logger.debug("connect: starting normal loop")
7492
await self.start_processing(ws)
7593
logger.debug("connect: normal exit")
76-
except Exception:
77-
logger.exception("connect")
78-
logger.debug("connect: reconnecting")
79-
self.loop.create_task(self.connect(uri))
94+
except Exception:
95+
logger.exception("connect")
96+
finally:
97+
self.unregister(ws)
98+
await asyncio.sleep(5)
99+
logger.debug("connect: reconnecting")
100+
self.loop.create_task(self.connect(uri))
80101

81102

82103
class WSServer(WSBase):
@@ -97,6 +118,7 @@ async def serve(self, request):
97118
await self.receptor.send_route_advertisement()
98119
logger.debug("serve: starting normal recv loop")
99120
await self.start_processing(ws)
121+
self.unregister(ws)
100122

101123
def app(self):
102124
app = aiohttp.web.Application()

0 commit comments

Comments
 (0)