UNet for Semantic Segmentation

UNet for Semantic Segmentation

Introduction

All of the CNNs that we have explored so far allow for labeling an entire image or scene to a class. This task is generally referred to as scene labeling or scene classification. However, in the geospatial sciences and when working with satellite images, we are commonly interested in classifying each individual pixel in the image separately. Landcover classification of satellite images is an example of a task that requires pixel-level predictions. Another term for pixel-level classification is semantic segmentation.

Is it possible to learn spatial patterns using CNNs while also obtaining a pixel-level classification? It turns out that this is possible using fully convolutional neural network architectures and their derivatives. In this module, we will specifically explore the UNet architecture. It was introduced in the following publication, which at the time of this writing has over 55,000 citations!

Ronneberger, O., Fischer, P. and Brox, T., 2015, October. U-net: Convolutional networks for biomedical image segmentation. In International Conference on Medical image computing and computer-assisted intervention (pp. 234-241). Springer, Cham.

UNet is a pretty famous and powerful techniques, and more recently introduced architectures draw from it or expand upon it in some way. Examples include UNet++ and DeepLabv3+. In this module, we will begin exploring UNet by building one from scratch using PyTorch. Note that this is a modified version of the original architecture presented in the paper above. In the following two modules, we will explore training a UNet and altering the encoder component of UNet by using the VGGNet and ResNet architectures. In the last three modules relating semantic segmentation we will explore different semantic segmentation methods using the Segmentation Models package, learn to generate image chips from geospatial data, and use a UNet to predict back to spatial data to create a georeferenced map output.

In this module, our primary objective is to explore the UNet architecture so that you understand its components and how it is implemented.

Preparation

Since we are only building a UNet and will not be loading data; defining loss metrics, assessment metrics, or an optimizer; or implementing a training loop, I only need to import the torch package and torch.nn subpackage. I am also reading in the torchsummary package so that I can summarize the model. I also set the device to the GPU.

import torch 
import torch.nn as nn
from torchsummary import summary
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
cuda:0

Define a UNet

Since a UNet has many components, I will simplify my implementation by first defining some units as functions that can be used repeatidly within the UNet definition. I will then be able to call these functions to implement these components of the model.

First, I define a function called double_conv(). This function performs the following process:

2D Convolution –> 2D Batch Normalization –> ReLU Activation –> 2D Convolution –> 2D Batch Normalization –> ReLU Activation

This function accepts two parameters: the number of input channels or feature maps and the number of output channels or feature maps. The first 2D convolution accepts the defined number of input channels and returns the number of defined output channels. The second 2D convolution expects the number of define output channels and returns the same number of augmented channels. Note the use of inplace=True when applying the ReLU activation functions. This is used to reduce memory consumption by applying the transformation in place. For each 2D convolution, a kernel size of 3x3 is used with a stride and padding of 1. This will result in returning arrays with the same height and width as the input image or feature maps. All of these steps are defined inside of nn.Sequential().

def double_conv(inChannels, outChannels):
  return nn.Sequential(
      nn.Conv2d(inChannels, outChannels, kernel_size=(3,3), stride=1, padding=1),
      nn.BatchNorm2d(outChannels),
      nn.ReLU(inplace=True),
      nn.Conv2d(outChannels, outChannels, kernel_size=(3,3), stride=1, padding=1),
      nn.BatchNorm2d(outChannels),
      nn.ReLU(inplace=True)
  )

Next, I define a function called up_conv(). It performs the following operations:

2D Transpose Deconvolution –> 2D Batch Normalization –> ReLU Activation

This function will be used repeatedly in the decoder component of the UNet to increase the size of array in the spatial dimensions using trainable filters. This is the first demonstrated use of 2D transpose convolution since introducing it in the Convolutional Neural Network Building Blocks module. For the transpose convolution, I am using a kernel size of 2x2 and a stride of 2. This will result in doubling the size of the feature maps in the spatial dimensions.

def up_conv(inChannels, outChannels):
    return nn.Sequential(
      nn.ConvTranspose2d(inChannels, outChannels, kernel_size=(2,2), stride=2),
      nn.BatchNorm2d(outChannels),
      nn.ReLU(inplace=True)
  )

I am now ready to build the UNet architecture by subclassing the nn.Module class. A UNet consists of two components. The encoder learns kernels to capture spatial context information at different scales using a series of 2D convolution and max pooling operations. Within each block of the encoder, I will make use of the double_conv() function. Between each block, I will apply max pooling with a kernel size of 2x2 and a stride of 2 to decrease the array size in the spatial dimensions by half.

In the decoder block, the array’s size is increased in the spatial dimensions using 2D transpose convolution, effectively undoing the downsampling resulting from the max pooling operations performed in the encoder. More specifically, it is undoing the change in the spatial dimension sizes, not undoing the impact of applying the kernels. This will result in returning the original spatial dimensions of the input data. In order to make a prediction at each cell location in the final block of the decoder, 2D convolution with a kernel size of 1x1, a stride of 1, and an output size equal to the number of classes being differentiated is applied. This results in a predicted logit for each class at each pixel location. These logits can then be converted to probabilities using a softmax activation. However, this may not be required, depending on what loss function is applied. I will not apply a softmax activation here.

You may have noticed that there are no fully connected layers in the architecture. This has a very convenient added benefit: the UNet will be able to accept training data with different heights and widths. In the CNNs that we have already explored, the size of the array in the spatial dimensions after the final max pooling operation will impact the number of inputs to the first fully connected layer after the array is flattened. Thus, changing the input array size would result in changing the number of inputs to the first fully connected layer. Thus, not relying on fully connected layers allows for some generalization since the network can be trained on and can infer to image chips of varying sizes.

Let’s explore the content within the __init__() constructor method. This UNet implementation, called myUNet, accepts the following parameters:

  1. encoderChn = a list of channel sizes for the output of each encoder block
  2. decoderChn = a list of channel sizes for the output of each decoder block
  3. inChn = the number of channels in the input image
  4. botChn = the number of channels in the bottleneck layer
  5. nCls = the number of classes being differentiated

The encoder contains 4 blocks. The first block consists of just double 2D convolution. The next three blocks start with a max pooling operation with a kernel size of 2x2 and a stride of 2 followed by double 2D convolution. The bottleneck is the transition from then encoder to the decoder component. It begins with a max pooling operation followed by double 2D convolution. After the last max pooling operation and if the input image had spatial dimensions of 256x256, the spatial dimension sizes would change as follows through the encoder and bottleneck layers:

256x256 –> 128x128 –> 64x64 –> 32x32 –> 16x16

Similar to the encoder, the decoder consists of 4 blocks. The first three start with applying the up_conv() function, which includes the 2D transpose deconvolution. This process increases the array size in the spatial dimensions. Next, the double 2D convolution is applied to learn filters as part of the decoder. As the data pass through the decoder the spatial dimensions change as follows:

16x16 –> 32x32 –> 64x64 –> 128x128 –> 256x256

Thus, by the time the data reach the last block of the decoder, the spatial dimensions have been restored. The final step is to apply a 2D convolution with a kernel size of 1x1, a stride of 1, and an output size equal to the number of input classes.

As with our other networks, the forward() method describes how the data pass through the network. First, the data pass sequentially through each of the encoder blocks followed by the bottleneck block. Between the application of 2D transpose deconvolution and double 2D convolution in each decoder block, a concatenation is performed where the output feature maps from the encoder block with the same spatial dimension sizes are added to the current set of arrays. This is the skip connection component of the UNet architecture. By concatenating the outputs from the encoder to the associated decoder block, the information from the encoder is passed to the decoder so that this information can be integrated into the decoder operations. This is one key component of the UNet architecture and is on reason for its success. If you return back to the __init__() constructor method, you will see that each double_conv() implementation has an input size equal to the number of channels being fed to it from the prior 2D transpose layer plus the number of layers from the associated encoder block. Again, this is required due to the use of skip connections. Note that the concatenation occurs along dimension 2 (index = 1). This is the channels dimension.

Lastly, the data are passed through the classification head, which will output a logit for each of the differentiated classes.

class myUNet(nn.Module):
  def __init__(self, encoderChn, decoderChn, inChn, botChn, nCls):
    super().__init__()
    self.encoderChn = encoderChn
    self.decoderChn = decoderChn 
    self.botChn = botChn
    self.nCls = nCls

    self.encoder1 = double_conv(inChn, encoderChn[0])
    
    self.encoder2 = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2),
                                  double_conv(encoderChn[0], encoderChn[1]))
    
    self.encoder3 = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2),
                                  double_conv(encoderChn[1], encoderChn[2]))
    
    self.encoder4 = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2),
                                  double_conv(encoderChn[2], encoderChn[3]))
    
    self.bottleneck = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2),
                                    double_conv(encoderChn[3], botChn))

    self.decoder1up = up_conv(botChn, botChn)
    self.decoder1 = double_conv(encoderChn[3]+botChn, decoderChn[0])

    self.decoder2up = up_conv(decoderChn[0], decoderChn[0])
    self.decoder2 = double_conv(encoderChn[2]+decoderChn[0], decoderChn[1])

    self.decoder3up = up_conv(decoderChn[1], decoderChn[1])
    self.decoder3 = double_conv(encoderChn[1]+decoderChn[1], decoderChn[2])

    self.decoder4up = up_conv(decoderChn[2], decoderChn[2])
    self.decoder4 = double_conv(encoderChn[0]+decoderChn[2], decoderChn[3])

    self.classifier = nn.Conv2d(decoderChn[3], nCls, kernel_size=(1,1))

  def forward(self, x):

    #Encoder
    encoder1 = self.encoder1(x)
    encoder2 = self.encoder2(encoder1)
    encoder3 = self.encoder3(encoder2)
    encoder4 = self.encoder4(encoder3)

    #Bottleneck
    x = self.bottleneck(encoder4)

    #Decoder
    x = self.decoder1up(x)
    x = torch.concat([x, encoder4], dim=1)
    x = self.decoder1(x)

    x = self.decoder2up(x)
    x = torch.concat([x, encoder3], dim=1)
    x = self.decoder2(x)

    x = self.decoder3up(x)
    x = torch.concat([x, encoder2], dim=1)
    x = self.decoder3(x)

    x = self.decoder4up(x)
    x = torch.concat([x, encoder1], dim=1)
    x = self.decoder4(x)

    #Classifier head
    x = self.classifier(x)

    return x

Once the myUNet subclass is defined, I instantiate an instance of the class called model.

model = myUNet(encoderChn=[16,32,64,128], decoderChn=[128,64,32,16], inChn=3, botChn=256, nCls=10).to(device)

I think it is helpful to think through the shapes of the data in each component of the network. If I pass in an image with spatial dimensions of 256x256 pixels, the shapes would be as follows:

  • Encoder Block 1: Input = (3,256,256); Output = (16,256,256)

  • Encoder Block 2: Input = (16,256,256); Output = (32,128,128)

  • Encoder Block 3: Input = (32,128,128); Output = (64,64,64)

  • Encoder Block 4: Input = (64,64,64); Output = (128,32,32)

  • Bottleneck: Input: (128,32,32); Output = (256,16,16)

  • 2D Transpose: Input: (256,16,16); Output = (256,32,32)

  • Decoder Block 1: Input: (256+128,32,32); Output = (128,32,32)

  • 2D Transpose: Input: (128,32,32); Output = (128,64,64)

  • Decoder Block 2: Input = (64+128,64,64); Output = (64,64,64)

  • 2D Transpose: Input = (64,64,64); Output = (64,128,128)

  • Decoder Block 3: Input = (32+64,128,128); Output = (32,128,128)

  • 2D Transpose: Input = (32,128,128); Output = (32,256,256)

  • Decoder Block 4: Input = (16+32,256,256); Output = (16,256,256)

  • Classification Head: Input = (16,256,256); Output = (10,256,256)

In the sequence above, note how the spatial dimensions decrease through the encoder then increase through the decoder such that the final output has spatial dimensions equal to the input image. Also, the 2D transpose operations double the spatial dimensions, effectively restoring the array size prior to max pooling. In each decoder block the concatenation of the channels from the prior 2D transpose convolution and the encoder block with the same spatial dimension size represents the skip connections.

If you are having trouble understanding this sequence, it might be helpful to take some time to draw out a UNet architecture and write out the associated array sizes at each step or block.

Summarize Model

The summary() function from torchmetrics can be very useful for exploring network architectures. This function requires passing a model object and an input array size to the function. Again, take some time to explore the output and make sure you understand why each step results in arrays of a specific shape. Note that the spatial dimensions decrease in size only when max pooling is applied and increase in size only when 2D transpose convolution is applied. The 2D convolution, ReLU activation, and 2D batch normalization layers do not impact the spatial dimension sizes. The 2D convolution, 2D transpose deconvolution, and 2D batch normalization layers have trainable parameters. The max pooling and ReLU activation layers do not have trainable parameters.

In the printout, -1 represents the unknown size of the batch dimension. So, if the training batch size is 32, -1 would be replaced with 32 at all stages of the model since the batch size is not impacted by the operations.

Lastly, it should be noted that UNet and UNet-like architectures can become very large and have many trainable parameters. This fairly simple UNet has over 2 million trainable parameters and each model has an estimated size of nearly 300MB to store the forward and backward passes and model parameters. There are no non-trainable parameters in the model. This is because none of the layers have been frozen. In a later module, we will explore how layers can be frozen so that the weights cannot be updated during the learning process. This is sometimes done when applying transfer learning and can be applied for all of or only some of the layers and/or training epochs.

summary(model, (3,256,256))
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 16, 256, 256]             448
       BatchNorm2d-2         [-1, 16, 256, 256]              32
              ReLU-3         [-1, 16, 256, 256]               0
            Conv2d-4         [-1, 16, 256, 256]           2,320
       BatchNorm2d-5         [-1, 16, 256, 256]              32
              ReLU-6         [-1, 16, 256, 256]               0
         MaxPool2d-7         [-1, 16, 128, 128]               0
            Conv2d-8         [-1, 32, 128, 128]           4,640
       BatchNorm2d-9         [-1, 32, 128, 128]              64
             ReLU-10         [-1, 32, 128, 128]               0
           Conv2d-11         [-1, 32, 128, 128]           9,248
      BatchNorm2d-12         [-1, 32, 128, 128]              64
             ReLU-13         [-1, 32, 128, 128]               0
        MaxPool2d-14           [-1, 32, 64, 64]               0
           Conv2d-15           [-1, 64, 64, 64]          18,496
      BatchNorm2d-16           [-1, 64, 64, 64]             128
             ReLU-17           [-1, 64, 64, 64]               0
           Conv2d-18           [-1, 64, 64, 64]          36,928
      BatchNorm2d-19           [-1, 64, 64, 64]             128
             ReLU-20           [-1, 64, 64, 64]               0
        MaxPool2d-21           [-1, 64, 32, 32]               0
           Conv2d-22          [-1, 128, 32, 32]          73,856
      BatchNorm2d-23          [-1, 128, 32, 32]             256
             ReLU-24          [-1, 128, 32, 32]               0
           Conv2d-25          [-1, 128, 32, 32]         147,584
      BatchNorm2d-26          [-1, 128, 32, 32]             256
             ReLU-27          [-1, 128, 32, 32]               0
        MaxPool2d-28          [-1, 128, 16, 16]               0
           Conv2d-29          [-1, 256, 16, 16]         295,168
      BatchNorm2d-30          [-1, 256, 16, 16]             512
             ReLU-31          [-1, 256, 16, 16]               0
           Conv2d-32          [-1, 256, 16, 16]         590,080
      BatchNorm2d-33          [-1, 256, 16, 16]             512
             ReLU-34          [-1, 256, 16, 16]               0
  ConvTranspose2d-35          [-1, 256, 32, 32]         262,400
      BatchNorm2d-36          [-1, 256, 32, 32]             512
             ReLU-37          [-1, 256, 32, 32]               0
           Conv2d-38          [-1, 128, 32, 32]         442,496
      BatchNorm2d-39          [-1, 128, 32, 32]             256
             ReLU-40          [-1, 128, 32, 32]               0
           Conv2d-41          [-1, 128, 32, 32]         147,584
      BatchNorm2d-42          [-1, 128, 32, 32]             256
             ReLU-43          [-1, 128, 32, 32]               0
  ConvTranspose2d-44          [-1, 128, 64, 64]          65,664
      BatchNorm2d-45          [-1, 128, 64, 64]             256
             ReLU-46          [-1, 128, 64, 64]               0
           Conv2d-47           [-1, 64, 64, 64]         110,656
      BatchNorm2d-48           [-1, 64, 64, 64]             128
             ReLU-49           [-1, 64, 64, 64]               0
           Conv2d-50           [-1, 64, 64, 64]          36,928
      BatchNorm2d-51           [-1, 64, 64, 64]             128
             ReLU-52           [-1, 64, 64, 64]               0
  ConvTranspose2d-53         [-1, 64, 128, 128]          16,448
      BatchNorm2d-54         [-1, 64, 128, 128]             128
             ReLU-55         [-1, 64, 128, 128]               0
           Conv2d-56         [-1, 32, 128, 128]          27,680
      BatchNorm2d-57         [-1, 32, 128, 128]              64
             ReLU-58         [-1, 32, 128, 128]               0
           Conv2d-59         [-1, 32, 128, 128]           9,248
      BatchNorm2d-60         [-1, 32, 128, 128]              64
             ReLU-61         [-1, 32, 128, 128]               0
  ConvTranspose2d-62         [-1, 32, 256, 256]           4,128
      BatchNorm2d-63         [-1, 32, 256, 256]              64
             ReLU-64         [-1, 32, 256, 256]               0
           Conv2d-65         [-1, 16, 256, 256]           6,928
      BatchNorm2d-66         [-1, 16, 256, 256]              32
             ReLU-67         [-1, 16, 256, 256]               0
           Conv2d-68         [-1, 16, 256, 256]           2,320
      BatchNorm2d-69         [-1, 16, 256, 256]              32
             ReLU-70         [-1, 16, 256, 256]               0
           Conv2d-71         [-1, 10, 256, 256]             170
================================================================
Total params: 2,315,322
Trainable params: 2,315,322
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.75
Forward/backward pass size (MB): 281.75
Params size (MB): 8.83
Estimated Total Size (MB): 291.33
----------------------------------------------------------------

Concluding Remarks

Now that you now how to build a fairly simple UNet using PyTorch and by subclassing nn.Module, in the next section we will train a UNet architecture and assess the resulting pixel-level predictions.