Skip to content

Commit 9a48fe6

Browse files
committed
cleaned up some descriptions and added a missing TODO
1 parent 9cf53b4 commit 9a48fe6

1 file changed

Lines changed: 5 additions & 6 deletions

File tree

lab2/solutions/PT_Part2_Debiasing_Solution.ipynb

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1165,7 +1165,7 @@
11651165
"id": "yWCMu12w1BuD"
11661166
},
11671167
"source": [
1168-
"Now, we will put this decoder together with the standard CNN classifier as our encoder to define the DB-VAE. Note that at this point, there is nothing special about how we put the model together that makes it a \"debiasing\" model -- that will come when we define the training operation. Here, we will define the core VAE architecture by sublassing the `Model` class; defining encoding, reparameterization, and decoding operations; and calling the network end-to-end."
1168+
"Now, we will put this decoder together with the standard CNN classifier as our encoder to define the DB-VAE. Note that at this point, there is nothing special about how we put the model together that makes it a \"debiasing\" model -- that will come when we define the training operation. Here, we will define the core VAE architecture by sublassing `nn.Module` class; defining encoding, reparameterization, and decoding operations; and calling the network end-to-end."
11691169
]
11701170
},
11711171
{
@@ -1242,7 +1242,7 @@
12421242
"id": "M-clbYAj2waY"
12431243
},
12441244
"source": [
1245-
"As stated, the encoder architecture is identical to the CNN from earlier in this lab. Note the outputs of our constructed DB_VAE model in the `call` function: `y_logit, z_mean, z_logsigma, z`. Think carefully about why each of these are outputted and their significance to the problem at hand.\n",
1245+
"As stated, the encoder architecture is identical to the CNN from earlier in this lab. Note the outputs of our constructed DB_VAE model in the `forward` function: `y_logit, z_mean, z_logsigma, z`. Think carefully about why each of these are outputted and their significance to the problem at hand.\n",
12461246
"\n"
12471247
]
12481248
},
@@ -1602,10 +1602,13 @@
16021602
"\n",
16031603
" y_logit, z_mean, z_logsigma, x_recon = dbvae(x)\n",
16041604
"\n",
1605+
" '''TODO: call the DB_VAE loss function to compute the loss'''\n",
16051606
" loss, class_loss = debiasing_loss_function(\n",
16061607
" x, x_recon, y, y_logit, z_mean, z_logsigma\n",
16071608
" )\n",
1609+
" # loss, class_loss = debiasing_loss_function('''TODO arguments''') # TODO\n",
16081610
"\n",
1611+
" '''TODO: backpropagate'''\n",
16091612
" loss.backward()\n",
16101613
" optimizer.step()\n",
16111614
"\n",
@@ -1688,8 +1691,6 @@
16881691
],
16891692
"source": [
16901693
"dbvae.to(device)\n",
1691-
"\n",
1692-
"\n",
16931694
"dbvae_logits_list = []\n",
16941695
"for face in test_faces:\n",
16951696
" face = np.asarray(face, dtype=np.float32)\n",
@@ -1702,7 +1703,6 @@
17021703
" dbvae_logits_list.append(logit.detach().cpu().numpy())\n",
17031704
"\n",
17041705
"dbvae_logits_array = np.concatenate(dbvae_logits_list, axis=0)\n",
1705-
"\n",
17061706
"dbvae_logits_tensor = torch.from_numpy(dbvae_logits_array)\n",
17071707
"dbvae_probs_tensor = torch.sigmoid(dbvae_logits_tensor)\n",
17081708
"dbvae_probs_array = dbvae_probs_tensor.squeeze(dim=-1).numpy()\n",
@@ -1712,7 +1712,6 @@
17121712
"std_probs_mean = standard_classifier_probs.mean(axis=1)\n",
17131713
"dbvae_probs_mean = dbvae_probs_array.reshape(len(keys), -1).mean(axis=1)\n",
17141714
"\n",
1715-
"\n",
17161715
"plt.bar(xx, std_probs_mean, width=0.2, label=\"Standard CNN\")\n",
17171716
"plt.bar(xx + 0.2, dbvae_probs_mean, width=0.2, label=\"DB-VAE\")\n",
17181717
"\n",

0 commit comments

Comments
 (0)