Skip to content

Commit a5fbe71

Browse files
Han123supre-commit-ci[bot]ericspodKumoLiu
authored
Refactor Export for Model Conversion and Saving (#7934)
Fixes #6375 . ### Description Changes to be made based on the [previous discussion #7835](#7835). Modify the `_export` function to call the `saver` parameter for saving different models. Rewrite the `onnx_export` function using the updated `_export` to achieve consistency in model format conversion and saving. * Rewrite `onnx_export` to call `_export` with `convert_to_onnx` and appropriate `kwargs`. * Add a `saver: Callable` parameter to `_export`, replacing `save_net_with_metadata`. * Pass `save_net_with_metadata` function wrapped with `partial` to set parameters like `include_config_vals` and `append_timestamp`. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Han123su <popsmall212@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
1 parent de2a819 commit a5fbe71

1 file changed

Lines changed: 29 additions & 17 deletions

File tree

monai/bundle/scripts.py

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import warnings
1919
import zipfile
2020
from collections.abc import Mapping, Sequence
21+
from functools import partial
2122
from pathlib import Path
2223
from pydoc import locate
2324
from shutil import copyfile
@@ -1254,6 +1255,7 @@ def verify_net_in_out(
12541255

12551256
def _export(
12561257
converter: Callable,
1258+
saver: Callable,
12571259
parser: ConfigParser,
12581260
net_id: str,
12591261
filepath: str,
@@ -1268,6 +1270,8 @@ def _export(
12681270
Args:
12691271
converter: a callable object that takes a torch.nn.module and kwargs as input and
12701272
converts the module to another type.
1273+
saver: a callable object that accepts the converted model to save, a filepath to save to, meta values
1274+
(extracted from the parser), and a dictionary of extra JSON files (name -> contents) as input.
12711275
parser: a ConfigParser of the bundle to be converted.
12721276
net_id: ID name of the network component in the parser, it must be `torch.nn.Module`.
12731277
filepath: filepath to export, if filename has no extension, it becomes `.ts`.
@@ -1307,14 +1311,9 @@ def _export(
13071311
# add .json extension to all extra files which are always encoded as JSON
13081312
extra_files = {k + ".json": v for k, v in extra_files.items()}
13091313

1310-
save_net_with_metadata(
1311-
jit_obj=net,
1312-
filename_prefix_or_stream=filepath,
1313-
include_config_vals=False,
1314-
append_timestamp=False,
1315-
meta_values=parser.get().pop("_meta_", None),
1316-
more_extra_files=extra_files,
1317-
)
1314+
meta_values = parser.get().pop("_meta_", None)
1315+
saver(net, filepath, meta_values=meta_values, more_extra_files=extra_files)
1316+
13181317
logger.info(f"exported to file: {filepath}.")
13191318

13201319

@@ -1413,17 +1412,23 @@ def onnx_export(
14131412
input_shape_ = _get_fake_input_shape(parser=parser)
14141413

14151414
inputs_ = [torch.rand(input_shape_)]
1416-
net = parser.get_parsed_content(net_id_)
1417-
if has_ignite:
1418-
# here we use ignite Checkpoint to support nested weights and be compatible with MONAI CheckpointSaver
1419-
Checkpoint.load_objects(to_load={key_in_ckpt_: net}, checkpoint=ckpt_file_)
1420-
else:
1421-
ckpt = torch.load(ckpt_file_)
1422-
copy_model_state(dst=net, src=ckpt if key_in_ckpt_ == "" else ckpt[key_in_ckpt_])
14231415

14241416
converter_kwargs_.update({"inputs": inputs_, "use_trace": use_trace_})
1425-
onnx_model = convert_to_onnx(model=net, **converter_kwargs_)
1426-
onnx.save(onnx_model, filepath_)
1417+
1418+
def save_onnx(onnx_obj: Any, filename_prefix_or_stream: str, **kwargs: Any) -> None:
1419+
onnx.save(onnx_obj, filename_prefix_or_stream)
1420+
1421+
_export(
1422+
convert_to_onnx,
1423+
save_onnx,
1424+
parser,
1425+
net_id=net_id_,
1426+
filepath=filepath_,
1427+
ckpt_file=ckpt_file_,
1428+
config_file=config_file_,
1429+
key_in_ckpt=key_in_ckpt_,
1430+
**converter_kwargs_,
1431+
)
14271432

14281433

14291434
def ckpt_export(
@@ -1544,8 +1549,12 @@ def ckpt_export(
15441549

15451550
converter_kwargs_.update({"inputs": inputs_, "use_trace": use_trace_})
15461551
# Use the given converter to convert a model and save with metadata, config content
1552+
1553+
save_ts = partial(save_net_with_metadata, include_config_vals=False, append_timestamp=False)
1554+
15471555
_export(
15481556
convert_to_torchscript,
1557+
save_ts,
15491558
parser,
15501559
net_id=net_id_,
15511560
filepath=filepath_,
@@ -1715,8 +1724,11 @@ def trt_export(
17151724
}
17161725
converter_kwargs_.update(trt_api_parameters)
17171726

1727+
save_ts = partial(save_net_with_metadata, include_config_vals=False, append_timestamp=False)
1728+
17181729
_export(
17191730
convert_to_trt,
1731+
save_ts,
17201732
parser,
17211733
net_id=net_id_,
17221734
filepath=filepath_,

0 commit comments

Comments
 (0)