Skip to content

Commit 2f9b2f1

Browse files
committed
moving connection management from controller to connection package
Signed-off-by: Jesse Jaggars <jjaggars@redhat.com>
1 parent 5ec02a9 commit 2f9b2f1

6 files changed

Lines changed: 177 additions & 155 deletions

File tree

receptor/connection/__init__.py

Lines changed: 0 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -1,122 +0,0 @@
1-
import logging
2-
3-
import asyncio
4-
from collections.abc import AsyncIterator
5-
from abc import abstractmethod, abstractproperty
6-
7-
from ..messages.envelope import FramedBuffer
8-
9-
logger = logging.getLogger(__name__)
10-
11-
12-
class Transport(AsyncIterator):
13-
14-
@abstractmethod
15-
async def close(self):
16-
pass
17-
18-
@abstractproperty
19-
def closed(self):
20-
pass
21-
22-
@abstractmethod
23-
async def send(self, bytes_):
24-
pass
25-
26-
27-
async def watch_queue(conn, buf):
28-
while not conn.closed:
29-
try:
30-
msg = await asyncio.wait_for(buf.get(), 5.0)
31-
except asyncio.TimeoutError:
32-
continue
33-
except Exception:
34-
logger.exception("watch_queue: error getting data from buffer")
35-
continue
36-
37-
try:
38-
await conn.send(msg)
39-
except Exception:
40-
logger.exception("watch_queue: error received trying to write")
41-
await buf.put(msg)
42-
return await conn.close()
43-
44-
45-
class Worker:
46-
def __init__(self, receptor, loop):
47-
self.receptor = receptor
48-
self.loop = loop
49-
self.conn = None
50-
self.buf = FramedBuffer(loop=self.loop)
51-
self.remote_id = None
52-
self.read_task = None
53-
self.handle_task = None
54-
self.write_task = None
55-
56-
def start_receiving(self):
57-
self.read_task = self.loop.create_task(self.receive())
58-
59-
async def receive(self):
60-
try:
61-
async for msg in self.conn:
62-
await self.buf.put(msg)
63-
except Exception:
64-
logger.exception("receive")
65-
66-
def register(self):
67-
self.receptor.update_connections(self.conn, id_=self.remote_id)
68-
69-
def unregister(self):
70-
self.receptor.remove_connection(self.conn, id_=self.remote_id, loop=self.loop)
71-
self._cancel(self.read_task)
72-
self._cancel(self.handle_task)
73-
self._cancel(self.write_task)
74-
75-
def _cancel(self, task):
76-
if task:
77-
task.cancel()
78-
79-
async def hello(self):
80-
logger.debug("sending HI")
81-
msg = self.receptor._say_hi().serialize()
82-
await self.conn.send(msg)
83-
84-
async def start_processing(self):
85-
logger.debug("sending routes")
86-
await self.receptor.send_route_advertisement()
87-
logger.debug("starting normal loop")
88-
self.handle_task = self.loop.create_task(
89-
self.receptor.message_handler(self.buf)
90-
)
91-
out = self.receptor.buffer_mgr.get_buffer_for_node(
92-
self.remote_id, self.receptor
93-
)
94-
self.write_task = self.loop.create_task(watch_queue(self.conn, out))
95-
return await self.write_task
96-
97-
async def _wait_handshake(self):
98-
logger.debug("waiting for HI")
99-
response = await self.buf.get() # TODO: deal with timeout
100-
self.remote_id = response.header["id"]
101-
self.register()
102-
103-
async def client(self, transport):
104-
try:
105-
self.conn = transport
106-
self.start_receiving()
107-
await self.hello()
108-
await self._wait_handshake()
109-
await self.start_processing()
110-
logger.debug("normal exit")
111-
finally:
112-
self.unregister()
113-
114-
async def server(self, transport):
115-
try:
116-
self.conn = transport
117-
self.start_receiving()
118-
await self._wait_handshake()
119-
await self.hello()
120-
await self.start_processing()
121-
finally:
122-
self.unregister()

receptor/connection/base.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
import logging
2+
3+
import asyncio
4+
from collections.abc import AsyncIterator
5+
from abc import abstractmethod, abstractproperty
6+
7+
from ..messages.envelope import FramedBuffer
8+
9+
logger = logging.getLogger(__name__)
10+
11+
12+
class Transport(AsyncIterator):
13+
@abstractmethod
14+
async def close(self):
15+
pass
16+
17+
@abstractproperty
18+
def closed(self):
19+
pass
20+
21+
@abstractmethod
22+
async def send(self, bytes_):
23+
pass
24+
25+
26+
async def watch_queue(conn, buf):
27+
while not conn.closed:
28+
try:
29+
msg = await asyncio.wait_for(buf.get(), 5.0)
30+
except asyncio.TimeoutError:
31+
continue
32+
except Exception:
33+
logger.exception("watch_queue: error getting data from buffer")
34+
continue
35+
36+
try:
37+
await conn.send(msg)
38+
except Exception:
39+
logger.exception("watch_queue: error received trying to write")
40+
await buf.put(msg)
41+
return await conn.close()
42+
43+
44+
class Worker:
45+
def __init__(self, receptor, loop):
46+
self.receptor = receptor
47+
self.loop = loop
48+
self.conn = None
49+
self.buf = FramedBuffer(loop=self.loop)
50+
self.remote_id = None
51+
self.read_task = None
52+
self.handle_task = None
53+
self.write_task = None
54+
55+
def start_receiving(self):
56+
self.read_task = self.loop.create_task(self.receive())
57+
58+
async def receive(self):
59+
try:
60+
async for msg in self.conn:
61+
await self.buf.put(msg)
62+
except Exception:
63+
logger.exception("receive")
64+
65+
def register(self):
66+
self.receptor.update_connections(self.conn, id_=self.remote_id)
67+
68+
def unregister(self):
69+
self.receptor.remove_connection(self.conn, id_=self.remote_id, loop=self.loop)
70+
self._cancel(self.read_task)
71+
self._cancel(self.handle_task)
72+
self._cancel(self.write_task)
73+
74+
def _cancel(self, task):
75+
if task:
76+
task.cancel()
77+
78+
async def hello(self):
79+
logger.debug("sending HI")
80+
msg = self.receptor._say_hi().serialize()
81+
await self.conn.send(msg)
82+
83+
async def start_processing(self):
84+
logger.debug("sending routes")
85+
await self.receptor.send_route_advertisement()
86+
logger.debug("starting normal loop")
87+
self.handle_task = self.loop.create_task(
88+
self.receptor.message_handler(self.buf)
89+
)
90+
out = self.receptor.buffer_mgr.get_buffer_for_node(
91+
self.remote_id, self.receptor
92+
)
93+
self.write_task = self.loop.create_task(watch_queue(self.conn, out))
94+
return await self.write_task
95+
96+
async def _wait_handshake(self):
97+
logger.debug("waiting for HI")
98+
response = await self.buf.get() # TODO: deal with timeout
99+
self.remote_id = response.header["id"]
100+
self.register()
101+
102+
async def client(self, transport):
103+
try:
104+
self.conn = transport
105+
self.start_receiving()
106+
await self.hello()
107+
await self._wait_handshake()
108+
await self.start_processing()
109+
logger.debug("normal exit")
110+
finally:
111+
self.unregister()
112+
113+
async def server(self, transport):
114+
try:
115+
self.conn = transport
116+
self.start_receiving()
117+
await self._wait_handshake()
118+
await self.hello()
119+
await self.start_processing()
120+
finally:
121+
self.unregister()

receptor/connection/manager.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import functools
2+
import asyncio
3+
from urllib.parse import urlparse
4+
5+
from . import sock, ws
6+
7+
8+
def parse_peer(peer):
9+
if "://" not in peer:
10+
peer = f"receptor://{peer}"
11+
return urlparse(peer)
12+
13+
14+
class Manager:
15+
def __init__(self, factory, ssl_context, loop=None):
16+
self.factory = factory
17+
self.ssl_context = ssl_context
18+
self.loop = loop or asyncio.get_event_loop()
19+
20+
def get_listener(self, listen_url):
21+
service = parse_peer(listen_url)
22+
if service.scheme == "receptor":
23+
return asyncio.start_server(
24+
functools.partial(sock.serve, factory=self.factory),
25+
host=service.hostname,
26+
port=service.port,
27+
ssl=self.ssl_context,
28+
)
29+
elif service.scheme in ("ws", "wss"):
30+
return self.loop.create_server(
31+
ws.app(self.factory).make_handler(),
32+
service.hostname,
33+
service.port,
34+
ssl=self.ssl_context,
35+
)
36+
37+
def get_peer(self, peer):
38+
service = parse_peer(peer)
39+
if service.scheme == "receptor":
40+
return sock.connect(service.hostname, service.port, self.factory, self.loop)
41+
elif service.scheme in ("ws", "wss"):
42+
return ws.connect(peer, self.factory, self.loop)

receptor/connection/sock.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import asyncio
22
import logging
3-
from . import Transport
3+
from .base import Transport
44

55
logger = logging.getLogger(__name__)
66

receptor/connection/ws.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import aiohttp.web
55
import asyncio
66

7-
from . import Transport
7+
from .base import Transport
88

99
logger = logging.getLogger(__name__)
1010

receptor/controller.py

Lines changed: 12 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
import asyncio
22
import datetime
3-
import functools
43
import logging
54
import uuid
6-
from urllib.parse import urlparse
75

86
from .receptor import Receptor
97
from .messages import envelope
10-
from .connection import ws, sock, Worker
8+
from .connection.base import Worker
9+
from .connection.manager import Manager
1110

1211
logger = logging.getLogger(__name__)
1312

@@ -17,45 +16,27 @@ class Controller:
1716
def __init__(self, config, loop=asyncio.get_event_loop(), queue=None):
1817
self.receptor = Receptor(config)
1918
self.loop = loop
20-
self.factory = lambda: Worker(self.receptor, loop)
19+
self.connection_manager = Manager(
20+
lambda: Worker(self.receptor, loop),
21+
self.receptor.config.get_server_ssl_context(),
22+
loop
23+
)
2124
self.queue = queue
2225
if self.queue is None:
2326
self.queue = asyncio.Queue(loop=loop)
2427
self.receptor.response_queue = self.queue
2528

26-
def parse_peer(self, peer):
27-
if "://" not in peer:
28-
peer = f"receptor://{peer}"
29-
return urlparse(peer)
30-
3129
def enable_server(self, listen_url):
32-
service = self.parse_peer(listen_url)
33-
client_connected_cb = functools.partial(sock.serve, factory=self.factory)
34-
listener = asyncio.start_server(
35-
client_connected_cb,
36-
host=service.hostname,
37-
port=service.port,
38-
ssl=self.receptor.config.get_server_ssl_context())
39-
logger.info("Serving on {}:{}".format(service.hostname, service.port))
30+
listener = self.connection_manager.get_listener(listen_url)
31+
logger.info("Serving on %s", listen_url)
4032
self.loop.create_task(listener)
4133

4234
def enable_websocket_server(self, listen_url):
43-
service = urlparse(listen_url)
44-
listener = self.loop.create_server(
45-
ws.app(self.factory).make_handler(),
46-
service.hostname, service.port,
47-
ssl=self.receptor.config.get_server_ssl_context())
48-
logger.info("Serving websockets on {}:{}".format(service.hostname, service.port))
49-
self.loop.create_task(listener)
35+
self.enable_server(listen_url)
5036

5137
async def add_peer(self, peer):
52-
parsed = self.parse_peer(peer)
53-
if parsed.scheme == 'receptor':
54-
logger.info("Connecting to receptor peer {}".format(peer))
55-
await self.loop.create_task(sock.connect(parsed.hostname, parsed.port, self.factory, self.loop))
56-
elif parsed.scheme in ('ws', 'wss'):
57-
logger.info("Connecting to websocket peer {}".format(peer))
58-
await self.loop.create_task(ws.connect(peer, self.factory, self.loop))
38+
logger.info("Connecting to peer {}".format(peer))
39+
await self.loop.create_task(self.connection_manager.get_peer(peer))
5940

6041
async def recv(self):
6142
inner = await self.receptor.response_queue.get()

0 commit comments

Comments
 (0)