|
428 | 428 | " total_loss += loss.item()*images.size(0)\n", |
429 | 429 | "\n", |
430 | 430 | " # Calculate accuracy\n", |
431 | | - " _, predicted = torch.max(outputs, 1) # Get predicted class\n", |
| 431 | + " predicted = torch.argmax(outputs, dim=1) # Get predicted class\n", |
432 | 432 | " correct_pred += (predicted == labels).sum().item() # Count correct predictions\n", |
433 | 433 | " total_pred += labels.size(0) # Count total predictions\n", |
434 | 434 | "\n", |
|
509 | 509 | " # test_loss += TODO\n", |
510 | 510 | "\n", |
511 | 511 | " # Calculate accuracy\n", |
512 | | - " _, predicted = torch.max(outputs, 1)\n", |
| 512 | + " predicted = torch.argmax(outputs, dim=1)\n", |
513 | 513 | " #TODO: identify the digit with the highest confidence prediction for the first image in the test dataset.\n", |
514 | 514 | "\n", |
515 | 515 | " correct_pred += (predicted == labels).sum().item()\n", |
|
736 | 736 | " optimizer.step()\n", |
737 | 737 | "\n", |
738 | 738 | " # Calculate accuracy\n", |
739 | | - " _, predicted = torch.max(logits, 1) \n", |
| 739 | + " predicted = torch.argmax(logits, dim=1)\n", |
740 | 740 | " correct_pred += (predicted == labels).sum().item() \n", |
741 | 741 | " total_pred += labels.size(0)\n", |
742 | 742 | "\n", |
|
907 | 907 | " probabilities = torch.nn.functional.softmax(outputs, dim=1)\n", |
908 | 908 | "\n", |
909 | 909 | " # Get predicted classes\n", |
910 | | - " _, predicted = torch.max(outputs, 1)\n", |
| 910 | + " predicted = torch.argmax(outputs, dim=1)\n", |
911 | 911 | "\n", |
912 | 912 | " all_predictions.append(probabilities) \n", |
913 | 913 | " all_labels.append(labels)\n", |
|
0 commit comments