diff --git a/ipsframework/services.py b/ipsframework/services.py index 81f80d21..525f7805 100644 --- a/ipsframework/services.py +++ b/ipsframework/services.py @@ -2748,7 +2748,8 @@ def send_ensemble_instance_to_portal(ensemble_name: str, data_path: Path) -> Non try: # Note that we *always* use Dask to run the ensemble tasks num_submitted = self.submit_tasks( - task_pool_name, # block=True, + task_pool_name, + block=True, use_dask=True, dask_nodes=num_nodes, dask_ppw=cores_per_instance, @@ -3005,7 +3006,7 @@ def _process_dask_event(self, event): def submit_dask_tasks( self, - block=True, + block= True, dask_nodes=1, dask_ppw=None, use_shifter=False, @@ -3248,6 +3249,7 @@ def _make_worker_args(num_workers: int, num_threads: int, use_shifter: bool, shi hwthreads=hwthreads)) try: + # FIXME this is deprecated, but be mindful of blithely deleting file_id = str(self.services._portal_runid) if self.services._portal_runid > 0 else self.services._fallback_portal_runid self.worker_event_logfile = services.sim_name + '_' + file_id + '_' + self.name + '_{}.json' self.services.debug(f'Worker event log file: {self.worker_event_logfile}') @@ -3277,6 +3279,23 @@ def _make_worker_args(num_workers: int, num_threads: int, use_shifter: bool, shi ) self.active_tasks = self.queued_tasks self.queued_tasks = {} + + if block: + self.services.debug(f'submit_dask_tasks: blocking tasks to await ' + f'results') + # Await all the futures to finish, thereby blocking until they + # are all done. + result = self.dask_client.gather(self.futures, direct=True) + self.services.debug(f'submit_dask_tasks: have {len(result)} ' + f'results, block released') + # TODO check actual result values for problems + + # Set this to empty list so that get_dask_finished_tasks_status + # doesn't try to gather() needlessly again. + self.futures = [] + else: + self.services.debug(f'submit_dask_tasks: not blocking tasks') + return len(self.futures) def submit_tasks( @@ -3451,20 +3470,27 @@ def get_dask_finished_tasks_status(self): return {} - self.services.debug('get_dask_finished_tasks_status: before gather()') - result = self.dask_client.gather(self.futures) - self.services.debug('get_dask_finished_tasks_status: after gather()') - - # If we don't have a result, then there were no tasks to gather. - if result is None: - self.services.warning('No futures available in call to finished ') - self._shutdown_dask() - return {} + elif len(self.futures) > 0: + # submit_dask_tasks was called with block = False, so we we + # await here for the futures. + # FIXME This may not be an ideal location for this + self.services.debug('get_dask_finished_tasks_status: before gather()') + result = self.dask_client.gather(self.futures, direct=True) + self.services.debug('get_dask_finished_tasks_status: after gather()') + # If we don't have a result, then there were no tasks to gather. + if result is None: + self.services.warning( + 'No futures available in call to finished ') + self._shutdown_dask() + return {} + else: + self.services.debug( + f'get_dask_finished_tasks_status: have {len(result)} futures') else: - self.services.debug(f'get_dask_finished_tasks_status: have {len(result)} futures') + # This is ok if submit_dask_tasks.block = True, but we echo this + # anyway in debug mode as a reality check. + self.services.debug('get_dask_finished_tasks_status: have no futures') - worker_names = [''.join(c for c in worker['name'] if c.isalnum()) for worker in self.dask_client.scheduler_info()['workers'].values()] - self.services.debug(f'get_dask_finished_tasks_status: worker_names: {worker_names!s}') # NOTE: You may get an exception stack trace from Dask, this is currently not believed to cause an issue. # We no longer need Dask running, so shut it down. @@ -3472,6 +3498,8 @@ def get_dask_finished_tasks_status(self): self._shutdown_dask() self.services.debug(f'get_dask_finished_tasks_status: after _shutdown_dask()') + # TODO These probably should be migrated to _shutdown_dask() since + # these are part of that housekeeping. self.finished_tasks = {} self.active_tasks = {} self.services.wait_task(self.dask_workers_tid)