Skip to content

Commit 1a23a60

Browse files
committed
chore: simplify the websocket response queue
1 parent a11dfa0 commit 1a23a60

4 files changed

Lines changed: 38 additions & 23 deletions

File tree

juju/client/connection.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ class Connection:
121121
_retries: int
122122
_retry_backoff: float
123123
uuid: str | None
124+
messages: IdQueue
124125

125126
@classmethod
126127
async def connect(
@@ -373,7 +374,7 @@ async def close(self, to_reconnect=False):
373374
if self.proxy is not None:
374375
self.proxy.close()
375376

376-
async def _recv(self, request_id):
377+
async def _recv(self, request_id: int) -> dict[str, Any]:
377378
if not self.is_open:
378379
raise websockets.exceptions.ConnectionClosed(
379380
websockets.frames.Close(

juju/jasyncio.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
)
2222
from asyncio import (
2323
CancelledError,
24+
Task,
2425
create_task,
2526
wait,
2627
)
@@ -84,7 +85,7 @@
8485
ROOT_LOGGER = logging.getLogger()
8586

8687

87-
def create_task_with_handler(coro, task_name, logger=ROOT_LOGGER):
88+
def create_task_with_handler(coro, task_name, logger=ROOT_LOGGER) -> Task:
8889
"""Wrapper around "asyncio.create_task" to make sure the task
8990
exceptions are handled properly.
9091

juju/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2917,7 +2917,7 @@ async def _get_source_api(self, url):
29172917

29182918
async def wait_for_idle(
29192919
self,
2920-
apps=None,
2920+
apps: list[str] | None = None,
29212921
raise_on_error=True,
29222922
raise_on_blocked=False,
29232923
wait_for_active=False,
@@ -2927,7 +2927,7 @@ async def wait_for_idle(
29272927
status=None,
29282928
wait_for_at_least_units=None,
29292929
wait_for_exact_units=None,
2930-
):
2930+
) -> None:
29312931
"""Wait for applications in the model to settle into an idle state.
29322932
29332933
:param List[str] apps: Optional list of specific app names to wait on.

juju/utils.py

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
# Copyright 2023 Canonical Ltd.
22
# Licensed under the Apache V2, see LICENCE file for details.
3+
from __future__ import annotations
34

5+
import asyncio
46
import base64
57
import os
68
import textwrap
79
import zipfile
810
from collections import defaultdict
9-
from functools import partial
1011
from pathlib import Path
12+
from typing import Any
1113

1214
import yaml
1315
from pyasn1.codec.der.encoder import encode
@@ -20,11 +22,11 @@
2022

2123
async def execute_process(*cmd, log=None):
2224
"""Wrapper around asyncio.create_subprocess_exec."""
23-
p = await jasyncio.create_subprocess_exec(
25+
p = await asyncio.create_subprocess_exec(
2426
*cmd,
25-
stdin=jasyncio.subprocess.PIPE,
26-
stdout=jasyncio.subprocess.PIPE,
27-
stderr=jasyncio.subprocess.PIPE,
27+
stdin=asyncio.subprocess.PIPE,
28+
stdout=asyncio.subprocess.PIPE,
29+
stderr=asyncio.subprocess.PIPE,
2830
)
2931
stdout, stderr = await p.communicate()
3032
if log:
@@ -84,7 +86,7 @@ async def read_ssh_key():
8486
"""Attempt to read the local juju admin's public ssh key, so that it can be
8587
passed on to a model.
8688
"""
87-
loop = jasyncio.get_running_loop()
89+
loop = asyncio.get_running_loop()
8890
return await loop.run_in_executor(None, _read_ssh_key)
8991

9092

@@ -93,20 +95,31 @@ class IdQueue:
9395
ID.
9496
"""
9597

96-
def __init__(self, maxsize=0):
97-
self._queues = defaultdict(partial(jasyncio.Queue, maxsize))
98-
99-
async def get(self, id_):
98+
_queues: dict[int, asyncio.Queue[dict[str, Any] | Exception]]
99+
100+
def __init__(self):
101+
self._queues = defaultdict(asyncio.Queue)
102+
# FIXME cleanup needed.
103+
# in some cases an Exception is put into the queue.
104+
# if the main coro exits, this exception will be logged as "never awaited"
105+
# we gotta do something about that to keep the output clean.
106+
#
107+
# Additionally, it's conceivable that a response is put in the queue
108+
# and then an exception is put via put_all()
109+
# the reader only ever fetches one item, and exception is "never awaited"
110+
# rewrite put_all to replace the pending response instead.
111+
112+
async def get(self, id_: int) -> dict[str, Any]:
100113
value = await self._queues[id_].get()
101114
del self._queues[id_]
102115
if isinstance(value, Exception):
103116
raise value
104117
return value
105118

106-
async def put(self, id_, value):
119+
async def put(self, id_: int, value: dict[str, Any]):
107120
await self._queues[id_].put(value)
108121

109-
async def put_all(self, value):
122+
async def put_all(self, value: Exception):
110123
for queue in self._queues.values():
111124
await queue.put(value)
112125

@@ -120,9 +133,9 @@ async def block_until(*conditions, timeout=None, wait_period=0.5):
120133

121134
async def _block():
122135
while not all(c() for c in conditions):
123-
await jasyncio.sleep(wait_period)
136+
await asyncio.sleep(wait_period)
124137

125-
await jasyncio.shield(jasyncio.wait_for(_block(), timeout))
138+
await asyncio.shield(asyncio.wait_for(_block(), timeout))
126139

127140

128141
async def block_until_with_coroutine(
@@ -136,12 +149,12 @@ async def block_until_with_coroutine(
136149

137150
async def _block():
138151
while not await condition_coroutine():
139-
await jasyncio.sleep(wait_period)
152+
await asyncio.sleep(wait_period)
140153

141-
await jasyncio.shield(jasyncio.wait_for(_block(), timeout=timeout))
154+
await asyncio.shield(asyncio.wait_for(_block(), timeout=timeout))
142155

143156

144-
async def wait_for_bundle(model, bundle, **kwargs):
157+
async def wait_for_bundle(model, bundle: str | Path, **kwargs) -> None:
145158
"""Helper to wait for just the apps in a specific bundle.
146159
147160
Equivalent to loading the bundle, pulling out the app names, and calling::
@@ -156,8 +169,8 @@ async def wait_for_bundle(model, bundle, **kwargs):
156169
bundle = bundle_path / "bundle.yaml"
157170
except OSError:
158171
pass
159-
bundle = yaml.safe_load(textwrap.dedent(bundle).strip())
160-
apps = list(bundle.get("applications", bundle.get("services")).keys())
172+
content: dict[str, Any] = yaml.safe_load(textwrap.dedent(bundle).strip())
173+
apps = list(content.get("applications", content.get("services")).keys())
161174
await model.wait_for_idle(apps, **kwargs)
162175

163176

0 commit comments

Comments
 (0)