Pytorch Transfer Learning for End to End Multiclass Image Classification

Article by Rahul Agarwal | June 01, 2020

In this post we’ll create an end to end pipeline for image multiclass classification using Pytorch and transfer learning. This will include training the model, putting the model’s results in a form that can be shown to a potential business, and functions to help deploy the model easily.

Have you ever wondered how Facebook takes care of the abusive and inappropriate images shared by some of its users? Or how Facebook’s tagging feature works? Or how Google Lens recognizes products through images?

All of the above are examples of image classification in different settings. Multiclass image classification is a common task in computer vision, where we categorize an image by using the image.

In the past, I always used Keras for computer vision projects. However, recently when the opportunity to work on multiclass image classification presented itself, I decided to use PyTorch. I have already moved from Keras to PyTorch for all NLP tasks, so why not vision, too? PyTorch is powerful, and I also like its more pythonic structure.

But before we learn how to do image classification, let’s first look at transfer learning, the most common method for dealing with such problems.



What is Transfer Learning?

Transfer learning is the process of repurposing knowledge from one task to another. From a modeling perspective, this means using a model trained on one dataset and fine-tuning it for use with another. But why does it work?

Let’s start with some background. Every year the visual recognition community comes together for a very particular challenge: The Imagenet Challenge. The task in this challenge is to classify 1,000,000 images into 1,000 categories.

This challenge has already resulted in researchers training big convolutional deep learning models. The results have included great models like Resnet50 and Inception.

But, what does it mean to train a neural model? Essentially, it means the researchers have learned the weights for a neural network after training the model on a million images.

So, what if we could get those weights? We could then use them and load them into our own neural networks model to predict on the test dataset, right? Actually, we can go even further than that; we can add an extra layer on top of the neural network these researchers have prepared to classify our own dataset.

While the exact workings of these complex models is still a mystery, we do know that the lower convolutional layers capture low-level image features like edges and gradients. In comparison, higher convolutional layers capture more and more intricate details, such as body parts, faces, and other compositional features.

You can see how the first few layers capture basic shapes, and the shapes become more and more complex in the later layers. (Source:

In the example above from ZFNet (a variant of Alexnet), one of the first convolutional neural networks to achieve success on the Imagenet task, you can see how the lower layers capture lines and edges, and the later layers capture more complex features. The final fully-connected layers are generally assumed to capture information that is relevant for solving the respective task, e.g. ZFNet’s fully-connected layers indicate which features are relevant for classifying an image into one of 1,000 object categories.

For a new vision task, it is possible for us to simply use the off-the-shelf features of a state-of-the-art CNN pre-trained on ImageNet, and train a new model on these extracted features.

The intuition behind this idea is that a model trained to recognize animals might also be used to recognize cats vs dogs. In our case, a model that has been trained on 1000 different categories has seen a lot of real-world information, and we can use this information to create our own custom classifier.

So that’s the theory and intuition. How do we get it to actually work? Let’s look at some code. You can find the complete code for this post on Github.



Data Exploration

We will start with the Boat Dataset from Kaggle to understand the multiclass image classification problem. This dataset contains about 1,500 pictures of boats of different types: buoys, cruise ships, ferry boats, freight boats, gondolas, inflatable boats, kayaks, paper boats, and sailboats. Our goal is to create a model that looks at a boat image and classifies it into the correct category.

Here’s a sample of images from the dataset:

pytorch transfer learning data exploration example images


And here are the category counts:

Since the categories freight boats, inflatable boats, and boats don’t have many images, we will remove these categories when we train our model.



Creating the Required Directory Structure

Before we can go through with training our deep learning models, we need to create the required directory structure for our images. Right now, our data directory structure looks like this:


We need our images to be in three folders: train, val and test. We will then train the model on the images in the train dataset, validate on the val dataset and finally test with the test dataset.


You might have your data in a different format, but I have found that apart from the usual libraries, the glob.glob and os.system functions are very helpful. Here you can find the complete data preparation code. Now let’s take a quick look at some of the not-so-used libraries that I found useful while doing data prep.


What is glob.glob?

Simply put, glob lets you get names of files or folders in a directory using a regex. For example, you can do something like this:

from glob import glob
categories = glob(“images/*”)
['images/kayak', 'images/boats', 'images/gondola', 'images/sailboat', 'images/inflatable boat', 'images/paper boat', 'images/buoy', 'images/cruise ship', 'images/freight boat', 'images/ferry boat']

What is os.system?

os.system is a function in os library which lets you run any command-line function in python itself. I generally use it to run Linux functions, but it can also be used to run R scripts within python as shown here. For example, I use it in my data preparation to copy files from one directory to another after getting the information from a pandas data frame. I also use f string formatting.

import os
for i,row in fulldf.iterrows():
    # Boat category
    cat = row['category']
    # section is train,val or test
    section = row['type']
    # input filepath to copy
    ipath = row['filepath']
    # output filepath to paste
    opath = ipath.replace(f"images/",f"data/{section}/")
    # running the cp command
    os.system(f"cp '{ipath}' '{opath}'")

Now that we have our data in the required folder structures, we can move on to the more exciting parts.



Data Preprocessing


1. Imagenet Preprocessing

In order to use our images with a network trained on the Imagenet dataset, we need to preprocess our images in the same way as the Imagenet network. For that, we need to rescale the images to 224×224 and normalize them as per Imagenet standards. We can use the torchvision transforms library to do that. Here we take a CenterCrop of 224×224 and normalize as per Imagenet standards. The operations defined below happen sequentially. You can find a list of all transforms provided by PyTorch here.

        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])


2. Data Augmentation

We can do a lot more preprocessing for data augmentations. Neural networks work better with a lot of data. Data augmentation is a strategy which we use at training time to increase the amount of data we have.

For example, we can flip the image of a boat horizontally, and it will still be a boat. Or we can randomly crop images or add color jitters. Here is the image transforms dictionary I have used that applies to both the Imagenet preprocessing as well as augmentations. This dictionary contains the various transforms we have for the train, test and validation data as used in this great post. As you’d expect, we don’t apply the horizontal flips or other data augmentation transforms to the test data and validation data because we don’t want to get predictions on an augmented image.

# Image transformations
image_transforms = {
    # Train uses data augmentation
        transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)),
        transforms.CenterCrop(size=224), # Image net standards
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225]) # Imagenet standards
    # Validation does not use augmentation
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

        # Test does not use augmentation
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

Here is an example of the train transforms applied to an image in the training dataset. Not only do we get a huge range of different images from one single image, but it also helps our network become invariant to the object orientation.

ex_img ='/home/rahul/projects/compvisblog/data/train/cruise ship/cruise-ship-oasis-of-the-seas-boat-water-482183.jpg')
t = image_transforms['train']
plt.figure(figsize=(24, 24))
for i in range(16):
    ax = plt.subplot(4, 4, i + 1)
    _ = imshow_tensor(t(ex_img), ax=ax)



Data Loaders

The next step is to provide the training, validation, and test dataset locations to PyTorch. We can do this by using the PyTorch datasets and DataLoader class. This part of the code will mostly remain the same if we have our data in the required directory structures.

# Datasets from folders
traindir = "data/train"
validdir = "data/val"
testdir = "data/test"
data = {
    datasets.ImageFolder(root=traindir, transform=image_transforms['train']),
    datasets.ImageFolder(root=validdir, transform=image_transforms['valid']),
    datasets.ImageFolder(root=testdir, transform=image_transforms['test'])
# Dataloader iterators, make sure to shuffle
dataloaders = {
    'train': DataLoader(data['train'], batch_size=batch_size, shuffle=True,num_workers=10),
    'val': DataLoader(data['valid'], batch_size=batch_size, shuffle=True,num_workers=10),
    'test': DataLoader(data['test'], batch_size=batch_size, shuffle=True,num_workers=10)

These dataloaders help us to iterate through the dataset. For example, we will use the dataloader below in our model training. The data variable will contain data in the form (batch_size, color_channels, height, width) while the target is of shape (batch_size) and hold the label information.

train_loader = dataloaders['train']
for ii, (data, target) in enumerate(train_loader):




1. Create the model using a pre-trained model

At present, the following pre-trained models are available to use in the torchvision library:

I will use resnet50 on our dataset, but you can effectively use any other model depending on your preference.

from torchvision import models
model = models.resnet50(pretrained=True)

We start by freezing our model weights since we don’t want to change the weights for the resnet50 models.

# Freeze model weights
for param in model.parameters():
    param.requires_grad = False

The next thing we need to do is replace the linear classification layer in the model with our custom classifier. I have found that to do this, it is better first to see the model structure to determine the final linear layer. We can do this simply by printing the model object:

  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
(conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=2048, out_features=1000, bias=True)

Here we find that the final linear layer that takes the input from the convolutional layers is named fc.

We can now simply replace the fc layer using our custom neural network. This neural network takes input from the previous layer to fc and gives the log softmax output of shape (batch_size x n_classes).

n_inputs = model.fc.in_features
model.fc = nn.Sequential(
                      nn.Linear(n_inputs, 256),
                      nn.Linear(256, n_classes),

Please note that the new layers added now are fully trainable by default.


2. Load the model on GPU

We can use a single GPU or multiple GPUs (if we have them) using DataParallel from PyTorch. Here is what we can use to detect the GPU as well as the number of GPUs to load the model on GPU. Right now I am training my models on dual Titan RTX GPUs.

# Whether to train on a gpu
train_on_gpu = cuda.is_available()
print(f'Train on gpu: {train_on_gpu}')
# Number of gpus
if train_on_gpu:
    gpu_count = cuda.device_count()
    print(f'{gpu_count} gpus detected.')
    if gpu_count > 1:
        multi_gpu = True
        multi_gpu = False
if train_on_gpu:
    model ='cuda')
if multi_gpu:
    model = nn.DataParallel(model)


3. Define criterion and optimizers

One of the most important things to notice when you are training any model is the choice of loss-function and the optimizer used. Here we want to use categorical cross-entropy as we have a multiclass classification problem and the Adam optimizer, which is the most commonly used optimizer. But since we are applying a LogSoftmax operation on the output of our model, we will be using the NLL loss.

from torch import optim
criteration = nn.NLLLoss()
optimizer = optim.Adam(model.parameters())


4. Training the model

Below you’ll find the full code used to train the model. It might look pretty big on its own, but essentially what we are doing is as follows:

  • Start running epochs. In each epoch-
  • Set the model mode to train using model.train().
  • Loop through the data using the train dataloader.
  • Load your data to the GPU using the data, target = data.cuda(), target.cuda() command
  • Set the existing gradients in the optimizer to zero using optimizer.zero_grad()
  • Run the forward pass through the batch using output = model(data)
  • Compute loss using loss = criterion(output, target)
  • Backpropagate the losses through the network using loss.backward()
  • Take an optimizer step to change the weights in the whole network using optimizer.step()
  • All the other steps in the training loop are just to maintain the history and calculate accuracy.
  • Set the model mode to eval using model.eval().
  • Get predictions for the validation data using valid_loader and calculate valid_loss and valid_acc
  • Print the validation loss and validation accuracy results every print_every epoch.
  • Save the best model based on validation loss.
  • Early Stopping: If the cross-validation loss doesn’t improve for max_epochs_stop stop the training and load the best available model with the minimum validation loss.

Here is the output from running the above code. Just showing the last few epochs. The validation accuracy started at ~55% in the first epoch, and we ended up with a validation accuracy of ~90%.


And here are the training curves showing the loss and accuracy metrics:

Training curves


Inference and Model Results

When using our model, we want our results in a variety of different ways. For one, we require test accuracies and confusion matrices. All of the code for creating these results is in the code notebook.

1. Test Results

The overall accuracy of the test model is:

Overall Accuracy: 88.65 %

Here is the confusion matrix for results on the test dataset:


We can also look at the category-wise accuracies. I have also added the train counts to see the results from a new perspective.


2. Visualizing Predictions for a Single Image

For deployment purposes, it helps to be able to get predictions for a single image. You can get the code from the notebook.


3. Visualizing Predictions for a Category

We can also see the category-wise results for debugging purposes and presentations.


4. Test results with Test Time Augmentation

We can also do test time augmentation to increase our test accuracy. Here I am using a new test data loader and transforms:

# Image transformations
tta_random_image_transforms = transforms.Compose([
        transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)),
        transforms.CenterCrop(size=224), # Image net standards
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225]) # Imagenet standards
# Datasets from folders
ttadata = {
    datasets.ImageFolder(root=testdir, transform=tta_random_image_transforms)
# Dataloader iterators
ttadataloader = {
    'test': DataLoader(ttadata['test'], batch_size=512, shuffle=False,num_workers=10)

We can then get the predictions on the test set using the below function:

In the function above, I am applying the tta_random_image_transforms to each image 5 times before getting its prediction. The final prediction is the average of all five predictions. When we use TTA over the whole test dataset, we noticed that the accuracy increased by around 1%.

TTA Accuracy: 89.71%

Also, here is the results for TTA compared to normal results category wise:


In this small dataset, the TTA might not seem to add much value, but I have noticed that it adds value with big datasets.




In this post, I talked about the end to end pipeline for working on a multiclass image classification project using PyTorch and transfer learning. We worked on creating some readymade code to train a model using transfer learning, visualized the results, used test time augmentation, and got predictions for a single image in order to deploy our model when needed using any tool like Streamlit.

You can find the complete code for this post on Github.

Subscribe to our newsletter for more technical articles
The Author
Rahul Agarwal

Rahul is a data scientist currently working with Facebook. He enjoys working with data-intensive problems and is constantly in search of new ideas to work on. Contact him on Twitter: @MLWhiz


    Sign up to our newsletter for fresh developments from the world of training data. Lionbridge brings you interviews with industry experts, dataset collections and more.