Skip to content

Commit 7c7f55c

Browse files
committed
use torch.argmax instead of torch.max to remove tuple unpacking confusion
1 parent 93cbd57 commit 7c7f55c

1 file changed

Lines changed: 4 additions & 4 deletions

File tree

lab2/solutions/PT_Part1_MNIST_Solution.ipynb

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,7 @@
428428
" total_loss += loss.item()*images.size(0)\n",
429429
"\n",
430430
" # Calculate accuracy\n",
431-
" _, predicted = torch.max(outputs, 1) # Get predicted class\n",
431+
" predicted = torch.argmax(outputs, dim=1) # Get predicted class\n",
432432
" correct_pred += (predicted == labels).sum().item() # Count correct predictions\n",
433433
" total_pred += labels.size(0) # Count total predictions\n",
434434
"\n",
@@ -509,7 +509,7 @@
509509
" # test_loss += TODO\n",
510510
"\n",
511511
" # Calculate accuracy\n",
512-
" _, predicted = torch.max(outputs, 1)\n",
512+
" predicted = torch.argmax(outputs, dim=1)\n",
513513
" #TODO: identify the digit with the highest confidence prediction for the first image in the test dataset.\n",
514514
"\n",
515515
" correct_pred += (predicted == labels).sum().item()\n",
@@ -736,7 +736,7 @@
736736
" optimizer.step()\n",
737737
"\n",
738738
" # Calculate accuracy\n",
739-
" _, predicted = torch.max(logits, 1) \n",
739+
" predicted = torch.argmax(logits, dim=1)\n",
740740
" correct_pred += (predicted == labels).sum().item() \n",
741741
" total_pred += labels.size(0)\n",
742742
"\n",
@@ -907,7 +907,7 @@
907907
" probabilities = torch.nn.functional.softmax(outputs, dim=1)\n",
908908
"\n",
909909
" # Get predicted classes\n",
910-
" _, predicted = torch.max(outputs, 1)\n",
910+
" predicted = torch.argmax(outputs, dim=1)\n",
911911
"\n",
912912
" all_predictions.append(probabilities) \n",
913913
" all_labels.append(labels)\n",

0 commit comments

Comments
 (0)