1212async def watch_queue (ws , buf ):
1313 while not ws .closed :
1414 try :
15- msg = await buf .get ()
15+ msg = await asyncio .wait_for (buf .get (), 5.0 )
16+ except asyncio .TimeoutError :
17+ continue
1618 except Exception :
1719 logger .exception ("watch_queue: error getting data from buffer" )
1820 continue
@@ -32,6 +34,12 @@ def __init__(self, receptor, loop):
3234 self .loop = loop
3335 self .buf = FramedBuffer (loop = self .loop )
3436 self .remote_id = None
37+ self .read_task = None
38+ self .handle_task = None
39+ self .write_task = None
40+
41+ def start_receiving (self , ws ):
42+ self .read_task = self .loop .create_task (self .receive (ws ))
3543
3644 async def receive (self , ws ):
3745 try :
@@ -43,25 +51,35 @@ async def receive(self, ws):
4351 def register (self , ws ):
4452 self .receptor .update_connections (ws , id_ = self .remote_id )
4553
54+ def unregister (self , ws ):
55+ self .receptor .remove_connection (ws , id_ = self .remote_id )
56+ self ._cancel (self .read_task )
57+ self ._cancel (self .handle_task )
58+ self ._cancel (self .write_task )
59+
60+ def _cancel (self , task ):
61+ if task :
62+ task .cancel ()
63+
4664 async def hello (self , ws ):
4765 msg = self .receptor ._say_hi ().serialize ()
4866 await ws .send_bytes (msg )
4967
5068 async def start_processing (self , ws ):
51- self .loop .create_task (self .receptor .message_handler (self .buf ))
69+ self .handle_task = self . loop .create_task (self .receptor .message_handler (self .buf ))
5270 out = self .receptor .buffer_mgr .get_buffer_for_node (
5371 self .remote_id , self .receptor
5472 )
55- return await watch_queue (ws , out )
73+ self .write_task = self .loop .create_task (watch_queue (ws , out ))
74+ return await self .write_task
5675
5776
5877class WSClient (WSBase ):
5978 async def connect (self , uri ):
60- try :
61- async with aiohttp .ClientSession ().ws_connect (uri ) as ws :
62-
79+ async with aiohttp .ClientSession ().ws_connect (uri ) as ws :
80+ try :
6381 logger .debug ("connect: starting recv" )
64- recv_loop = self .loop . create_task ( self . receive ( ws )) # reader
82+ self .start_receiving ( ws )
6583 logger .debug ("connect: sending HI" )
6684 await self .hello (ws )
6785 logger .debug ("connect: waiting for HI" )
@@ -73,10 +91,13 @@ async def connect(self, uri):
7391 logger .debug ("connect: starting normal loop" )
7492 await self .start_processing (ws )
7593 logger .debug ("connect: normal exit" )
76- except Exception :
77- logger .exception ("connect" )
78- logger .debug ("connect: reconnecting" )
79- self .loop .create_task (self .connect (uri ))
94+ except Exception :
95+ logger .exception ("connect" )
96+ finally :
97+ self .unregister (ws )
98+ await asyncio .sleep (5 )
99+ logger .debug ("connect: reconnecting" )
100+ self .loop .create_task (self .connect (uri ))
80101
81102
82103class WSServer (WSBase ):
@@ -97,6 +118,7 @@ async def serve(self, request):
97118 await self .receptor .send_route_advertisement ()
98119 logger .debug ("serve: starting normal recv loop" )
99120 await self .start_processing (ws )
121+ self .unregister (ws )
100122
101123 def app (self ):
102124 app = aiohttp .web .Application ()
0 commit comments