Skip to content

Commit 93cbd57

Browse files
committed
move the model and GPU to tensors and move back to cpu as needed for eval
1 parent c9aafd5 commit 93cbd57

1 file changed

Lines changed: 22 additions & 11 deletions

File tree

lab2/solutions/PT_Part1_MNIST_Solution.ipynb

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,10 @@
107107
"# Check that we are using a GPU, if not switch runtimes\n",
108108
"# using Runtime > Change Runtime Type > GPU\n",
109109
"assert torch.cuda.is_available(), \"Please enable GPU from runtime settings\"\n",
110-
"assert COMET_API_KEY != \"\", \"Please insert your Comet API Key\""
110+
"assert COMET_API_KEY != \"\", \"Please insert your Comet API Key\"\n",
111+
"\n",
112+
"# Set GPU for computation\n",
113+
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
111114
]
112115
},
113116
{
@@ -329,7 +332,7 @@
329332
"\n",
330333
" return x\n",
331334
"\n",
332-
"fc_model = FullyConnectedModel()"
335+
"fc_model = FullyConnectedModel().to(device)"
333336
]
334337
},
335338
{
@@ -410,6 +413,8 @@
410413
" total_pred = 0\n",
411414
"\n",
412415
" for images, labels in trainset_loader:\n",
416+
" # Move tensors to GPU so compatible with model\n",
417+
" images, labels = images.to(device), labels.to(device)\n",
413418
" # Clear gradients before performing backward pass\n",
414419
" optimizer.zero_grad()\n",
415420
" # Forward pass\n",
@@ -489,6 +494,10 @@
489494
" # Disable gradient calculations when in inference mode\n",
490495
" with torch.no_grad():\n",
491496
" for images, labels in testset_loader:\n",
497+
" # TODO: ensure evalaution happens on the GPU\n",
498+
" images, labels = images.to(device), labels.to(device)\n",
499+
" # images, labels = TODO\n",
500+
" \n",
492501
" #TODO: feed the images into the model and obtain the predictions (forward pass)\n",
493502
" outputs = model(images)\n",
494503
" # outputs = TODO\n",
@@ -632,10 +641,10 @@
632641
" return x\n",
633642
"\n",
634643
"# Instantiate the model\n",
635-
"cnn_model = CNN()\n",
644+
"cnn_model = CNN().to(device)\n",
636645
"# Initialize the model by passing some data through\n",
637646
"image, label = train_dataset[0]\n",
638-
"image = image.unsqueeze(0) # Add batch dimension → Shape: (1, 1, 28, 28)\n",
647+
"image = image.to(device).unsqueeze(0) # Add batch dimension → Shape: (1, 1, 28, 28)\n",
639648
"output = cnn_model(image)\n",
640649
"# Print the model summary\n",
641650
"print(cnn_model)"
@@ -665,7 +674,7 @@
665674
"outputs": [],
666675
"source": [
667676
"# Rebuild the CNN model\n",
668-
"cnn_model = CNN()\n",
677+
"cnn_model = CNN().to(device)\n",
669678
"\n",
670679
"# Define hyperparams\n",
671680
"batch_size = 64\n",
@@ -703,6 +712,8 @@
703712
"\n",
704713
" # First grab a batch of training data which our data loader returns as a tensor\n",
705714
" for idx, (images, labels) in enumerate(tqdm(trainset_loader)):\n",
715+
" images, labels = images.to(device), labels.to(device)\n",
716+
" \n",
706717
" # Forward pass\n",
707718
" #'''TODO: feed the images into the model and obtain the predictions'''\n",
708719
" logits = cnn_model(images)\n",
@@ -791,7 +802,7 @@
791802
"outputs": [],
792803
"source": [
793804
"test_image, test_label = test_dataset[0]\n",
794-
"test_image = test_image.unsqueeze(0)\n",
805+
"test_image = test_image.to(device).unsqueeze(0)\n",
795806
"cnn_model.eval()\n",
796807
"predictions_test_image = cnn_model(test_image)"
797808
]
@@ -837,7 +848,7 @@
837848
"source": [
838849
"'''TODO: identify the digit with the highest confidence prediction for the first\n",
839850
" image in the test dataset. '''\n",
840-
"predictions_value = predictions_test_image.detach().numpy()\n",
851+
"predictions_value = predictions_test_image.cpu().detach().numpy() #.cpu() to copy tensor to memory first\n",
841852
"prediction = np.argmax(predictions_value)\n",
842853
"# prediction = # TODO\n",
843854
"print(prediction)"
@@ -861,7 +872,7 @@
861872
"outputs": [],
862873
"source": [
863874
"print(\"Label of this digit is:\", test_label)\n",
864-
"plt.imshow(test_image[0,0,:,:], cmap=plt.cm.binary)\n",
875+
"plt.imshow(test_image[0,0,:,:].cpu(), cmap=plt.cm.binary)\n",
865876
"comet_model_2.log_figure(figure=plt)"
866877
]
867878
},
@@ -907,9 +918,9 @@
907918
"all_images = torch.cat(all_images) # Shape: (total_samples, 1, 28, 28)\n",
908919
"\n",
909920
"# Convert tensors to NumPy for compatibility with plotting functions\n",
910-
"predictions = all_predictions.numpy() # Shape: (total_samples, num_classes)\n",
911-
"test_labels = all_labels.numpy() # Shape: (total_samples,)\n",
912-
"test_images = all_images.numpy() # Shape: (total_samples, 1, 28, 28)"
921+
"predictions = all_predictions.cpu().numpy() # Shape: (total_samples, num_classes)\n",
922+
"test_labels = all_labels.cpu().numpy() # Shape: (total_samples,)\n",
923+
"test_images = all_images.cpu().numpy() # Shape: (total_samples, 1, 28, 28)"
913924
]
914925
},
915926
{

0 commit comments

Comments
 (0)