|
343 | 343 | "\n", |
344 | 344 | "### Define and train the CNN model\n", |
345 | 345 | "\n", |
346 | | - "Like we did in the first part of the lab, we'll define our CNN model, and then train on the CelebA and ImageNet datasets using the `tf.GradientTape` class and the `tf.GradientTape.gradient` method." |
| 346 | + "Like we did in the first part of the lab, we'll define our CNN model, and then train on the CelebA and ImageNet datasets by leveraging PyTorch's automatic differentiation (`torch.autograd`) by using the `loss.backward()` and `optimizer.step()` functions." |
347 | 347 | ] |
348 | 348 | }, |
349 | 349 | { |
|
870 | 870 | "\n", |
871 | 871 | " # TODO: Define the reconstruction loss as the mean absolute pixel-wise\n", |
872 | 872 | " # difference between the input and reconstruction. Hint: you'll need to\n", |
873 | | - " # use tf.reduce_mean, and supply an axis argument which specifies which\n", |
874 | | - " # dimensions to reduce over. For example, reconstruction loss needs to average\n", |
| 873 | + " # use torch.mean, and specify the dimensions to reduce over. \n", |
| 874 | + " # For example, reconstruction loss needs to average\n", |
875 | 875 | " # over the height, width, and channel image dimensions.\n", |
876 | | - " # https://www.tensorflow.org/api_docs/python/tf/math/reduce_mean\n", |
| 876 | + " # https://pytorch.org/docs/stable/generated/torch.mean.html\n", |
877 | 877 | " reconstruction_loss = torch.mean(torch.abs(x - x_recon), dim=(1, 2, 3))\n", |
878 | 878 | "\n", |
879 | 879 | " # TODO: Define the VAE loss. Note this is given in the equation for L_{VAE}\n", |
|
928 | 928 | "\n", |
929 | 929 | "\n", |
930 | 930 | "def sampling(z_mean, z_logsigma):\n", |
931 | | - " # By default, random.normal is \"standard\" (ie. mean=0 and std=1.0)\n", |
932 | | - " # batch, latent_dim = z_mean.shape\n", |
933 | | - " # epsilon = tf.random.normal(shape=(batch, latent_dim))\n", |
| 931 | + " # By default, randn_like is \"standard\" (ie. mean=0 and std=1.0)\n", |
| 932 | + " eps = torch.randn_like(z_mean)\n", |
934 | 933 | "\n", |
935 | 934 | " # # TODO: Define the reparameterization computation!\n", |
936 | 935 | " # # Note the equation is given in the text block immediately above.\n", |
| 936 | + " \n", |
| 937 | + " z = z_mean + torch.exp(z_logsigma) * eps\n", |
937 | 938 | " # z = # TODO\n", |
938 | | - " # return z\n", |
939 | | - "\n", |
940 | | - " eps = torch.randn_like(z_mean)\n", |
941 | | - " return z_mean + torch.exp(z_logsigma) * eps" |
| 939 | + " return z" |
942 | 940 | ] |
943 | 941 | }, |
944 | 942 | { |
|
1025 | 1023 | " vae_loss = vae_loss_function(x, x_pred, mu, logsigma)\n", |
1026 | 1024 | " # vae_loss = vae_loss_function('''TODO''') # TODO\n", |
1027 | 1025 | "\n", |
1028 | | - " # TODO: define the classification loss using sigmoid_cross_entropy\n", |
1029 | | - " # https://www.tensorflow.org/api_docs/python/tf/nn/sigmoid_cross_entropy_with_logits\n", |
| 1026 | + " # TODO: define the classification loss using binary_cross_entropy\n", |
| 1027 | + " # https://pytorch.org/docs/stable/generated/torch.nn.functional.binary_cross_entropy_with_logits.html\n", |
1030 | 1028 | " # classification_loss = # TODO\n", |
1031 | 1029 | " classification_loss = F.binary_cross_entropy_with_logits(\n", |
1032 | 1030 | " y_logit, y, reduction=\"none\"\n", |
|
1037 | 1035 | " y = y.float()\n", |
1038 | 1036 | " face_indicator = (y == 1.0).float()\n", |
1039 | 1037 | "\n", |
1040 | | - " # TODO: define the DB-VAE total loss! Use tf.reduce_mean to average over all\n", |
| 1038 | + " # TODO: define the DB-VAE total loss! Use torch.mean to average over all\n", |
1041 | 1039 | " # samples\n", |
1042 | 1040 | " # total_loss = # TODO\n", |
1043 | 1041 | " total_loss = torch.mean(classification_loss * face_indicator + vae_loss)\n", |
|
0 commit comments