@@ -231,6 +231,7 @@ async def connect(
231231 retries = 3 ,
232232 retry_backoff = 10 ,
233233 specified_facades = None ,
234+ proxy = None ,
234235 ):
235236 """Connect to the websocket.
236237
@@ -308,6 +309,10 @@ async def connect(
308309 max_frame_size = self .MAX_FRAME_SIZE
309310 self .max_frame_size = max_frame_size
310311
312+ self .proxy = proxy
313+ if self .proxy is not None :
314+ self .proxy .connect ()
315+
311316 _endpoints = [(endpoint , cacert )] if isinstance (endpoint , str ) else [(e , cacert ) for e in endpoint ]
312317 for _ep in _endpoints :
313318 try :
@@ -348,12 +353,23 @@ async def _open(self, endpoint, cacert):
348353 else :
349354 url = "wss://{}/api" .format (endpoint )
350355
356+ # We need to establish a server_hostname here for TLS sni if we are
357+ # connecting through a proxy as the Juju controller certificates will
358+ # not be covering the proxy
359+ sock = None
360+ server_hostname = None
361+ if self .proxy is not None :
362+ sock = self .proxy .socket ()
363+ server_hostname = "juju-app"
364+
351365 return (await websockets .connect (
352366 url ,
353367 ssl = self ._get_ssl (cacert ),
354368 loop = self .loop ,
355369 max_size = self .max_frame_size ,
356- ), url , endpoint , cacert )
370+ server_hostname = server_hostname ,
371+ sock = sock ,
372+ )), url , endpoint , cacert
357373
358374 async def close (self ):
359375 if not self .ws :
@@ -364,6 +380,9 @@ async def close(self):
364380 await self .ws .close ()
365381 self .ws = None
366382
383+ if self .proxy is not None :
384+ self .proxy .close ()
385+
367386 async def _recv (self , request_id ):
368387 if not self .is_open :
369388 raise websockets .exceptions .ConnectionClosed (0 , 'websocket closed' )
@@ -551,11 +570,9 @@ async def clone(self):
551570 return await Connection .connect (** self .connect_params ())
552571
553572 def connect_params (self ):
554- """Return a tuple of parameters suitable for passing to
573+ """Return a dict of parameters suitable for passing to
555574 Connection.connect that can be used to make a new connection
556- to the same controller (and model if specified. The first
557- element in the returned tuple holds the endpoint argument;
558- the other holds a dict of the keyword args.
575+ to the same controller (and model if specified).
559576 """
560577 return {
561578 'endpoint' : self .endpoint ,
@@ -566,6 +583,7 @@ def connect_params(self):
566583 'bakery_client' : self .bakery_client ,
567584 'loop' : self .loop ,
568585 'max_frame_size' : self .max_frame_size ,
586+ 'proxy' : self .proxy ,
569587 }
570588
571589 async def controller (self ):
0 commit comments