Understanding U-Net (2024)

Understanding U-Net (3)

Table of content
1. The task at hand
2. Encoder-Decoder
3. Skip connections
4. Implementation details
a. Loss function
b. Up-sampling methods
c. To pad or not to pad?
5. U-Net in action

U-Net is developed for the task of semantic segmentation. When a neural network is fed images as inputs, we can choose to classify objects either generally or by instances. We can predict what object is in the image (image classification), where all objects are located(image localization/semantic segmentation), or where individual objects are located (object detection/instance segmentation). The figure below shows differences between these computer vision tasks. To simplify the matter, we only consider classification for only one class and one label (binary classification).

Understanding U-Net (4)

In classification task, we output a vector of size k, where k is the number of classes. In detection tasks, we need to output the vector x, y, height, width, class, which define bounding boxes. But in segmentation tasks, we need to output an image with the same dimension as the original input. This represents quite an engineering challenge: How can a neural network extract relevant features from the input image, and then project them into segmentation masks?

If you are not familiar with encoder-decoder, I recommend you to read this article:

The reason why encoder-decoders are relevant because they produce outputs similar to what we want: output that have the same dimension as the input. Can we apply the concept of encoder-decoder to images segmentation? We can generate a one-dimensional binary mask and train the network using cross-entropy loss. Our network consists of two parts: the encoder which extracts relevant features from images, and the decoder part which takes the extracted features and reconstructs a segmentation mask.

Understanding U-Net (5)

In the encoder part, I used convolutional layers, followed by ReLU and MaxPool as the feature extractors. In the decoder part, I transposed convolution to increase the size of the feature map and decrease the number of channels. I used padding to keep the size of the feature maps the same after convolution operations.

One thing you might notice is that unlike classification networks, this network doesn’t have a fully connected / linear layer. This is an example of a fully convolutional network (FCN). FCN has been shown to work well on segmentation tasks, starting with Shelhamer et al. paper “Fully Convolutional Networks for Semantic Segmentation” [1].

However, this network has a problem. As we expand the number of layers in the encoder and decoder layers, we effectively “shrink” the feature map more and more. As such, the encoder may discard features that are more detailed in favor of more general features. If we are dealing with medical image segmentation, every pixels classified as diseased/normal can be important. How can we make sure that this encoder-decoder network take in features that are both general and detailed?

Because deep neural networks can “forget” certain features as it pass information through successive layers, skip connections can reintroduce them to make learning stronger. Skip connection was introduced in Residual Network (ResNet) and showed classification improvements as well as smoother learning gradients. Inspired by this mechanism, we can add skip connections to U-Net such that every decoder incorporate the feature map from its corresponding encoder. This is a defining feature of U-Net.

Understanding U-Net (6)

U-Net has two defining qualities:

  1. An encoder-decoder network that extract more general features the deeper it goes.
  2. A skip connection that reintroduces detailed features into the decoder.

These two qualities means that U-Net can segment using features that are both detailed and general. U-Net was originally introduced for biomedical image processing, where the accuracy of segmentation is very important [2].

Understanding U-Net (7)

The previous sections gave a very general overview of U-Net and why it works. However, details stand between general understanding and actual implementation. Here, I gave an overview of some implementation choices for U-Net.

Loss functions

Because the target are binary masks (pixel value is 1 when pixel contains object), a common loss function to compare output with the ground truth is the categorical cross-entropy loss (or binary cross-entropy in the case of single label).

Understanding U-Net (8)

In the original U-Net paper, an additional weight is added to the loss function. This weight parameter does two things: it compensates for class imbalance, and it gives higher importance to segmentation borders. In many implementations of U-Net that I’ve found online, this additional weight factor is not often used.

Another loss function commonly seen is the dice loss. Dice loss measures how similar two set of images are by comparing their intersection area with their total area. Note that dice loss is not the same as Intersection-over-Union (IOU). They measures similar things, but they have different denominator. The higher the dice coefficient, the lower the dice loss.

Understanding U-Net (9)

Here, an epsilon term is added to avoid division by 0 (epsilon is typically 1). Some implementations, such as the one in Milletari et al., squared the pixel values in the denominator before summing them [3]. Compared to cross-entropy loss, dice loss is very robust against imbalanced segmentation mask, which is typical in biomedical image segmentation tasks.

Up-sampling methods

Another detail is the choice of up-sampling method for the decoder. Here are some common methods:

Bi-linear interpolation. This method predicts the output pixel using linear interpolation. Usually, up-scaling through this method is followed by a convolution layer.

Max-Unpooling. This method is the opposite of Max-pooling. It uses the indices of the maxpool operation and populate these indices with maximum value. All other values are set to 0. Typically, a convolution layer follows max-unpooling to “smooth-out” all the missing values.

Deconvolution / Transpose convolution. Many blog post has been written about deconvolution. I recommend this article as a good visual guide.

Deconvolution has two steps: add padding to each pixel in the original image, then apply convolution. In the original U-Net, a 2x2 transposed convolution with stride 2 is used to change both the spatial resolution and the channel depth.

Pixel Shuffling. This method was seen in super-resolution networks such as SRGAN. To start, we use convolution to go from C x H x W feature map to (Cr^2) x H x W. Then, pixel shuffle will take this and “reshuffle” the pixels in a mosaic fashion to produce output of size C x (Hr) x (Wr).

To pad or not to pad?

Convolution layer, with a kernel larger than 1x1 and without padding, will produce output that is smaller than the input. This is a problem for U-Net. Recall in the U-Net figure in earlier section, we concatenate part of the image with its decoded counterpart. If we don’t use padding, then the decoded image will have smaller spatial dimension compared to the encoded image.

However, the original U-Net paper didn’t use padding. Even though no justification were given, I believe it was because the authors didn’t want to introduce segmentation errors at the image margin. Instead, they center-cropped the encoded image before concatenation. For an image with input size 572 x 572, the output will be 388 x 388, a ~50% loss. If you want to run U-Net without padding, you need to run it multiple times on overlapping tiles to get the full segmentation image.

Here, I implemented a very simple U-Net-like network to segment only ellipses. The U-Net is only 3 layers deep, uses same padding, and binary cross entropy loss. More complicated networks can use more convolution layers at each resolution, or extending the depth as see fit.

import torch
import numpy as np
import torch.nn as nn

class EncoderBlock(nn.Module):
# Consists of Conv -> ReLU -> MaxPool
def __init__(self, in_chans, out_chans, layers=2, sampling_factor=2, padding="same"):
super().__init__()
self.encoder = nn.ModuleList()
self.encoder.append(nn.Conv2d(in_chans, out_chans, 3, 1, padding=padding))
self.encoder.append(nn.ReLU())
for _ in range(layers-1):
self.encoder.append(nn.Conv2d(out_chans, out_chans, 3, 1, padding=padding))
self.encoder.append(nn.ReLU())
self.mp = nn.MaxPool2d(sampling_factor)
def forward(self, x):
for enc in self.encoder:
x = enc(x)
mp_out = self.mp(x)
return mp_out, x

class DecoderBlock(nn.Module):
# Consists of 2x2 transposed convolution -> Conv -> relu
def __init__(self, in_chans, out_chans, layers=2, skip_connection=True, sampling_factor=2, padding="same"):
super().__init__()
skip_factor = 1 if skip_connection else 2
self.decoder = nn.ModuleList()
self.tconv = nn.ConvTranspose2d(in_chans, in_chans//2, sampling_factor, sampling_factor)

self.decoder.append(nn.Conv2d(in_chans//skip_factor, out_chans, 3, 1, padding=padding))
self.decoder.append(nn.ReLU())

for _ in range(layers-1):
self.decoder.append(nn.Conv2d(out_chans, out_chans, 3, 1, padding=padding))
self.decoder.append(nn.ReLU())

self.skip_connection = skip_connection
self.padding = padding
def forward(self, x, enc_features=None):
x = self.tconv(x)
if self.skip_connection:
if self.padding != "same":
# Crop the enc_features to the same size as input
w = x.size(-1)
c = (enc_features.size(-1) - w) // 2
enc_features = enc_features[:,:,c:c+w,c:c+w]
x = torch.cat((enc_features, x), dim=1)
for dec in self.decoder:
x = dec(x)
return x

class UNet(nn.Module):
def __init__(self, nclass=1, in_chans=1, depth=5, layers=2, sampling_factor=2, skip_connection=True, padding="same"):
super().__init__()
self.encoder = nn.ModuleList()
self.decoder = nn.ModuleList()

out_chans = 64
for _ in range(depth):
self.encoder.append(EncoderBlock(in_chans, out_chans, layers, sampling_factor, padding))
in_chans, out_chans = out_chans, out_chans*2

out_chans = in_chans // 2
for _ in range(depth-1):
self.decoder.append(DecoderBlock(in_chans, out_chans, layers, skip_connection, sampling_factor, padding))
in_chans, out_chans = out_chans, out_chans//2
# Add a 1x1 convolution to produce final classes
self.logits = nn.Conv2d(in_chans, nclass, 1, 1)

def forward(self, x):
encoded = []
for enc in self.encoder:
x, enc_output = enc(x)
encoded.append(enc_output)
x = encoded.pop()
for dec in self.decoder:
enc_output = encoded.pop()
x = dec(x, enc_output)

# Return the logits
return self.logits(x)

Understanding U-Net (10)

As we can see, the U-Net can produce acceptable segmentation even without skip connections, but the added skip connections can introduce finer details (see the join between the two ellipses on the right).

If I were to explain U-Net in one sentence, it would be that U-Net is like an encoder-decoder for images, but with skip connections to make sure fine details are not lost. U-Net are used often in many segmentation tasks, and in recent years have made their ways onto image generation tasks as well.

If you want to see the code that I used to produce figures and train my U-Net, here is the Github link. Happy coding!

If you enjoy reading this article and would like to read more similar ones in the future, consider following me on Medium or Linkedin.

References:

[1] Long, Jonathan, Evan Shelhamer, and Trevor Darrell. “Fully convolutional networks for semantic segmentation.” Proceedings of the IEEE conference on computer vision and pattern recognition. 2015.

[2] Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. “U-net: Convolutional networks for biomedical image segmentation.” International Conference on Medical image computing and computer-assisted intervention. Springer, Cham, 2015.

[3] Milletari, Fausto, Nassir Navab, and Seyed-Ahmad Ahmadi. “V-net: Fully convolutional neural networks for volumetric medical image segmentation.” 2016 fourth international conference on 3D vision (3DV). IEEE, 2016.

Understanding U-Net (2024)
Top Articles
Latest Posts
Article information

Author: Moshe Kshlerin

Last Updated:

Views: 6734

Rating: 4.7 / 5 (57 voted)

Reviews: 80% of readers found this page helpful

Author information

Name: Moshe Kshlerin

Birthday: 1994-01-25

Address: Suite 609 315 Lupita Unions, Ronnieburgh, MI 62697

Phone: +2424755286529

Job: District Education Designer

Hobby: Yoga, Gunsmithing, Singing, 3D printing, Nordic skating, Soapmaking, Juggling

Introduction: My name is Moshe Kshlerin, I am a gleaming, attractive, outstanding, pleasant, delightful, outstanding, famous person who loves writing and wants to share my knowledge and understanding with you.