11# Copyright 2023 Canonical Ltd.
22# Licensed under the Apache V2, see LICENCE file for details.
3+ from __future__ import annotations
34
45import base64
56import json
910import warnings
1011import weakref
1112from http .client import HTTPSConnection
12- from typing import Dict , Literal , Optional , Sequence
13+ from typing import Any , Literal , Sequence
1314
1415import macaroonbakery .bakery as bakery
1516import macaroonbakery .httpbakery as httpbakery
1617import websockets
1718from dateutil .parser import parse
19+ from typing_extensions import Self
1820
1921from juju import errors , jasyncio , tag , utils
2022from juju .client import client
2729log = logging .getLogger ("juju.client.connection" )
2830
2931
30- def facade_versions (name , versions ):
31- """facade_versions returns a new object that correctly returns a object in
32- format expected by the connection facades inspection.
33- :param name: name of the facade
34- :param versions: versions to support by the facade
35- """
36- if name .endswith ("Facade" ):
37- name = name [: - len ("Facade" )]
38- return {
39- name : {"versions" : versions },
40- }
41-
42-
4332class Monitor :
4433 """Monitor helper class for our Connection class.
4534
@@ -59,7 +48,7 @@ class Monitor:
5948 DISCONNECTING = "disconnecting"
6049 DISCONNECTED = "disconnected"
6150
62- def __init__ (self , connection ):
51+ def __init__ (self , connection : Connection ):
6352 self .connection = weakref .ref (connection )
6453 self .reconnecting = jasyncio .Lock ()
6554 self .close_called = jasyncio .Event ()
@@ -117,28 +106,40 @@ class Connection:
117106
118107 MAX_FRAME_SIZE = 2 ** 22
119108 "Maximum size for a single frame. Defaults to 4MB."
120- facades : Dict [str , int ]
121- _specified_facades : Dict [str , Sequence [int ]]
109+ facades : dict [str , int ]
110+ _specified_facades : dict [str , Sequence [int ]]
111+ bakery_client : Any
112+ usertag : str | None
113+ password : str | None
114+ name : str
115+ __request_id__ : int
116+ endpoints : list [tuple [str , str ]] | None # Set by juju/controller.py
117+ is_debug_log_connection : bool
118+ monitor : Monitor
119+ proxy : Any # Need to find types for this library
120+ max_frame_size : int
121+ _retries : int
122+ _retry_backoff : float
123+ uuid : str | None
122124
123125 @classmethod
124126 async def connect (
125127 cls ,
126128 endpoint = None ,
127- uuid = None ,
128- username = None ,
129- password = None ,
129+ uuid : str | None = None ,
130+ username : str | None = None ,
131+ password : str | None = None ,
130132 cacert = None ,
131133 bakery_client = None ,
132- max_frame_size = None ,
134+ max_frame_size : int | None = None ,
133135 retries = 3 ,
134136 retry_backoff = 10 ,
135- specified_facades : Optional [
136- Dict [str , Dict [Literal ["versions" ], Sequence [int ]]]
137- ] = None ,
137+ specified_facades : dict [str , dict [Literal ["versions" ], Sequence [int ]]]
138+ | None = None ,
138139 proxy = None ,
139140 debug_log_conn = None ,
140141 debug_log_params = {},
141- ):
142+ ) -> Self :
142143 """Connect to the websocket.
143144
144145 If uuid is None, the connection will be to the controller. Otherwise it
@@ -270,7 +271,7 @@ def ws(self):
270271 return self ._ws
271272
272273 @property
273- def username (self ):
274+ def username (self ) -> str | None :
274275 if not self .usertag :
275276 return None
276277 return self .usertag [len ("user-" ) :]
@@ -534,7 +535,7 @@ async def _do_ping():
534535 log .debug ("ping failed because of closed connection" )
535536 pass
536537
537- async def rpc (self , msg , encoder = None ):
538+ async def rpc (self , msg : dict , encoder = None ) -> dict :
538539 """Make an RPC to the API. The message is encoded as JSON
539540 using the given encoder if any.
540541 :param msg: Parameters for the call (will be encoded as JSON).
@@ -744,8 +745,13 @@ async def _try_endpoint(endpoint, cacert, delay):
744745 # only executed if inner loop's else did not continue
745746 # (i.e., inner loop did break due to successful connection)
746747 break
748+ else :
749+ # impossible, work around https://github.com/microsoft/pyright/issues/8791
750+ assert False # noqa: B011
751+
747752 for task in tasks :
748753 task .cancel ()
754+
749755 self ._ws = result [0 ]
750756 self .addr = result [1 ]
751757 self .endpoint = result [2 ]
0 commit comments