1414import shutil
1515import subprocess
1616import sys
17+ import time
18+ import warnings
1719from copy import deepcopy
1820from pathlib import Path
1921from tempfile import TemporaryDirectory
20- from typing import Any , Dict , List , Mapping
22+ from typing import Any , Dict , List , Mapping , Optional , Union
23+ from urllib .parse import urlparse
2124
2225import torch
2326
@@ -284,15 +287,75 @@ def get_output_path(self):
284287}
285288
286289
290+ def _download_algos_url (url : str , at_path : str ):
291+ """
292+ Downloads the algorithm templates release archive, and extracts it into a parent directory of the at_path folder.
293+ Returns a dictionary of the algorithm templates.
294+ """
295+ at_path = os .path .abspath (at_path )
296+ zip_download_dir = TemporaryDirectory ()
297+ algo_compressed_file = os .path .join (zip_download_dir .name , "algo_templates.tar.gz" )
298+
299+ download_attempts = 3
300+ for i in range (download_attempts ):
301+ try :
302+ download_and_extract (url = url , filepath = algo_compressed_file , output_dir = os .path .dirname (at_path ))
303+ except Exception as e :
304+ msg = f"Download and extract of { url } failed, attempt { i + 1 } /{ download_attempts } ."
305+ if i < download_attempts - 1 :
306+ warnings .warn (msg )
307+ time .sleep (i )
308+ else :
309+ zip_download_dir .cleanup ()
310+ raise ValueError (msg ) from e
311+ else :
312+ break
313+
314+ zip_download_dir .cleanup ()
315+
316+ algos_all = deepcopy (default_algos )
317+ for name in algos_all :
318+ algos_all [name ]["template_path" ] = os .path .join (at_path , algos_all [name ]["template_path" ])
319+
320+ return algos_all
321+
322+
323+ def _copy_algos_folder (folder , at_path ):
324+ """
325+ Copies the algorithm templates folder to at_path.
326+ Returns a dictionary of of algorithm templates.
327+ """
328+ folder = os .path .abspath (folder )
329+ at_path = os .path .abspath (at_path )
330+
331+ if folder != at_path :
332+ if os .path .exists (at_path ):
333+ shutil .rmtree (at_path )
334+ shutil .copytree (folder , at_path )
335+
336+ algos_all = {}
337+ for name in os .listdir (at_path ):
338+ if os .path .exists (os .path .join (folder , name , "scripts" , "algo.py" )):
339+ algos_all [name ] = dict (
340+ _target_ = f"{ name } .scripts.algo.{ name .capitalize ()} Algo" , template_path = os .path .join (at_path , name )
341+ )
342+ if len (algos_all ) == 0 :
343+ raise ValueError (f"Unable to find any algos in { folder } " )
344+
345+ return algos_all
346+
347+
287348class BundleGen (AlgoGen ):
288349 """
289350 This class generates a set of bundles according to the cross-validation folds, each of them can run independently.
290351
291352 Args:
292353 algo_path: the directory path to save the algorithm templates. Default is the current working dir.
293- algos: if dictionary, it outlines the algorithm to use. if None, automatically download the zip file
294- from the default link. if string, it represents the download link.
295- The current default options are released at:
354+ algos: If dictionary, it outlines the algorithm to use. If a list or a string, defines a subset of names of
355+ the algorithms to use, e.g. ('segresnet', 'dints') out of the full set of algorithm templates provided
356+ by templates_path_or_url. Defaults to None - to use all available algorithms.
357+ templates_path_or_url: the folder with the algorithm templates or a url. If None provided, the default template
358+ zip url will be downloaded and extracted into the algo_path. The current default options are released at:
296359 https://github.com/Project-MONAI/research-contributions/tree/main/auto3dseg
297360 data_stats_filename: the path to the data stats file (generated by DataAnalyzer)
298361 data_src_cfg_name: the path to the data source config YAML file. The config will be in a form of
@@ -303,22 +366,39 @@ class BundleGen(AlgoGen):
303366 python -m monai.apps.auto3dseg BundleGen generate --data_stats_filename="../algorithms/data_stats.yaml"
304367 """
305368
306- def __init__ (self , algo_path : str = "." , algos = None , data_stats_filename = None , data_src_cfg_name = None ):
307- self .algos : Any = []
308-
309- if algos is None or isinstance (algos , str ):
310- # trigger the download process
311- zip_download_dir = TemporaryDirectory ()
312- algo_compressed_file = os .path .join (zip_download_dir .name , "algo_templates.tar.gz" )
313- download_and_extract (default_algo_zip if algos is None else algos , algo_compressed_file , algo_path )
314- zip_download_dir .cleanup ()
315- sys .path .insert (0 , os .path .join (algo_path , "algorithm_templates" ))
316- algos = deepcopy (default_algos )
317- for name in algos :
318- algos [name ]["template_path" ] = os .path .join (
319- algo_path , "algorithm_templates" , algos [name ]["template_path" ]
320- )
369+ def __init__ (
370+ self ,
371+ algo_path : str = "." ,
372+ algos : Optional [Union [Dict , List , str ]] = None ,
373+ templates_path_or_url : Optional [str ] = None ,
374+ data_stats_filename : Optional [str ] = None ,
375+ data_src_cfg_name : Optional [str ] = None ,
376+ ):
377+
378+ if algos is None or isinstance (algos , (list , tuple , str )):
379+
380+ if templates_path_or_url is None :
381+ templates_path_or_url = default_algo_zip
382+
383+ at_path = os .path .join (os .path .abspath (algo_path ), "algorithm_templates" )
384+
385+ if os .path .isdir (templates_path_or_url ):
386+ # if a local folder, copy if necessary
387+ algos_all = _copy_algos_folder (folder = templates_path_or_url , at_path = at_path )
388+ elif urlparse (templates_path_or_url ).scheme in ("http" , "https" ):
389+ # if url, trigger the download and extract process
390+ algos_all = _download_algos_url (url = templates_path_or_url , at_path = at_path )
391+ else :
392+ raise ValueError (f"{ self .__class__ } received invalid templates_path_or_url: { templates_path_or_url } " )
393+
394+ if algos is not None :
395+ algos = {k : v for k , v in algos_all .items () if k in ensure_tuple (algos )} # keep only provided
396+ if len (algos ) == 0 :
397+ raise ValueError (f"Unable to find provided algos in { algos_all } " )
398+ else :
399+ algos = algos_all
321400
401+ self .algos : Any = []
322402 if isinstance (algos , dict ):
323403 for algo_name , algo_params in algos .items ():
324404
@@ -327,7 +407,9 @@ def __init__(self, algo_path: str = ".", algos=None, data_stats_filename=None, d
327407 sys .path .append (template_path )
328408
329409 try :
330- self .algos .append (ConfigParser (algo_params ).get_parsed_content ())
410+ onealgo = ConfigParser (algo_params ).get_parsed_content ()
411+ onealgo .name = algo_name
412+ self .algos .append (onealgo )
331413 except RuntimeError as e :
332414 msg = """Please make sure the folder structure of an Algo Template follows
333415 [algo_name]
@@ -339,9 +421,8 @@ def __init__(self, algo_path: str = ".", algos=None, data_stats_filename=None, d
339421 └── validate.py
340422 """
341423 raise RuntimeError (msg ) from e
342- self .algos [- 1 ].name = algo_name
343424 else :
344- self . algos = ensure_tuple ( algos )
425+ raise ValueError ( "Unexpected error algos is not a dict" )
345426
346427 self .data_stats_filename = data_stats_filename
347428 self .data_src_cfg_filename = data_src_cfg_name
0 commit comments