Fine tuning for image classification using Pytorch

I’m a tf/keras fan but the number of models covered by tf/keras is quite low whereas in pytorch you can find state-of-the-art models very easily.

There are plenty of resources available in the latter.

Moreover, there are very very few pre-trained weights available in tf/keras and its not the same for pytorch.

Just take a look at the repository here: https://github.



Pytorch also offers a much easier and independent interface compared to tensorflow/keras.

I still find tf/keras better when it comes to deploying on a production environment and for NLP problems.

Enough of the introduction.

Let’s start the business of fine-tuning.

For this article, we will be using the data from: https://www.



The task is to recognize artwork from The Metropolitan Museum of Art in New York (The Met).

It is a multi-label, multi-class problem.

Every image can have one or more class associated with it as shown below:on the left we have image ids and on the right the classes associated with that image idTo fine-tune using pytorch, we need the following:A dataset generatorThe model itselfA training/validation loopA training codeAn inference codeLet’s start with a data generator for training data.

Let’s look at the class CollectionsDataset:__init__:csv_file: the path to the CSV as shown aboveroot_dir: directory where images are located.

In our case its in “.

/input/train/”num_classes: it is the total number of classestransform: what kind of transformation we would like to have on the image.

we will come to transformations later.

__len__:This should return the length of all the data samples that we have.

In our case, it is the length of the supplied data frame.

__getitem__:This function takes an argument called “index”.

Given an index, one has to return an image and a label(s) (if its training data) or just an image (if its test data).

It must be noted that the index starts from 0.

Similarly, one can write a generator for test data:Here, we do not need any labels, as we don’t have any.

Now, we need a model!.For the model we will use the python package “pretrainedmodels” as mentioned above.

Now, what we have done differently here is using AdaptiveAvgPool2d.

To support the network learn images with sizes which are not standard (224×224), we need to replace the avg_pool layer of the model with the AdaptiveAvgPool2d layer.

As one can notice, we have also added couple of batch normalization layers, dropout, and a linear layer.

Our output is a linear layer with 1103 output features, which is the number of classes in our problem.

Please note that this part of the code is very crucial.

If you change the model, you have to see what needs to be adjusted for your image classification problem.

And that’s it.

We are done with defining the model.

Now its time to move on to the training loop.

Pytorch does not come with “fit” and “predict” functions.

We need a training loop!Now for training, we need the dataset loaders from torch, an optimizer and a learning rate scheduler.

Please note that we have defined the loss criterion inside the training function.

The loss function we are using here is Binary Cross Entropy with Logits (BCEWithLogitsLoss).

Let’s define some stuff before the training:As you can see we are using transforms from torchvision to define some basic augmentations.

You can use any other library to make the augmentations and plug them in the transforms.

Compose function.

The next thing that we did was initialization of CollectionsDataset class.

Once that is ready, we have to use the torch DataLoader that generates batches of images from CollectionsDataset class.

The next step is to define the optimizer and scheduler.

Here plist defines the layers we want to fine-tune.

As you can see here, we have taken layer4 and last_linear layer with different learning rates for fine-tuning.

We have kept the other layers as it is.

We have used Adam optimizer and a simple Step Learning Rate.

Now, we are ready to train the model!Making inference from the model is also quite easy:As you can see, I am not doing any augmentations for test data :)The full code is available here: https://github.


It uses se_resnext_101_32x4d and has some other goodies like validation data, NVIDIA Apex to support mixed precision training, etc.

I will cover that sometime later.

Join the Data Science India (DSI) Slack team: https://ds-india.


com/Add me on LinkedIn: bit.

ly/thakurabhishek.. More details

Leave a Reply