|
779 | 779 | "source": [ |
780 | 780 | "### Make predictions with the CNN model\n", |
781 | 781 | "\n", |
782 | | - "With the model trained, we can use it to make predictions about some images. The [`predict`](https://www.tensorflow.org/api_docs/python/tf/keras/models/Sequential#predict) function call generates the output predictions given a set of input samples.\n" |
| 782 | + "With the model trained, we can use it to make predictions about some images." |
783 | 783 | ] |
784 | 784 | }, |
785 | 785 | { |
|
790 | 790 | }, |
791 | 791 | "outputs": [], |
792 | 792 | "source": [ |
793 | | - "predictions = cnn_model.predict(test_images)" |
| 793 | + "test_image, test_label = test_dataset[0]\n", |
| 794 | + "test_image = test_image.unsqueeze(0)\n", |
| 795 | + "cnn_model.eval()\n", |
| 796 | + "predictions_test_image = cnn_model(test_image)" |
794 | 797 | ] |
795 | 798 | }, |
796 | 799 | { |
|
799 | 802 | "id": "x9Kk1voUCaXJ" |
800 | 803 | }, |
801 | 804 | "source": [ |
802 | | - "With this function call, the model has predicted the label for each image in the testing set. Let's take a look at the prediction for the first image in the test dataset:" |
| 805 | + "With this function call, the model has predicted the label of the first image in the testing set. Let's take a look at the prediction:" |
803 | 806 | ] |
804 | 807 | }, |
805 | 808 | { |
|
810 | 813 | }, |
811 | 814 | "outputs": [], |
812 | 815 | "source": [ |
813 | | - "predictions[0]" |
| 816 | + "predictions_test_image" |
814 | 817 | ] |
815 | 818 | }, |
816 | 819 | { |
|
834 | 837 | "source": [ |
835 | 838 | "'''TODO: identify the digit with the highest confidence prediction for the first\n", |
836 | 839 | " image in the test dataset. '''\n", |
837 | | - "prediction = np.argmax(predictions[0])\n", |
| 840 | + "predictions_value = predictions_test_image.detach().numpy()\n", |
| 841 | + "prediction = np.argmax(predictions_value)\n", |
838 | 842 | "# prediction = # TODO\n", |
839 | | - "\n", |
840 | 843 | "print(prediction)" |
841 | 844 | ] |
842 | 845 | }, |
|
857 | 860 | }, |
858 | 861 | "outputs": [], |
859 | 862 | "source": [ |
860 | | - "print(\"Label of this digit is:\", test_labels[0])\n", |
861 | | - "plt.imshow(test_images[0,:,:,0], cmap=plt.cm.binary)\n", |
| 863 | + "print(\"Label of this digit is:\", test_label)\n", |
| 864 | + "plt.imshow(test_image[0,0,:,:].cpu(), cmap=plt.cm.binary)\n", |
862 | 865 | "comet_model_2.log_figure(figure=plt)" |
863 | 866 | ] |
864 | 867 | }, |
|
0 commit comments