Skip to content

Commit ad36cb1

Browse files
committed
make predictions with the CNN
1 parent 2d2f292 commit ad36cb1

1 file changed

Lines changed: 11 additions & 8 deletions

File tree

lab2/solutions/PT_Part1_MNIST_Solution.ipynb

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -779,7 +779,7 @@
779779
"source": [
780780
"### Make predictions with the CNN model\n",
781781
"\n",
782-
"With the model trained, we can use it to make predictions about some images. The [`predict`](https://www.tensorflow.org/api_docs/python/tf/keras/models/Sequential#predict) function call generates the output predictions given a set of input samples.\n"
782+
"With the model trained, we can use it to make predictions about some images."
783783
]
784784
},
785785
{
@@ -790,7 +790,10 @@
790790
},
791791
"outputs": [],
792792
"source": [
793-
"predictions = cnn_model.predict(test_images)"
793+
"test_image, test_label = test_dataset[0]\n",
794+
"test_image = test_image.unsqueeze(0)\n",
795+
"cnn_model.eval()\n",
796+
"predictions_test_image = cnn_model(test_image)"
794797
]
795798
},
796799
{
@@ -799,7 +802,7 @@
799802
"id": "x9Kk1voUCaXJ"
800803
},
801804
"source": [
802-
"With this function call, the model has predicted the label for each image in the testing set. Let's take a look at the prediction for the first image in the test dataset:"
805+
"With this function call, the model has predicted the label of the first image in the testing set. Let's take a look at the prediction:"
803806
]
804807
},
805808
{
@@ -810,7 +813,7 @@
810813
},
811814
"outputs": [],
812815
"source": [
813-
"predictions[0]"
816+
"predictions_test_image"
814817
]
815818
},
816819
{
@@ -834,9 +837,9 @@
834837
"source": [
835838
"'''TODO: identify the digit with the highest confidence prediction for the first\n",
836839
" image in the test dataset. '''\n",
837-
"prediction = np.argmax(predictions[0])\n",
840+
"predictions_value = predictions_test_image.detach().numpy()\n",
841+
"prediction = np.argmax(predictions_value)\n",
838842
"# prediction = # TODO\n",
839-
"\n",
840843
"print(prediction)"
841844
]
842845
},
@@ -857,8 +860,8 @@
857860
},
858861
"outputs": [],
859862
"source": [
860-
"print(\"Label of this digit is:\", test_labels[0])\n",
861-
"plt.imshow(test_images[0,:,:,0], cmap=plt.cm.binary)\n",
863+
"print(\"Label of this digit is:\", test_label)\n",
864+
"plt.imshow(test_image[0,0,:,:].cpu(), cmap=plt.cm.binary)\n",
862865
"comet_model_2.log_figure(figure=plt)"
863866
]
864867
},

0 commit comments

Comments
 (0)