Image classification with transfer learning on tensorflow

Xin Cheng
7 min readMar 21, 2022

In my previous Transfer learning post, we reviewed transfer learning for NLP. Huggingface has made NLP transfer learning very easy. However, so far, I have not found similar framework for various computer vision tasks. Let’s review image classification task to see what is the pattern. This post we will focus on tensorflow.

We will use the famous cats and dogs image classification task (tell the image is cat image or dog image). Tensorflow has a good tutorial (with colab notebook) for starters and we will complement it with further explanations.

Mentioned by the above example, you will see two ways to customize a pretrained model:

  1. Feature Extraction: Use the representations learned by a previous network to extract meaningful features from new samples. You simply add a new classifier, which will be trained from scratch, on top of the pretrained model so that you can repurpose the feature maps learned previously for the dataset.
    You do not need to (re)train the entire model. The base convolutional network already contains features that are generically useful for classifying pictures. However, the final, classification part of the pretrained model is specific to the original classification task, and subsequently specific to the set of classes on which the model was trained.
  2. Fine-Tuning: Unfreeze a few of the top layers of a frozen model base and jointly train both the newly-added classifier layers and the last layers of the base model. This allows us to “fine-tune” the higher-order feature representations in the base model in order to make them more relevant for the specific task.

Currently, the dominant model architecture for computer vision is convolutional neural network/CNN architecture. Refer to Appendix for more information for CNN, but now just understand the model is trained to automatically capture features like edge, contour, orientation, texture that can be leveraged for upper-layer tasks.

General steps in image classification transfer learning

  1. Data loader
  2. Preprocessing
  3. Load pretrained model, freeze model layers according to your needs
  4. Add additional layers according to your needs, to form the final model
  5. Compile the model, setting up optimizer and loss function
  6. Train the model with model.fit

Feature extraction

Before the main model training, some code to load dataset, setup preprocessing. You can directly jump to “Create base model” part.

Import library and download images

This steps just import libraries and download training images into “train” and “validation” folder

import matplotlib.pyplot as plt
import numpy as np
import os
import tensorflow as tf
_URL = 'https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip'
path_to_zip = tf.keras.utils.get_file('cats_and_dogs.zip', origin=_URL, extract=True)
PATH = os.path.join(os.path.dirname(path_to_zip), 'cats_and_dogs_filtered')
train_dir = os.path.join(PATH, 'train')
validation_dir = os.path.join(PATH, 'validation')

You see following folders under keras downloaded folder, /root/.keras/datasets/cats_and_dogs_filtered

train

— cats

— dogs

validation

— cats

— dogs

You can use linux tool to inspect original image size

!apt install file
!apt install -y imagemagick
!file /root/.keras/datasets/cats_and_dogs_filtered/train/cats/cat.199.jpg!identify /root/.keras/datasets/cats_and_dogs_filtered/train/cats/cat.199.jpg

The original image size is 270x319.

Load train and validation dataset

We will load training dataset from train folder and validation dataset from validation folder. Notice two parameters: shuffle, Whether to shuffle the data. Default: True. If set to False, sorts the data in alphanumeric order; image_size, Size to resize images to after they are read from disk. Since the pipeline processes batches of images that must all have the same size, this must be provided.

BATCH_SIZE = 32
IMG_SIZE = (160, 160)
train_dataset = tf.keras.utils.image_dataset_from_directory(train_dir,
shuffle=True,
batch_size=BATCH_SIZE,
image_size=IMG_SIZE)
validation_dataset = tf.keras.utils.image_dataset_from_directory(validation_dir,
shuffle=True,
batch_size=BATCH_SIZE,
image_size=IMG_SIZE)

As the original dataset doesn’t contain a test set, you will create one. To do so, determine how many batches of data are available in the validation set using tf.data.experimental.cardinality, then move 20% of them to a test set.

val_batches = tf.data.experimental.cardinality(validation_dataset)
test_dataset = validation_dataset.take(val_batches // 5)
validation_dataset = validation_dataset.skip(val_batches // 5)

Data augmentation and preprocessing

You can add some data augmentation to images to increase dataset size to prevent overfitting, e.g. flip horizontally or vertically, rotation to add diversity to the training images

data_augmentation = tf.keras.Sequential([
# A preprocessing layer which randomly flips images during training.
tf.keras.layers.RandomFlip('horizontal_and_vertical'),
# A preprocessing layer which randomly rotates images during training.
tf.keras.layers.RandomRotation(0.2),
])

Rescale pixel values
tf.keras.applications.MobileNetV2 model expects pixel values in [-1, 1], but at this point, the pixel values in your images are in [0, 255]. To rescale them, use the preprocessing method included with the model.

preprocess_input = tf.keras.applications.mobilenet_v2.preprocess_input

Create base model

We will use MobileNetV2, which performs well on mobile devices. We will pass training images to base model and get features output by base model.

# Create the base model from the pre-trained model MobileNet V2
IMG_SHAPE = IMG_SIZE + (3,)
base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,
include_top=False,
weights='imagenet')
image_batch, label_batch = next(iter(train_dataset))
feature_batch = base_model(image_batch)
# 32 images, since our batch_size is 32
print(feature_batch.shape)

Freeze base model

When you use base model as feature extraction layer, it is important to freeze the convolutional base before you compile and train the model. Freezing (by setting layer.trainable = False) prevents the weights in a given layer from being updated during training. MobileNet V2 has many layers, so setting the entire model’s trainable flag to False will freeze all of them.

base_model.trainable = False

Setup model architecture

When you define the model, you need to tell Keras how to map inputs to outputs, in our scenario, inputs -> data augmentation layer -> preprocess (rescale) layer -> MobileV2Net -> GlobalAveragePooling2D layer -> Dropout layer -> Dense layer.

To generate predictions from the block of features, average over the spatial 5x5 spatial locations, using a tf.keras.layers.GlobalAveragePooling2D layer to convert the features to a single 1280-element vector per image. It is like GlobalAveragePooling2D applies average pooling on the spatial dimensions until each spatial dimension is one.

global_average_layer = tf.keras.layers.GlobalAveragePooling2D()
feature_batch_average = global_average_layer(feature_batch)
# (32, 5, 5, 1280) -> (32, 1280)
print(feature_batch_average.shape)

Apply a tf.keras.layers.Dense layer to convert these features into a single prediction per image.

prediction_layer = tf.keras.layers.Dense(1)
prediction_batch = prediction_layer(feature_batch_average)
print(prediction_batch.shape)

Chain the layers

inputs = tf.keras.Input(shape=(160, 160, 3))
x = data_augmentation(inputs)
x = preprocess_input(x)
x = base_model(x, training=False)
x = global_average_layer(x)
x = tf.keras.layers.Dropout(0.2)(x)
outputs = prediction_layer(x)
model = tf.keras.Model(inputs, outputs)

Compile the model

Specify optimizer and loss function

base_learning_rate = 0.0001
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=base_learning_rate),
loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
metrics=['accuracy'])

Train the model

We record history of training, so later we can continue training

initial_epochs = 10
history = model.fit(train_dataset,
epochs=initial_epochs,
validation_data=validation_dataset)

Fine tuning

The main idea of fine-tuning is that you want to adjust some weights in the pretrained model, especially in last few layers to adjust weights from generic feature maps to features associated specifically with your dataset. You should try to fine-tune a small number of top layers rather than the whole MobileNet model. In most convolutional networks, the higher up a layer is, the more specialized it is. The first few layers learn very simple and generic features that generalize to almost all types of images. As you go higher up, the features are increasingly more specific to the dataset on which the model was trained. The goal of fine-tuning is to adapt these specialized features to work with the new dataset, rather than overwrite the generic learning.

Un-freeze the top layers of the model

Now we want to adjust weights of the pretrained model, but only after a certain layer. The following code first sets base model to be trainable, then set all layers before layer 100 to be non-trainable (freezing earlier layer which contains simple and generic features).

base_model.trainable = True# Fine-tune from this layer onwards
fine_tune_at = 100
# Freeze all the layers before the `fine_tune_at` layer
for layer in base_model.layers[:fine_tune_at]:
layer.trainable = False

Setup model architecture

We will use same model architecture like feature extraction case.

Compile the model

As you are training a much larger model and want to readapt the pretrained weights, it is important to use a lower learning rate at this stage. Otherwise, your model could overfit very quickly.

model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
optimizer = tf.keras.optimizers.RMSprop(learning_rate=base_learning_rate/10),
metrics=['accuracy'])

Train the model

Here we continue training from where we left off at the previous feature extraction model

fine_tune_epochs = 10
total_epochs = initial_epochs + fine_tune_epochs
history_fine = model.fit(train_dataset,
epochs=total_epochs,
initial_epoch=history.epoch[-1],
validation_data=validation_dataset)

Prediction

# Retrieve a batch of images from the test set
image_batch, label_batch = test_dataset.as_numpy_iterator().next()
predictions = model.predict_on_batch(image_batch).flatten()
# Apply a sigmoid since our model returns logits
predictions = tf.nn.sigmoid(predictions)
predictions = tf.where(predictions < 0.5, 0, 1)
print('Predictions:\n', predictions.numpy())
print('Labels:\n', label_batch)
plt.figure(figsize=(10, 10))
for i in range(9):
ax = plt.subplot(3, 3, i + 1)
plt.imshow(image_batch[i].astype("uint8"))
plt.title(class_names[predictions[i]])
plt.axis("off")

Appendix

--

--

Xin Cheng

Multi/Hybrid-cloud, Kubernetes, cloud-native, big data, machine learning, IoT developer/architect, 3x Azure-certified, 3x AWS-certified, 2x GCP-certified