Skip to content

Commit 469b146

Browse files
Remove hard-coded sigmoid for voting ensemble method in Auto3DSeg (#5734)
Signed-off-by: Mingxin Zheng <18563433+mingxin-zheng@users.noreply.github.com> Fixes item 3 and 5 in #5564. ### Description The `vote` method in `ensemble_pred` currently does not work for under sigmoid mode, because the function overrides the argument to False before the `VoteEnsemble`. Also, if the user only trains a small number of algorithm (1 fold for 1 algo) and forgets the update the `n_best` (default is 5) in `AlgoEnsembleBestN` , instead of throwing an error, the fix will automatically use all available algos after posting a warning. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. Signed-off-by: Mingxin Zheng <18563433+mingxin-zheng@users.noreply.github.com>
1 parent a09d2a2 commit 469b146

1 file changed

Lines changed: 7 additions & 3 deletions

File tree

monai/apps/auto3dseg/ensemble_builder.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,11 @@ def ensemble_pred(self, preds, sigmoid=False):
107107
prob = MeanEnsemble()(preds)
108108
return prob2class(prob, dim=0, keepdim=True, sigmoid=sigmoid)
109109
elif self.mode == "vote":
110-
classes = [prob2class(p, dim=0, keepdim=True, sigmoid=False) for p in preds]
111-
return VoteEnsemble(num_classes=preds[0].shape[0])(classes)
110+
classes = [prob2class(p, dim=0, keepdim=True, sigmoid=sigmoid) for p in preds]
111+
if sigmoid:
112+
return VoteEnsemble()(classes) # do not specify num_classes for one-hot encoding
113+
else:
114+
return VoteEnsemble(num_classes=preds[0].shape[0])(classes)
112115

113116
def __call__(self, pred_param: Optional[Dict[str, Any]] = None):
114117
"""
@@ -194,7 +197,8 @@ def collect_algos(self, n_best: int = -1):
194197

195198
ranks = self.sort_score()
196199
if len(ranks) < n_best:
197-
raise ValueError("Number of available algos is less than user-defined N")
200+
warn(f"Found {len(ranks)} available algos (pre-defined n_best={n_best}). All {len(ranks)} will be used.")
201+
n_best = len(ranks)
198202

199203
# get the indices that the rank is larger than N
200204
indices = [i for (i, r) in enumerate(ranks) if r >= n_best]

0 commit comments

Comments
 (0)