Skip to content

Commit 25130db

Browse files
authored
replace view with reshape for robustness (#5690)
#5684 (comment) Signed-off-by: Joeycho <joeyadamcho@gmail.com> Fixes #5684 ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Joeycho <joeyadamcho@gmail.com>
1 parent 78d4f42 commit 25130db

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

monai/metrics/confusion_matrix.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,8 @@ def get_confusion_matrix(y_pred: torch.Tensor, y: torch.Tensor, include_backgrou
163163
batch_size, n_class = y_pred.shape[:2]
164164
# convert to [BNS], where S is the number of pixels for one sample.
165165
# As for classification tasks, S equals to 1.
166-
y_pred = y_pred.view(batch_size, n_class, -1)
167-
y = y.view(batch_size, n_class, -1)
166+
y_pred = y_pred.reshape(batch_size, n_class, -1)
167+
y = y.reshape(batch_size, n_class, -1)
168168
tp = ((y_pred + y) == 2).float()
169169
tn = ((y_pred + y) == 0).float()
170170

0 commit comments

Comments
 (0)