2525import torch
2626from torch .cuda import is_available
2727
28+ from monai .apps .mmars .mmars import _get_all_ngc_models
2829from monai .apps .utils import _basename , download_url , extractall , get_logger
2930from monai .bundle .config_item import ConfigComponent
3031from monai .bundle .config_parser import ConfigParser
4243
4344logger = get_logger (module_name = __name__ )
4445
46+ # set BUNDLE_DOWNLOAD_SRC="ngc" to use NGC source in default for bundle download
47+ download_source = os .environ .get ("BUNDLE_DOWNLOAD_SRC" , "github" )
48+
4549
4650def _update_args (args : Optional [Union [str , Dict ]] = None , ignore_none : bool = True , ** kwargs ) -> Dict :
4751 """
@@ -130,9 +134,11 @@ def _get_git_release_url(repo_owner: str, repo_name: str, tag_name: str, filenam
130134 return f"https://github.com/{ repo_owner } /{ repo_name } /releases/download/{ tag_name } /{ filename } "
131135
132136
137+ def _get_ngc_bundle_url (model_name : str , version : str ):
138+ return f"https://api.ngc.nvidia.com/v2/models/nvidia/monaitoolkit/{ model_name } /versions/{ version } /zip"
139+
140+
133141def _download_from_github (repo : str , download_path : Path , filename : str , progress : bool = True ):
134- if len (repo .split ("/" )) != 3 :
135- raise ValueError ("if source is `github`, repo should be in the form of `repo_owner/repo_name/release_tag`." )
136142 repo_owner , repo_name , tag_name = repo .split ("/" )
137143 if ".zip" not in filename :
138144 filename += ".zip"
@@ -142,6 +148,45 @@ def _download_from_github(repo: str, download_path: Path, filename: str, progres
142148 extractall (filepath = filepath , output_dir = download_path , has_base = True )
143149
144150
151+ def _add_ngc_prefix (name : str , prefix : str = "monai_" ):
152+ if name .startswith (prefix ):
153+ return name
154+ return f"{ prefix } { name } "
155+
156+
157+ def _remove_ngc_prefix (name : str , prefix : str = "monai_" ):
158+ if name .startswith (prefix ):
159+ return name [len (prefix ) :]
160+ return name
161+
162+
163+ def _download_from_ngc (download_path : Path , filename : str , version : str , remove_prefix : Optional [str ], progress : bool ):
164+ # ensure prefix is contained
165+ filename = _add_ngc_prefix (filename )
166+ url = _get_ngc_bundle_url (model_name = filename , version = version )
167+ filepath = download_path / f"{ filename } _v{ version } .zip"
168+ if remove_prefix :
169+ filename = _remove_ngc_prefix (filename )
170+ extract_path = download_path / f"{ filename } "
171+ download_url (url = url , filepath = filepath , hash_val = None , progress = progress )
172+ extractall (filepath = filepath , output_dir = extract_path , has_base = True )
173+
174+
175+ def _get_latest_bundle_version (source : str , name : str , repo : str ):
176+ if source == "ngc" :
177+ name = _add_ngc_prefix (name )
178+ model_dict = _get_all_ngc_models (name )
179+ for v in model_dict .values ():
180+ if v ["name" ] == name :
181+ return v ["latest" ]
182+ return None
183+ elif source == "github" :
184+ repo_owner , repo_name , tag_name = repo .split ("/" )
185+ return get_bundle_versions (name , repo = os .path .join (repo_owner , repo_name ), tag = tag_name )["latest_version" ]
186+ else :
187+ raise ValueError (f"To get the latest bundle version, source should be 'github' or 'ngc', got { source } ." )
188+
189+
145190def _process_bundle_dir (bundle_dir : Optional [PathLike ] = None ):
146191 if bundle_dir is None :
147192 get_dir , has_home = optional_import ("torch.hub" , name = "get_dir" )
@@ -156,9 +201,10 @@ def download(
156201 name : Optional [str ] = None ,
157202 version : Optional [str ] = None ,
158203 bundle_dir : Optional [PathLike ] = None ,
159- source : str = "github" ,
160- repo : str = "Project-MONAI/model-zoo/hosting_storage_v1" ,
204+ source : str = download_source ,
205+ repo : Optional [ str ] = None ,
161206 url : Optional [str ] = None ,
207+ remove_prefix : Optional [str ] = "monai_" ,
162208 progress : bool = True ,
163209 args_file : Optional [str ] = None ,
164210):
@@ -175,9 +221,12 @@ def download(
175221 # Execute this module as a CLI entry, and download bundle from the model-zoo repo:
176222 python -m monai.bundle download --name <bundle_name> --version "0.1.0" --bundle_dir "./"
177223
178- # Execute this module as a CLI entry, and download bundle:
224+ # Execute this module as a CLI entry, and download bundle from specified github repo :
179225 python -m monai.bundle download --name <bundle_name> --source "github" --repo "repo_owner/repo_name/release_tag"
180226
227+ # Execute this module as a CLI entry, and download bundle from ngc with latest version:
228+ python -m monai.bundle download --name <bundle_name> --source "ngc" --bundle_dir "./"
229+
181230 # Execute this module as a CLI entry, and download bundle via URL:
182231 python -m monai.bundle download --name <bundle_name> --url <url>
183232
@@ -190,18 +239,27 @@ def download(
190239
191240 Args:
192241 name: bundle name. If `None` and `url` is `None`, it must be provided in `args_file`.
193- for example: "spleen_ct_segmentation", "prostate_mri_anatomy" in the model-zoo:
242+ for example:
243+ "spleen_ct_segmentation", "prostate_mri_anatomy" in model-zoo:
194244 https://github.com/Project-MONAI/model-zoo/releases/tag/hosting_storage_v1.
195- version: version name of the target bundle to download, like: "0.1.0".
245+ "monai_brats_mri_segmentation" in ngc:
246+ https://catalog.ngc.nvidia.com/models?filters=&orderBy=scoreDESC&query=monai.
247+ version: version name of the target bundle to download, like: "0.1.0". If `None`, will download
248+ the latest version.
196249 bundle_dir: target directory to store the downloaded data.
197250 Default is `bundle` subfolder under `torch.hub.get_dir()`.
198251 source: storage location name. This argument is used when `url` is `None`.
199- "github" is currently the only supported value.
200- repo: repo name. This argument is used when `url` is `None`.
201- If `source` is "github", it should be in the form of "repo_owner/repo_name/release_tag".
252+ In default, the value is achieved from the environment variable BUNDLE_DOWNLOAD_SRC, and
253+ it should be "ngc" or "github".
254+ repo: repo name. This argument is used when `url` is `None` and `source` is "github".
255+ If used, it should be in the form of "repo_owner/repo_name/release_tag".
202256 url: url to download the data. If not `None`, data will be downloaded directly
203257 and `source` will not be checked.
204258 If `name` is `None`, filename is determined by `monai.apps.utils._basename(url)`.
259+ remove_prefix: This argument is used when `source` is "ngc". Currently, all ngc bundles
260+ have the ``monai_`` prefix, which is not existing in their model zoo contrasts. In order to
261+ maintain the consistency between these two sources, remove prefix is necessary.
262+ Therefore, if specified, downloaded folder name will remove the prefix.
205263 progress: whether to display a progress bar.
206264 args_file: a JSON or YAML file to provide default values for all the args in this function.
207265 so that the command line inputs can be simplified.
@@ -215,17 +273,20 @@ def download(
215273 source = source ,
216274 repo = repo ,
217275 url = url ,
276+ remove_prefix = remove_prefix ,
218277 progress = progress ,
219278 )
220279
221280 _log_input_summary (tag = "download" , args = _args )
222- source_ , repo_ , progress_ , name_ , version_ , bundle_dir_ , url_ = _pop_args (
223- _args , "source" , "repo " , "progress" , name = None , version = None , bundle_dir = None , url = None
281+ source_ , progress_ , remove_prefix_ , repo_ , name_ , version_ , bundle_dir_ , url_ = _pop_args (
282+ _args , "source" , "progress " , remove_prefix = None , repo = None , name = None , version = None , bundle_dir = None , url = None
224283 )
225284
226285 bundle_dir_ = _process_bundle_dir (bundle_dir_ )
227- if name_ is not None and version_ is not None :
228- name_ = "_v" .join ([name_ , version_ ])
286+ if repo_ is None :
287+ repo_ = "Project-MONAI/model-zoo/hosting_storage_v1"
288+ if len (repo_ .split ("/" )) != 3 :
289+ raise ValueError ("repo should be in the form of `repo_owner/repo_name/release_tag`." )
229290
230291 if url_ is not None :
231292 if name_ is not None :
@@ -234,14 +295,27 @@ def download(
234295 filepath = bundle_dir_ / f"{ _basename (url_ )} "
235296 download_url (url = url_ , filepath = filepath , hash_val = None , progress = progress_ )
236297 extractall (filepath = filepath , output_dir = bundle_dir_ , has_base = True )
237- elif source_ == "github" :
238- if name_ is None :
239- raise ValueError (f"To download from source: Github, `name` must be provided, got { name_ } ." )
240- _download_from_github (repo = repo_ , download_path = bundle_dir_ , filename = name_ , progress = progress_ )
241298 else :
242- raise NotImplementedError (
243- f"Currently only download from provided URL in `url` or Github is implemented, got source: { source_ } ."
244- )
299+ if name_ is None :
300+ raise ValueError (f"To download from source: { source_ } , `name` must be provided." )
301+ if version_ is None :
302+ version_ = _get_latest_bundle_version (source = source_ , name = name_ , repo = repo_ )
303+ if source_ == "github" :
304+ if version_ is not None :
305+ name_ = "_v" .join ([name_ , version_ ])
306+ _download_from_github (repo = repo_ , download_path = bundle_dir_ , filename = name_ , progress = progress_ )
307+ elif source_ == "ngc" :
308+ _download_from_ngc (
309+ download_path = bundle_dir_ ,
310+ filename = name_ ,
311+ version = version_ ,
312+ remove_prefix = remove_prefix_ ,
313+ progress = progress_ ,
314+ )
315+ else :
316+ raise NotImplementedError (
317+ f"Currently only download from `url`, source 'github' or 'ngc' are implemented, got source: { source_ } ."
318+ )
245319
246320
247321def load (
@@ -250,8 +324,8 @@ def load(
250324 model_file : Optional [str ] = None ,
251325 load_ts_module : bool = False ,
252326 bundle_dir : Optional [PathLike ] = None ,
253- source : str = "github" ,
254- repo : str = "Project-MONAI/model-zoo/hosting_storage_v1" ,
327+ source : str = download_source ,
328+ repo : Optional [ str ] = None ,
255329 progress : bool = True ,
256330 device : Optional [str ] = None ,
257331 key_in_ckpt : Optional [str ] = None ,
@@ -263,18 +337,25 @@ def load(
263337 Load model weights or TorchScript module of a bundle.
264338
265339 Args:
266- name: bundle name, for example: "spleen_ct_segmentation", "prostate_mri_anatomy" in the model-zoo:
340+ name: bundle name. If `None` and `url` is `None`, it must be provided in `args_file`.
341+ for example:
342+ "spleen_ct_segmentation", "prostate_mri_anatomy" in model-zoo:
267343 https://github.com/Project-MONAI/model-zoo/releases/tag/hosting_storage_v1.
268- version: version name of the target bundle to download, like: "0.1.0".
344+ "monai_brats_mri_segmentation" in ngc:
345+ https://catalog.ngc.nvidia.com/models?filters=&orderBy=scoreDESC&query=monai.
346+ version: version name of the target bundle to download, like: "0.1.0". If `None`, will download
347+ the latest version.
269348 model_file: the relative path of the model weights or TorchScript module within bundle.
270349 If `None`, "models/model.pt" or "models/model.ts" will be used.
271350 load_ts_module: a flag to specify if loading the TorchScript module.
272351 bundle_dir: directory the weights/TorchScript module will be loaded from.
273352 Default is `bundle` subfolder under `torch.hub.get_dir()`.
274353 source: storage location name. This argument is used when `model_file` is not existing locally and need to be
275- downloaded first. "github" is currently the only supported value.
276- repo: repo name. This argument is used when `model_file` is not existing locally and need to be
277- downloaded first. If `source` is "github", it should be in the form of "repo_owner/repo_name/release_tag".
354+ downloaded first.
355+ In default, the value is achieved from the environment variable BUNDLE_DOWNLOAD_SRC, and
356+ it should be "ngc" or "github".
357+ repo: repo name. This argument is used when `url` is `None` and `source` is "github".
358+ If used, it should be in the form of "repo_owner/repo_name/release_tag".
278359 progress: whether to display a progress bar when downloading.
279360 device: target device of returned weights or module, if `None`, prefer to "cuda" if existing.
280361 key_in_ckpt: for nested checkpoint like `{"model": XXX, "optimizer": XXX, ...}`, specify the key of model
@@ -421,7 +502,7 @@ def get_bundle_versions(
421502
422503 bundles_info = _get_all_bundles_info (repo = repo , tag = tag , auth_token = auth_token )
423504 if bundle_name not in bundles_info :
424- raise ValueError (f"bundle: { bundle_name } is not existing." )
505+ raise ValueError (f"bundle: { bundle_name } is not existing in repo: { repo } ." )
425506 bundle_info = bundles_info [bundle_name ]
426507 all_versions = sorted (bundle_info .keys ())
427508
0 commit comments