11import logging
22
3+ import functools
34import asyncio
45import aiohttp
56import aiohttp .web
910logger = logging .getLogger (__name__ )
1011
1112
12- class Connection :
13-
13+ class Transport :
1414 def __aiter__ (self ):
1515 return self
1616
@@ -25,11 +25,10 @@ def closed(self):
2525 raise NotImplementedError ("subclasses should implement this" )
2626
2727 async def send (self , bytes_ ):
28- pass
29-
28+ raise NotImplementedError ("subclasses should implement this" )
3029
31- class WebSocket (Connection ):
3230
31+ class WebSocket (Transport ):
3332 def __init__ (self , ws ):
3433 self .ws = ws
3534
@@ -67,7 +66,7 @@ async def watch_queue(conn, buf):
6766 logger .debug ("watch_queue: ws is now closed" )
6867
6968
70- class WSBase :
69+ class Worker :
7170 def __init__ (self , receptor , loop ):
7271 self .receptor = receptor
7372 self .loop = loop
@@ -126,43 +125,55 @@ async def _wait_handshake(self):
126125 self .remote_id = response .header ["id" ]
127126 self .register ()
128127
128+ async def client (self , transport ):
129+ try :
130+ self .conn = transport
131+ self .start_receiving ()
132+ await self .hello ()
133+ await self ._wait_handshake ()
134+ await self .start_processing ()
135+ logger .debug ("connect: normal exit" )
136+ finally :
137+ self .unregister ()
129138
130- class WSClient (WSBase ):
131- async def connect (self , uri ):
132- async with aiohttp .ClientSession ().ws_connect (uri ) as ws :
133- try :
134- self .conn = WebSocket (ws )
135- self .start_receiving ()
136- await self .hello ()
137- await self ._wait_handshake ()
138- await self .start_processing ()
139- logger .debug ("connect: normal exit" )
140- except Exception :
141- logger .exception ("connect" )
142- finally :
143- self .unregister ()
144- await asyncio .sleep (5 )
145- logger .debug ("connect: reconnecting" )
146- self .loop .create_task (self .connect (uri ))
147-
148-
149- class WSServer (WSBase ):
150- async def serve (self , request ):
151-
152- ws = aiohttp .web .WebSocketResponse ()
153- await ws .prepare (request )
154-
155- self .conn = WebSocket (ws )
156-
139+ async def server (self , transport ):
157140 try :
141+ self .conn = transport
158142 self .start_receiving ()
159143 await self ._wait_handshake ()
160144 await self .hello ()
161145 await self .start_processing ()
162146 finally :
163147 self .unregister ()
164148
165- def app (self ):
166- app = aiohttp .web .Application ()
167- app .add_routes ([aiohttp .web .get ("/" , self .serve )])
168- return app
149+
150+ async def connect (uri , factory , loop = None ):
151+ if not loop :
152+ loop = asyncio .get_event_loop ()
153+
154+ worker = factory ()
155+ try :
156+ async with aiohttp .ClientSession ().ws_connect (uri ) as ws :
157+ t = WebSocket (ws )
158+ await worker .client (t )
159+ except Exception :
160+ logger .exception ("connect" )
161+ finally :
162+ await asyncio .sleep (5 )
163+ logger .debug ("reconnecting" )
164+ loop .create_task (connect (uri , factory = factory , loop = loop ))
165+
166+
167+ async def serve (request , factory ):
168+ ws = aiohttp .web .WebSocketResponse ()
169+ await ws .prepare (request )
170+
171+ t = WebSocket (ws )
172+ await factory ().server (t )
173+
174+
175+ def app (factory ):
176+ handler = functools .partial (serve , factory = factory )
177+ app = aiohttp .web .Application ()
178+ app .add_routes ([aiohttp .web .get ("/" , handler )])
179+ return app
0 commit comments