Skip to content

Commit aef2ecc

Browse files
authored
AutoRunner enhancements: local folder and algorithm subsets (#5623)
AutoRunner class improvements - adds auto-runner option **templates_path_or_url** to optionally accept a local templates folder location. The default (None) is the same: to download the release zip e.g. AutoRunner(templates_path_or_url='/my/local/algorithm_templates') - enhances **algos** input option to define a subset of algorithm names to run (e.g. AutoRunner(algos=['segresnet','dints']) ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: myron <amyronenko@nvidia.com>
1 parent f380cf2 commit aef2ecc

2 files changed

Lines changed: 136 additions & 25 deletions

File tree

monai/apps/auto3dseg/auto_runner.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import subprocess
1515
from copy import deepcopy
1616
from time import sleep
17-
from typing import Any, Dict, List, Optional, Tuple, Union
17+
from typing import Any, Dict, List, Optional, Union
1818

1919
import numpy as np
2020
import torch
@@ -64,7 +64,11 @@ class AutoRunner:
6464
work_dir: working directory to save the intermediate and final results.
6565
input: the configuration dictionary or the file path to the configuration in form of YAML.
6666
The configuration should contain datalist, dataroot, modality, multigpu, and class_names info.
67-
algos: optionally specify a list of algorithms to use.
67+
algos: optionally specify algorithms to use. If a dictionary, must be in the form
68+
{"algname": dict(_target_="algname.scripts.algo.AlgnameAlgo", template_path="algname"), ...}
69+
If a list or a string, defines a subset of names of the algorithms to use, e.g. 'segresnet' or
70+
['segresnet', 'dints'] out of the full set of algorithm templates provided by templates_path_or_url.
71+
Defaults to None, to use all available algorithms.
6872
analyze: on/off switch to run DataAnalyzer and generate a datastats report. Defaults to None, to automatically
6973
decide based on cache, and run data analysis only if we have not completed this step yet.
7074
algo_gen: on/off switch to run AlgoGen and generate templated BundleAlgos. Defaults to None, to automatically
@@ -79,6 +83,8 @@ class AutoRunner:
7983
datasets.
8084
not_use_cache: if the value is True, it will ignore all cached results in data analysis,
8185
algorithm generation, or training, and start the pipeline from scratch.
86+
templates_path_or_url: the folder with the algorithm templates or a url. If None provided, the default template
87+
zip url will be downloaded and extracted into the work_dir.
8288
kwargs: image writing parameters for the ensemble inference. The kwargs format follows the SaveImage
8389
transform. For more information, check https://docs.monai.io/en/stable/transforms.html#saveimage.
8490
@@ -106,6 +112,27 @@ class AutoRunner:
106112
runner = AutoRunner(work_dir=work_dir, input=input)
107113
runner.run()
108114
115+
- User can specify a subset of algorithms to use and run AutoRunner:
116+
117+
.. code-block:: python
118+
119+
work_dir = "./work_dir"
120+
input = "path_to_yaml_data_cfg"
121+
algos = ["segresnet", "dints"]
122+
runner = AutoRunner(work_dir=work_dir, input=input, algos=algos)
123+
runner.run()
124+
125+
- User can specify a a local folder with algorithms templates and run AutoRunner:
126+
127+
.. code-block:: python
128+
129+
work_dir = "./work_dir"
130+
input = "path_to_yaml_data_cfg"
131+
algos = "segresnet"
132+
templates_path_or_url = "./local_path_to/algorithm_templates"
133+
runner = AutoRunner(work_dir=work_dir, input=input, algos=algos, templates_path_or_url=templates_path_or_url)
134+
runner.run()
135+
109136
- User can specify training parameters by:
110137
111138
.. code-block:: python
@@ -181,14 +208,15 @@ def __init__(
181208
self,
182209
work_dir: str = "./work_dir",
183210
input: Union[Dict[str, Any], str, None] = None,
184-
algos: Optional[Union[Tuple, List]] = None,
211+
algos: Optional[Union[Dict, List, str]] = None,
185212
analyze: Optional[bool] = None,
186213
algo_gen: Optional[bool] = None,
187214
train: Optional[bool] = None,
188215
hpo: bool = False,
189216
hpo_backend: str = "nni",
190217
ensemble: bool = True,
191218
not_use_cache: bool = False,
219+
templates_path_or_url: Optional[str] = None,
192220
**kwargs,
193221
):
194222

@@ -198,6 +226,7 @@ def __init__(
198226
self.work_dir = os.path.abspath(work_dir)
199227
self.data_src_cfg_name = os.path.join(self.work_dir, "input.yaml")
200228
self.algos = algos
229+
self.templates_path_or_url = templates_path_or_url
201230

202231
if input is None and os.path.isfile(self.data_src_cfg_name):
203232
input = self.data_src_cfg_name
@@ -558,6 +587,7 @@ def run(self):
558587
bundle_generator = BundleGen(
559588
algos=self.algos,
560589
algo_path=self.work_dir,
590+
templates_path_or_url=self.templates_path_or_url,
561591
data_stats_filename=self.datastats_filename,
562592
data_src_cfg_name=self.data_src_cfg_name,
563593
)

monai/apps/auto3dseg/bundle_gen.py

Lines changed: 103 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,13 @@
1414
import shutil
1515
import subprocess
1616
import sys
17+
import time
18+
import warnings
1719
from copy import deepcopy
1820
from pathlib import Path
1921
from tempfile import TemporaryDirectory
20-
from typing import Any, Dict, List, Mapping
22+
from typing import Any, Dict, List, Mapping, Optional, Union
23+
from urllib.parse import urlparse
2124

2225
import torch
2326

@@ -284,15 +287,75 @@ def get_output_path(self):
284287
}
285288

286289

290+
def _download_algos_url(url: str, at_path: str):
291+
"""
292+
Downloads the algorithm templates release archive, and extracts it into a parent directory of the at_path folder.
293+
Returns a dictionary of the algorithm templates.
294+
"""
295+
at_path = os.path.abspath(at_path)
296+
zip_download_dir = TemporaryDirectory()
297+
algo_compressed_file = os.path.join(zip_download_dir.name, "algo_templates.tar.gz")
298+
299+
download_attempts = 3
300+
for i in range(download_attempts):
301+
try:
302+
download_and_extract(url=url, filepath=algo_compressed_file, output_dir=os.path.dirname(at_path))
303+
except Exception as e:
304+
msg = f"Download and extract of {url} failed, attempt {i+1}/{download_attempts}."
305+
if i < download_attempts - 1:
306+
warnings.warn(msg)
307+
time.sleep(i)
308+
else:
309+
zip_download_dir.cleanup()
310+
raise ValueError(msg) from e
311+
else:
312+
break
313+
314+
zip_download_dir.cleanup()
315+
316+
algos_all = deepcopy(default_algos)
317+
for name in algos_all:
318+
algos_all[name]["template_path"] = os.path.join(at_path, algos_all[name]["template_path"])
319+
320+
return algos_all
321+
322+
323+
def _copy_algos_folder(folder, at_path):
324+
"""
325+
Copies the algorithm templates folder to at_path.
326+
Returns a dictionary of of algorithm templates.
327+
"""
328+
folder = os.path.abspath(folder)
329+
at_path = os.path.abspath(at_path)
330+
331+
if folder != at_path:
332+
if os.path.exists(at_path):
333+
shutil.rmtree(at_path)
334+
shutil.copytree(folder, at_path)
335+
336+
algos_all = {}
337+
for name in os.listdir(at_path):
338+
if os.path.exists(os.path.join(folder, name, "scripts", "algo.py")):
339+
algos_all[name] = dict(
340+
_target_=f"{name}.scripts.algo.{name.capitalize()}Algo", template_path=os.path.join(at_path, name)
341+
)
342+
if len(algos_all) == 0:
343+
raise ValueError(f"Unable to find any algos in {folder}")
344+
345+
return algos_all
346+
347+
287348
class BundleGen(AlgoGen):
288349
"""
289350
This class generates a set of bundles according to the cross-validation folds, each of them can run independently.
290351
291352
Args:
292353
algo_path: the directory path to save the algorithm templates. Default is the current working dir.
293-
algos: if dictionary, it outlines the algorithm to use. if None, automatically download the zip file
294-
from the default link. if string, it represents the download link.
295-
The current default options are released at:
354+
algos: If dictionary, it outlines the algorithm to use. If a list or a string, defines a subset of names of
355+
the algorithms to use, e.g. ('segresnet', 'dints') out of the full set of algorithm templates provided
356+
by templates_path_or_url. Defaults to None - to use all available algorithms.
357+
templates_path_or_url: the folder with the algorithm templates or a url. If None provided, the default template
358+
zip url will be downloaded and extracted into the algo_path. The current default options are released at:
296359
https://github.com/Project-MONAI/research-contributions/tree/main/auto3dseg
297360
data_stats_filename: the path to the data stats file (generated by DataAnalyzer)
298361
data_src_cfg_name: the path to the data source config YAML file. The config will be in a form of
@@ -303,22 +366,39 @@ class BundleGen(AlgoGen):
303366
python -m monai.apps.auto3dseg BundleGen generate --data_stats_filename="../algorithms/data_stats.yaml"
304367
"""
305368

306-
def __init__(self, algo_path: str = ".", algos=None, data_stats_filename=None, data_src_cfg_name=None):
307-
self.algos: Any = []
308-
309-
if algos is None or isinstance(algos, str):
310-
# trigger the download process
311-
zip_download_dir = TemporaryDirectory()
312-
algo_compressed_file = os.path.join(zip_download_dir.name, "algo_templates.tar.gz")
313-
download_and_extract(default_algo_zip if algos is None else algos, algo_compressed_file, algo_path)
314-
zip_download_dir.cleanup()
315-
sys.path.insert(0, os.path.join(algo_path, "algorithm_templates"))
316-
algos = deepcopy(default_algos)
317-
for name in algos:
318-
algos[name]["template_path"] = os.path.join(
319-
algo_path, "algorithm_templates", algos[name]["template_path"]
320-
)
369+
def __init__(
370+
self,
371+
algo_path: str = ".",
372+
algos: Optional[Union[Dict, List, str]] = None,
373+
templates_path_or_url: Optional[str] = None,
374+
data_stats_filename: Optional[str] = None,
375+
data_src_cfg_name: Optional[str] = None,
376+
):
377+
378+
if algos is None or isinstance(algos, (list, tuple, str)):
379+
380+
if templates_path_or_url is None:
381+
templates_path_or_url = default_algo_zip
382+
383+
at_path = os.path.join(os.path.abspath(algo_path), "algorithm_templates")
384+
385+
if os.path.isdir(templates_path_or_url):
386+
# if a local folder, copy if necessary
387+
algos_all = _copy_algos_folder(folder=templates_path_or_url, at_path=at_path)
388+
elif urlparse(templates_path_or_url).scheme in ("http", "https"):
389+
# if url, trigger the download and extract process
390+
algos_all = _download_algos_url(url=templates_path_or_url, at_path=at_path)
391+
else:
392+
raise ValueError(f"{self.__class__} received invalid templates_path_or_url: {templates_path_or_url}")
393+
394+
if algos is not None:
395+
algos = {k: v for k, v in algos_all.items() if k in ensure_tuple(algos)} # keep only provided
396+
if len(algos) == 0:
397+
raise ValueError(f"Unable to find provided algos in {algos_all}")
398+
else:
399+
algos = algos_all
321400

401+
self.algos: Any = []
322402
if isinstance(algos, dict):
323403
for algo_name, algo_params in algos.items():
324404

@@ -327,7 +407,9 @@ def __init__(self, algo_path: str = ".", algos=None, data_stats_filename=None, d
327407
sys.path.append(template_path)
328408

329409
try:
330-
self.algos.append(ConfigParser(algo_params).get_parsed_content())
410+
onealgo = ConfigParser(algo_params).get_parsed_content()
411+
onealgo.name = algo_name
412+
self.algos.append(onealgo)
331413
except RuntimeError as e:
332414
msg = """Please make sure the folder structure of an Algo Template follows
333415
[algo_name]
@@ -339,9 +421,8 @@ def __init__(self, algo_path: str = ".", algos=None, data_stats_filename=None, d
339421
└── validate.py
340422
"""
341423
raise RuntimeError(msg) from e
342-
self.algos[-1].name = algo_name
343424
else:
344-
self.algos = ensure_tuple(algos)
425+
raise ValueError("Unexpected error algos is not a dict")
345426

346427
self.data_stats_filename = data_stats_filename
347428
self.data_src_cfg_filename = data_src_cfg_name

0 commit comments

Comments
 (0)