|
1165 | 1165 | "id": "yWCMu12w1BuD" |
1166 | 1166 | }, |
1167 | 1167 | "source": [ |
1168 | | - "Now, we will put this decoder together with the standard CNN classifier as our encoder to define the DB-VAE. Note that at this point, there is nothing special about how we put the model together that makes it a \"debiasing\" model -- that will come when we define the training operation. Here, we will define the core VAE architecture by sublassing the `Model` class; defining encoding, reparameterization, and decoding operations; and calling the network end-to-end." |
| 1168 | + "Now, we will put this decoder together with the standard CNN classifier as our encoder to define the DB-VAE. Note that at this point, there is nothing special about how we put the model together that makes it a \"debiasing\" model -- that will come when we define the training operation. Here, we will define the core VAE architecture by sublassing `nn.Module` class; defining encoding, reparameterization, and decoding operations; and calling the network end-to-end." |
1169 | 1169 | ] |
1170 | 1170 | }, |
1171 | 1171 | { |
|
1242 | 1242 | "id": "M-clbYAj2waY" |
1243 | 1243 | }, |
1244 | 1244 | "source": [ |
1245 | | - "As stated, the encoder architecture is identical to the CNN from earlier in this lab. Note the outputs of our constructed DB_VAE model in the `call` function: `y_logit, z_mean, z_logsigma, z`. Think carefully about why each of these are outputted and their significance to the problem at hand.\n", |
| 1245 | + "As stated, the encoder architecture is identical to the CNN from earlier in this lab. Note the outputs of our constructed DB_VAE model in the `forward` function: `y_logit, z_mean, z_logsigma, z`. Think carefully about why each of these are outputted and their significance to the problem at hand.\n", |
1246 | 1246 | "\n" |
1247 | 1247 | ] |
1248 | 1248 | }, |
|
1602 | 1602 | "\n", |
1603 | 1603 | " y_logit, z_mean, z_logsigma, x_recon = dbvae(x)\n", |
1604 | 1604 | "\n", |
| 1605 | + " '''TODO: call the DB_VAE loss function to compute the loss'''\n", |
1605 | 1606 | " loss, class_loss = debiasing_loss_function(\n", |
1606 | 1607 | " x, x_recon, y, y_logit, z_mean, z_logsigma\n", |
1607 | 1608 | " )\n", |
| 1609 | + " # loss, class_loss = debiasing_loss_function('''TODO arguments''') # TODO\n", |
1608 | 1610 | "\n", |
| 1611 | + " '''TODO: backpropagate'''\n", |
1609 | 1612 | " loss.backward()\n", |
1610 | 1613 | " optimizer.step()\n", |
1611 | 1614 | "\n", |
|
1688 | 1691 | ], |
1689 | 1692 | "source": [ |
1690 | 1693 | "dbvae.to(device)\n", |
1691 | | - "\n", |
1692 | | - "\n", |
1693 | 1694 | "dbvae_logits_list = []\n", |
1694 | 1695 | "for face in test_faces:\n", |
1695 | 1696 | " face = np.asarray(face, dtype=np.float32)\n", |
|
1702 | 1703 | " dbvae_logits_list.append(logit.detach().cpu().numpy())\n", |
1703 | 1704 | "\n", |
1704 | 1705 | "dbvae_logits_array = np.concatenate(dbvae_logits_list, axis=0)\n", |
1705 | | - "\n", |
1706 | 1706 | "dbvae_logits_tensor = torch.from_numpy(dbvae_logits_array)\n", |
1707 | 1707 | "dbvae_probs_tensor = torch.sigmoid(dbvae_logits_tensor)\n", |
1708 | 1708 | "dbvae_probs_array = dbvae_probs_tensor.squeeze(dim=-1).numpy()\n", |
|
1712 | 1712 | "std_probs_mean = standard_classifier_probs.mean(axis=1)\n", |
1713 | 1713 | "dbvae_probs_mean = dbvae_probs_array.reshape(len(keys), -1).mean(axis=1)\n", |
1714 | 1714 | "\n", |
1715 | | - "\n", |
1716 | 1715 | "plt.bar(xx, std_probs_mean, width=0.2, label=\"Standard CNN\")\n", |
1717 | 1716 | "plt.bar(xx + 0.2, dbvae_probs_mean, width=0.2, label=\"DB-VAE\")\n", |
1718 | 1717 | "\n", |
|
0 commit comments