Skip to content

Commit d3e9826

Browse files
Make template_path in BundleAlgo consistently point to the algorithm_templates folder (#6505)
Fixes #6502 Fixes #6501 ### Background - BundleAlgo `template_path` points to `algorithm_templates/<Algo>` <= MONAI 1.1. Recent PR #6436 changed that to `algorithm_templates` and make it easier to instantiate the BundleAlgo from "algorithm_templates" folder when the folders are moved. - But `BundleGen` did not update accordingly and still save `algorithm_templates/<Algo>` as the `template_path` , which breaks NNI HPO feature which relies heavily on these paths. ### Description - Make `template_path` points to `algorithm_templates` - Avoid using the latest `filelock` (3.12.0) which breaks nni - Fix comments in #6485 (comment) in #6485 ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [x] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. --------- Signed-off-by: Mingxin <18563433+mingxin-zheng@users.noreply.github.com>
1 parent f1a1677 commit d3e9826

3 files changed

Lines changed: 22 additions & 18 deletions

File tree

monai/apps/auto3dseg/bundle_gen.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ class BundleAlgo(Algo):
5252
5353
from monai.apps.auto3dseg import BundleAlgo
5454
55-
data_stats_yaml = "/workspace/datastats.yaml"
56-
algo = BundleAlgo(template_path=../algorithms/templates/segresnet2d/configs)
55+
data_stats_yaml = "../datastats.yaml"
56+
algo = BundleAlgo(template_path="../algorithm_templates")
5757
algo.set_data_stats(data_stats_yaml)
5858
# algo.set_data_src("../data_src.json")
5959
algo.export_to_disk(".", algo_name="segresnet2d_1")
@@ -69,7 +69,8 @@ def __init__(self, template_path: PathLike):
6969
Create an Algo instance based on the predefined Algo template.
7070
7171
Args:
72-
template_path: path to the root of the algo template.
72+
template_path: path to a folder that contains the algorithm templates.
73+
Please check https://github.com/Project-MONAI/research-contributions/tree/main/auto3dseg/algorithm_templates
7374
7475
"""
7576

@@ -154,7 +155,8 @@ def export_to_disk(self, output_path: str, algo_name: str, **kwargs: Any) -> Non
154155
os.makedirs(self.output_path, exist_ok=True)
155156
if os.path.isdir(self.output_path):
156157
shutil.rmtree(self.output_path)
157-
shutil.copytree(str(self.template_path), self.output_path)
158+
# copy algorithm_templates/<Algo> to the working directory output_path
159+
shutil.copytree(os.path.join(str(self.template_path), self.name), self.output_path)
158160
else:
159161
self.output_path = str(self.template_path)
160162
if kwargs.pop("fill_template", True):
@@ -342,10 +344,10 @@ def get_output_path(self):
342344

343345
# default algorithms
344346
default_algos = {
345-
"segresnet2d": dict(_target_="segresnet2d.scripts.algo.Segresnet2dAlgo", template_path="segresnet2d"),
346-
"dints": dict(_target_="dints.scripts.algo.DintsAlgo", template_path="dints"),
347-
"swinunetr": dict(_target_="swinunetr.scripts.algo.SwinunetrAlgo", template_path="swinunetr"),
348-
"segresnet": dict(_target_="segresnet.scripts.algo.SegresnetAlgo", template_path="segresnet"),
347+
"segresnet2d": dict(_target_="segresnet2d.scripts.algo.Segresnet2dAlgo"),
348+
"dints": dict(_target_="dints.scripts.algo.DintsAlgo"),
349+
"swinunetr": dict(_target_="swinunetr.scripts.algo.SwinunetrAlgo"),
350+
"segresnet": dict(_target_="segresnet.scripts.algo.SegresnetAlgo"),
349351
}
350352

351353

@@ -377,7 +379,7 @@ def _download_algos_url(url: str, at_path: str) -> dict[str, dict[str, str]]:
377379

378380
algos_all = deepcopy(default_algos)
379381
for name in algos_all:
380-
algos_all[name]["template_path"] = os.path.join(at_path, algos_all[name]["template_path"])
382+
algos_all[name]["template_path"] = at_path
381383

382384
return algos_all
383385

@@ -398,9 +400,7 @@ def _copy_algos_folder(folder, at_path):
398400
algos_all = {}
399401
for name in os.listdir(at_path):
400402
if os.path.exists(os.path.join(folder, name, "scripts", "algo.py")):
401-
algos_all[name] = dict(
402-
_target_=f"{name}.scripts.algo.{name.capitalize()}Algo", template_path=os.path.join(at_path, name)
403-
)
403+
algos_all[name] = dict(_target_=f"{name}.scripts.algo.{name.capitalize()}Algo", template_path=at_path)
404404
logger.info(f"Copying template: {name} -- {algos_all[name]}")
405405
if not algos_all:
406406
raise ValueError(f"Unable to find any algos in {folder}")
@@ -463,7 +463,7 @@ def __init__(
463463
self.algos: Any = []
464464
if isinstance(algos, dict):
465465
for algo_name, algo_params in sorted(algos.items()):
466-
template_path = os.path.dirname(algo_params.get("template_path", "."))
466+
template_path = algo_params.get("template_path", ".")
467467
if len(template_path) > 0 and template_path not in sys.path:
468468
sys.path.append(template_path)
469469

@@ -486,7 +486,7 @@ def __init__(
486486
raise ValueError("Unexpected error algos is not a dict")
487487

488488
self.data_stats_filename = data_stats_filename
489-
self.data_src_cfg_filename = data_src_cfg_name
489+
self.data_src_cfg_name = data_src_cfg_name
490490
self.history: list[dict] = []
491491

492492
def set_data_stats(self, data_stats_filename: str) -> None:
@@ -502,18 +502,18 @@ def get_data_stats(self):
502502
"""Get the filename of the data stats"""
503503
return self.data_stats_filename
504504

505-
def set_data_src(self, data_src_cfg_filename):
505+
def set_data_src(self, data_src_cfg_name):
506506
"""
507507
Set the data source filename
508508
509509
Args:
510-
data_src_cfg_filename: filename of data_source file
510+
data_src_cfg_name: filename of data_source file
511511
"""
512-
self.data_src_cfg_filename = data_src_cfg_filename
512+
self.data_src_cfg_name = data_src_cfg_name
513513

514514
def get_data_src(self):
515515
"""Get the data source filename"""
516-
return self.data_src_cfg_filename
516+
return self.data_src_cfg_name
517517

518518
def get_history(self) -> list:
519519
"""Get the history of the bundleAlgo object with their names/identifiers"""

requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,4 @@ git+https://github.com/Project-MONAI/MetricsReloaded@monai-support#egg=MetricsRe
5151
onnx>=1.13.0
5252
onnxruntime; python_version <= '3.10'
5353
typeguard<3 # https://github.com/microsoft/nni/issues/5457
54+
filelock!=3.12.0 # https://github.com/microsoft/nni/issues/5523

tests/test_auto3dseg_bundlegen.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import os
1515
import shutil
16+
import sys
1617
import tempfile
1718
import unittest
1819

@@ -126,6 +127,7 @@ def test_move_bundle_gen_folder(self) -> None:
126127
data_src_cfg = os.path.join(work_dir, "data_src_cfg.yaml")
127128
ConfigParser.export_config_file(data_src, data_src_cfg)
128129

130+
sys_path = sys.path.copy()
129131
with skip_if_downloading_fails():
130132
bundle_generator = BundleGen(
131133
algo_path=work_dir,
@@ -138,6 +140,7 @@ def test_move_bundle_gen_folder(self) -> None:
138140
history_before = bundle_generator.get_history()
139141
export_bundle_algo_history(history_before)
140142

143+
sys.path = sys_path # prevent the import_bundle_algo_history from using the path "work_dir/algorithm_templates"
141144
tempfile.TemporaryDirectory()
142145
work_dir_new = os.path.join(test_path, "workdir_2")
143146
shutil.move(work_dir, work_dir_new)

0 commit comments

Comments
 (0)