Skip to content

Commit 057b8df

Browse files
docs: type hinting utils.py
1 parent f87b50c commit 057b8df

2 files changed

Lines changed: 812 additions & 35 deletions

File tree

juju/utils.py

Lines changed: 33 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import zipfile
1010
from collections import defaultdict
1111
from pathlib import Path
12-
from typing import Any
12+
from typing import Any, Callable, Coroutine
1313

1414
import yaml
1515
from pyasn1.codec.der.encoder import encode
@@ -20,7 +20,7 @@
2020
from .errors import JujuError
2121

2222

23-
async def execute_process(*cmd, log=None):
23+
async def execute_process(*cmd, log=None) -> bool:
2424
"""Wrapper around asyncio.create_subprocess_exec."""
2525
p = await asyncio.create_subprocess_exec(
2626
*cmd,
@@ -38,7 +38,7 @@ async def execute_process(*cmd, log=None):
3838
return p.returncode == 0
3939

4040

41-
def juju_config_dir():
41+
def juju_config_dir() -> str:
4242
"""Resolves and returns the path string to the juju configuration folder
4343
for the juju CLI tool. Of the following items, returns the first option
4444
that works (top to bottom):
@@ -60,7 +60,7 @@ def juju_config_dir():
6060
return str(config_dir.expanduser().resolve())
6161

6262

63-
def juju_ssh_key_paths():
63+
def juju_ssh_key_paths() -> tuple[str, str]:
6464
"""Resolves and returns the path strings for public and private ssh keys
6565
for juju CLI.
6666
"""
@@ -71,7 +71,7 @@ def juju_ssh_key_paths():
7171
return public_key_path, private_key_path
7272

7373

74-
def _read_ssh_key():
74+
def _read_ssh_key() -> str:
7575
"""Inner function for read_ssh_key, suitable for passing to our
7676
Executor.
7777
"""
@@ -82,7 +82,7 @@ def _read_ssh_key():
8282
return ssh_key
8383

8484

85-
async def read_ssh_key():
85+
async def read_ssh_key() -> str:
8686
"""Attempt to read the local juju admin's public ssh key, so that it can be
8787
passed on to a model.
8888
"""
@@ -124,7 +124,7 @@ async def put_all(self, value: Exception):
124124
await queue.put(value)
125125

126126

127-
async def block_until(*conditions, timeout=None, wait_period=0.5):
127+
async def block_until(*conditions: Callable[[], bool], timeout: float | None = None, wait_period: float = 0.5):
128128
"""Return only after all conditions are true.
129129
130130
If a timeout occurs, it cancels the task and raises
@@ -139,7 +139,7 @@ async def _block():
139139

140140

141141
async def block_until_with_coroutine(
142-
condition_coroutine, timeout=None, wait_period=0.5
142+
condition_coroutine: Callable[[], Coroutine[Any, Any, bool]], timeout: float | None = None, wait_period: float = 0.5
143143
):
144144
"""Return only after the given coroutine returns True.
145145
@@ -169,12 +169,12 @@ async def wait_for_bundle(model, bundle: str | Path, **kwargs) -> None:
169169
bundle = bundle_path / "bundle.yaml"
170170
except OSError:
171171
pass
172-
content: dict[str, Any] = yaml.safe_load(textwrap.dedent(bundle).strip())
172+
content: dict[str, Any] = yaml.safe_load(textwrap.dedent(str(bundle)).strip())
173173
apps = list(content.get("applications", content.get("services")).keys())
174174
await model.wait_for_idle(apps, **kwargs)
175175

176176

177-
async def run_with_interrupt(task, *events, log=None):
177+
async def run_with_interrupt(task, *events: asyncio.Event, log=None):
178178
"""Awaits a task while allowing it to be interrupted by one or more
179179
`asyncio.Event`s.
180180
@@ -228,7 +228,7 @@ class RegistrationInfo(univ.Sequence):
228228

229229

230230
def generate_user_controller_access_token(
231-
username, controller_endpoints, secret_key, controller_name
231+
username: str, controller_endpoints, secret_key: str, controller_name
232232
):
233233
"""Implement in python what is currently done in GO.
234234
@@ -258,14 +258,12 @@ def generate_user_controller_access_token(
258258
return base64.urlsafe_b64encode(registration_string)
259259

260260

261-
def get_local_charm_data(path, yaml_file):
261+
def get_local_charm_data(path: str | Path, yaml_file: str) -> dict[str, Any]:
262262
"""Retrieve Metadata of a Charm from its path.
263263
264-
:patam str path: Path of charm directory or .charm file :patam str
265-
yaml_
266-
file:
267-
name of the yaml file, can be either "metadata.yaml", or
268-
"manifest.yaml", or "charmcraft.yaml"
264+
:param str path: Path of charm directory or .charm file
265+
:param str yaml_file: name of the yaml file, can be either
266+
"metadata.yaml", or "manifest.yaml", or "charmcraft.yaml"
269267
270268
:return: Object of charm metadata
271269
"""
@@ -282,15 +280,15 @@ def get_local_charm_data(path, yaml_file):
282280
return metadata
283281

284282

285-
def get_local_charm_metadata(path):
283+
def get_local_charm_metadata(path: str | Path) -> dict[str, Any]:
286284
return get_local_charm_data(path, "metadata.yaml")
287285

288286

289-
def get_local_charm_manifest(path):
287+
def get_local_charm_manifest(path: str | Path) -> dict[str, Any]:
290288
return get_local_charm_data(path, "manifest.yaml")
291289

292290

293-
def get_local_charm_charmcraft_yaml(path):
291+
def get_local_charm_charmcraft_yaml(path: str | Path) -> dict[str, Any]:
294292
return get_local_charm_data(path, "charmcraft.yaml")
295293

296294

@@ -354,7 +352,7 @@ def get_local_charm_charmcraft_yaml(path):
354352
ALL_SERIES_VERSIONS = {**UBUNTU_SERIES, **KUBERNETES_SERIES}
355353

356354

357-
def get_series_version(series_name):
355+
def get_series_version(series_name: str) -> str:
358356
"""get_series_version outputs the version of the OS based on the given
359357
series e.g. jammy -> 22.04, kubernetes -> kubernetes.
360358
@@ -366,7 +364,7 @@ def get_series_version(series_name):
366364
return ALL_SERIES_VERSIONS[series_name]
367365

368366

369-
def get_version_series(version):
367+
def get_version_series(version: str) -> str:
370368
"""get_version_series is the opposite of the get_series_version. It outputs
371369
the series based on given OS version.
372370
@@ -378,7 +376,7 @@ def get_version_series(version):
378376
return list(UBUNTU_SERIES.keys())[list(UBUNTU_SERIES.values()).index(version)]
379377

380378

381-
def get_local_charm_base(series, charm_path, base_class):
379+
def get_local_charm_base(series: str, charm_path: str, base_class: type):
382380
"""Deduce the base [channel/osname] of a local charm based on what we know
383381
already.
384382
@@ -428,7 +426,7 @@ def get_local_charm_base(series, charm_path, base_class):
428426
return base_class(channel_for_base, os_name_for_base)
429427

430428

431-
def base_channel_to_series(channel):
429+
def base_channel_to_series(channel: str) -> str:
432430
"""Returns the series string using the track inside the base channel.
433431
434432
:param str channel: is track/risk (e.g. 20.04/stable)
@@ -437,7 +435,7 @@ def base_channel_to_series(channel):
437435
return get_version_series(origin.Channel.parse(channel).track)
438436

439437

440-
def parse_base_arg(base):
438+
def parse_base_arg(base: str) -> client.Base:
441439
"""Parses a given base into a Client.Base object :param base str : The base
442440
to deploy a charm (e.g. ubuntu@22.04)
443441
"""
@@ -463,7 +461,7 @@ def base_channel_from_series(track, risk, series):
463461
)
464462

465463

466-
def get_os_from_series(series):
464+
def get_os_from_series(series: str) -> str:
467465
if series in UBUNTU_SERIES:
468466
return "ubuntu"
469467
raise JujuError(f"os for the series {series} needs to be added")
@@ -479,7 +477,7 @@ def get_base_from_origin_or_channel(origin_or_channel, series=None):
479477
return client.Base(channel=channel, name=os_name)
480478

481479

482-
def series_for_charm(requested_series, supported_series):
480+
def series_for_charm(requested_series: str, supported_series: list[str]) -> str:
483481
"""series_for_charm takes a requested series and a list of series supported
484482
by a charm and returns the series which is relevant.
485483
@@ -506,7 +504,7 @@ def series_for_charm(requested_series, supported_series):
506504
)
507505

508506

509-
def user_requested(series_arg, supported_series, force):
507+
def user_requested(series_arg: str, supported_series: list[str], force: bool) -> str:
510508
series = series_for_charm(series_arg, supported_series)
511509
if force:
512510
series = series_arg
@@ -516,8 +514,8 @@ def user_requested(series_arg, supported_series, force):
516514

517515

518516
def series_selector(
519-
series_arg="", charm_url=None, model_config=None, supported_series=[], force=False
520-
):
517+
series_arg: str = "", charm_url=None, model_config=None, supported_series: list[str] = [], force: bool = False
518+
) -> str:
521519
"""Select series to deploy on.
522520
523521
series_selector corresponds to the CharmSeries() in
@@ -563,19 +561,19 @@ def series_selector(
563561
return DEFAULT_SUPPORTED_LTS
564562

565563

566-
def should_upgrade_resource(available_resource, existing_resources, arg_resources):
564+
def should_upgrade_resource(available_resource: dict[str, str], existing_resources, arg_resources) -> bool:
567565
"""Determine if the given resource should be upgraded.
568566
569567
Called in the context of upgrade_charm. Given a resource R, takes a look
570568
at the resources we already have and decides if we need to refresh R.
571569
572570
:param dict[str] available_resource: The dict representing the
573571
client.Resource coming from the charmhub api. We're considering if
574-
we need to refresh this during upgrade_charm. :param dict[str]
575-
existing_resources: The dict coming from
572+
we need to refresh this during upgrade_charm.
573+
:param dict[str] existing_resources: The dict coming from
576574
resources_facade.ListResources representing the resources of the
577-
currently deployed charm. :param dict[str] arg_resources: user
578-
provided resources to be refreshed
575+
currently deployed charm.
576+
:param dict[str] arg_resources: user provided resources to be refreshed
579577
580578
:result bool: The decision to refresh the given resource
581579
"""

0 commit comments

Comments
 (0)