11# Copyright 2023 Canonical Ltd.
22# Licensed under the Apache V2, see LICENCE file for details.
3+ from __future__ import annotations
34
45import base64
56import json
910import warnings
1011import weakref
1112from http .client import HTTPSConnection
12- from typing import Dict , Literal , Optional , Sequence
13+ from typing import Any , Literal , Sequence
1314
1415import macaroonbakery .bakery as bakery
1516import macaroonbakery .httpbakery as httpbakery
1617import websockets
1718from dateutil .parser import parse
19+ from typing_extensions import Self , TypeAlias , overload
1820
1921from juju import errors , jasyncio , tag , utils
2022from juju .client import client
2123from juju .utils import IdQueue
2224from juju .version import CLIENT_VERSION
2325
26+ from .facade import TypeEncoder , _Json , _RichJson
2427from .facade_versions import client_facade_versions , known_unsupported_facades
2528
29+ SpecifiedFacades : TypeAlias = "dict[str, dict[Literal['versions'], Sequence[int]]]"
30+ _WebSocket : TypeAlias = "websockets.legacy.client.WebSocketClientProtocol"
31+
2632LEVELS = ["TRACE" , "DEBUG" , "INFO" , "WARNING" , "ERROR" ]
2733log = logging .getLogger ("juju.client.connection" )
2834
2935
30- def facade_versions (name , versions ):
31- """facade_versions returns a new object that correctly returns a object in
32- format expected by the connection facades inspection.
33- :param name: name of the facade
34- :param versions: versions to support by the facade
35- """
36- if name .endswith ("Facade" ):
37- name = name [: - len ("Facade" )]
38- return {
39- name : {"versions" : versions },
40- }
41-
42-
4336class Monitor :
4437 """Monitor helper class for our Connection class.
4538
@@ -59,7 +52,7 @@ class Monitor:
5952 DISCONNECTING = "disconnecting"
6053 DISCONNECTED = "disconnected"
6154
62- def __init__ (self , connection ):
55+ def __init__ (self , connection : Connection ):
6356 self .connection = weakref .ref (connection )
6457 self .reconnecting = jasyncio .Lock ()
6558 self .close_called = jasyncio .Event ()
@@ -117,28 +110,41 @@ class Connection:
117110
118111 MAX_FRAME_SIZE = 2 ** 22
119112 "Maximum size for a single frame. Defaults to 4MB."
120- facades : Dict [str , int ]
121- _specified_facades : Dict [str , Sequence [int ]]
113+ facades : dict [str , int ]
114+ _specified_facades : dict [str , Sequence [int ]]
115+ bakery_client : Any
116+ usertag : str | None
117+ password : str | None
118+ name : str
119+ __request_id__ : int
120+ endpoints : list [tuple [str , str ]] | None # Set by juju/controller.py
121+ is_debug_log_connection : bool
122+ monitor : Monitor
123+ proxy : Any # Need to find types for this library
124+ max_frame_size : int
125+ _retries : int
126+ _retry_backoff : float
127+ uuid : str | None
128+ messages : IdQueue
129+ _ws : _WebSocket | None
122130
123131 @classmethod
124132 async def connect (
125133 cls ,
126134 endpoint = None ,
127- uuid = None ,
128- username = None ,
129- password = None ,
135+ uuid : str | None = None ,
136+ username : str | None = None ,
137+ password : str | None = None ,
130138 cacert = None ,
131139 bakery_client = None ,
132- max_frame_size = None ,
140+ max_frame_size : int | None = None ,
133141 retries = 3 ,
134142 retry_backoff = 10 ,
135- specified_facades : Optional [
136- Dict [str , Dict [Literal ["versions" ], Sequence [int ]]]
137- ] = None ,
143+ specified_facades : SpecifiedFacades | None = None ,
138144 proxy = None ,
139145 debug_log_conn = None ,
140146 debug_log_params = {},
141- ):
147+ ) -> Self :
142148 """Connect to the websocket.
143149
144150 If uuid is None, the connection will be to the controller. Otherwise it
@@ -270,7 +276,7 @@ def ws(self):
270276 return self ._ws
271277
272278 @property
273- def username (self ):
279+ def username (self ) -> str | None :
274280 if not self .usertag :
275281 return None
276282 return self .usertag [len ("user-" ) :]
@@ -299,7 +305,7 @@ def _get_ssl(self, cert=None):
299305 context .check_hostname = False
300306 return context
301307
302- async def _open (self , endpoint , cacert ):
308+ async def _open (self , endpoint , cacert ) -> tuple [ _WebSocket , str , str , str ] :
303309 if self .is_debug_log_connection :
304310 assert self .uuid
305311 url = f"wss://user-{ self .username } :{ self .password } @{ endpoint } /model/{ self .uuid } /log"
@@ -372,7 +378,7 @@ async def close(self, to_reconnect=False):
372378 if self .proxy is not None :
373379 self .proxy .close ()
374380
375- async def _recv (self , request_id ) :
381+ async def _recv (self , request_id : int ) -> dict [ str , Any ] :
376382 if not self .is_open :
377383 raise websockets .exceptions .ConnectionClosed (
378384 websockets .frames .Close (
@@ -534,7 +540,19 @@ async def _do_ping():
534540 log .debug ("ping failed because of closed connection" )
535541 pass
536542
537- async def rpc (self , msg , encoder = None ):
543+ @overload
544+ async def rpc (
545+ self , msg : dict [str , _Json ], encoder : None = None
546+ ) -> dict [str , _Json ]: ...
547+
548+ @overload
549+ async def rpc (
550+ self , msg : dict [str , _RichJson ], encoder : TypeEncoder
551+ ) -> dict [str , _Json ]: ...
552+
553+ async def rpc (
554+ self , msg : dict [str , Any ], encoder : json .JSONEncoder | None = None
555+ ) -> dict [str , _Json ]:
538556 """Make an RPC to the API. The message is encoded as JSON
539557 using the given encoder if any.
540558 :param msg: Parameters for the call (will be encoded as JSON).
@@ -710,7 +728,9 @@ async def _connect(self, endpoints):
710728 if len (endpoints ) == 0 :
711729 raise errors .JujuConnectionError ("no endpoints to connect to" )
712730
713- async def _try_endpoint (endpoint , cacert , delay ):
731+ async def _try_endpoint (
732+ endpoint , cacert , delay
733+ ) -> tuple [_WebSocket , str , str , str ]:
714734 if delay :
715735 await jasyncio .sleep (delay )
716736 return await self ._open (endpoint , cacert )
@@ -722,6 +742,8 @@ async def _try_endpoint(endpoint, cacert, delay):
722742 jasyncio .ensure_future (_try_endpoint (endpoint , cacert , 0.1 * i ))
723743 for i , (endpoint , cacert ) in enumerate (endpoints )
724744 ]
745+ result : tuple [_WebSocket , str , str , str ] | None = None
746+
725747 for attempt in range (self ._retries + 1 ):
726748 for task in jasyncio .as_completed (tasks ):
727749 try :
@@ -744,8 +766,12 @@ async def _try_endpoint(endpoint, cacert, delay):
744766 # only executed if inner loop's else did not continue
745767 # (i.e., inner loop did break due to successful connection)
746768 break
769+
747770 for task in tasks :
748771 task .cancel ()
772+
773+ assert result # loop raises or sets the result
774+
749775 self ._ws = result [0 ]
750776 self .addr = result [1 ]
751777 self .endpoint = result [2 ]
0 commit comments