Skip to content

Commit b2f8822

Browse files
committed
websockets now use the abstraction
Signed-off-by: Jesse Jaggars <jjaggars@redhat.com>
1 parent d77725e commit b2f8822

2 files changed

Lines changed: 51 additions & 39 deletions

File tree

receptor/ws.py renamed to receptor/connection.py

Lines changed: 48 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22

3+
import functools
34
import asyncio
45
import aiohttp
56
import aiohttp.web
@@ -9,8 +10,7 @@
910
logger = logging.getLogger(__name__)
1011

1112

12-
class Connection:
13-
13+
class Transport:
1414
def __aiter__(self):
1515
return self
1616

@@ -25,11 +25,10 @@ def closed(self):
2525
raise NotImplementedError("subclasses should implement this")
2626

2727
async def send(self, bytes_):
28-
pass
29-
28+
raise NotImplementedError("subclasses should implement this")
3029

31-
class WebSocket(Connection):
3230

31+
class WebSocket(Transport):
3332
def __init__(self, ws):
3433
self.ws = ws
3534

@@ -67,7 +66,7 @@ async def watch_queue(conn, buf):
6766
logger.debug("watch_queue: ws is now closed")
6867

6968

70-
class WSBase:
69+
class Worker:
7170
def __init__(self, receptor, loop):
7271
self.receptor = receptor
7372
self.loop = loop
@@ -126,43 +125,55 @@ async def _wait_handshake(self):
126125
self.remote_id = response.header["id"]
127126
self.register()
128127

128+
async def client(self, transport):
129+
try:
130+
self.conn = transport
131+
self.start_receiving()
132+
await self.hello()
133+
await self._wait_handshake()
134+
await self.start_processing()
135+
logger.debug("connect: normal exit")
136+
finally:
137+
self.unregister()
129138

130-
class WSClient(WSBase):
131-
async def connect(self, uri):
132-
async with aiohttp.ClientSession().ws_connect(uri) as ws:
133-
try:
134-
self.conn = WebSocket(ws)
135-
self.start_receiving()
136-
await self.hello()
137-
await self._wait_handshake()
138-
await self.start_processing()
139-
logger.debug("connect: normal exit")
140-
except Exception:
141-
logger.exception("connect")
142-
finally:
143-
self.unregister()
144-
await asyncio.sleep(5)
145-
logger.debug("connect: reconnecting")
146-
self.loop.create_task(self.connect(uri))
147-
148-
149-
class WSServer(WSBase):
150-
async def serve(self, request):
151-
152-
ws = aiohttp.web.WebSocketResponse()
153-
await ws.prepare(request)
154-
155-
self.conn = WebSocket(ws)
156-
139+
async def server(self, transport):
157140
try:
141+
self.conn = transport
158142
self.start_receiving()
159143
await self._wait_handshake()
160144
await self.hello()
161145
await self.start_processing()
162146
finally:
163147
self.unregister()
164148

165-
def app(self):
166-
app = aiohttp.web.Application()
167-
app.add_routes([aiohttp.web.get("/", self.serve)])
168-
return app
149+
150+
async def connect(uri, factory, loop=None):
151+
if not loop:
152+
loop = asyncio.get_event_loop()
153+
154+
worker = factory()
155+
try:
156+
async with aiohttp.ClientSession().ws_connect(uri) as ws:
157+
t = WebSocket(ws)
158+
await worker.client(t)
159+
except Exception:
160+
logger.exception("connect")
161+
finally:
162+
await asyncio.sleep(5)
163+
logger.debug("reconnecting")
164+
loop.create_task(connect(uri, factory=factory, loop=loop))
165+
166+
167+
async def serve(request, factory):
168+
ws = aiohttp.web.WebSocketResponse()
169+
await ws.prepare(request)
170+
171+
t = WebSocket(ws)
172+
await factory().server(t)
173+
174+
175+
def app(factory):
176+
handler = functools.partial(serve, factory=factory)
177+
app = aiohttp.web.Application()
178+
app.add_routes([aiohttp.web.get("/", handler)])
179+
return app

receptor/controller.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
import uuid
55
from urllib.parse import urlparse
66

7-
from .ws import WSServer, WSClient
87
from .protocol import BasicProtocol, create_peer
98
from .receptor import Receptor
109
from .messages import envelope
10+
from . import connection
1111

1212
logger = logging.getLogger(__name__)
1313

@@ -38,8 +38,9 @@ def enable_server(self, listen_url):
3838

3939
def enable_websocket_server(self, listen_url):
4040
service = urlparse(listen_url)
41+
factory = lambda: connection.Worker(receptor, loop)
4142
listener = self.loop.create_server(
42-
WSServer(self.receptor, self.loop).app().make_handler(),
43+
connection.app(factory).make_handler(),
4344
service.hostname, service.port,
4445
ssl=self.receptor.config.get_server_ssl_context())
4546
logger.info("Serving websockets on {}:{}".format(service.hostname, service.port))

0 commit comments

Comments
 (0)