99logger = logging .getLogger (__name__ )
1010
1111
12- async def watch_queue (ws , buf ):
13- while not ws .closed :
12+ class Connection :
13+
14+ def __aiter__ (self ):
15+ return self
16+
17+ async def __anext__ (self ):
18+ raise NotImplementedError ("subclasses should implement this" )
19+
20+ async def close (self ):
21+ raise NotImplementedError ("subclasses should implement this" )
22+
23+ @property
24+ def closed (self ):
25+ raise NotImplementedError ("subclasses should implement this" )
26+
27+ async def send (self , bytes_ ):
28+ pass
29+
30+
31+ class WebSocket (Connection ):
32+
33+ def __init__ (self , ws ):
34+ self .ws = ws
35+
36+ async def __anext__ (self ):
37+ msg = await self .ws .__anext__ ()
38+ return msg .data
39+
40+ async def close (self ):
41+ return await self .ws .close ()
42+
43+ @property
44+ def closed (self ):
45+ return self .ws .closed
46+
47+ async def send (self , bytes_ ):
48+ await self .ws .send_bytes (bytes_ )
49+
50+
51+ async def watch_queue (conn , buf ):
52+ while not conn .closed :
1453 try :
1554 msg = await asyncio .wait_for (buf .get (), 5.0 )
1655 except asyncio .TimeoutError :
@@ -20,40 +59,41 @@ async def watch_queue(ws, buf):
2059 continue
2160
2261 try :
23- await ws . send_bytes (msg )
62+ await conn . send (msg )
2463 except Exception :
2564 logger .exception ("watch_queue: error received trying to write" )
2665 await buf .put (msg )
27- return await ws .close ()
66+ return await conn .close ()
2867 logger .debug ("watch_queue: ws is now closed" )
2968
3069
3170class WSBase :
3271 def __init__ (self , receptor , loop ):
3372 self .receptor = receptor
3473 self .loop = loop
74+ self .conn = None
3575 self .buf = FramedBuffer (loop = self .loop )
3676 self .remote_id = None
3777 self .read_task = None
3878 self .handle_task = None
3979 self .write_task = None
4080
41- def start_receiving (self , ws ):
81+ def start_receiving (self ):
4282 logger .debug ("starting recv" )
43- self .read_task = self .loop .create_task (self .receive (ws ))
83+ self .read_task = self .loop .create_task (self .receive ())
4484
45- async def receive (self , ws ):
85+ async def receive (self ):
4686 try :
47- async for msg in ws :
48- await self .buf .put (msg . data )
87+ async for msg in self . conn :
88+ await self .buf .put (msg )
4989 except Exception :
5090 logger .exception ("receive" )
5191
52- def register (self , ws ):
53- self .receptor .update_connections (ws , id_ = self .remote_id )
92+ def register (self ):
93+ self .receptor .update_connections (self . conn , id_ = self .remote_id )
5494
55- def unregister (self , ws ):
56- self .receptor .remove_connection (ws , id_ = self .remote_id , loop = self .loop )
95+ def unregister (self ):
96+ self .receptor .remove_connection (self . conn , id_ = self .remote_id , loop = self .loop )
5797 self ._cancel (self .read_task )
5898 self ._cancel (self .handle_task )
5999 self ._cancel (self .write_task )
@@ -62,12 +102,12 @@ def _cancel(self, task):
62102 if task :
63103 task .cancel ()
64104
65- async def hello (self , ws ):
105+ async def hello (self ):
66106 logger .debug ("sending HI" )
67107 msg = self .receptor ._say_hi ().serialize ()
68- await ws . send_bytes (msg )
108+ await self . conn . send (msg )
69109
70- async def start_processing (self , ws ):
110+ async def start_processing (self ):
71111 logger .debug ("sending routes" )
72112 await self .receptor .send_route_advertisement ()
73113 logger .debug ("starting normal loop" )
@@ -77,29 +117,30 @@ async def start_processing(self, ws):
77117 out = self .receptor .buffer_mgr .get_buffer_for_node (
78118 self .remote_id , self .receptor
79119 )
80- self .write_task = self .loop .create_task (watch_queue (ws , out ))
120+ self .write_task = self .loop .create_task (watch_queue (self . conn , out ))
81121 return await self .write_task
82122
83- async def _wait_handshake (self , ws ):
123+ async def _wait_handshake (self ):
84124 logger .debug ("serve: waiting for HI" )
85125 response = await self .buf .get () # TODO: deal with timeout
86126 self .remote_id = response .header ["id" ]
87- self .register (ws )
127+ self .register ()
88128
89129
90130class WSClient (WSBase ):
91131 async def connect (self , uri ):
92132 async with aiohttp .ClientSession ().ws_connect (uri ) as ws :
93133 try :
94- self .start_receiving (ws )
95- await self .hello (ws )
96- await self ._wait_handshake (ws )
97- await self .start_processing (ws )
134+ self .conn = WebSocket (ws )
135+ self .start_receiving ()
136+ await self .hello ()
137+ await self ._wait_handshake ()
138+ await self .start_processing ()
98139 logger .debug ("connect: normal exit" )
99140 except Exception :
100141 logger .exception ("connect" )
101142 finally :
102- self .unregister (ws )
143+ self .unregister ()
103144 await asyncio .sleep (5 )
104145 logger .debug ("connect: reconnecting" )
105146 self .loop .create_task (self .connect (uri ))
@@ -111,13 +152,15 @@ async def serve(self, request):
111152 ws = aiohttp .web .WebSocketResponse ()
112153 await ws .prepare (request )
113154
155+ self .conn = WebSocket (ws )
156+
114157 try :
115- self .start_receiving (ws )
116- await self ._wait_handshake (ws )
117- await self .hello (ws )
118- await self .start_processing (ws )
158+ self .start_receiving ()
159+ await self ._wait_handshake ()
160+ await self .hello ()
161+ await self .start_processing ()
119162 finally :
120- self .unregister (ws )
163+ self .unregister ()
121164
122165 def app (self ):
123166 app = aiohttp .web .Application ()
0 commit comments