UNet

For computers, these images are nothing but matrices and understanding the nuances behind these matrices has been an obsession for many mathematicians for years.

But after the emergence of artificial intelligence and particularly CNN architectures, the research has made progress like never before.

Many problems which are previously considered untouchable are now showing astounding results.

One such problem is the image segmentation.

In Image Segmentation, the machine has to partition the image into different segments, each of them representing a different entity.

Image Segmentation ExampleAs you can see above, how the image turned into two segments, one represents the cat and the other background.

Image segmentation is useful in many fields from self-driving cars to satellites.

Perhaps the most important of them all is medical imaging.

The subtleties in medical images are quite complex and sometimes even challenging for trained physicians.

A machine that can understand these nuances and can identify necessary areas can make a profound impact in medical care.

Convolutional Neural Networks gave decent results in easier image segmentation problems but it hasn't made any good progress on complex ones.

That’s where UNet comes in the picture.

UNet was first designed especially for medical image segmentation.

It showed such good results that it used in many other fields after.

In this article, we’ll talk about why and how UNet works.

If you don’t know intuition behind CNN, please read this first.

You can check out UNet in action here.

The Intuition Behind UNetThe main idea behind CNN is to learn the feature mapping of an image and exploit it to make more nuanced feature mapping.

This works well in classification problems as the image is converted into a vector which used further for classification.

But in image segmentation, we not only need to convert feature map into a vector but also reconstruct an image from this vector.

This is a mammoth task because it’s a lot tougher to convert a vector into an image than vice versa.

The whole idea of UNet is revolved around this problem.

While converting an image into a vector, we already learned the feature mapping of the image so why not use the same mapping to convert it again to image.

This is the recipe behind UNet.

Use the same feature maps that are used for contraction to expand a vector to a segmented image.

This would preserve the structural integrity of the image which would reduce distortion enormously.

Let’s understand the architecture more briefly.

UNet ArchitectureHow UNet WorksUNet ArchitectureThe architecture looks like a ‘U’ which justifies its name.

This architecture consists of three sections: The contraction, The bottleneck, and the expansion section.

The contraction section is made of many contraction blocks.

Each block takes an input applies two 3X3 convolution layers followed by a 2X2 max pooling.

The number of kernels or feature maps after each block doubles so that architecture can learn the complex structures effectively.

The bottommost layer mediates between the contraction layer and the expansion layer.

It uses two 3X3 CNN layers followed by 2X2 up convolution layer.

But the heart of this architecture lies in the expansion section.

Similar to contraction layer, it also consists of several expansion blocks.

Each block passes the input to two 3X3 CNN layers followed by a 2X2 upsampling layer.

Also after each block number of feature maps used by convolutional layer get half to maintain symmetry.

However, every time the input is also get appended by feature maps of the corresponding contraction layer.

This action would ensure that the features that are learned while contracting the image will be used to reconstruct it.

The number of expansion blocks is as same as the number of contraction block.

After that, the resultant mapping passes through another 3X3 CNN layer with the number of feature maps equal to the number of segments desired.

Loss calculation in UNetWhat kind of loss one would use in such an intrinsic image segmentation?.Well, it is defined simply in the paper itself.

The energy function is computed by a pixel-wise soft-max over the final feature map combined with the cross-entropy loss functionUNet uses a rather novel loss weighting scheme for each pixel such that there is a higher weight at the border of segmented objects.

This loss weighting scheme helped the U-Net model segment cells in biomedical images in a discontinuous fashion such that individual cells may be easily identified within the binary segmentation map.

First of all pixel-wise softmax applied on the resultant image which is followed by cross-entropy loss function.

So we are classifying each pixel into one of the classes.

The idea is that even in segmentation every pixel have to lie in some category and we just need to make sure that they do.

So we just converted a segmentation problem into a multiclass classification one and it performed very well as compared to the traditional loss functions.

UNet ImplementationI implemented the UNet model using Pytorch framework.

You can check out the UNet module here.

Images for segmentation of optical coherence tomography images with diabetic macular edema are used.

You can checkout UNet in action here.

The UNet module in the above code represents the whole architecture of UNet.

contraction_block and expansive_block are used to create the contraction section and the expansion section respectively.

The function crop_and_concat appends the output of contraction layer with the new expansion layer input.

The training part can be written asunet = Unet(in_channel=1,out_channel=2)#out_channel represents number of segments desiredcriterion = torch.

nn.

CrossEntropyLoss()optimizer = torch.

optim.

SGD(unet.

parameters(), lr = 0.

01, momentum=0.

99)optimizer.

zero_grad() outputs = unet(inputs)# permute such that number of desired segments would be on 4th dimensionoutputs = outputs.

permute(0, 2, 3, 1)m = outputs.

shape[0]# Resizing the outputs and label to caculate pixel wise softmax lossoutputs = outputs.

resize(m*width_out*height_out, 2)labels = labels.

resize(m*width_out*height_out)loss = criterion(outputs, labels)loss.

backward()optimizer.

step()ConclusionImage segmentation is an important problem and every day some new research papers are published.

UNet contributed significantly in such research.

Many new architectures are inspired by UNet.

But still, there is so much to explore.

There are so many variants of this architecture in the industry and hence it is necessary to understand the first one to understand them better.

So if you have any doubts please comment below or refer to the resources page.

ResourcesUNet original paperUNet Pytorch implementationUNet Tensorflow implementationMore about Semantic SegmentationPractical Image SegmentationAuthor’s NoteThis tutorial is the second article in my series of DeepResearch articles.

If you like this tutorial please let me know in comments and if you don’t please let me know in comments more briefly.

If you have any doubts or any criticism just flood the comments with it.

I’ll reply as soon as I can.

If you like this tutorial please share it with your peers.

.. More details

Leave a Reply