Deep Learning and Medical Image Analysis for Malaria Detection with fastai

# You do not want to phantasize data.

# Warping, for example, will let your images badly distorted, # so don't do it!# This dataset is big, so don't rotate the images either.

# Lets stick to flipping.

tfms = get_transforms(max_rotate=None, max_warp=None, max_zoom=1.

0)# Create the DataBunch!# Remember that you'll have images that are bigger than 128×128 # and images that are smaller, o squish them all in order to # occupy exactly 128×128 pixels.

data = ImageDataBunch.

from_folder(path, ds_tfms=tfms, size=size, resize_method=ResizeMethod.

SQUISH, valid_pct = 0.

2, bs=bs)#print('Transforms = ', len(tfms))# Save the DataBunch in case the training goes south.

# so you won't have to regenerate it.

# Remember: this DataBunch is tied to the batch size you selected.

data.

save('imageDataBunch-bs-'+str(bs)+'-size-'+str(size)+'.

pkl')# Show the statistics of the Bunch.

print(data.

classes)dataThe print() will output the transforms and the classes:Transforms = 2['Parasitized', 'Uninfected']The last line, ‘data’ will simply output the return value of the ImageDataBunch instance:ImageDataBunch;Train: LabelList (22047 items)x: ImageListImage (3, 128, 128),Image (3, 128, 128),Image (3, 128, 128),Image (3, 128, 128),Image (3, 128, 128)y: CategoryListUninfected,Uninfected,Uninfected,Uninfected,UninfectedPath: data;Valid: LabelList (5511 items)x: ImageListImage (3, 128, 128),Image (3, 128, 128),Image (3, 128, 128),Image (3, 128, 128),Image (3, 128, 128)y: CategoryListParasitized,Uninfected,Parasitized,Uninfected,UninfectedPath: data;Test: NoneLook at your DataBunch to see if the augmentations are acceptable…data.

show_batch(rows=5, figsize=(15,15))Training: resnet34If you do not know what to use, it is a good choice to start with a Residual Network with 34 layers.

Not too small and not too big… In the tutorials listed above the authors used:a custom, but small, ResNet (PyImagesearch)a VGG19 (TowardsDataScience)We will employ off-the-shelf fast.

ai residual networks (ResNets).

Let’s create our first network:learn = cnn_learner(data, models.

resnet34, metrics=error_rate)learn.

modelThe last line will output the architecture of the network as a text stream.

It will look like this:Sequential( (0): Sequential( (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) (1): BatchNorm2d(64, eps=1e-05, momentum=0.

1, affine=True, track_running_stats=True) (2): ReLU(inplace) (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) (4): Sequential( (0): BasicBlock( (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.

1, affine=True, track_running_stats=True) (relu): ReLU(inplace) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.

1, affine=True, track_running_stats=True) ) (1): BasicBlock( (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.

1, affine=True, track_running_stats=True) (relu): ReLU(inplace) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.

1, affine=True, track_running_stats=True) ) (2): BasicBlock( (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.

1, affine=True, track_running_stats=True) (relu): ReLU(inplace) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.

1, affine=True, track_running_stats=True) ) ).

and so on.

Even a “small” network such as a ResNet34 is still very large.

Do not bother trying to understand the output.

You can read more about residual networks later.

There are lots of introductory postings about ResNets.

Training strategyHere comes one of the great differentials of fast.

ai: Easy-to-use HYPOs (hyperparameter optimization strategies).

Hyperparameter optimization is a somewhat arcane subdiscipline of CNNs.

This is because CNNs have so much parameters and trying to choose which we will change by setting some non-standard values in order to provide for a better performance of our network, is a very complex issue and a study per se.

The fast.

ai library provides a few very advanced yet easy-to-use HYPOs that help immensely in implementing much better CNNs in a fast way.

We will employ the fit1cycle method developed by Leslie N.

Smith — see below for details:https://docs.

fast.

ai/callbacks.

one_cycle.

htmlA disciplined approach to neural network hyper-parameters: Part 1 — learning rate, batch size, momentum, and weight decay — https://arxiv.

org/abs/1803.

09820Super-Convergence: Very Fast Training of Residual Networks Using Large Learning Rates — https://arxiv.

org/abs/1708.

07120Since this method is fast, we will employ only 10 epochs in this first Transfer Learning stage.

We will also save the network each epoch, if the performance gets better: https://docs.

fast.

ai/callbacks.

html#SaveModelCallbacklearn.

fit_one_cycle(10, callbacks=[SaveModelCallback(learn, every='epoch', monitor='accuracy', name='malaria-1')])# Save it!learn.

save('malaria-stage-1')# Deploy it!exportStageTo(learn, path)This will produce a table like this as an output:The table above shows an accuracy of 96.

4% for the validation set, and this with transfer learning only!.The error_rate fast.

ai shows you will always see the one associated with the training set.

As a comparison, consider that Adrian Rosebrock achieved 97% with his custom ResNet in the PyImagesearch posting.

Results for ResNet34Let’s see what additional results we have got.

We will first look which were the instances the model confused with one another.

We will try to see if what the model predicted was reasonable or not.

In this case the mistakes look reasonable (none of the mistakes seems obviously naive).

This is an indicator that our classifier is working correctly.

Furthermore, we will plot the confusion matrix.

This is also very simple in fast.

ai.

interp = ClassificationInterpretation.

from_learner(learn)losses,idxs = interp.

top_losses()len(data.

valid_ds)==len(losses)==len(idxs)Look at your 9 worst results (without employing a heatmap, at first):interp.

plot_top_losses(9, figsize=(20,11), heatmap=False)Now, do the same but highlight using a heatmap what induced the wrong classification:interp.

plot_top_losses(9, figsize=(20,11), heatmap=True)Show the confusion Matrixfast.

ai’s ClassificationInterpretation class has a high-level instance method that allows for fast and easy plotting of confusion matrices that show you in a better way how good the CNN has performed.

It doesn’t make so much sense with only two classes, but we’ll do it anyway: it generates beautiful pictures… You can set the size and the resolution of the resulting plot.

We’ll set 5×5 inches with 100 dpi.

interp.

plot_confusion_matrix(figsize=(5,5), dpi=100)Show your learning curve:It is interesting to look at the learning and the validation curves.

It will show us if the network has learned in a steady way or if it oscillated (which can indicate bad quality data) and if we have a result that is OK or if we are overfitting or underfitting our network.

Again fast.

ai has high-level methods that’ll help us.

Each fast.

ai cnn_learner has an automatically created instance of a Recorder.

A recorder records epoch, loss, opt and metric data during training.

The plot_losses() method will create a graphic with the train and validation curves:learn.

recorder.

plot_losses()This result looks really so good that it doesn’t make sense to fine-tune the network.

If we look attentively, we’ll see that the validation loss gets worse than the training loss at about 500 batches, indicating that the network is probably starting to overfit at this point.

This is an indication that we have definitively trained enough, at least for this ResNet model.

The overfitting we observed could be an indication that we are employing a network model that is an overkill for the complexity of the data, meaning that we are training a network that learns individual instances and not a generalization for the dataset.

One very simple and practical way to test this hypothesis is to try to learn the dataset with a simpler network and see what happens.

Let’s employ a smaller network and try it all again…ResNet18This network is much simpler.

Let’s see if it does the job.

The ResNet18 is much smaller, so we’ll have more GPU RAM for us.

We will create the DataBunch again, this time with a bigger batch size…# Limit your augmentations: it's medical data!.# You do not want to phantasize data.

# Warping, for example, will let your images badly distorted, # so don't do it!# This dataset is big, so don't rotate the images either.

# Lets stick to flipping.

tfms = get_transforms(max_rotate=None, max_warp=None, max_zoom=1.

0)# Create the DataBunch!# Remember that you'll have images that are bigger than 128×128 # and images that are smaller, so squish them to occupy # exactly 128×128 pixels.

data = ImageDataBunch.

from_folder(path, ds_tfms=tfms, size=size, resize_method=ResizeMethod.

SQUISH, valid_pct = 0.

2, bs=512)#print('Transforms = ', len(tfms))# Save the DataBunch in case the training goes south.

so you won't have to regenerate it.

# Remember: this DataBunch is tied to the batch size you selected.

data.

save('imageDataBunch-bs-'+str(bs)+'-size-'+str(size)+'.

pkl')# Show the statistics of the Bunch.

print(data.

classes)dataObserve that we stuck with our valid_pct = 0.

2: we will still let fast.

ai randomly choose 20% of the dataset as validation set.

The code above will output something like this:Transforms = 2['Parasitized', 'Uninfected']and:ImageDataBunch;Train: LabelList (22047 items)x: ImageListImage (3, 128, 128),Image (3, 128, 128),Image (3, 128, 128),Image (3, 128, 128),Image (3, 128, 128)y: CategoryListUninfected,Uninfected,Uninfected,Uninfected,UninfectedPath: data;Valid: LabelList (5511 items)x: ImageListImage (3, 128, 128),Image (3, 128, 128),Image (3, 128, 128),Image (3, 128, 128),Image (3, 128, 128)y: CategoryListParasitized,Uninfected,Parasitized,Uninfected,ParasitizedPath: data;Test: NoneNow, create the learner:learn18 = cnn_learner(data, models.

resnet18, metrics=error_rate)If you Colab environment doesn’t have the pretrained data for the ResNet18, fast.

ai will automatically download it:Downloading: "https://download.

pytorch.

org/models/resnet18-5c106cde.

pth" to /root/.

torch/models/resnet18-5c106cde.

pth46827520it [00:01, 28999302.

58it/s]Look at the model:learn18.

modelThis will list the structure of your net.

It is much smaller than the ResNet34, but still has a lot of layers.

The output will look like this:Sequential( (0): Sequential( (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) (1): BatchNorm2d(64, eps=1e-05, momentum=0.

1, affine=True, track_running_stats=True) (2): ReLU(inplace) (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) (4): Sequential( (0): BasicBlock( (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.

1, affine=True, track_running_stats=True) (relu): ReLU(inplace) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.

1, affine=True, track_running_stats=True) ) (1): BasicBlock( (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.

1, affine=True, track_running_stats=True) (relu): ReLU(inplace) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.

1, affine=True, track_running_stats=True) ) ) (5): Sequential( (0): BasicBlock( (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.

1, affine=True, track_running_stats=True) (relu): ReLU(inplace) (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.

1, affine=True, track_running_stats=True) (downsample): Sequential( (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False) (1): BatchNorm2d(128, eps=1e-05, momentum=0.

1, affine=True, track_running_stats=True) ) ).

and so on.

Let’s train itWe will again use the fit_one_cycle HYPO training strategy.

Limit the training to 10 epochs to see how this smaller network behaves:learn18.

fit_one_cycle(10, callbacks=[SaveModelCallback(learn, every='epoch', monitor='accuracy', name='malaria18-1')])# Save the networklearn18.

save('malaria18-stage-1')# Deploy it alsoexportStageTo(learn18, path)This table shows that the network learned to an accuracy of roughly 96.

1% and suggests that the network should not be trained further: the loss between epoch #8 and #9 shows a 0.

005 decrease but the accuracy has remained the same, suggesting that the network has started overfitting.

Let’s generate a ClassificationInterpretation and look at the confusion matrix and the loss curves.

interp = ClassificationInterpretation.

from_learner(learn18)losses,idxs = interp.

top_losses()interp.

plot_confusion_matrix(figsize=(5,5), dpi=100)This confusion matrix is slightly, but only very slightly worse than the one we generated for the ResNet34.

Is the ResNet18 less well suited for this problem?Let’s look at the losses:This graph shows that the ResNet18 started overfitting a bit after about 290 batches.

Remember that our bs is 512 here and was 256 for the ResNet34.

Let’s see if we can make this better fine-tuning the network.

Fine-Tune it!Here we will introduce another fast.

ai HYPO: automatically chosen variable learning rates.

We will let fast.

ai choose which learning rate to use for each epoch and each layer, providing a range of learning rates we consider adequate.

We will train the network for 30 epochs.

# Unfreeze the networklearn18.

unfreeze()# Learning rates range: max_lr=slice(1e-4,1e-5)learn18.

fit_one_cycle(30, max_lr=slice(1e-4,1e-5), callbacks=[SaveModelCallback(learn, every='epoch', monitor='accuracy', name='malaria18')])# Save as stage 2.

learn18.

save('malaria18-stage-2')# DeployexportStageTo(learn18, path)97% accuracy!.That is exactly what Adrian Rosebrock also achieved with his custom Keras ResNet implementation in the PyImagesearch posting, which presents the best accuracy results among the three references above.

The validation loss, however, was becoming worse for the last epochs.

This indicates that we have been overfitting from about epoch #20 on.

If you want to deploy this network, I would suggest you to load the results from epoch 20 and generate a deployment network.

It does not seem to become better, not with this network.

Look at the resultsinterp = ClassificationInterpretation.

from_learner(learn18)losses,idxs = interp.

top_losses()interp.

plot_confusion_matrix(figsize=(5,5), dpi=100)This is better than what we’ve had before.

Let’s look at the loss curves:Here we see that the network seems to start to overfit after 500 batches, which would confirm our suspicion inferred from the results table above.

If you look at the curves above, you’ll see that the validation loss starts to grow in the last third of the training, suggesting that this part of the training only overfitted the network.

What have we learned?Compared to other approaches, where the authors respectively employ pure Keras and TensorFlow.

Keras to solve the problem, with fast.

ai we were able to solve the same malaria blood smear classification problem employing much less code while using high-level hyperparameter optimization strategies that allowed us to train much faster.

At the same time, a set of high level functions allows us also to easily inspect the results both as tables and as graphs.

With off-the-shelf residual network models, pre-trained on ImageNet, provided by fast.

ai we obtained accuracy results that are 1% better than two of the three previous works above (incluiding the scientific paper published on PeerJ) and equal to the best performing of the works above.

This shows that fast.

ai is a very promising alternative to more traditional CNN frameworks, especially if the task at hand is is a “standard” deep learning task such as image classification, object detection or semantic segmentation, that can be solved fine-tuning off-the-shelf pre-trained network models.

.

. More details

Leave a Reply