Tensorflow: training loop from scratch

The deep learning model has model weights which start from a random state wen the model object is created. As we expose the model with the training data, the weights are modified such that their mathematical operations produce values close to the training labels.

Following are the major things that happen during training:

  • Weight initialization
  • Forward pass of the data through the model layers
  • Calculation of loss between output of forward pass and the labels using a loss function.
  • Calculation of the gradients with automatic differentiation
  • Updating the weights
  • Repeat from forward pass.

If we want to control exactly how the training occurs, we need to do the training using custom loop for each epoch and batch of the training data, implementing the above steps.

In the custom epoch, each batch of data is passed through the model inside a GradientTape scope. Tensorflow’s GradientTape is used to record the automatic differentiation. Then, the loss is calculated using the loss function.
The GradientTape calculates the gradients of the weights of the model which are then updated using the optimizer.

Following diagram shows the workflow of a custom training loop.

Let’s demonstrate this with an example.

We will take the mnist data, which has 10 classes. We select only three classes from the data and make a multiclass classification model for classifying the data into these three classes. We will select the classes, 0, 1, and 2.

First we will import the python libraries we will need.

Python
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow import keras

from sklearn.utils import shuffle

Read data and preprocess

We will use the mnist dataset provided in keras for this tutorial.

We need to know the size and shape of the data. Also, to normalize the data we will see the range of values in the data.

Python
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

x_train, y_train = shuffle(x_train, y_train)

print('\n\nData shape', x_train.shape)
print('minimum', np.min(x_train))
print('maximum', np.max(x_train))
Output
Data shape (60000, 28, 28)
minimum 0
maximum 255

In this example, we will use only the x_train data to make training and validation datasets. The testing part of the process will not be shown as it is same seen previously.

We see the data values range from 0 to 255. We will divide all the values by 255, to transform the data into numbers in range 0 to 1. This normalization method is sufficient for our data. To normalize other datasets, different nonrmalization methods would be more apporopriate.

Python
x_train = x_train / 255

For this example we need each sample to be a flattened array as we will only be using dense layers in neural network. We will see the shape of the data first.

Python
print(x_train[0].shape)
Output
(28, 28)

So, the samples in the data are shaped (28,28). To flatten the samples, we do the following:

Python
x_train = x_train.reshape(-1, 28*28)
x_train.shape
Output
(60000, 784)

We will use the first 50,000 samples as training data and the rest as validation data. Remember that we have shuffled the data earlier, so we can use this method to separate the training and validation data.

Python
n_train = 50000
train_data = x_train[:n_train]
train_labels = y_train[:n_train]

val_data = x_train[n_train:]
val_labels = y_train[n_train:]

print(f'{train_data.shape=}')
print(f'{train_labels.shape=}')
print(f'{val_data.shape=}')
print(f'{val_labels.shape=}')
Output
train_data.shape=(50000, 784)
train_labels.shape=(50000,)
val_data.shape=(10000, 784)
val_labels.shape=(10000,)

Tensorflow dataset and BatchDataset

In the earlier posts we reshaped our data and labels so that it is compatible with the neural network. We reshaped the labels to (numsamples, 1) shape, and the samples to (numsamples, numfeatures, 1).

Here, we will use another method. We will convert our NumPy arrays into tensorflow datasets. Further, we also would make BatchDataset from these datasets using the batch_size. For this example, we will use 32 as the batch size.

For creating the tensorflow dataset, from_tensor_slices method is used. It takes a single argument, a tuple of data and labels.

The batch method of the dataset is used to create batches.

Python
# batch size
batch_size = 32

# prepare training dataset
train_dataset = tf.data.Dataset.from_tensor_slices((train_data, train_labels))
train_dataset = train_dataset.batch(batch_size) # batch dataset

# prepare validation dataset
val_dataset = tf.data.Dataset.from_tensor_slices((val_data, val_labels))
val_dataset = val_dataset.batch(batch_size) # batch dataset

Get model

We will make a get_model function to get the neural network. Note that, we will not compile the model as we are going to run a custom loop with GradientTape.

The output layer of the model will have 3 neurons, the same number of classes we have selected. Also, there is no activation function applied to the output layer.

Python
# get model
def get_model():
    # input layer
    inp = keras.layers.Input(shape=train_data.shape[1:])

    # dense layer
    x = keras.layers.Dense(512, activation='relu')(inp)
    x = keras.layers.Dense(256, activation='relu')(x)
    x = keras.layers.Dense(128, activation='relu')(x)

    # output layer
    out =  keras.layers.Dense(10)(x)

    # construct model
    model = keras.Model(inputs=inp, outputs=out)

    return model

model = get_model()

Optimizer and loss function

We will use Adam as the optimizer with learning rate 0.001.

For loss function we will use SparseCategoricalCrossentropy as we have three classes and we will not be using the OneHot encoded labels.

Python
# optimizer
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)

# loss function
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

We will track Accuracy metric.

Python
# metrics
train_acc = keras.metrics.Accuracy()
val_acc = keras.metrics.Accuracy()

At the end of each epoch, we will store the loss and accuracy values in a dictionary.

Python
history = {
    'train_loss' : [],
    'val_loss' : [],
    'train_accuracy' : [],
    'val_accuracy' : []
}

Training

We now have all the things ready for training our model. We will train the model for 10 epochs.

Following is how we implement the training steps as shown in the figure above.

Python
# number of epochs
num_epochs = 10

# iterate over epochs
for ep in range(num_epochs):
  print(f'Epoch {ep+1}')

  # iterate over batches
  for i, (train_batch_data, train_batch_labels) in enumerate(train_dataset):
    # gradient tape
    with tf.GradientTape() as tape:
      logits = model(train_batch_data, training=True)
      loss = loss_fn(train_batch_labels, logits)

    # calculate gradients
    grads = tape.gradient(loss, model.trainable_weights)

    # apply gradients
    optimizer.apply_gradients(zip(grads, model.trainable_weights))

    # update training metric
    preds = tf.argmax(logits, axis=1)
    train_acc.update_state(train_batch_labels, preds)

    # print metric value after every 300 batches
    if i % 300 == 0:
      print(f'Loss = {loss}, Accuracy = {train_acc.result()}')

    # --- batch loop ends here ---------

  # record train loss
  history['train_loss'].append(loss)

  # record training accuracy
  history['train_accuracy'].append(train_acc.result())

  # reset metric state
  train_acc.reset_state()

  # validation loop
  for val_batch_data, val_batch_labels in val_dataset:
    val_logits = model(val_batch_data, training=False)
    val_loss = loss_fn(val_batch_labels, val_logits)

    # update validation metric
    val_preds =  tf.argmax(val_logits, axis=1)
    val_acc.update_state(val_batch_labels, val_preds)

  # record validation loss
  history['val_loss'].append(val_loss)

  # print and record validation accuracy
  print('Validation accuracy: ', val_acc.result())
  history['val_accuracy'].append(val_acc.result())

  # reset states of validation accuracy
  val_acc.reset_state()
   
  # --- epoch loop ends here ---------
Output
Epoch 1
Loss = 2.3693222999572754, Accuracy = 0.09375
Loss = 0.3300689458847046, Accuracy = 0.7864410281181335
Loss = 0.22833402454853058, Accuracy = 0.8482736945152283
Loss = 0.48036694526672363, Accuracy = 0.8740982413291931
Loss = 0.0692846029996872, Accuracy = 0.890169620513916
Loss = 0.06315110623836517, Accuracy = 0.9006704092025757
Validation accuracy:  tf.Tensor(0.9471, shape=(), dtype=float32)
Epoch 2
Loss = 0.6598396301269531, Accuracy = 0.875
Loss = 0.09225505590438843, Accuracy = 0.9491279125213623
Loss = 0.18820396065711975, Accuracy = 0.9515390992164612
Loss = 0.34771081805229187, Accuracy = 0.9537666440010071
Loss = 0.028796473518013954, Accuracy = 0.9557400345802307
Loss = 0.026022296398878098, Accuracy = 0.9572992920875549
Validation accuracy:  tf.Tensor(0.9612, shape=(), dtype=float32)
Epoch 3
Loss = 0.48503226041793823, Accuracy = 0.875
Loss = 0.04891586676239967, Accuracy = 0.9652200937271118
Loss = 0.1606135368347168, Accuracy = 0.9663061499595642
Loss = 0.24057382345199585, Accuracy = 0.9676747918128967
Loss = 0.015865761786699295, Accuracy = 0.9690882563591003
Loss = 0.01238587312400341, Accuracy = 0.9699575304985046
Validation accuracy:  tf.Tensor(0.967, shape=(), dtype=float32)
Epoch 4
Loss = 0.28929534554481506, Accuracy = 0.875
Loss = 0.043264687061309814, Accuracy = 0.9738371968269348
Loss = 0.1447778344154358, Accuracy = 0.9750936031341553
Loss = 0.15062391757965088, Accuracy = 0.9758254885673523
Loss = 0.008953276090323925, Accuracy = 0.9770764112472534
Loss = 0.005596090108156204, Accuracy = 0.9779521822929382
Validation accuracy:  tf.Tensor(0.9705, shape=(), dtype=float32)
Epoch 5
Loss = 0.1715865135192871, Accuracy = 0.90625
Loss = 0.04048454761505127, Accuracy = 0.9822466969490051
Loss = 0.14332400262355804, Accuracy = 0.9823211431503296
Loss = 0.11258415132761002, Accuracy = 0.9826234579086304
Loss = 0.005552412010729313, Accuracy = 0.9832951426506042
Loss = 0.0025558266788721085, Accuracy = 0.9839690327644348
Validation accuracy:  tf.Tensor(0.9726, shape=(), dtype=float32)
Epoch 6
Loss = 0.12078477442264557, Accuracy = 0.9375
Loss = 0.03734421730041504, Accuracy = 0.9870223999023438
Loss = 0.1515222042798996, Accuracy = 0.9867408275604248
Loss = 0.0950898602604866, Accuracy = 0.9872016906738281
Loss = 0.003616809379309416, Accuracy = 0.9876665472984314
Loss = 0.0012269732542335987, Accuracy = 0.9882161617279053
Validation accuracy:  tf.Tensor(0.974, shape=(), dtype=float32)
Epoch 7
Loss = 0.08778789639472961, Accuracy = 0.96875
Loss = 0.029765240848064423, Accuracy = 0.9908638000488281
Loss = 0.1529184728860855, Accuracy = 0.990432620048523
Loss = 0.08670826256275177, Accuracy = 0.9909822344779968
Loss = 0.002418466378003359, Accuracy = 0.9912573099136353
Loss = 0.0005564287421293557, Accuracy = 0.9914640188217163
Validation accuracy:  tf.Tensor(0.974, shape=(), dtype=float32)
Epoch 8
Loss = 0.05686352401971817, Accuracy = 0.96875
Loss = 0.022822609171271324, Accuracy = 0.9948089718818665
Loss = 0.1572648286819458, Accuracy = 0.9937084317207336
Loss = 0.06381712108850479, Accuracy = 0.9936875700950623
Loss = 0.001745270099490881, Accuracy = 0.9940414428710938
Loss = 0.00030787079595029354, Accuracy = 0.9942746758460999
Validation accuracy:  tf.Tensor(0.9756, shape=(), dtype=float32)
Epoch 9
Loss = 0.030033979564905167, Accuracy = 1.0
Loss = 0.023136965930461884, Accuracy = 0.9969891905784607
Loss = 0.1436910629272461, Accuracy = 0.9962042570114136
Loss = 0.03852612525224686, Accuracy = 0.9961154460906982
Loss = 0.001674670958891511, Accuracy = 0.9962010979652405
Loss = 0.00017247656069230288, Accuracy = 0.9962733387947083
Validation accuracy:  tf.Tensor(0.976, shape=(), dtype=float32)
Epoch 10
Loss = 0.01602054387331009, Accuracy = 1.0
Loss = 0.018718460574746132, Accuracy = 0.9980273842811584
Loss = 0.1131332591176033, Accuracy = 0.9969841837882996
Loss = 0.0441979318857193, Accuracy = 0.9970518946647644
Loss = 0.0009350811596959829, Accuracy = 0.9968776106834412
Loss = 0.0001424634683644399, Accuracy = 0.9969812035560608
Validation accuracy:  tf.Tensor(0.9763, shape=(), dtype=float32)

Explanation of the training loop

First we iterate over the epochs and then over the batches with nested for loops.

Python
# number of epochs
num_epochs = 10

# iterate over epochs
for ep in range(num_epochs):
  print(f'Epoch {ep+1}')

  # iterate over batches
  for i, (train_batch_data, train_batch_labels) in enumerate(train_dataset):

Now, inside a GradientTape scope, we calculate the logits and loss.

Python
with tf.GradientTape() as tape:
  logits = model(train_batch_data, training=True)
  loss = loss_fn(train_batch_labels, logits)

The logits are raw outputs of the forward pass of the data through the model. The logits have shape (batch_size, output layer size), which in this case is (32, 10).

The training=True sets the model into training mode. In practice, if we have layers such as dropout and BatchNormalization layers, they will behave differently in training and evaluation mode.

The loss is a tensor of type float.

Python
# calculate gradients
grads = tape.gradient(loss, model.trainable_weights)

# apply gradients
optimizer.apply_gradients(zip(grads, model.trainable_weights))

The grads are calculated for each batch and the weights updated accordingly.

Python
# update training metric
preds = tf.argmax(logits, axis=1)
train_acc.update_state(train_batch_labels, preds)

Update our metric, i.e. train_acc, for each batch. We first get the predictions from the logits. The logits have 10 value for each sample, representing probability of samples belonging to the ten classes. Class with highest probability is the predicted class. This is identified using tf.argmax(logits, axis=1).

Python
# print metric value after every 300 batches
if i % 300 == 0:
  print(f'Loss = {loss}, Accuracy = {train_acc.result()}')

After every 300 batches, prints out the loss and training accuracy values.

Python
# record train loss
history['train_loss'].append(loss)

# record training accuracy
history['train_accuracy'].append(train_acc.result())

Record the values of training loss and training accuracy to the history dictionary.

Python
# reset metric state
train_acc.reset_state()

Resets the state of the training accuracy.

Python
# validation loop
for val_batch_data, val_batch_labels in val_dataset:
  val_logits = model(val_batch_data, training=False)
  val_loss = loss_fn(val_batch_labels, val_logits)

  # update validation metric
  val_preds =  tf.argmax(val_logits, axis=1)
  val_acc.update_state(val_batch_labels, val_preds)

The validation loop passes the validation data through model and outputs the predicted values based on the optimized weights of the model after training. Here we put the model in evaluation mode by setting training=False.

The actual predicted class is calculated using the tf.argmax method.

The loss and accuracy of validation data are recorded in the history dictionary.

Python
# record validation loss
  history['val_loss'].append(val_loss)

  # print and record validation accuracy
  print('Validation accuracy: ', val_acc.result())
  history['val_accuracy'].append(val_acc.result())

In the end we compile all the results and print the validation accuracy after each epoch.

Python
 # reset states of validation accuracy
  val_acc.reset_state()

Before the next epoch starts, we reset the state of validation accuracy.

Visualizing the training metrics.

We had stored the loss and accuracy values in the history, which can be plotted using Matplotlib as below.

Python
# plt history
fig , ax = plt.subplots(2, sharex=True)
# accuracy history
ax[0].plot(history['train_accuracy'], label='train_accuracy')
ax[0].plot(history['val_accuracy'], label='val_accuracy')
ax[0].legend()
ax[0].set_title('Accuracy values')


# loss history
ax[1].plot(history['train_loss'], label='train_loss')
ax[1].plot(history['val_loss'], label='val_loss')
ax[1].legend()
ax[1].set_title('Loss values')

plt.show()