55import json
66import logging
77import ssl
8+ import signal
89import urllib .request
910import weakref
1011from http .client import HTTPSConnection
@@ -217,7 +218,7 @@ def status(self):
217218 if connection .is_debug_log_connection :
218219 stopped = connection ._debug_log_task .cancelled ()
219220 else :
220- stopped = connection ._receiver_task .cancelled ()
221+ stopped = connection ._receiver_task is not None and connection . _receiver_task .cancelled ()
221222
222223 if stopped or not connection ._ws .open :
223224 return self .ERROR
@@ -418,6 +419,14 @@ async def _open(self, endpoint, cacert):
418419 sock = self .proxy .socket ()
419420 server_hostname = "juju-app"
420421
422+ def _exit_tasks ():
423+ for task in jasyncio .all_tasks ():
424+ task .cancel ()
425+
426+ loop = jasyncio .get_running_loop ()
427+ for sig in (signal .SIGINT , signal .SIGTERM ):
428+ loop .add_signal_handler (sig , _exit_tasks )
429+
421430 return (await websockets .connect (
422431 url ,
423432 ssl = self ._get_ssl (cacert ),
@@ -431,25 +440,41 @@ async def close(self, to_reconnect=False):
431440 return
432441 self .monitor .close_called .set ()
433442
443+ # Cancel all the tasks (that we started):
434444 if self ._pinger_task :
435445 self ._pinger_task .cancel ()
436- self ._pinger_task = None
437446 if self ._receiver_task :
438447 self ._receiver_task .cancel ()
439- self ._receiver_task = None
440448 if self ._debug_log_task :
441449 self ._debug_log_task .cancel ()
442- self ._debug_log_task = None
443- # Allow a second for tasks to be cancelled
444- await jasyncio .sleep (1 )
445450
446451 if self ._ws and not self ._ws .closed :
447452 await self ._ws .close ()
448- self ._ws = None
453+
454+ if not to_reconnect :
455+ try :
456+ log .debug ('Gathering all tasks for connection close' )
457+
458+ # Avoid gathering the current task
459+ tasks_need_to_be_gathered = [task for task in jasyncio .all_tasks () if task != jasyncio .current_task ()]
460+ await jasyncio .gather (* tasks_need_to_be_gathered )
461+ except jasyncio .CancelledError :
462+ pass
463+ except websockets .exceptions .ConnectionClosed :
464+ pass
465+
466+ self ._pinger_task = None
467+ self ._receiver_task = None
468+ self ._debug_log_task = None
449469
450470 if self .proxy is not None :
451471 self .proxy .close ()
452472
473+ # Remove signal handlers
474+ loop = jasyncio .get_running_loop ()
475+ for sig in (signal .SIGINT , signal .SIGTERM ):
476+ loop .remove_signal_handler (sig )
477+
453478 async def _recv (self , request_id ):
454479 if not self .is_open :
455480 raise websockets .exceptions .ConnectionClosed (0 , 'websocket closed' )
@@ -517,15 +542,15 @@ async def _debug_logger(self):
517542 self .debug_log_shown_lines += number_of_lines_written
518543
519544 if self .debug_log_shown_lines >= self .debug_log_params ['limit' ]:
520- jasyncio .create_task (self .close ())
545+ jasyncio .create_task (self .close (), name = "Task_Close" )
521546 return
522547
523548 except KeyError as e :
524549 log .exception ('Unexpected debug line -- %s' % e )
525- jasyncio .create_task (self .close ())
550+ jasyncio .create_task (self .close (), name = "Task_Close" )
526551 raise
527552 except jasyncio .CancelledError :
528- jasyncio .create_task (self .close ())
553+ jasyncio .create_task (self .close (), name = "Task_Close" )
529554 raise
530555 except websockets .exceptions .ConnectionClosed :
531556 log .warning ('Debug Logger: Connection closed, reconnecting' )
@@ -536,7 +561,7 @@ async def _debug_logger(self):
536561 return
537562 except Exception as e :
538563 log .exception ("Error in debug logger : %s" % e )
539- jasyncio .create_task (self .close ())
564+ jasyncio .create_task (self .close (), name = "Task_Close" )
540565 raise
541566
542567 async def _receiver (self ):
@@ -552,7 +577,8 @@ async def _receiver(self):
552577 result = json .loads (result )
553578 await self .messages .put (result ['request-id' ], result )
554579 except jasyncio .CancelledError :
555- raise
580+ log .debug ('Receiver: Cancelled' )
581+ pass
556582 except websockets .exceptions .ConnectionClosed as e :
557583 log .warning ('Receiver: Connection closed, reconnecting' )
558584 await self .messages .put_all (e )
@@ -592,7 +618,8 @@ async def _do_ping():
592618 break
593619 await jasyncio .sleep (10 )
594620 except jasyncio .CancelledError :
595- raise
621+ log .debug ('Pinger: Cancelled' )
622+ pass
596623 except websockets .exceptions .ConnectionClosed :
597624 # The connection has closed - we can't do anything
598625 # more until the connection is restarted.
@@ -769,7 +796,7 @@ async def reconnect(self):
769796 if not self .is_debug_log_connection :
770797 self ._build_facades (res .get ('facades' , {}))
771798 if not self ._pinger_task :
772- self ._pinger_task = jasyncio .create_task (self ._pinger ())
799+ self ._pinger_task = jasyncio .create_task (self ._pinger (), name = "Task_Pinger" )
773800
774801 async def _connect (self , endpoints ):
775802 if len (endpoints ) == 0 :
@@ -820,12 +847,12 @@ async def _try_endpoint(endpoint, cacert, delay):
820847 # If this is a debug-log connection, and the _debug_log_task
821848 # is not created yet, then go ahead and schedule it
822849 if self .is_debug_log_connection and not self ._debug_log_task :
823- self ._debug_log_task = jasyncio .create_task (self ._debug_logger ())
850+ self ._debug_log_task = jasyncio .create_task (self ._debug_logger (), name = "Task_Debug_Log" )
824851
825852 # If this is regular connection, and we dont have a
826853 # receiver_task yet, then schedule a _receiver_task
827854 elif not self .is_debug_log_connection and not self ._receiver_task :
828- self ._receiver_task = jasyncio .create_task (self ._receiver ())
855+ self ._receiver_task = jasyncio .create_task (self ._receiver (), name = "Task_Receiver" )
829856
830857 log .debug ("Driver connected to juju %s" , self .addr )
831858 self .monitor .close_called .clear ()
@@ -880,7 +907,7 @@ async def _connect_with_redirect(self, endpoints):
880907 login_result = await self ._connect_with_login (e .endpoints )
881908 self ._build_facades (login_result .get ('facades' , {}))
882909 if not self ._pinger_task :
883- self ._pinger_task = jasyncio .create_task (self ._pinger ())
910+ self ._pinger_task = jasyncio .create_task (self ._pinger (), name = "Task_Pinger" )
884911
885912 # _build_facades takes the facade list that comes from the connection with the controller,
886913 # validates that the client knows about them (client_facades) and builds the facade list
0 commit comments