Skip to content

Commit 99d8fc1

Browse files
committed
chore: no need to assert connection
1 parent ee6414e commit 99d8fc1

3 files changed

Lines changed: 18 additions & 26 deletions

File tree

juju/client/connection.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from .facade_versions import client_facade_versions, known_unsupported_facades
2828

2929
SpecifiedFacades: TypeAlias = "dict[str, dict[Literal['versions'], Sequence[int]]]"
30-
_WebSocket: TypeAlias = "websockets.legacy.client.WebSocketClientProtocol"
30+
_WebSocket: TypeAlias = websockets.WebSocketClientProtocol
3131

3232
LEVELS = ["TRACE", "DEBUG", "INFO", "WARNING", "ERROR"]
3333
log = logging.getLogger("juju.client.connection")
@@ -291,7 +291,7 @@ def is_using_old_client(self):
291291
def is_open(self):
292292
return self.monitor.status == Monitor.CONNECTED
293293

294-
def _get_ssl(self, cert=None):
294+
def _get_ssl(self, cert: str | None = None) -> ssl.SSLContext:
295295
context = ssl.create_default_context(
296296
purpose=ssl.Purpose.SERVER_AUTH, cadata=cert
297297
)
@@ -305,7 +305,9 @@ def _get_ssl(self, cert=None):
305305
context.check_hostname = False
306306
return context
307307

308-
async def _open(self, endpoint, cacert) -> tuple[_WebSocket, str, str, str]:
308+
async def _open(
309+
self, endpoint: str, cacert: str
310+
) -> tuple[_WebSocket, str, str, str]:
309311
if self.is_debug_log_connection:
310312
assert self.uuid
311313
url = f"wss://user-{self.username}:{self.password}@{endpoint}/model/{self.uuid}/log"
@@ -323,10 +325,6 @@ async def _open(self, endpoint, cacert) -> tuple[_WebSocket, str, str, str]:
323325
sock = self.proxy.socket()
324326
server_hostname = "juju-app"
325327

326-
def _exit_tasks():
327-
for task in jasyncio.all_tasks():
328-
task.cancel()
329-
330328
return (
331329
(
332330
await websockets.connect(
@@ -342,7 +340,7 @@ def _exit_tasks():
342340
cacert,
343341
)
344342

345-
async def close(self, to_reconnect=False):
343+
async def close(self, to_reconnect: bool = False):
346344
if not self._ws:
347345
return
348346
self.monitor.close_called.set()
@@ -380,11 +378,7 @@ async def close(self, to_reconnect=False):
380378

381379
async def _recv(self, request_id: int) -> dict[str, Any]:
382380
if not self.is_open:
383-
raise websockets.exceptions.ConnectionClosed(
384-
websockets.frames.Close(
385-
websockets.frames.CloseCode.NORMAL_CLOSURE, "websocket closed"
386-
)
387-
)
381+
raise websockets.exceptions.ConnectionClosedOK(None, None)
388382
try:
389383
return await self.messages.get(request_id)
390384
except GeneratorExit:
@@ -626,7 +620,7 @@ async def rpc(
626620

627621
return result
628622

629-
def _http_headers(self):
623+
def _http_headers(self) -> dict[str, str]:
630624
"""Return dictionary of http headers necessary for making an http
631625
connection to the endpoint of this Connection.
632626
@@ -640,7 +634,7 @@ def _http_headers(self):
640634
token = base64.b64encode(creds.encode())
641635
return {"Authorization": f"Basic {token.decode()}"}
642636

643-
def https_connection(self):
637+
def https_connection(self) -> tuple[HTTPSConnection, dict[str, str], str]:
644638
"""Return an https connection to this Connection's endpoint.
645639
646640
Returns a 3-tuple containing::

juju/client/connector.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def __init__(
5050
self.model_name = None
5151
self.jujudata = jujudata or FileJujuData()
5252

53-
def is_connected(self):
53+
def is_connected(self) -> bool:
5454
"""Report whether there is a currently connected controller or not"""
5555
return self._connection is not None
5656

@@ -60,6 +60,7 @@ def connection(self) -> Connection:
6060
"""
6161
if not self.is_connected():
6262
raise NoConnectionException("not connected")
63+
assert self._connection
6364
return self._connection
6465

6566
async def connect(self, **kwargs):

juju/model.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -924,10 +924,7 @@ def add_local_charm(self, charm_file, series="", size=None):
924924
instead.
925925
926926
"""
927-
connection = self.connection()
928-
assert connection
929-
930-
conn, headers, path_prefix = connection.https_connection()
927+
conn, headers, path_prefix = self.connection().https_connection()
931928
path = "%s/charms?series=%s" % (path_prefix, series)
932929
headers["Content-Type"] = "application/zip"
933930
if size:
@@ -1320,14 +1317,14 @@ async def _all_watcher():
13201317
del allwatcher.Id
13211318
continue
13221319
except websockets.ConnectionClosed:
1323-
connection = self.connection()
1324-
assert connection
1325-
monitor = connection.monitor
1326-
if monitor.status == monitor.ERROR:
1320+
if self.connection().monitor.status == connection.Monitor.ERROR:
13271321
# closed unexpectedly, try to reopen
13281322
log.warning("Watcher: connection closed, reopening")
1329-
await connection.reconnect()
1330-
if monitor.status != monitor.CONNECTED:
1323+
await self.connection().reconnect()
1324+
if (
1325+
self.connection().monitor.status
1326+
!= connection.Monitor.CONNECTED
1327+
):
13311328
# reconnect failed; abort and shutdown
13321329
log.error(
13331330
"Watcher: automatic reconnect "

0 commit comments

Comments
 (0)