|
107 | 107 | "# Check that we are using a GPU, if not switch runtimes\n", |
108 | 108 | "# using Runtime > Change Runtime Type > GPU\n", |
109 | 109 | "assert torch.cuda.is_available(), \"Please enable GPU from runtime settings\"\n", |
110 | | - "assert COMET_API_KEY != \"\", \"Please insert your Comet API Key\"" |
| 110 | + "assert COMET_API_KEY != \"\", \"Please insert your Comet API Key\"\n", |
| 111 | + "\n", |
| 112 | + "# Set GPU for computation\n", |
| 113 | + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" |
111 | 114 | ] |
112 | 115 | }, |
113 | 116 | { |
|
329 | 332 | "\n", |
330 | 333 | " return x\n", |
331 | 334 | "\n", |
332 | | - "fc_model = FullyConnectedModel()" |
| 335 | + "fc_model = FullyConnectedModel().to(device)" |
333 | 336 | ] |
334 | 337 | }, |
335 | 338 | { |
|
410 | 413 | " total_pred = 0\n", |
411 | 414 | "\n", |
412 | 415 | " for images, labels in trainset_loader:\n", |
| 416 | + " # Move tensors to GPU so compatible with model\n", |
| 417 | + " images, labels = images.to(device), labels.to(device)\n", |
413 | 418 | " # Clear gradients before performing backward pass\n", |
414 | 419 | " optimizer.zero_grad()\n", |
415 | 420 | " # Forward pass\n", |
|
489 | 494 | " # Disable gradient calculations when in inference mode\n", |
490 | 495 | " with torch.no_grad():\n", |
491 | 496 | " for images, labels in testset_loader:\n", |
| 497 | + " # TODO: ensure evalaution happens on the GPU\n", |
| 498 | + " images, labels = images.to(device), labels.to(device)\n", |
| 499 | + " # images, labels = TODO\n", |
| 500 | + " \n", |
492 | 501 | " #TODO: feed the images into the model and obtain the predictions (forward pass)\n", |
493 | 502 | " outputs = model(images)\n", |
494 | 503 | " # outputs = TODO\n", |
|
632 | 641 | " return x\n", |
633 | 642 | "\n", |
634 | 643 | "# Instantiate the model\n", |
635 | | - "cnn_model = CNN()\n", |
| 644 | + "cnn_model = CNN().to(device)\n", |
636 | 645 | "# Initialize the model by passing some data through\n", |
637 | 646 | "image, label = train_dataset[0]\n", |
638 | | - "image = image.unsqueeze(0) # Add batch dimension → Shape: (1, 1, 28, 28)\n", |
| 647 | + "image = image.to(device).unsqueeze(0) # Add batch dimension → Shape: (1, 1, 28, 28)\n", |
639 | 648 | "output = cnn_model(image)\n", |
640 | 649 | "# Print the model summary\n", |
641 | 650 | "print(cnn_model)" |
|
665 | 674 | "outputs": [], |
666 | 675 | "source": [ |
667 | 676 | "# Rebuild the CNN model\n", |
668 | | - "cnn_model = CNN()\n", |
| 677 | + "cnn_model = CNN().to(device)\n", |
669 | 678 | "\n", |
670 | 679 | "# Define hyperparams\n", |
671 | 680 | "batch_size = 64\n", |
|
703 | 712 | "\n", |
704 | 713 | " # First grab a batch of training data which our data loader returns as a tensor\n", |
705 | 714 | " for idx, (images, labels) in enumerate(tqdm(trainset_loader)):\n", |
| 715 | + " images, labels = images.to(device), labels.to(device)\n", |
| 716 | + " \n", |
706 | 717 | " # Forward pass\n", |
707 | 718 | " #'''TODO: feed the images into the model and obtain the predictions'''\n", |
708 | 719 | " logits = cnn_model(images)\n", |
|
791 | 802 | "outputs": [], |
792 | 803 | "source": [ |
793 | 804 | "test_image, test_label = test_dataset[0]\n", |
794 | | - "test_image = test_image.unsqueeze(0)\n", |
| 805 | + "test_image = test_image.to(device).unsqueeze(0)\n", |
795 | 806 | "cnn_model.eval()\n", |
796 | 807 | "predictions_test_image = cnn_model(test_image)" |
797 | 808 | ] |
|
837 | 848 | "source": [ |
838 | 849 | "'''TODO: identify the digit with the highest confidence prediction for the first\n", |
839 | 850 | " image in the test dataset. '''\n", |
840 | | - "predictions_value = predictions_test_image.detach().numpy()\n", |
| 851 | + "predictions_value = predictions_test_image.cpu().detach().numpy() #.cpu() to copy tensor to memory first\n", |
841 | 852 | "prediction = np.argmax(predictions_value)\n", |
842 | 853 | "# prediction = # TODO\n", |
843 | 854 | "print(prediction)" |
|
861 | 872 | "outputs": [], |
862 | 873 | "source": [ |
863 | 874 | "print(\"Label of this digit is:\", test_label)\n", |
864 | | - "plt.imshow(test_image[0,0,:,:], cmap=plt.cm.binary)\n", |
| 875 | + "plt.imshow(test_image[0,0,:,:].cpu(), cmap=plt.cm.binary)\n", |
865 | 876 | "comet_model_2.log_figure(figure=plt)" |
866 | 877 | ] |
867 | 878 | }, |
|
907 | 918 | "all_images = torch.cat(all_images) # Shape: (total_samples, 1, 28, 28)\n", |
908 | 919 | "\n", |
909 | 920 | "# Convert tensors to NumPy for compatibility with plotting functions\n", |
910 | | - "predictions = all_predictions.numpy() # Shape: (total_samples, num_classes)\n", |
911 | | - "test_labels = all_labels.numpy() # Shape: (total_samples,)\n", |
912 | | - "test_images = all_images.numpy() # Shape: (total_samples, 1, 28, 28)" |
| 921 | + "predictions = all_predictions.cpu().numpy() # Shape: (total_samples, num_classes)\n", |
| 922 | + "test_labels = all_labels.cpu().numpy() # Shape: (total_samples,)\n", |
| 923 | + "test_images = all_images.cpu().numpy() # Shape: (total_samples, 1, 28, 28)" |
913 | 924 | ] |
914 | 925 | }, |
915 | 926 | { |
|
0 commit comments