Skip to content

Commit 6a62d83

Browse files
authored
Merge pull request #1189 from dimaqq/chore-connnection-types
#1189 Pull the Connection type hints from #1097 Only minor cleanup beyond that, no logical changes.
2 parents 706d174 + c7f5c91 commit 6a62d83

10 files changed

Lines changed: 173 additions & 100 deletions

File tree

juju/client/connection.py

Lines changed: 56 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright 2023 Canonical Ltd.
22
# Licensed under the Apache V2, see LICENCE file for details.
3+
from __future__ import annotations
34

45
import base64
56
import json
@@ -9,37 +10,29 @@
910
import warnings
1011
import weakref
1112
from http.client import HTTPSConnection
12-
from typing import Dict, Literal, Optional, Sequence
13+
from typing import Any, Literal, Sequence
1314

1415
import macaroonbakery.bakery as bakery
1516
import macaroonbakery.httpbakery as httpbakery
1617
import websockets
1718
from dateutil.parser import parse
19+
from typing_extensions import Self, TypeAlias, overload
1820

1921
from juju import errors, jasyncio, tag, utils
2022
from juju.client import client
2123
from juju.utils import IdQueue
2224
from juju.version import CLIENT_VERSION
2325

26+
from .facade import TypeEncoder, _Json, _RichJson
2427
from .facade_versions import client_facade_versions, known_unsupported_facades
2528

29+
SpecifiedFacades: TypeAlias = "dict[str, dict[Literal['versions'], Sequence[int]]]"
30+
_WebSocket: TypeAlias = "websockets.legacy.client.WebSocketClientProtocol"
31+
2632
LEVELS = ["TRACE", "DEBUG", "INFO", "WARNING", "ERROR"]
2733
log = logging.getLogger("juju.client.connection")
2834

2935

30-
def facade_versions(name, versions):
31-
"""facade_versions returns a new object that correctly returns a object in
32-
format expected by the connection facades inspection.
33-
:param name: name of the facade
34-
:param versions: versions to support by the facade
35-
"""
36-
if name.endswith("Facade"):
37-
name = name[: -len("Facade")]
38-
return {
39-
name: {"versions": versions},
40-
}
41-
42-
4336
class Monitor:
4437
"""Monitor helper class for our Connection class.
4538
@@ -59,7 +52,7 @@ class Monitor:
5952
DISCONNECTING = "disconnecting"
6053
DISCONNECTED = "disconnected"
6154

62-
def __init__(self, connection):
55+
def __init__(self, connection: Connection):
6356
self.connection = weakref.ref(connection)
6457
self.reconnecting = jasyncio.Lock()
6558
self.close_called = jasyncio.Event()
@@ -117,28 +110,41 @@ class Connection:
117110

118111
MAX_FRAME_SIZE = 2**22
119112
"Maximum size for a single frame. Defaults to 4MB."
120-
facades: Dict[str, int]
121-
_specified_facades: Dict[str, Sequence[int]]
113+
facades: dict[str, int]
114+
_specified_facades: dict[str, Sequence[int]]
115+
bakery_client: Any
116+
usertag: str | None
117+
password: str | None
118+
name: str
119+
__request_id__: int
120+
endpoints: list[tuple[str, str]] | None # Set by juju/controller.py
121+
is_debug_log_connection: bool
122+
monitor: Monitor
123+
proxy: Any # Need to find types for this library
124+
max_frame_size: int
125+
_retries: int
126+
_retry_backoff: float
127+
uuid: str | None
128+
messages: IdQueue
129+
_ws: _WebSocket | None
122130

123131
@classmethod
124132
async def connect(
125133
cls,
126134
endpoint=None,
127-
uuid=None,
128-
username=None,
129-
password=None,
135+
uuid: str | None = None,
136+
username: str | None = None,
137+
password: str | None = None,
130138
cacert=None,
131139
bakery_client=None,
132-
max_frame_size=None,
140+
max_frame_size: int | None = None,
133141
retries=3,
134142
retry_backoff=10,
135-
specified_facades: Optional[
136-
Dict[str, Dict[Literal["versions"], Sequence[int]]]
137-
] = None,
143+
specified_facades: SpecifiedFacades | None = None,
138144
proxy=None,
139145
debug_log_conn=None,
140146
debug_log_params={},
141-
):
147+
) -> Self:
142148
"""Connect to the websocket.
143149
144150
If uuid is None, the connection will be to the controller. Otherwise it
@@ -270,7 +276,7 @@ def ws(self):
270276
return self._ws
271277

272278
@property
273-
def username(self):
279+
def username(self) -> str | None:
274280
if not self.usertag:
275281
return None
276282
return self.usertag[len("user-") :]
@@ -299,7 +305,7 @@ def _get_ssl(self, cert=None):
299305
context.check_hostname = False
300306
return context
301307

302-
async def _open(self, endpoint, cacert):
308+
async def _open(self, endpoint, cacert) -> tuple[_WebSocket, str, str, str]:
303309
if self.is_debug_log_connection:
304310
assert self.uuid
305311
url = f"wss://user-{self.username}:{self.password}@{endpoint}/model/{self.uuid}/log"
@@ -372,7 +378,7 @@ async def close(self, to_reconnect=False):
372378
if self.proxy is not None:
373379
self.proxy.close()
374380

375-
async def _recv(self, request_id):
381+
async def _recv(self, request_id: int) -> dict[str, Any]:
376382
if not self.is_open:
377383
raise websockets.exceptions.ConnectionClosed(
378384
websockets.frames.Close(
@@ -534,7 +540,19 @@ async def _do_ping():
534540
log.debug("ping failed because of closed connection")
535541
pass
536542

537-
async def rpc(self, msg, encoder=None):
543+
@overload
544+
async def rpc(
545+
self, msg: dict[str, _Json], encoder: None = None
546+
) -> dict[str, _Json]: ...
547+
548+
@overload
549+
async def rpc(
550+
self, msg: dict[str, _RichJson], encoder: TypeEncoder
551+
) -> dict[str, _Json]: ...
552+
553+
async def rpc(
554+
self, msg: dict[str, Any], encoder: json.JSONEncoder | None = None
555+
) -> dict[str, _Json]:
538556
"""Make an RPC to the API. The message is encoded as JSON
539557
using the given encoder if any.
540558
:param msg: Parameters for the call (will be encoded as JSON).
@@ -710,7 +728,9 @@ async def _connect(self, endpoints):
710728
if len(endpoints) == 0:
711729
raise errors.JujuConnectionError("no endpoints to connect to")
712730

713-
async def _try_endpoint(endpoint, cacert, delay):
731+
async def _try_endpoint(
732+
endpoint, cacert, delay
733+
) -> tuple[_WebSocket, str, str, str]:
714734
if delay:
715735
await jasyncio.sleep(delay)
716736
return await self._open(endpoint, cacert)
@@ -722,6 +742,8 @@ async def _try_endpoint(endpoint, cacert, delay):
722742
jasyncio.ensure_future(_try_endpoint(endpoint, cacert, 0.1 * i))
723743
for i, (endpoint, cacert) in enumerate(endpoints)
724744
]
745+
result: tuple[_WebSocket, str, str, str] | None = None
746+
725747
for attempt in range(self._retries + 1):
726748
for task in jasyncio.as_completed(tasks):
727749
try:
@@ -744,8 +766,12 @@ async def _try_endpoint(endpoint, cacert, delay):
744766
# only executed if inner loop's else did not continue
745767
# (i.e., inner loop did break due to successful connection)
746768
break
769+
747770
for task in tasks:
748771
task.cancel()
772+
773+
assert result # loop raises or sets the result
774+
749775
self._ws = result[0]
750776
self.addr = result[1]
751777
self.endpoint = result[2]

juju/client/connector.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
# Copyright 2023 Canonical Ltd.
22
# Licensed under the Apache V2, see LICENCE file for details.
3+
from __future__ import annotations
34

45
import copy
56
import logging
7+
from typing import Any
68

79
import macaroonbakery.httpbakery as httpbakery
810
from packaging import version
@@ -33,9 +35,9 @@ class Connector:
3335

3436
def __init__(
3537
self,
36-
max_frame_size=None,
37-
bakery_client=None,
38-
jujudata=None,
38+
max_frame_size: int | None = None,
39+
bakery_client: Any | None = None,
40+
jujudata: Any | None = None,
3941
):
4042
"""Initialize a connector that will use the given parameters
4143
by default when making a new connection
@@ -52,7 +54,7 @@ def is_connected(self):
5254
"""Report whether there is a currently connected controller or not"""
5355
return self._connection is not None
5456

55-
def connection(self):
57+
def connection(self) -> Connection:
5658
"""Return the current connection; raises an exception if there
5759
is no current connection.
5860
"""

juju/client/facade.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright 2023 Canonical Ltd.
22
# Licensed under the Apache V2, see LICENCE file for details.
3+
from __future__ import annotations
34

45
import argparse
56
import builtins
@@ -13,13 +14,22 @@
1314
from collections import defaultdict
1415
from glob import glob
1516
from pathlib import Path
16-
from typing import Any, Dict, List, Mapping, Sequence
17+
from typing import Any, Mapping, Sequence
1718

1819
import packaging.version
1920
import typing_inspect
21+
from typing_extensions import TypeAlias
2022

2123
from . import codegen
2224

25+
# Plain JSON, what is received from Juju
26+
_JsonLeaf: TypeAlias = "None | bool | int | float | str"
27+
_Json: TypeAlias = "_JsonLeaf|list[_Json]|dict[str, _Json]"
28+
29+
# Type-enriched JSON, what can be sent to Juju
30+
_RichLeaf: TypeAlias = "_JsonLeaf|Type"
31+
_RichJson: TypeAlias = "_RichLeaf|list[_RichJson]|dict[str, _RichJson]"
32+
2333
_marker = object()
2434

2535
JUJU_VERSION = re.compile(r"[0-9]+\.[0-9-]+[\.\-][0-9a-z]+(\.[0-9]+)?")
@@ -634,7 +644,7 @@ class {name}Facade(Type):
634644

635645

636646
class TypeEncoder(json.JSONEncoder):
637-
def default(self, obj):
647+
def default(self, obj: _RichJson) -> _Json:
638648
if isinstance(obj, Type):
639649
return obj.serialize()
640650
return json.JSONEncoder.default(self, obj)
@@ -653,7 +663,7 @@ def __eq__(self, other):
653663

654664
return self.__dict__ == other.__dict__
655665

656-
async def rpc(self, msg):
666+
async def rpc(self, msg: dict[str, _RichJson]) -> _Json:
657667
result = await self.connection.rpc(msg, encoder=TypeEncoder)
658668
return result
659669

@@ -704,13 +714,13 @@ def _parse_nested_list_entry(expr, result_dict):
704714
return cls(**d)
705715
return None
706716

707-
def serialize(self):
717+
def serialize(self) -> dict[str, _Json]:
708718
d = {}
709719
for attr, tgt in self._toSchema.items():
710720
d[tgt] = getattr(self, attr)
711721
return d
712722

713-
def to_json(self):
723+
def to_json(self) -> str:
714724
return json.dumps(self.serialize(), cls=TypeEncoder, sort_keys=True)
715725

716726
def __contains__(self, key):
@@ -917,8 +927,8 @@ def generate_definitions(schemas):
917927

918928

919929
def generate_facades(
920-
schemas: Dict[str, List[Schema]],
921-
) -> Dict[str, Dict[int, codegen.Capture]]:
930+
schemas: dict[str, list[Schema]],
931+
) -> dict[str, dict[int, codegen.Capture]]:
922932
captures = defaultdict(codegen.Capture)
923933

924934
# Build the Facade classes

juju/client/overrides.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
# Copyright 2023 Canonical Ltd.
22
# Licensed under the Apache V2, see LICENCE file for details.
3+
from __future__ import annotations
34

45
import re
5-
from collections import namedtuple
6+
from typing import Any, NamedTuple
67

78
from . import _client, _definitions
89
from .facade import ReturnMapping, Type, TypeEncoder
@@ -22,6 +23,12 @@
2223
]
2324

2425

26+
class _Change(NamedTuple):
27+
entity: str
28+
type: str
29+
data: dict[str, Any]
30+
31+
2532
class Delta(Type):
2633
"""A single websocket delta.
2734
@@ -42,12 +49,11 @@ class Delta(Type):
4249
_toSchema = {"deltas": "deltas"}
4350
_toPy = {"deltas": "deltas"}
4451

45-
def __init__(self, deltas=None):
52+
def __init__(self, deltas: tuple[str, str, dict[str, Any]]):
4653
""":param deltas: [str, str, object]"""
4754
self.deltas = deltas
4855

49-
Change = namedtuple("Change", "entity type data")
50-
change = Change(*self.deltas)
56+
change = _Change(*self.deltas)
5157

5258
self.entity = change.entity
5359
self.type = change.type

juju/delta.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
# Copyright 2023 Canonical Ltd.
22
# Licensed under the Apache V2, see LICENCE file for details.
3+
from __future__ import annotations
34

4-
from .client import client
5+
from . import model
6+
from .client import client, overrides
57

68

7-
def get_entity_delta(d):
9+
def get_entity_delta(d: overrides.Delta):
810
return _delta_types[d.entity](d.deltas)
911

1012

@@ -13,12 +15,14 @@ def get_entity_class(entity_type):
1315

1416

1517
class EntityDelta(client.Delta):
16-
def get_id(self):
18+
data: dict[str, str]
19+
20+
def get_id(self) -> str:
1721
return self.data["id"]
1822

1923
@classmethod
20-
def get_entity_class(cls):
21-
return None
24+
def get_entity_class(cls) -> type[model.ModelEntity]:
25+
raise NotImplementedError()
2226

2327

2428
class ActionDelta(EntityDelta):

juju/jasyncio.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
)
2222
from asyncio import (
2323
CancelledError,
24+
Task,
2425
create_task,
2526
wait,
2627
)
@@ -84,7 +85,7 @@
8485
ROOT_LOGGER = logging.getLogger()
8586

8687

87-
def create_task_with_handler(coro, task_name, logger=ROOT_LOGGER):
88+
def create_task_with_handler(coro, task_name, logger=ROOT_LOGGER) -> Task:
8889
"""Wrapper around "asyncio.create_task" to make sure the task
8990
exceptions are handled properly.
9091

0 commit comments

Comments
 (0)