|
962 | 962 | "comet_model_2.end()" |
963 | 963 | ] |
964 | 964 | }, |
965 | | - { |
966 | | - "cell_type": "markdown", |
967 | | - "metadata": { |
968 | | - "id": "k-2glsRiMdqa" |
969 | | - }, |
970 | | - "source": [ |
971 | | - "## 1.4 Training the model 2.0\n", |
972 | | - "\n", |
973 | | - "Earlier in the lab, we used the [`fit`](https://www.tensorflow.org/api_docs/python/tf/keras/models/Sequential#fit) function call to train the model. This function is quite high-level and intuitive, which is really useful for simpler models. As you may be able to tell, this function abstracts away many details in the training call, and we have less control over training model, which could be useful in other contexts.\n", |
974 | | - "\n", |
975 | | - "As an alternative to this, we can use the [`tf.GradientTape`](https://www.tensorflow.org/api_docs/python/tf/GradientTape) class to record differentiation operations during training, and then call the [`tf.GradientTape.gradient`](https://www.tensorflow.org/api_docs/python/tf/GradientTape#gradient) function to actually compute the gradients. You may recall seeing this in Lab 1 Part 1, but let's take another look at this here.\n", |
976 | | - "\n", |
977 | | - "We'll use this framework to train our `cnn_model` using stochastic gradient descent." |
978 | | - ] |
979 | | - }, |
980 | | - { |
981 | | - "cell_type": "code", |
982 | | - "execution_count": null, |
983 | | - "metadata": { |
984 | | - "id": "Wq34id-iN1Ml" |
985 | | - }, |
986 | | - "outputs": [], |
987 | | - "source": [ |
988 | | - "# Rebuild the CNN model\n", |
989 | | - "cnn_model = build_cnn_model()\n", |
990 | | - "\n", |
991 | | - "batch_size = 12\n", |
992 | | - "loss_history = mdl.util.LossHistory(smoothing_factor=0.95) # to record the evolution of the loss\n", |
993 | | - "plotter = mdl.util.PeriodicPlotter(sec=2, xlabel='Iterations', ylabel='Loss', scale='semilogy')\n", |
994 | | - "optimizer = tf.keras.optimizers.SGD(learning_rate=1e-2) # define our optimizer\n", |
995 | | - "\n", |
996 | | - "comet_ml.init(project_name=\"6.s191lab2_part1_CNN2\")\n", |
997 | | - "comet_model_3 = comet_ml.Experiment()\n", |
998 | | - "\n", |
999 | | - "if hasattr(tqdm, '_instances'): tqdm._instances.clear() # clear if it exists\n", |
1000 | | - "\n", |
1001 | | - "for idx in tqdm(range(0, train_images.shape[0], batch_size)):\n", |
1002 | | - " # First grab a batch of training data and convert the input images to tensors\n", |
1003 | | - " (images, labels) = (train_images[idx:idx+batch_size], train_labels[idx:idx+batch_size])\n", |
1004 | | - " images = tf.convert_to_tensor(images, dtype=tf.float32)\n", |
1005 | | - "\n", |
1006 | | - " # GradientTape to record differentiation operations\n", |
1007 | | - " with tf.GradientTape() as tape:\n", |
1008 | | - " #'''TODO: feed the images into the model and obtain the predictions'''\n", |
1009 | | - " logits = cnn_model(images)\n", |
1010 | | - " # logits = # TODO\n", |
1011 | | - "\n", |
1012 | | - " #'''TODO: compute the categorical cross entropy loss\n", |
1013 | | - " loss_value = tf.keras.backend.sparse_categorical_crossentropy(labels, logits)\n", |
1014 | | - " comet_model_3.log_metric(\"loss\", loss_value.numpy().mean(), step=idx)\n", |
1015 | | - " # loss_value = tf.keras.backend.sparse_categorical_crossentropy('''TODO''', '''TODO''') # TODO\n", |
1016 | | - "\n", |
1017 | | - " loss_history.append(loss_value.numpy().mean()) # append the loss to the loss_history record\n", |
1018 | | - " plotter.plot(loss_history.get())\n", |
1019 | | - "\n", |
1020 | | - " # Backpropagation\n", |
1021 | | - " '''TODO: Use the tape to compute the gradient against all parameters in the CNN model.\n", |
1022 | | - " Use cnn_model.trainable_variables to access these parameters.'''\n", |
1023 | | - " grads = tape.gradient(loss_value, cnn_model.trainable_variables)\n", |
1024 | | - " # grads = # TODO\n", |
1025 | | - " optimizer.apply_gradients(zip(grads, cnn_model.trainable_variables))\n", |
1026 | | - "\n", |
1027 | | - "comet_model_3.log_figure(figure=plt)\n", |
1028 | | - "comet_model_3.end()\n" |
1029 | | - ] |
1030 | | - }, |
1031 | 965 | { |
1032 | 966 | "cell_type": "markdown", |
1033 | 967 | "metadata": { |
|
0 commit comments