|
861 | 861 | "outputs": [], |
862 | 862 | "source": [ |
863 | 863 | "print(\"Label of this digit is:\", test_label)\n", |
864 | | - "plt.imshow(test_image[0,0,:,:].cpu(), cmap=plt.cm.binary)\n", |
| 864 | + "plt.imshow(test_image[0,0,:,:], cmap=plt.cm.binary)\n", |
865 | 865 | "comet_model_2.log_figure(figure=plt)" |
866 | 866 | ] |
867 | 867 | }, |
|
871 | 871 | "id": "ygh2yYC972ne" |
872 | 872 | }, |
873 | 873 | "source": [ |
874 | | - "It is! Let's visualize the classification results on the MNIST dataset. We will plot images from the test dataset along with their predicted label, as well as a histogram that provides the prediction probabilities for each of the digits:" |
| 874 | + "It is! Let's visualize the classification results on the MNIST dataset. We will plot images from the test dataset along with their predicted label, as well as a histogram that provides the prediction probabilities for each of the digits.\n", |
| 875 | + "\n", |
| 876 | + "Recall that in PyTorch the MNIST dataset is typically accessed using a DataLoader to iterate through the test set in smaller, manageable batches. By appending the predictions, test labels, and test images from each batch, we will first gradually accumulate all the data needed for visualization into singular variables to observe our model's predictions." |
| 877 | + ] |
| 878 | + }, |
| 879 | + { |
| 880 | + "cell_type": "code", |
| 881 | + "execution_count": null, |
| 882 | + "metadata": {}, |
| 883 | + "outputs": [], |
| 884 | + "source": [ |
| 885 | + "# Initialize variables to store all data\n", |
| 886 | + "all_predictions = []\n", |
| 887 | + "all_labels = []\n", |
| 888 | + "all_images = []\n", |
| 889 | + "\n", |
| 890 | + "# Process test set in batches\n", |
| 891 | + "with torch.no_grad():\n", |
| 892 | + " for images, labels in testset_loader:\n", |
| 893 | + " outputs = cnn_model(images)\n", |
| 894 | + "\n", |
| 895 | + " # Apply softmax to get probabilities\n", |
| 896 | + " probabilities = torch.nn.functional.softmax(outputs, dim=1)\n", |
| 897 | + "\n", |
| 898 | + " # Get predicted classes\n", |
| 899 | + " _, predicted = torch.max(outputs, 1)\n", |
| 900 | + "\n", |
| 901 | + " all_predictions.append(probabilities) \n", |
| 902 | + " all_labels.append(labels)\n", |
| 903 | + " all_images.append(images)\n", |
| 904 | + "\n", |
| 905 | + "all_predictions = torch.cat(all_predictions) # Shape: (total_samples, num_classes)\n", |
| 906 | + "all_labels = torch.cat(all_labels) # Shape: (total_samples,)\n", |
| 907 | + "all_images = torch.cat(all_images) # Shape: (total_samples, 1, 28, 28)\n", |
| 908 | + "\n", |
| 909 | + "# Convert tensors to NumPy for compatibility with plotting functions\n", |
| 910 | + "predictions = all_predictions.numpy() # Shape: (total_samples, num_classes)\n", |
| 911 | + "test_labels = all_labels.numpy() # Shape: (total_samples,)\n", |
| 912 | + "test_images = all_images.numpy() # Shape: (total_samples, 1, 28, 28)" |
875 | 913 | ] |
876 | 914 | }, |
877 | 915 | { |
|
0 commit comments