Skip to content

Commit 9784506

Browse files
Support MLFlow Handler for single process/multi task enviornment (#5728)
Signed-off-by: Sachidanand Alle <sachidanand.alle@gmail.com> Current MLFlow Handler fails when you invoke 2 train requests back to back with different URI. Or multiple train requests within the same process. This is mainly for using global array where it saves active experiment, active run and others share the same. This will cause conflicts between 2 invokes with 2 different URI. Fixes ------ - Use MLFlow Client to create experiment/runs instead of global functions. - Save the current run through the lifecycle of handler. If any handler has the same experiment name and same run name, the metrics all will be merged as part of the same run (e.g. train and validation handler). - If the run name is not provided (fall back on default) then last active run within the same experiment (sorted based on start time) is used for adding the metrics. The above two conditions will help create similar behavior compared to using `mlflow.active_run()` Verified --------- - Running single and multi gpu training on bundles - spleen_ct_segmentation_v0.1.0 - spleen_deepedit_annotation_v0.1.0 - swin_unetr_btcv_segmentation_v0.1.0 - Running Training workflows for both single and multi gpu in MONAI Label - Verified against running shared/single tracking URI (where all the experiments get saved) - Verified against individual eval/mlruns per bundle/workflow > I suggest, original owner of this handler to verify/test all the behaviors that were currently supported. Error Description --------- Error stack when you run two train workflows with in the same process (simply one after another). ``` [2022-12-13 21:08:11,095] [4047823] [MainThread] [ERROR] (uvicorn.error:369) - Exception in ASGI application Traceback (most recent call last): File "/localhome/sachi/.local/lib/python3.10/site-packages/uvicorn/protocols/http/h11_impl.py", line 366, in run_asgi result = await app(self.scope, self.receive, self.send) File "/localhome/sachi/.local/lib/python3.10/site-packages/uvicorn/middleware/proxy_headers.py", line 75, in __call__ return await self.app(scope, receive, send) File "/localhome/sachi/.local/lib/python3.10/site-packages/fastapi/applications.py", line 199, in __call__ await super().__call__(scope, receive, send) File "/localhome/sachi/.local/lib/python3.10/site-packages/starlette/applications.py", line 112, in __call__ await self.middleware_stack(scope, receive, send) File "/localhome/sachi/.local/lib/python3.10/site-packages/starlette/middleware/errors.py", line 181, in __call__ raise exc from None File "/localhome/sachi/.local/lib/python3.10/site-packages/starlette/middleware/errors.py", line 159, in __call__ await self.app(scope, receive, _send) File "/localhome/sachi/.local/lib/python3.10/site-packages/starlette/middleware/cors.py", line 78, in __call__ await self.app(scope, receive, send) File "/localhome/sachi/.local/lib/python3.10/site-packages/starlette/exceptions.py", line 82, in __call__ raise exc from None File "/localhome/sachi/.local/lib/python3.10/site-packages/starlette/exceptions.py", line 71, in __call__ await self.app(scope, receive, sender) File "/localhome/sachi/.local/lib/python3.10/site-packages/starlette/routing.py", line 580, in __call__ await route.handle(scope, receive, send) File "/localhome/sachi/.local/lib/python3.10/site-packages/starlette/routing.py", line 241, in handle await self.app(scope, receive, send) File "/localhome/sachi/.local/lib/python3.10/site-packages/starlette/routing.py", line 52, in app response = await func(request) File "/localhome/sachi/.local/lib/python3.10/site-packages/fastapi/routing.py", line 219, in app raw_response = await run_endpoint_function( File "/localhome/sachi/.local/lib/python3.10/site-packages/fastapi/routing.py", line 152, in run_endpoint_function return await dependant.call(**values) File "/localhome/sachi/Projects/monailabel/monailabel/endpoints/train.py", line 96, in api_run_model return run_model(model, params, run_sync, enqueue) File "/localhome/sachi/Projects/monailabel/monailabel/endpoints/train.py", line 55, in run_model res, detail = AsyncTask.run("train", request=request, params=params, force_sync=run_sync, enqueue=enqueue) File "/localhome/sachi/Projects/monailabel/monailabel/utils/async_tasks/task.py", line 43, in run return instance.train(request), None File "/localhome/sachi/Projects/monailabel/monailabel/interfaces/app.py", line 422, in train result = task(request, self.datastore()) File "/localhome/sachi/Projects/monailabel/monailabel/tasks/train/basic_train.py", line 458, in __call__ res = self.train(0, world_size, req, datalist) File "/localhome/sachi/Projects/monailabel/monailabel/tasks/train/basic_train.py", line 545, in train context.trainer.run() File "/localhome/sachi/Projects/MONAI/monai/engines/trainer.py", line 53, in run super().run() File "/localhome/sachi/Projects/MONAI/monai/engines/workflow.py", line 281, in run super().run(data=self.data_loader, max_epochs=self.state.max_epochs) File "/localhome/sachi/.local/lib/python3.10/site-packages/ignite/engine/engine.py", line 892, in run return self._internal_run() File "/localhome/sachi/.local/lib/python3.10/site-packages/ignite/engine/engine.py", line 935, in _internal_run return next(self._internal_run_generator) File "/localhome/sachi/.local/lib/python3.10/site-packages/ignite/engine/engine.py", line 993, in _internal_run_as_gen self._handle_exception(e) File "/localhome/sachi/.local/lib/python3.10/site-packages/ignite/engine/engine.py", line 636, in _handle_exception self._fire_event(Events.EXCEPTION_RAISED, e) File "/localhome/sachi/.local/lib/python3.10/site-packages/ignite/engine/engine.py", line 425, in _fire_event func(*first, *(event_args + others), **kwargs) File "/localhome/sachi/Projects/MONAI/monai/handlers/stats_handler.py", line 181, in exception_raised raise e File "/localhome/sachi/.local/lib/python3.10/site-packages/ignite/engine/engine.py", line 946, in _internal_run_as_gen self._fire_event(Events.STARTED) File "/localhome/sachi/.local/lib/python3.10/site-packages/ignite/engine/engine.py", line 425, in _fire_event func(*first, *(event_args + others), **kwargs) File "/localhome/sachi/Projects/MONAI/monai/handlers/mlflow_handler.py", line 183, in start self._delete_exist_param_in_dict(attrs) File "/localhome/sachi/Projects/MONAI/monai/handlers/mlflow_handler.py", line 141, in _delete_exist_param_in_dict log_data = self.client.get_run(cur_run.info.run_id).data File "/localhome/sachi/.local/lib/python3.10/site-packages/mlflow/tracking/client.py", line 150, in get_run return self._tracking_client.get_run(run_id) File "/localhome/sachi/.local/lib/python3.10/site-packages/mlflow/tracking/_tracking_service/client.py", line 72, in get_run return self.store.get_run(run_id) File "/localhome/sachi/.local/lib/python3.10/site-packages/mlflow/store/tracking/file_store.py", line 623, in get_run run_info = self._get_run_info(run_id) File "/localhome/sachi/.local/lib/python3.10/site-packages/mlflow/store/tracking/file_store.py", line 646, in _get_run_info raise MlflowException( mlflow.exceptions.MlflowException: Run '1765aea084a3417586d052d9d8240039' not found FAILED [ 72%] ``` ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Sachidanand Alle <sachidanand.alle@gmail.com>
1 parent b2359b7 commit 9784506

1 file changed

Lines changed: 59 additions & 19 deletions

File tree

monai/handlers/mlflow_handler.py

Lines changed: 59 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events")
2323
mlflow, _ = optional_import("mlflow")
24+
mlflow.entities, _ = optional_import("mlflow.entities")
2425

2526
if TYPE_CHECKING:
2627
from ignite.engine import Engine
@@ -52,7 +53,7 @@ class MLFlowHandler:
5253
Args:
5354
tracking_uri: connects to a tracking URI. can also set the `MLFLOW_TRACKING_URI` environment
5455
variable to have MLflow find a URI from there. in both cases, the URI can either be
55-
a HTTP/HTTPS URI for a remote server, a database connection string, or a local path
56+
an HTTP/HTTPS URI for a remote server, a database connection string, or a local path
5657
to log data to a directory. The URI defaults to path `mlruns`.
5758
for more details: https://mlflow.org/docs/latest/python_api/mlflow.html#mlflow.set_tracking_uri.
5859
iteration_log: whether to log data to MLFlow when iteration completed, default to `True`.
@@ -109,9 +110,6 @@ def __init__(
109110
optimizer_param_names: Union[str, Sequence[str]] = "lr",
110111
close_on_complete: bool = False,
111112
) -> None:
112-
if tracking_uri is not None:
113-
mlflow.set_tracking_uri(tracking_uri)
114-
115113
self.iteration_log = iteration_log
116114
self.epoch_log = epoch_log
117115
self.epoch_logger = epoch_logger
@@ -125,8 +123,10 @@ def __init__(
125123
self.experiment_param = experiment_param
126124
self.artifacts = ensure_tuple(artifacts)
127125
self.optimizer_param_names = ensure_tuple(optimizer_param_names)
128-
self.client = mlflow.MlflowClient()
126+
self.client = mlflow.MlflowClient(tracking_uri=tracking_uri if tracking_uri else None)
129127
self.close_on_complete = close_on_complete
128+
self.experiment = None
129+
self.cur_run = None
130130

131131
def _delete_exist_param_in_dict(self, param_dict: Dict) -> None:
132132
"""
@@ -135,9 +135,11 @@ def _delete_exist_param_in_dict(self, param_dict: Dict) -> None:
135135
Args:
136136
param_dict: parameter dict to be logged to mlflow.
137137
"""
138+
if self.cur_run is None:
139+
return
140+
138141
key_list = list(param_dict.keys())
139-
cur_run = mlflow.active_run()
140-
log_data = self.client.get_run(cur_run.info.run_id).data
142+
log_data = self.client.get_run(self.cur_run.info.run_id).data
141143
log_param_dict = log_data.params
142144
for key in key_list:
143145
if key in log_param_dict:
@@ -167,17 +169,52 @@ def start(self, engine: Engine) -> None:
167169
Check MLFlow status and start if not active.
168170
169171
"""
170-
mlflow.set_experiment(self.experiment_name)
171-
if mlflow.active_run() is None:
172+
self._set_experiment()
173+
if not self.experiment:
174+
raise ValueError(f"Failed to set experiment '{self.experiment_name}' as the active experiment")
175+
176+
if not self.cur_run:
172177
run_name = f"run_{time.strftime('%Y%m%d_%H%M%S')}" if self.run_name is None else self.run_name
173-
mlflow.start_run(run_name=run_name)
178+
runs = self.client.search_runs(self.experiment.experiment_id)
179+
runs = [r for r in runs if r.info.run_name == run_name or not self.run_name]
180+
if runs:
181+
self.cur_run = self.client.get_run(runs[-1].info.run_id) # pick latest active run
182+
else:
183+
self.cur_run = self.client.create_run(experiment_id=self.experiment.experiment_id, run_name=run_name)
174184

175185
if self.experiment_param:
176-
mlflow.log_params(self.experiment_param)
186+
self._log_params(self.experiment_param)
177187

178188
attrs = {attr: getattr(engine.state, attr, None) for attr in self.default_tracking_params}
179189
self._delete_exist_param_in_dict(attrs)
180-
mlflow.log_params(attrs)
190+
self._log_params(attrs)
191+
192+
def _set_experiment(self):
193+
experiment = self.experiment
194+
if not experiment:
195+
experiment = self.client.get_experiment_by_name(self.experiment_name)
196+
if not experiment:
197+
experiment_id = self.client.create_experiment(self.experiment_name)
198+
experiment = self.client.get_experiment(experiment_id)
199+
200+
if experiment.lifecycle_stage != mlflow.entities.LifecycleStage.ACTIVE:
201+
raise ValueError(f"Cannot set a deleted experiment '{self.experiment_name}' as the active experiment")
202+
self.experiment = experiment
203+
204+
def _log_params(self, params: Dict[str, Any]) -> None:
205+
if not self.cur_run:
206+
raise ValueError("Current Run is not Active to log params")
207+
params_arr = [mlflow.entities.Param(key, str(value)) for key, value in params.items()]
208+
self.client.log_batch(run_id=self.cur_run.info.run_id, metrics=[], params=params_arr, tags=[])
209+
210+
def _log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None:
211+
if not self.cur_run:
212+
raise ValueError("Current Run is not Active to log metrics")
213+
214+
run_id = self.cur_run.info.run_id
215+
timestamp = int(time.time() * 1000)
216+
metrics_arr = [mlflow.entities.Metric(key, value, timestamp, step or 0) for key, value in metrics.items()]
217+
self.client.log_batch(run_id=run_id, metrics=metrics_arr, params=[], tags=[])
181218

182219
def _parse_artifacts(self):
183220
"""
@@ -202,17 +239,20 @@ def complete(self) -> None:
202239
"""
203240
Handler for train or validation/evaluation completed Event.
204241
"""
205-
if self.artifacts:
242+
if self.artifacts and self.cur_run:
206243
artifact_list = self._parse_artifacts()
207244
for artifact in artifact_list:
208-
mlflow.log_artifact(artifact)
245+
self.client.log_artifact(self.cur_run.info.run_id, artifact)
209246

210247
def close(self) -> None:
211248
"""
212249
Stop current running logger of MLFlow.
213250
214251
"""
215-
mlflow.end_run()
252+
if self.cur_run:
253+
status = mlflow.entities.RunStatus.to_string(mlflow.entities.RunStatus.FINISHED)
254+
self.client.set_terminated(self.cur_run.info.run_id, status)
255+
self.cur_run = None
216256

217257
def epoch_completed(self, engine: Engine) -> None:
218258
"""
@@ -257,11 +297,11 @@ def _default_epoch_log(self, engine: Engine) -> None:
257297
return
258298

259299
current_epoch = self.global_epoch_transform(engine.state.epoch)
260-
mlflow.log_metrics(log_dict, step=current_epoch)
300+
self._log_metrics(log_dict, step=current_epoch)
261301

262302
if self.state_attributes is not None:
263303
attrs = {attr: getattr(engine.state, attr, None) for attr in self.state_attributes}
264-
mlflow.log_metrics(attrs, step=current_epoch)
304+
self._log_metrics(attrs, step=current_epoch)
265305

266306
def _default_iteration_log(self, engine: Engine) -> None:
267307
"""
@@ -281,7 +321,7 @@ def _default_iteration_log(self, engine: Engine) -> None:
281321
if not isinstance(loss, dict):
282322
loss = {self.tag_name: loss.item() if isinstance(loss, torch.Tensor) else loss}
283323

284-
mlflow.log_metrics(loss, step=engine.state.iteration)
324+
self._log_metrics(loss, step=engine.state.iteration)
285325

286326
# If there is optimizer attr in engine, then record parameters specified in init function.
287327
if hasattr(engine, "optimizer"):
@@ -291,4 +331,4 @@ def _default_iteration_log(self, engine: Engine) -> None:
291331
f"{param_name} group_{i}": float(param_group[param_name])
292332
for i, param_group in enumerate(cur_optimizer.param_groups)
293333
}
294-
mlflow.log_metrics(params, step=engine.state.iteration)
334+
self._log_metrics(params, step=engine.state.iteration)

0 commit comments

Comments
 (0)