Understanding UNet for Image Segmentation - Tutorial 1.0
This tutorial assumes that you know the following:
- Deep Learning Fundamentals
- Convolutional Neural Networks
- Linear Algebra and Calculus
This tutorial aims to teach you how to understand UNet from scratch and implement UNet with your dataset step by step. We will learn slowly together, step by step by asking questions and with understanding. To help you in the process of learning in a sequential discrete manner, we will distribute the learning into multiple tutorials. Please run your Google Collab and implement it to understand. Please stay with me for some time to understand what's happening inside.
UNet is developed primarily for biomedical segmentation. Wikipedia gives a decent introduction. I will first revise the idea of basic semantic segmentation.
Let's say you are segmenting the image into four classes - tiger (1), ground (2), grass (3), and water (4). Mathematically, every pixel (i,j) in the image is assigned a label one of {1,2,3,4}. So segmentation is a function from the set of pixels of an image M to {1,2,3,4}. If the image dimension is height (H) x width (W), then we want to learn the function f() from the HW dimensional Euclidean space to {1,2,3,4}. In this tutorial, we will learn the function using a convolutional neural network called UNet. Then, a loss function will be used to find the difference between the predicted labels and the actual labels for all the pixels over an image. That loss function will be minimized in a step-by-step optimization algorithm. We will pick a small but important part of the entire process in this tutorial.
We will understand the feedforward structure of the UNet network in this tutorial. We will learn the dimensions of the inputs and outputs of the UNet, which will help you revise your ideas in convolutions, and the operators of convolutional neural networks, and understand how UNet works. If you have trained neural networks before in pytorch, you must be familiar with the following piece of code.
for (i, (x, y)) in enumerate(Loader):
Understand that while the process moves into this state, the data you have has an extra dimension called batch size. In other words, the optimization is done using batches. This batch thingy has interesting mathematics, which is that one average of this batch size-based optimization gives the same value as the original optimization value. We assume the input size to be (B, C, H, W), where B = batch size, C = channels, H = height, and W = width. In the feedforward network, the batch size will be constant. So, while referring to the image and the features, I will only bring in (C, H, W), but in the code output, you will see (B, C, H, W). These are mathematically called "tensors" which are extensions of matrices to higher dimensions with similar matrix operations and something extra. We will be using a proxy image which is a random tensor input. This is just to demonstrate the dimensions through the UNet, to help you understand the network fundamentally. You can use this to understand any network. Here is how you create a random tensor of size (B, C, H, W) in pytorch.
import torch
B = 1 #batch size
C = 1 #channels
H = 1013 #height
W = 1041 #width
image = torch.rand(B,C,H,W)
print(image.shape)
The original paper of UNet of this group uses the following architecture.
It mentions the input size and the dimensions at each step, but you need to know one important thing. "How to build this neural network from scratch?" This skill will help you translate an idea in a research paper into your deep-learning model. We will understand how to build this build, and also print each output dimension. However, this architecture has some missing information, which is later given in the paper in a paragraph. I will refer to that portion of the paragraph with annotations to help us build the network from scratch.
Let's write down the entire architecture piece by piece in a bullet format. Then, we will learn how to combine them into a single compact form using both the architecture (Figure 1), and the colorful paragraph (Figure 2) references.
We will understand the fundamental blocks, and see how they are used to calculate the entire architecture.
Fundamental Blocks
ic = input channels, oc = output channels
- (Input) Input Image
- (Conv1) Convolution (ic, oc, Kernel Size = 3, Stride = 1, Padding = 0)
- (ReLu) ReLu
- (Maxpool) MaxPooling (Kernel Size = 2, Stride = 2)
- (Upsample) Upscample (Scale Factor = 2)
- (Conv2) Convolution (ic, oc, Kernel Size = 2, Stride = 1, Padding = 0)
- (Convf) Convolution (ic, oc, Kernel Size = 1, Stride = 1, Padding = 0)
- (Output) Output Segmentation Masks
From Figure 1, we see that the author mentioned that a few blocks occur together by the arrow index on the right-hand bottom corner. Thus, derived blocks are created to explicitly mention them. Upsample and Maxpool are architectural inverses of each other. There is an actual inverse map of Maxpool, too called MaxUnPool. Haha.
Derived Blocks
- (Conv-Relu) Conv1 + ReLu
- (Up-Conv) Upsample + Conv2
The Full Architecture
Input -> Conv-ReLu x 2 -> Maxpool -> Conv-ReLu x 2 -> Maxpool -> Conv-ReLu x 2 -> Maxpool -> Conv-ReLu x 2 -> Maxpool -> Conv-ReLu x 2 -> Up-Conv -> Conv-ReLu x 2 -> Up-Conv -> Conv-ReLu x 2 -> Up-Conv -> Conv-ReLu x 2 -> Up-Conv -> Conv-ReLu x 2 -> Convf -> Output
Note that Batch Normalization was first developed one month before this work. Hence this was not introduced in UNet's original architecture. But, nowadays, batch normalization is used after the Convolution layer before ReLu in UNet architecture. We will not use them here, because we are just focused on building the paper's architecture, lest we deviate from the core discussion. We will now implement this architecture from scratch using these building blocks in Pytorch.
From the full architecture, it may seem that [Conv-ReLu x 2] and MaxPool or Up-Conv occur together. But, you will find out that it has [Conv-ReLu x 2] five times in the contraction region space, and expansion space both. Also, four instances of Maxpool and Up-Conv are there in the contraction and the expansion sections respectively. Also, we need the outputs of the [Conv-ReLu x 2] layer from the contraction portion to crop and concatenate to the expansion region. This mismatch in the number of these fundamental blocks will naturally lead us to define the following core blocks.
- Conv-ReLu x 2 (1 time)
- MaxPool (it already exists as a block)
- MaxPool + Conv-ReLu x 2 (4 times)
- Up-Conv (used in the next one)
- Up-Conv + Conv-ReLu x 2 (4 times)
- Convf (1) (it already exists as a block, since it is just convolution)
Yes, you will see these 10 blocks being visible in the final UNet structure one by one symmetrically just like U, from where it derived its name.
Let's start building one by one. Before you start, remember to load the following libraries.
import torch
import torch.nn as nn
import torch.nn.functional as F
This one structure will help you understand the details of how to implement each of the blocks in multiple ways, and how to debug for weights if needed.
Conv-ReLu x 2
class ConvReLux2(nn.Module):
def __init__(self, ic, oc):
super(ConvReLux2, self).__init__()
self.convreLux2 = nn.Sequential(
nn.Conv2d(ic, oc, 3, 1, 0),
nn.ReLU(inplace = True),
nn.Conv2d(oc, oc, 3, 1, 0),
nn.ReLU(inplace = True)
)
def forward(self, x):
x = self.convreLux2(x)
return x
model = ConvReLux2(1,64)
model.eval()
output = model(image)
output.shape
Observe that if you run this you will get, which aligns perfectly with Figure 1.
torch.Size([1, 64, 568, 568])
Understand that I am using nn. Sequential here because I do not want any intermediate value here in the forward function. Also, it saves a lot of time. For demonstration purposes, I could have also done this. However, this will help you in understanding the shape of the output at each layer. This is done during debugging often.
class ConvReLux2(nn.Module):
def __init__(self, ic, oc):
super(ConvReLux2, self).__init__()
self.layer1 = nn.Conv2d(ic, oc, 3, 1, 0)
self.layer2 = nn.ReLU(inplace = True)
self.layer3 = nn.Conv2d(oc, oc, 3, 1, 0)
self.layer4 = nn.ReLU(inplace = True)
def forward(self, x):
print(x.shape)
x = self.layer1(x)
print(x.shape)
x = self.layer2(x)
print(x.shape)
x = self.layer3(x)
print(x.shape)
x = self.layer4(x)
print(x.shape)
return x
model = ConvReLux2(1,64)
model.eval()
output = model(image)
This will give you an output of the following. This matches perfectly with Figure 1 again. We are doing it right. Yaay!
torch.Size([1, 1, 572, 572])
torch.Size([1, 64, 570, 570])
torch.Size([1, 64, 570, 570])
torch.Size([1, 64, 568, 568])
torch.Size([1, 64, 568, 568])
Let's now create the rest of the blocks in an optimized way.
MaxPool + Conv-ReLu x 2
class MaxPoolConvReLux2(nn.Module):
def __init__(self, ic, oc):
super(MaxPoolConvReLux2, self).__init__()
self.maxpoolconvrelux2 = nn.Sequential(
nn.MaxPool2d(2, stride = 2),
ConvReLux2(ic, oc)
)
def forward(self, x):
x = self.maxpoolconvrelux2(x)
return x
Now, if you want to see the output of this layer after taking in the input of the previous layer, you can do the following.
model1 = ConvReLux2(1,64)
model1.eval()
output1 = model1(image)
print(output1.shape)
model2 = MaxPoolConvReLux2(64,128)
model2.eval()
output2 = model2(output1)
print(output2.shape)
This will give you the results as follows:
torch.Size([1, 64, 568, 568])
torch.Size([1, 128, 280, 280])
Fun, right? This is the same ditto as the Figure 1 input output shapes. However, doing this over and over again is time-consuming, hence we will leave the task of exploring the shapes of the outputs of the remaining blocks as your exercise. I hope you are doing the things with me in Google collab.
Up-Conv
Up-Conv consists of Upsampling and Convolution of kernel size 2 with feature reductions into half. Let's create such a block.
class UpConv(nn.Module):
def __init__(self, ic, oc):
super(UpConv, self).__init__()
self.upconv = nn.Sequential(
nn.Upsample(scale_factor = 2),
nn.Conv2d(ic, oc, 2, 1)
)
def forward(self, x):
x = self.upconv(x)
x = x[:, :, :-1, :-1]
return x
We will check this block's output on a custom tensor of size which is an output in the lowest layer, from where the input is going to this block.
b = 1 #batchsize
c = 1024 #channels
h = 28 #width
w = 28 #height
image1 = torch.rand(b,c,h,w)
model3 = ConvReLux2(1024,512)
model3.eval()
output3 = model3(image1)
print(output3.shape)
You will get the following output.
torch.Size([1, 512, 56, 56])
However, you should ask why did I take x = x[:, :, :-1, :-1]. This is essentially cropping out one layer of pixels around the image to make it the proper dimensional as given in the paper. Without this, the image will have the shape (57, 57). This happened because I gave one extra padding layer. Without the padding layer, the output image shape will be off (55, 55). However after which, I could have added on a zero padding layer leading to (56, 56). To avoid this confusion, there is a module called ConvTranspose2d, which does both the operations together, while just multiplying the size by 2. This was first introduced in this paper, which was published just before the UNet paper. You can get a beautiful tutorial on convolutions in this article. You can visit here to see all the different types of convolutions in Pytorch. You can read the documentation to get a more detailed understanding. We will move on to one of the most exciting blocks: Up-Conv + Conv-ReLu x 2.
This is a tricky and albeit interesting one because of the concatenation of the left contraction block output to an internal feature of each of the blocks of the expansion section. Let's see how we tackle this.
Up-Conv + Conv-ReLu x 2
class UpConvConvReLux2(nn.Module):
def __init__(self, ic, oc):
super(UpConvConvReLux2, self).__init__()
self.upconv = UpConv(ic, oc)
self.doubleconv = ConvReLux2(ic, oc)
def forward(self, x1, x2):
#this reduces the feature size from ic to oc
x1 = self.upconv(x1)
#cropping x2 to the shape of x1
_, _, h1, w1 = x1.shape
x2 = x2[:, :, :h1, :w1]
#this increases the feature size from oc to ic
x = torch.cat((x1,x2), dim = 1)
x = self.doubleconv(x)
return x
b = 1 #batchsize
c = 512 #channels
h = 64 #width
w = 64 #height
image2 = torch.rand(b,c,h,w)
model4 = UpConvConvReLux2(1024, 512)
model4.eval()
output4 = model4(image1, image2)
print(output4.shape)
This leads to the following output. However, there are a few important things to mention.
torch.Size([1, 1024, 56, 56])
torch.Size([1, 512, 52, 52])
Observe that while upsampling, the size of the image is increased by 2, and by the convolution in upconv, the number of features is reduced by 2. Please refer to the architecture diagram. Now, concatenation of the same shape from the left contraction sequence with the output of the upconv to make it eligible for the next round of double convolution. Observe that the image input from the contraction step is larger. This needs to be cropped to the size of the expansion module. To verify it with a demo input, we generate a new tensor of the same dimensions in the paper. Isn't beautiful that this matches? Now, we are ready to build the final UNet structure. This is pretty straightforward, given the interesting challenge of the last module.
UNet
class UNet(nn.Module):
def __init__(self, ic, oc):
super(UNet, self).__init__()
self.inconv = ConvReLux2(ic, 64)
self.downconv1 = MaxPoolConvReLux2(64,128)
self.downconv2 = MaxPoolConvReLux2(128,256)
self.downconv3 = MaxPoolConvReLux2(256,512)
self.downconv4 = MaxPoolConvReLux2(512,1024)
self.upconv4 = UpConvConvReLux2(1024,512)
self.upconv3 = UpConvConvReLux2(512, 256)
self.upconv2 = UpConvConvReLux2(256, 128)
self.upconv1 = UpConvConvReLux2(128, 64)
self.outconv = nn.Conv2d(64, oc, 1, 1)
def forward(self, x):
print(x.shape)
x1 = self.inconv(x)
print(x1.shape)
x2 = self.downconv1(x1)
print(x2.shape)
x3 = self.downconv2(x2)
print(x3.shape)
x4 = self.downconv3(x3)
print(x4.shape)
x5 = self.downconv4(x4)
print(x5.shape)
x6 = self.upconv4(x5, x4)
print(x6.shape)
x7 = self.upconv3(x6, x3)
print(x7.shape)
x8 = self.upconv2(x7, x2)
print(x8.shape)
x9 = self.upconv1(x8, x1)
print(x9.shape)
x10 = self.outconv(x9)
print(x10.shape)
return x10
b = 1 #batchsize
c = 1 #channels
h = 572 #width
w = 572 #height
image = torch.rand(b,c,h,w)
model = UNet(1,3)
model.eval()
output = model(image)
This will lead to a long output. But the last output will be of the following shape.
torch.Size([1, 3, 388, 388])
Voila! We have done it step by step. This is exactly the shape in the original paper. In this final architecture, we take together the pieces and sew them from one end to another end by calling the respective core fundamental modules of the architecture. This indeed feels surreal when you get bugs, and you debug, and in the final strike, you get it. You can make any architecture, using this process of understanding the core steps, and operations, and sewing them together. You can get the entire code used in this article in this GitHub here.
There is one small catch. One important statement, that we didn't care about is the following: "To allow a seamless tiling of the output segmentation map (see Figure 2), it is important to select the input tile size such that all 2x2 max-pooling operations are applied to a layer with an even x- and y-size."
This means you have to be careful in selecting an image size so that it can smoothly pass through till the end. Because, you have four max pools, which are dividing the image size by 2. This means that in the intermediate step, the image size has to be divisible by 2. This is hard to guarantee for any image size. Also, there is another issue and a question that should come to your mind. Why should we have a difference in the shapes of the input and output images? Shouldn't they be the same, because we are doing segmentation? In the UNet paper, the authors wanted to crop and segment only. However, one may need to have the following two needs for making one's own UNet architecture for segmentation:
- Odd-sized input or intermediate images
- The segmentation mask's shape same as the input images' shape.
To solve these issues, I have created a general script. You can get it at Github here.
I hope you have enjoyed this process if you are still here. Haha. This took me quite a few hours to write this down in an organized manner for you. I need to know if this is helpful to you, otherwise, there is no reason to waste my time writing it, if it is not helping you. If this helps, please share your views in the comments, and do like, and share this with people, who may need it. Thank you for your time. The more people, it reaches, the more it motivates me to create longer types of useful content. :D
Interpretable AI Researcher at Penn State
11 个月Follow https://mukherjeesrijit.substack.com/ for a personalized email newsletter.
Academic Guidance | Research Mentorship | Academic Reviewer | Machine Learning | Deep Learning | NLP | Computer Vision | AI for Healthcare
11 个月Thank you for sharing this useful article Srijit Mukherjee
Interpretable AI Researcher at Penn State
11 个月Follow www.srijitmukherjee.com for more. This is long-form content that can help you learn how to recreate research papers' architectures. Let me know which architectures, you want me to recreate for you.