Skip to content

Commit e084a70

Browse files
committed
chore: connection impl type hint
1 parent 4b6b1ce commit e084a70

1 file changed

Lines changed: 10 additions & 5 deletions

File tree

juju/client/connection.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from .facade_versions import client_facade_versions, known_unsupported_facades
2828

2929
SPECIFIED_FACADES: TypeAlias = dict[str, dict[Literal["versions"], Sequence[int]]]
30+
_WebSocket: TypeAlias = "websockets.legacy.client.WebSocketClientProtocol"
3031

3132
LEVELS = ["TRACE", "DEBUG", "INFO", "WARNING", "ERROR"]
3233
log = logging.getLogger("juju.client.connection")
@@ -125,6 +126,7 @@ class Connection:
125126
_retry_backoff: float
126127
uuid: str | None
127128
messages: IdQueue
129+
_ws: _WebSocket | None
128130

129131
@classmethod
130132
async def connect(
@@ -303,7 +305,7 @@ def _get_ssl(self, cert=None):
303305
context.check_hostname = False
304306
return context
305307

306-
async def _open(self, endpoint, cacert):
308+
async def _open(self, endpoint, cacert) -> tuple[_WebSocket, str, str, str]:
307309
if self.is_debug_log_connection:
308310
assert self.uuid
309311
url = f"wss://user-{self.username}:{self.password}@{endpoint}/model/{self.uuid}/log"
@@ -726,7 +728,9 @@ async def _connect(self, endpoints):
726728
if len(endpoints) == 0:
727729
raise errors.JujuConnectionError("no endpoints to connect to")
728730

729-
async def _try_endpoint(endpoint, cacert, delay):
731+
async def _try_endpoint(
732+
endpoint, cacert, delay
733+
) -> tuple[_WebSocket, str, str, str]:
730734
if delay:
731735
await jasyncio.sleep(delay)
732736
return await self._open(endpoint, cacert)
@@ -738,6 +742,8 @@ async def _try_endpoint(endpoint, cacert, delay):
738742
jasyncio.ensure_future(_try_endpoint(endpoint, cacert, 0.1 * i))
739743
for i, (endpoint, cacert) in enumerate(endpoints)
740744
]
745+
result: tuple[_WebSocket, str, str, str] | None = None
746+
741747
for attempt in range(self._retries + 1):
742748
for task in jasyncio.as_completed(tasks):
743749
try:
@@ -760,13 +766,12 @@ async def _try_endpoint(endpoint, cacert, delay):
760766
# only executed if inner loop's else did not continue
761767
# (i.e., inner loop did break due to successful connection)
762768
break
763-
else:
764-
# impossible, work around https://github.com/microsoft/pyright/issues/8791
765-
assert False # noqa: B011
766769

767770
for task in tasks:
768771
task.cancel()
769772

773+
assert result # loop raises or sets the result
774+
770775
self._ws = result[0]
771776
self.addr = result[1]
772777
self.endpoint = result[2]

0 commit comments

Comments
 (0)