Machine Learning
UNet
Nov 27, 2021     3 minutes read

1. What is UNet?

UNet is a popular deep convolutional neural network used for object (image) segmentation.

2. How to learn UNet?

When I approach a new subject, I usually follow this path:

Following this path has several advantages:

Until recently I thought that understanding how something works requires building a replica by myself from ground-up, but this attitude has several disadvantages:

3. Learning

To get a general idea of image segmentation I watched a part of cs231n from Stanford. It gave me an overwiew of what I am going to do and why.

After that I managed to find an excellent implementation of UNet on github. It’s excellence on a fact that I could download the code and run it, at least on an EC2 instance, as my laptop has only 2GB GPU. And it worked :)

After that I started to analyse the code. I needed to learn several things:

And an implementation of a dumb net, which mimics the most important functionalities of UNet.

import torch.nn as nn
import torch


class DumbNet(nn.Module):

    def __init__(self) -> None:
        super(DumbNet, self).__init__()
        self.down = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1, stride=2),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
        )
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
        )
        self.out = nn.Conv2d(16, 3, kernel_size=1)

    def forward(self, x):
        x_down = self.down(x)
        x_up = self.up(x_down)
        x_out = self.out(x_up)
        return x_out

with down and upsampling.

4. Where to go next?

In practice you may not want to write you own version of UNet or copy-paste code from any implementation published on github. Instead you can use segmentation models package, which already has many interesting segmentation models implemented.

But tweaking this package is another story :)