UNet Encoders

UNet Encoders

Introduction

The backbone or encoder component of semantic segmentation models serve the same purpose as the convolutional layers in a CNN architecture designed for scene labeling tasks: they characterize spatial patterns at varying spatial scales. As a result, the encoder component of semantic segmentation architectures, such as UNets, can be augmented to use a variety of pre-defined CNN architectures, such as ResNets and VGGNets. Since these common architectures have been trained using large datasets, such as ImageNet, this allows for pre-trained weights to also be incorporated into the backbone or encoder component of semantic segmentation models. This component of the network can then either be frozen or updated during the training process. Such a use of transfer learning may allow for training models with less training data and/or for few epochs to obtain adequate results.

In this short module, I will demonstrate augmenting UNet. The first example will use a VGGNet-16 architecture as the model backbone, and the second example will use a ResNet backbone. I will only define the model architectures and summarize them using torchsummary. I will not train the algorithms. However, the methods used to train the UNet in the Train a UNet module could be applied to these architectures if desired.

Since I am defining architectures, I need to import torch and torch.nn. The torchsummary package is used to summarize the model architecture while the backbones will be accessed using the torchvision implementation. Lastly, I define the GPU as the device.

import torch 
import torch.nn as nn
from torchsummary import summary
import torchvision.models
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
cuda:0

VGGNet-16 Encoder

I begin by creating a VGGNet backbone. The first step is to instantiate the backbone, as implemented in the torchvision.models subpackage. I use a modified version of VGGNet-16 that incorporates batch normalization using the vgg16_bn() function. Since I will not actually train the algorithm, I do not download the pre-trained weights.

vgg16 = torchvision.models.vgg16_bn(pretrained=False).to(device)
C:\Users\vidcg\ANACON~1\envs\torchENV\lib\site-packages\torchvision\models\_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
  warnings.warn(
C:\Users\vidcg\ANACON~1\envs\torchENV\lib\site-packages\torchvision\models\_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=None`.
  warnings.warn(msg)

I next obtain a summary of the VGGNet-16 architecture using an input array size of (3,256,256). In order to use this architecture within UNet, I will need to be able to obtain the outputs before each max pooling operation so that the results can be passed to the decoder via the skip connections. The array sizes in the spatial dimensions being passed through the skip connections must be the same as those to which they are being concatenated.

summary(vgg16, (3,256,256))
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 256, 256]           1,792
       BatchNorm2d-2         [-1, 64, 256, 256]             128
              ReLU-3         [-1, 64, 256, 256]               0
            Conv2d-4         [-1, 64, 256, 256]          36,928
       BatchNorm2d-5         [-1, 64, 256, 256]             128
              ReLU-6         [-1, 64, 256, 256]               0
         MaxPool2d-7         [-1, 64, 128, 128]               0
            Conv2d-8        [-1, 128, 128, 128]          73,856
       BatchNorm2d-9        [-1, 128, 128, 128]             256
             ReLU-10        [-1, 128, 128, 128]               0
           Conv2d-11        [-1, 128, 128, 128]         147,584
      BatchNorm2d-12        [-1, 128, 128, 128]             256
             ReLU-13        [-1, 128, 128, 128]               0
        MaxPool2d-14          [-1, 128, 64, 64]               0
           Conv2d-15          [-1, 256, 64, 64]         295,168
      BatchNorm2d-16          [-1, 256, 64, 64]             512
             ReLU-17          [-1, 256, 64, 64]               0
           Conv2d-18          [-1, 256, 64, 64]         590,080
      BatchNorm2d-19          [-1, 256, 64, 64]             512
             ReLU-20          [-1, 256, 64, 64]               0
           Conv2d-21          [-1, 256, 64, 64]         590,080
      BatchNorm2d-22          [-1, 256, 64, 64]             512
             ReLU-23          [-1, 256, 64, 64]               0
        MaxPool2d-24          [-1, 256, 32, 32]               0
           Conv2d-25          [-1, 512, 32, 32]       1,180,160
      BatchNorm2d-26          [-1, 512, 32, 32]           1,024
             ReLU-27          [-1, 512, 32, 32]               0
           Conv2d-28          [-1, 512, 32, 32]       2,359,808
      BatchNorm2d-29          [-1, 512, 32, 32]           1,024
             ReLU-30          [-1, 512, 32, 32]               0
           Conv2d-31          [-1, 512, 32, 32]       2,359,808
      BatchNorm2d-32          [-1, 512, 32, 32]           1,024
             ReLU-33          [-1, 512, 32, 32]               0
        MaxPool2d-34          [-1, 512, 16, 16]               0
           Conv2d-35          [-1, 512, 16, 16]       2,359,808
      BatchNorm2d-36          [-1, 512, 16, 16]           1,024
             ReLU-37          [-1, 512, 16, 16]               0
           Conv2d-38          [-1, 512, 16, 16]       2,359,808
      BatchNorm2d-39          [-1, 512, 16, 16]           1,024
             ReLU-40          [-1, 512, 16, 16]               0
           Conv2d-41          [-1, 512, 16, 16]       2,359,808
      BatchNorm2d-42          [-1, 512, 16, 16]           1,024
             ReLU-43          [-1, 512, 16, 16]               0
        MaxPool2d-44            [-1, 512, 8, 8]               0
AdaptiveAvgPool2d-45            [-1, 512, 7, 7]               0
           Linear-46                 [-1, 4096]     102,764,544
             ReLU-47                 [-1, 4096]               0
          Dropout-48                 [-1, 4096]               0
           Linear-49                 [-1, 4096]      16,781,312
             ReLU-50                 [-1, 4096]               0
          Dropout-51                 [-1, 4096]               0
           Linear-52                 [-1, 1000]       4,097,000
================================================================
Total params: 138,365,992
Trainable params: 138,365,992
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.75
Forward/backward pass size (MB): 420.64
Params size (MB): 527.82
Estimated Total Size (MB): 949.21
----------------------------------------------------------------

It is possible to obtain a list of all of the architecture’s components using the features property of the instantiated VGGNet-16 architecture. Printing the result, you can see that the architecture has 43 components. I will need to divide the model’s components as follows to extract the results at the correct locations in the architecture.

  • Layers 0-5 -> 1st skip connection (original array size)
  • Layers 6-12 -> 2nd skip connection (original array size/2)
  • Layers 13-22 -> 3rd skip connection (original array size/4)
  • Layers 23-32 -> 4th skip connection (original array size/8)
  • Layers 33-42 -> Bottleneck (original array size/16)

I will need to split the architecture before each max pooling operation so that I can match each encoder step with the associated decoder step that has the same array sizes in the spatial dimensions. The last max pooling layer is not used since I do not want do decrease the array size further. Instead, the data will enter the decoder component and the first 2D transpose convolution operation.

vgg16F = torchvision.models.vgg16_bn(pretrained=False).features
vgg16F
Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (5): ReLU(inplace=True)
  (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (9): ReLU(inplace=True)
  (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (12): ReLU(inplace=True)
  (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (14): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (15): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (16): ReLU(inplace=True)
  (17): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (18): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (19): ReLU(inplace=True)
  (20): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (21): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (22): ReLU(inplace=True)
  (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (24): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (25): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (26): ReLU(inplace=True)
  (27): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (28): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (29): ReLU(inplace=True)
  (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (31): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (32): ReLU(inplace=True)
  (33): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (35): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (36): ReLU(inplace=True)
  (37): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (38): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (39): ReLU(inplace=True)
  (40): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (41): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (42): ReLU(inplace=True)
  (43): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)

With what we learned about the VGGNet-16 architecture from our explorations above, let’s now incorporate it into a UNet architecture. As I did in the UNet Architecture module, I define double_conv() and up_conv() functions for use in the UNet architecture. I next define the UNet architecture by subclassing nn.Module. Here are the key components.

  1. The __init__() constructor method defines the following parameters: the number of input channels (nCls), the number of classes being differentiated (nCls), and whether or not to use pre-trained weights.
  2. The features from the VGGNet-16 architecture are extracted using the vgg16_bn() function from torchvision and the features property.
  3. A list of output sizes for the decoder are defined as a list.
  4. The components of the encoder are defined as the appropriate subset of layers from VGGNet-16. The components are combined using nn.Sequential(), and * is used to unpack the list of features. Again, I am breaking the model apart before each max pooling operation so that the array sizes in the spatial dimensions match between each encoder block and associated decoder block.
  5. I define the bottleneck as the last set of components of the VGGNet-16 model. The last feature is not used since we do not want to apply the last max pooling operation.
  6. I define the decoder blocks, which each consist of upsampling using 2D transpose convolution and a series of two 2D convolution layers to learn filters.
  7. The last layer uses 2D convolution with a kernel size of 1x1 and a stride of 1. The number of outputs is equal to the number of classes being differentiated.
  8. The forward() method defines how data will pass through the architecture. Data first pass through each of the encoder blocks followed by the bottleneck. In each decoder block, upsampling is performed using 2D transpose convolution, the feature maps from the associated encoder block are concatenated, and the features pass through two 2D convolution layers. Lastly, the data pass through the final 2D convolution layer, which has a kernel size of 1x1 and a stride of 1. The output will be logits for each class. In the case of a binary classification, it could return only a logit for the positive case. Probabilities are not returned since a sigmoid or softmax activation is not being applied.
def double_conv(inChannels, outChannels):
  return nn.Sequential(
      nn.Conv2d(inChannels, outChannels, kernel_size=(3,3), stride=1, padding=1),
      nn.BatchNorm2d(outChannels),
      nn.ReLU(inplace=True),
      nn.Conv2d(outChannels, outChannels, kernel_size=(3,3), stride=1, padding=1),
      nn.BatchNorm2d(outChannels),
      nn.ReLU(inplace=True)
  )
def up_conv(inChannels, outChannels):
    return nn.Sequential(
      nn.ConvTranspose2d(inChannels, outChannels, kernel_size=(2,2), stride=2),
      nn.BatchNorm2d(outChannels),
      nn.ReLU(inplace=True)
  )
class myUNetVGG16(nn.Module):
  def __init__(self, inChn, nCls, useWghts=True):
    super().__init__()
    self.inChn = inChn
    self.nCls = nCls
    self.useWghts = useWghts

    self.base_model = torchvision.models.vgg16_bn(pretrained=useWghts).features
    self.outSizes = [64, 128, 256, 512, 512]

    self.encoder1 = nn.Sequential(*self.base_model[:6]) 
    self.encoder2 = nn.Sequential(*self.base_model[6:13]) 
    self.encoder3 = nn.Sequential(*self.base_model[13:23]) 
    self.encoder4 = nn.Sequential(*self.base_model[23:33]) 
    
    self.bottleneck = nn.Sequential(*self.base_model[33:43]) 

    self.decoder1up = up_conv(self.outSizes[4], 512) 
    self.decoder1 = double_conv(self.outSizes[3] + 512, 256) 

    self.decoder2up = up_conv(256, 256)
    self.decoder2 = double_conv(self.outSizes[2] + 256, 128)

    self.decoder3up = up_conv(128, 128)
    self.decoder3 = double_conv(self.outSizes[1] + 128, 64)

    self.decoder4up = up_conv(64, 64)
    self.decoder4 = double_conv(self.outSizes[0] + 64, 32)

    self.classifier = nn.Conv2d(32, nCls, kernel_size=(1,1))

  def forward(self, x):

    #Encoder
    encoder1 = self.encoder1(x)
    encoder2 = self.encoder2(encoder1)
    encoder3 = self.encoder3(encoder2)
    encoder4 = self.encoder4(encoder3)

    #Bottleneck
    x = self.bottleneck(encoder4)

    #Decoder
    x = self.decoder1up(x)
    x = torch.concat([x, encoder4], dim=1)
    x = self.decoder1(x)

    x = self.decoder2up(x)
    x = torch.concat([x, encoder3], dim=1)
    x = self.decoder2(x)

    x = self.decoder3up(x)
    x = torch.concat([x, encoder2], dim=1)
    x = self.decoder3(x)

    x = self.decoder4up(x)
    x = torch.concat([x, encoder1], dim=1)
    x = self.decoder4(x)

    #Classifier head
    x = self.classifier(x)

    return x

I instantiate an instance of the myUNetVGG16() architecture that accepts 3 input channels and outputs 10 class logits. I also initialize the model using the VGGNet-16 pre-trained weights available from torchvision.

Using the summary() function from torchsummary, you can see that the model has over 34 million trainable parameters. There are currently no non-trainable parameters. This is because, even though I downloaded the pre-trained weights, all parameters can still be updated during the learning process. In other words, the model will be initialized using these weights as opposed to random weights, but these layers and associated parameters are still trainable.

model = myUNetVGG16(inChn=3, nCls=10, useWghts=True).to(device)
C:\Users\vidcg\ANACON~1\envs\torchENV\lib\site-packages\torchvision\models\_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=VGG16_BN_Weights.IMAGENET1K_V1`. You can also use `weights=VGG16_BN_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
summary(model, (3,256,256))
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 256, 256]           1,792
            Conv2d-2         [-1, 64, 256, 256]           1,792
       BatchNorm2d-3         [-1, 64, 256, 256]             128
       BatchNorm2d-4         [-1, 64, 256, 256]             128
              ReLU-5         [-1, 64, 256, 256]               0
              ReLU-6         [-1, 64, 256, 256]               0
            Conv2d-7         [-1, 64, 256, 256]          36,928
            Conv2d-8         [-1, 64, 256, 256]          36,928
       BatchNorm2d-9         [-1, 64, 256, 256]             128
      BatchNorm2d-10         [-1, 64, 256, 256]             128
             ReLU-11         [-1, 64, 256, 256]               0
             ReLU-12         [-1, 64, 256, 256]               0
        MaxPool2d-13         [-1, 64, 128, 128]               0
        MaxPool2d-14         [-1, 64, 128, 128]               0
           Conv2d-15        [-1, 128, 128, 128]          73,856
           Conv2d-16        [-1, 128, 128, 128]          73,856
      BatchNorm2d-17        [-1, 128, 128, 128]             256
      BatchNorm2d-18        [-1, 128, 128, 128]             256
             ReLU-19        [-1, 128, 128, 128]               0
             ReLU-20        [-1, 128, 128, 128]               0
           Conv2d-21        [-1, 128, 128, 128]         147,584
           Conv2d-22        [-1, 128, 128, 128]         147,584
      BatchNorm2d-23        [-1, 128, 128, 128]             256
      BatchNorm2d-24        [-1, 128, 128, 128]             256
             ReLU-25        [-1, 128, 128, 128]               0
             ReLU-26        [-1, 128, 128, 128]               0
        MaxPool2d-27          [-1, 128, 64, 64]               0
        MaxPool2d-28          [-1, 128, 64, 64]               0
           Conv2d-29          [-1, 256, 64, 64]         295,168
           Conv2d-30          [-1, 256, 64, 64]         295,168
      BatchNorm2d-31          [-1, 256, 64, 64]             512
      BatchNorm2d-32          [-1, 256, 64, 64]             512
             ReLU-33          [-1, 256, 64, 64]               0
             ReLU-34          [-1, 256, 64, 64]               0
           Conv2d-35          [-1, 256, 64, 64]         590,080
           Conv2d-36          [-1, 256, 64, 64]         590,080
      BatchNorm2d-37          [-1, 256, 64, 64]             512
      BatchNorm2d-38          [-1, 256, 64, 64]             512
             ReLU-39          [-1, 256, 64, 64]               0
             ReLU-40          [-1, 256, 64, 64]               0
           Conv2d-41          [-1, 256, 64, 64]         590,080
           Conv2d-42          [-1, 256, 64, 64]         590,080
      BatchNorm2d-43          [-1, 256, 64, 64]             512
      BatchNorm2d-44          [-1, 256, 64, 64]             512
             ReLU-45          [-1, 256, 64, 64]               0
             ReLU-46          [-1, 256, 64, 64]               0
        MaxPool2d-47          [-1, 256, 32, 32]               0
        MaxPool2d-48          [-1, 256, 32, 32]               0
           Conv2d-49          [-1, 512, 32, 32]       1,180,160
           Conv2d-50          [-1, 512, 32, 32]       1,180,160
      BatchNorm2d-51          [-1, 512, 32, 32]           1,024
      BatchNorm2d-52          [-1, 512, 32, 32]           1,024
             ReLU-53          [-1, 512, 32, 32]               0
             ReLU-54          [-1, 512, 32, 32]               0
           Conv2d-55          [-1, 512, 32, 32]       2,359,808
           Conv2d-56          [-1, 512, 32, 32]       2,359,808
      BatchNorm2d-57          [-1, 512, 32, 32]           1,024
      BatchNorm2d-58          [-1, 512, 32, 32]           1,024
             ReLU-59          [-1, 512, 32, 32]               0
             ReLU-60          [-1, 512, 32, 32]               0
           Conv2d-61          [-1, 512, 32, 32]       2,359,808
           Conv2d-62          [-1, 512, 32, 32]       2,359,808
      BatchNorm2d-63          [-1, 512, 32, 32]           1,024
      BatchNorm2d-64          [-1, 512, 32, 32]           1,024
             ReLU-65          [-1, 512, 32, 32]               0
             ReLU-66          [-1, 512, 32, 32]               0
        MaxPool2d-67          [-1, 512, 16, 16]               0
        MaxPool2d-68          [-1, 512, 16, 16]               0
           Conv2d-69          [-1, 512, 16, 16]       2,359,808
           Conv2d-70          [-1, 512, 16, 16]       2,359,808
      BatchNorm2d-71          [-1, 512, 16, 16]           1,024
      BatchNorm2d-72          [-1, 512, 16, 16]           1,024
             ReLU-73          [-1, 512, 16, 16]               0
             ReLU-74          [-1, 512, 16, 16]               0
           Conv2d-75          [-1, 512, 16, 16]       2,359,808
           Conv2d-76          [-1, 512, 16, 16]       2,359,808
      BatchNorm2d-77          [-1, 512, 16, 16]           1,024
      BatchNorm2d-78          [-1, 512, 16, 16]           1,024
             ReLU-79          [-1, 512, 16, 16]               0
             ReLU-80          [-1, 512, 16, 16]               0
           Conv2d-81          [-1, 512, 16, 16]       2,359,808
           Conv2d-82          [-1, 512, 16, 16]       2,359,808
      BatchNorm2d-83          [-1, 512, 16, 16]           1,024
      BatchNorm2d-84          [-1, 512, 16, 16]           1,024
             ReLU-85          [-1, 512, 16, 16]               0
             ReLU-86          [-1, 512, 16, 16]               0
  ConvTranspose2d-87          [-1, 512, 32, 32]       1,049,088
      BatchNorm2d-88          [-1, 512, 32, 32]           1,024
             ReLU-89          [-1, 512, 32, 32]               0
           Conv2d-90          [-1, 256, 32, 32]       2,359,552
      BatchNorm2d-91          [-1, 256, 32, 32]             512
             ReLU-92          [-1, 256, 32, 32]               0
           Conv2d-93          [-1, 256, 32, 32]         590,080
      BatchNorm2d-94          [-1, 256, 32, 32]             512
             ReLU-95          [-1, 256, 32, 32]               0
  ConvTranspose2d-96          [-1, 256, 64, 64]         262,400
      BatchNorm2d-97          [-1, 256, 64, 64]             512
             ReLU-98          [-1, 256, 64, 64]               0
           Conv2d-99          [-1, 128, 64, 64]         589,952
     BatchNorm2d-100          [-1, 128, 64, 64]             256
            ReLU-101          [-1, 128, 64, 64]               0
          Conv2d-102          [-1, 128, 64, 64]         147,584
     BatchNorm2d-103          [-1, 128, 64, 64]             256
            ReLU-104          [-1, 128, 64, 64]               0
 ConvTranspose2d-105        [-1, 128, 128, 128]          65,664
     BatchNorm2d-106        [-1, 128, 128, 128]             256
            ReLU-107        [-1, 128, 128, 128]               0
          Conv2d-108         [-1, 64, 128, 128]         147,520
     BatchNorm2d-109         [-1, 64, 128, 128]             128
            ReLU-110         [-1, 64, 128, 128]               0
          Conv2d-111         [-1, 64, 128, 128]          36,928
     BatchNorm2d-112         [-1, 64, 128, 128]             128
            ReLU-113         [-1, 64, 128, 128]               0
 ConvTranspose2d-114         [-1, 64, 256, 256]          16,448
     BatchNorm2d-115         [-1, 64, 256, 256]             128
            ReLU-116         [-1, 64, 256, 256]               0
          Conv2d-117         [-1, 32, 256, 256]          36,896
     BatchNorm2d-118         [-1, 32, 256, 256]              64
            ReLU-119         [-1, 32, 256, 256]               0
          Conv2d-120         [-1, 32, 256, 256]           9,248
     BatchNorm2d-121         [-1, 32, 256, 256]              64
            ReLU-122         [-1, 32, 256, 256]               0
          Conv2d-123         [-1, 10, 256, 256]             330
================================================================
Total params: 34,761,802
Trainable params: 34,761,802
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.75
Forward/backward pass size (MB): 1205.00
Params size (MB): 132.61
Estimated Total Size (MB): 1338.36
----------------------------------------------------------------

I next demonstrate using for loops to freeze the trainable parameters in the backbone or encoder layers. This is accomplished by iterating over all layers in the list of layers in the VGGNet-16 model followed by iterating over all of the parameters in each of these layers to set the requires_grad property to False.

If I print the summary again, you can now see that only a subset of the total parameters is trainable. In other words, the parameters in the encoder component of the model that were defined using the VGGNet-16 model can no longer be updated. Only parameters in the decoder component will be trainable.

for l in model.base_model:
  for param in l.parameters():
    param.requires_grad = False
summary(model, (3,256,256))
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 256, 256]           1,792
            Conv2d-2         [-1, 64, 256, 256]           1,792
       BatchNorm2d-3         [-1, 64, 256, 256]             128
       BatchNorm2d-4         [-1, 64, 256, 256]             128
              ReLU-5         [-1, 64, 256, 256]               0
              ReLU-6         [-1, 64, 256, 256]               0
            Conv2d-7         [-1, 64, 256, 256]          36,928
            Conv2d-8         [-1, 64, 256, 256]          36,928
       BatchNorm2d-9         [-1, 64, 256, 256]             128
      BatchNorm2d-10         [-1, 64, 256, 256]             128
             ReLU-11         [-1, 64, 256, 256]               0
             ReLU-12         [-1, 64, 256, 256]               0
        MaxPool2d-13         [-1, 64, 128, 128]               0
        MaxPool2d-14         [-1, 64, 128, 128]               0
           Conv2d-15        [-1, 128, 128, 128]          73,856
           Conv2d-16        [-1, 128, 128, 128]          73,856
      BatchNorm2d-17        [-1, 128, 128, 128]             256
      BatchNorm2d-18        [-1, 128, 128, 128]             256
             ReLU-19        [-1, 128, 128, 128]               0
             ReLU-20        [-1, 128, 128, 128]               0
           Conv2d-21        [-1, 128, 128, 128]         147,584
           Conv2d-22        [-1, 128, 128, 128]         147,584
      BatchNorm2d-23        [-1, 128, 128, 128]             256
      BatchNorm2d-24        [-1, 128, 128, 128]             256
             ReLU-25        [-1, 128, 128, 128]               0
             ReLU-26        [-1, 128, 128, 128]               0
        MaxPool2d-27          [-1, 128, 64, 64]               0
        MaxPool2d-28          [-1, 128, 64, 64]               0
           Conv2d-29          [-1, 256, 64, 64]         295,168
           Conv2d-30          [-1, 256, 64, 64]         295,168
      BatchNorm2d-31          [-1, 256, 64, 64]             512
      BatchNorm2d-32          [-1, 256, 64, 64]             512
             ReLU-33          [-1, 256, 64, 64]               0
             ReLU-34          [-1, 256, 64, 64]               0
           Conv2d-35          [-1, 256, 64, 64]         590,080
           Conv2d-36          [-1, 256, 64, 64]         590,080
      BatchNorm2d-37          [-1, 256, 64, 64]             512
      BatchNorm2d-38          [-1, 256, 64, 64]             512
             ReLU-39          [-1, 256, 64, 64]               0
             ReLU-40          [-1, 256, 64, 64]               0
           Conv2d-41          [-1, 256, 64, 64]         590,080
           Conv2d-42          [-1, 256, 64, 64]         590,080
      BatchNorm2d-43          [-1, 256, 64, 64]             512
      BatchNorm2d-44          [-1, 256, 64, 64]             512
             ReLU-45          [-1, 256, 64, 64]               0
             ReLU-46          [-1, 256, 64, 64]               0
        MaxPool2d-47          [-1, 256, 32, 32]               0
        MaxPool2d-48          [-1, 256, 32, 32]               0
           Conv2d-49          [-1, 512, 32, 32]       1,180,160
           Conv2d-50          [-1, 512, 32, 32]       1,180,160
      BatchNorm2d-51          [-1, 512, 32, 32]           1,024
      BatchNorm2d-52          [-1, 512, 32, 32]           1,024
             ReLU-53          [-1, 512, 32, 32]               0
             ReLU-54          [-1, 512, 32, 32]               0
           Conv2d-55          [-1, 512, 32, 32]       2,359,808
           Conv2d-56          [-1, 512, 32, 32]       2,359,808
      BatchNorm2d-57          [-1, 512, 32, 32]           1,024
      BatchNorm2d-58          [-1, 512, 32, 32]           1,024
             ReLU-59          [-1, 512, 32, 32]               0
             ReLU-60          [-1, 512, 32, 32]               0
           Conv2d-61          [-1, 512, 32, 32]       2,359,808
           Conv2d-62          [-1, 512, 32, 32]       2,359,808
      BatchNorm2d-63          [-1, 512, 32, 32]           1,024
      BatchNorm2d-64          [-1, 512, 32, 32]           1,024
             ReLU-65          [-1, 512, 32, 32]               0
             ReLU-66          [-1, 512, 32, 32]               0
        MaxPool2d-67          [-1, 512, 16, 16]               0
        MaxPool2d-68          [-1, 512, 16, 16]               0
           Conv2d-69          [-1, 512, 16, 16]       2,359,808
           Conv2d-70          [-1, 512, 16, 16]       2,359,808
      BatchNorm2d-71          [-1, 512, 16, 16]           1,024
      BatchNorm2d-72          [-1, 512, 16, 16]           1,024
             ReLU-73          [-1, 512, 16, 16]               0
             ReLU-74          [-1, 512, 16, 16]               0
           Conv2d-75          [-1, 512, 16, 16]       2,359,808
           Conv2d-76          [-1, 512, 16, 16]       2,359,808
      BatchNorm2d-77          [-1, 512, 16, 16]           1,024
      BatchNorm2d-78          [-1, 512, 16, 16]           1,024
             ReLU-79          [-1, 512, 16, 16]               0
             ReLU-80          [-1, 512, 16, 16]               0
           Conv2d-81          [-1, 512, 16, 16]       2,359,808
           Conv2d-82          [-1, 512, 16, 16]       2,359,808
      BatchNorm2d-83          [-1, 512, 16, 16]           1,024
      BatchNorm2d-84          [-1, 512, 16, 16]           1,024
             ReLU-85          [-1, 512, 16, 16]               0
             ReLU-86          [-1, 512, 16, 16]               0
  ConvTranspose2d-87          [-1, 512, 32, 32]       1,049,088
      BatchNorm2d-88          [-1, 512, 32, 32]           1,024
             ReLU-89          [-1, 512, 32, 32]               0
           Conv2d-90          [-1, 256, 32, 32]       2,359,552
      BatchNorm2d-91          [-1, 256, 32, 32]             512
             ReLU-92          [-1, 256, 32, 32]               0
           Conv2d-93          [-1, 256, 32, 32]         590,080
      BatchNorm2d-94          [-1, 256, 32, 32]             512
             ReLU-95          [-1, 256, 32, 32]               0
  ConvTranspose2d-96          [-1, 256, 64, 64]         262,400
      BatchNorm2d-97          [-1, 256, 64, 64]             512
             ReLU-98          [-1, 256, 64, 64]               0
           Conv2d-99          [-1, 128, 64, 64]         589,952
     BatchNorm2d-100          [-1, 128, 64, 64]             256
            ReLU-101          [-1, 128, 64, 64]               0
          Conv2d-102          [-1, 128, 64, 64]         147,584
     BatchNorm2d-103          [-1, 128, 64, 64]             256
            ReLU-104          [-1, 128, 64, 64]               0
 ConvTranspose2d-105        [-1, 128, 128, 128]          65,664
     BatchNorm2d-106        [-1, 128, 128, 128]             256
            ReLU-107        [-1, 128, 128, 128]               0
          Conv2d-108         [-1, 64, 128, 128]         147,520
     BatchNorm2d-109         [-1, 64, 128, 128]             128
            ReLU-110         [-1, 64, 128, 128]               0
          Conv2d-111         [-1, 64, 128, 128]          36,928
     BatchNorm2d-112         [-1, 64, 128, 128]             128
            ReLU-113         [-1, 64, 128, 128]               0
 ConvTranspose2d-114         [-1, 64, 256, 256]          16,448
     BatchNorm2d-115         [-1, 64, 256, 256]             128
            ReLU-116         [-1, 64, 256, 256]               0
          Conv2d-117         [-1, 32, 256, 256]          36,896
     BatchNorm2d-118         [-1, 32, 256, 256]              64
            ReLU-119         [-1, 32, 256, 256]               0
          Conv2d-120         [-1, 32, 256, 256]           9,248
     BatchNorm2d-121         [-1, 32, 256, 256]              64
            ReLU-122         [-1, 32, 256, 256]               0
          Conv2d-123         [-1, 10, 256, 256]             330
================================================================
Total params: 34,761,802
Trainable params: 5,315,530
Non-trainable params: 29,446,272
----------------------------------------------------------------
Input size (MB): 0.75
Forward/backward pass size (MB): 1205.00
Params size (MB): 132.61
Estimated Total Size (MB): 1338.36
----------------------------------------------------------------

ResNet Encoder

As a second example, I will now define a UNet architecture that can accept different versions of a ResNet architecture as the backbone. I begin by redefining the double_conv() and up_conv() functions. In the myUNetResNet() subclass definition, I define the following parameters

  • inChn = number of input channels
  • inCls = number of output classes
  • resNet = which ResNet architecture to use (“18”, “34”, or “50”); the default is “18”
  • useWhgts = whether or not to initialize using pre-trained weights.

The backbone used will depend on the argument provided to the resNet parameter. Note that this will not change how the encoder blocks are defined since the list of layers does not change. The type or ReNet varies based on the number of operations in each layer, but the total number of layers remains the same. Similar to the UNet using the VGGNet-16 backbone, the layers from the ResNet must be partitioned into the appropriate encoder blocks such that the correct array size is extracted to concatenate with the associated decoder block. The first operation in the ResNet architecture actually reduces the array size, so the first encoder block uses the defined double_conv() operation as opposed to the first set of layers in the ResNet. Portions of the ResNet are used to define the operations in the next 4 encoder blocks and the bottleneck layer. Once these components are defined, the decoder components are defined, which consist of upsampling with 2D transpose convolution and learning new filters with a sequence of two 2D convolution layers.

In the forward method, as normal, how the data pass through the architecture is defined. The input data will first pass through the first encoder block. Since this is not part of the ResNet architecture, the original data are also passed through the second encoder block as opposed to the output from the first encoder block. The 2nd through 5th encoder blocks and the bottleneck block are defined based on components of the ResNet architecture.

Next, the data pass through the decoder component of the architecture. Each block consists of upsampling with 2D transpose convolution, concatenation of the feature maps from the associated encoder block, and using two 2D convolution layers to learn additional filters. Lastly, the data are then passed trough a 2D convolutional layer with a kernel size of 1x1 and a stride of 1 to obtain the class logits.

Again, it is important here that the correct stage of the ResNet architecture be assigned to the correct encoder block so that the skip connections will deliver arrays with the correct sizes in the spatial dimensions, which will then be concatenated with the layers from the bottleneck or prior decoder block.

def double_conv(inChannels, outChannels):
  return nn.Sequential(
      nn.Conv2d(inChannels, outChannels, kernel_size=(3,3), stride=1, padding=1),
      nn.BatchNorm2d(outChannels),
      nn.ReLU(inplace=True),
      nn.Conv2d(outChannels, outChannels, kernel_size=(3,3), stride=1, padding=1),
      nn.BatchNorm2d(outChannels),
      nn.ReLU(inplace=True)
  )
def up_conv(inChannels, outChannels):
    return nn.Sequential(
      nn.ConvTranspose2d(inChannels, outChannels, kernel_size=(2,2), stride=2),
      nn.BatchNorm2d(outChannels),
      nn.ReLU(inplace=True)
  )
class myUNetResNet(nn.Module):
  def __init__(self, inChn, nCls, resNet = "18", useWghts=True):
    super().__init__()
    self.inChn = inChn
    self.nCls = nCls
    self.resNet = resNet
    self.useWghts = useWghts

    if(resNet == "34"):

      self.base_model = torchvision.models.resnet34(pretrained=useWghts)
      self.base_layers = list(self.base_model.children())
      self.outSizes = [64, 64, 128, 256, 512]
    
    elif(resNet == "50"):

      self.base_model = torchvision.models.resnet50(pretrained=useWghts)
      self.base_layers = list(self.base_model.children())
      self.outSizes = [64, 256, 512, 1024, 2048]
       
    else:

      self.base_model = torchvision.models.resnet18(pretrained=useWghts)
      self.base_layers = list(self.base_model.children())
      self.outSizes = [64, 64, 128, 256, 512]

    self.encoder1 = double_conv(inChn, 16)

    self.encoder2 = nn.Sequential(*self.base_layers[:3]) 
    self.encoder3 = nn.Sequential(*self.base_layers[3:5]) 
    self.encoder4 = self.base_layers[5] 
    self.encoder5 = self.base_layers[6]  
    self.bottleneck = self.base_layers[7] 

    self.decoder1up = up_conv(self.outSizes[4], 512) 
    self.decoder1 = double_conv(self.outSizes[3] + 512, 256) 

    self.decoder2up = up_conv(256, 256)
    self.decoder2 = double_conv(self.outSizes[2] + 256, 128)

    self.decoder3up = up_conv(128, 128)
    self.decoder3 = double_conv(self.outSizes[1] + 128, 64)

    self.decoder4up = up_conv(64, 64)
    self.decoder4 = double_conv(self.outSizes[0] + 64, 32)

    self.decoder5up = up_conv(32, 32)
    self.decoder5 = double_conv(16 + 32, 16)

    self.classifier = nn.Conv2d(16, nCls, kernel_size=(1,1))

  def forward(self, x):

    #Encoder
    encoder1 = self.encoder1(x)
    encoder2 = self.encoder2(x)
    encoder3 = self.encoder3(encoder2)
    encoder4 = self.encoder4(encoder3)
    encoder5 = self.encoder5(encoder4)

    #Bottleneck
    x = self.bottleneck(encoder5)

    #Decoder
    x = self.decoder1up(x)
    x = torch.concat([x, encoder5], dim=1)
    x = self.decoder1(x)

    x = self.decoder2up(x)
    x = torch.concat([x, encoder4], dim=1)
    x = self.decoder2(x)

    x = self.decoder3up(x)
    x = torch.concat([x, encoder3], dim=1)
    x = self.decoder3(x)

    x = self.decoder4up(x)
    x = torch.concat([x, encoder2], dim=1)
    x = self.decoder4(x)

    x = self.decoder5up(x)
    x = torch.concat([x, encoder1], dim=1)
    x = self.decoder5(x)

    #Classifier head
    x = self.classifier(x)

    return x

I instantiate an instance of the myUnetResNet subclass that accepts 3 channels, differentiates 10 classes, uses a ResNet-18 architecture in the backbone or encoder, and is initialized using pre-trained weights in the encoder. I print a summary to explore the model architecture, which has over 26 million trainable parameter.

model = myUNetResNet(inChn=3, nCls=10, resNet = "18", useWghts=True).to(device)
C:\Users\vidcg\ANACON~1\envs\torchENV\lib\site-packages\torchvision\models\_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
summary(model, (3,256,256))
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 16, 256, 256]             448
       BatchNorm2d-2         [-1, 16, 256, 256]              32
              ReLU-3         [-1, 16, 256, 256]               0
            Conv2d-4         [-1, 16, 256, 256]           2,320
       BatchNorm2d-5         [-1, 16, 256, 256]              32
              ReLU-6         [-1, 16, 256, 256]               0
            Conv2d-7         [-1, 64, 128, 128]           9,408
            Conv2d-8         [-1, 64, 128, 128]           9,408
       BatchNorm2d-9         [-1, 64, 128, 128]             128
      BatchNorm2d-10         [-1, 64, 128, 128]             128
             ReLU-11         [-1, 64, 128, 128]               0
             ReLU-12         [-1, 64, 128, 128]               0
        MaxPool2d-13           [-1, 64, 64, 64]               0
        MaxPool2d-14           [-1, 64, 64, 64]               0
           Conv2d-15           [-1, 64, 64, 64]          36,864
           Conv2d-16           [-1, 64, 64, 64]          36,864
      BatchNorm2d-17           [-1, 64, 64, 64]             128
      BatchNorm2d-18           [-1, 64, 64, 64]             128
             ReLU-19           [-1, 64, 64, 64]               0
             ReLU-20           [-1, 64, 64, 64]               0
           Conv2d-21           [-1, 64, 64, 64]          36,864
           Conv2d-22           [-1, 64, 64, 64]          36,864
      BatchNorm2d-23           [-1, 64, 64, 64]             128
      BatchNorm2d-24           [-1, 64, 64, 64]             128
             ReLU-25           [-1, 64, 64, 64]               0
             ReLU-26           [-1, 64, 64, 64]               0
       BasicBlock-27           [-1, 64, 64, 64]               0
       BasicBlock-28           [-1, 64, 64, 64]               0
           Conv2d-29           [-1, 64, 64, 64]          36,864
           Conv2d-30           [-1, 64, 64, 64]          36,864
      BatchNorm2d-31           [-1, 64, 64, 64]             128
      BatchNorm2d-32           [-1, 64, 64, 64]             128
             ReLU-33           [-1, 64, 64, 64]               0
             ReLU-34           [-1, 64, 64, 64]               0
           Conv2d-35           [-1, 64, 64, 64]          36,864
           Conv2d-36           [-1, 64, 64, 64]          36,864
      BatchNorm2d-37           [-1, 64, 64, 64]             128
      BatchNorm2d-38           [-1, 64, 64, 64]             128
             ReLU-39           [-1, 64, 64, 64]               0
             ReLU-40           [-1, 64, 64, 64]               0
       BasicBlock-41           [-1, 64, 64, 64]               0
       BasicBlock-42           [-1, 64, 64, 64]               0
           Conv2d-43          [-1, 128, 32, 32]          73,728
           Conv2d-44          [-1, 128, 32, 32]          73,728
      BatchNorm2d-45          [-1, 128, 32, 32]             256
      BatchNorm2d-46          [-1, 128, 32, 32]             256
             ReLU-47          [-1, 128, 32, 32]               0
             ReLU-48          [-1, 128, 32, 32]               0
           Conv2d-49          [-1, 128, 32, 32]         147,456
           Conv2d-50          [-1, 128, 32, 32]         147,456
      BatchNorm2d-51          [-1, 128, 32, 32]             256
      BatchNorm2d-52          [-1, 128, 32, 32]             256
           Conv2d-53          [-1, 128, 32, 32]           8,192
           Conv2d-54          [-1, 128, 32, 32]           8,192
      BatchNorm2d-55          [-1, 128, 32, 32]             256
      BatchNorm2d-56          [-1, 128, 32, 32]             256
             ReLU-57          [-1, 128, 32, 32]               0
             ReLU-58          [-1, 128, 32, 32]               0
       BasicBlock-59          [-1, 128, 32, 32]               0
       BasicBlock-60          [-1, 128, 32, 32]               0
           Conv2d-61          [-1, 128, 32, 32]         147,456
           Conv2d-62          [-1, 128, 32, 32]         147,456
      BatchNorm2d-63          [-1, 128, 32, 32]             256
      BatchNorm2d-64          [-1, 128, 32, 32]             256
             ReLU-65          [-1, 128, 32, 32]               0
             ReLU-66          [-1, 128, 32, 32]               0
           Conv2d-67          [-1, 128, 32, 32]         147,456
           Conv2d-68          [-1, 128, 32, 32]         147,456
      BatchNorm2d-69          [-1, 128, 32, 32]             256
      BatchNorm2d-70          [-1, 128, 32, 32]             256
             ReLU-71          [-1, 128, 32, 32]               0
             ReLU-72          [-1, 128, 32, 32]               0
       BasicBlock-73          [-1, 128, 32, 32]               0
       BasicBlock-74          [-1, 128, 32, 32]               0
           Conv2d-75          [-1, 256, 16, 16]         294,912
           Conv2d-76          [-1, 256, 16, 16]         294,912
      BatchNorm2d-77          [-1, 256, 16, 16]             512
      BatchNorm2d-78          [-1, 256, 16, 16]             512
             ReLU-79          [-1, 256, 16, 16]               0
             ReLU-80          [-1, 256, 16, 16]               0
           Conv2d-81          [-1, 256, 16, 16]         589,824
           Conv2d-82          [-1, 256, 16, 16]         589,824
      BatchNorm2d-83          [-1, 256, 16, 16]             512
      BatchNorm2d-84          [-1, 256, 16, 16]             512
           Conv2d-85          [-1, 256, 16, 16]          32,768
           Conv2d-86          [-1, 256, 16, 16]          32,768
      BatchNorm2d-87          [-1, 256, 16, 16]             512
      BatchNorm2d-88          [-1, 256, 16, 16]             512
             ReLU-89          [-1, 256, 16, 16]               0
             ReLU-90          [-1, 256, 16, 16]               0
       BasicBlock-91          [-1, 256, 16, 16]               0
       BasicBlock-92          [-1, 256, 16, 16]               0
           Conv2d-93          [-1, 256, 16, 16]         589,824
           Conv2d-94          [-1, 256, 16, 16]         589,824
      BatchNorm2d-95          [-1, 256, 16, 16]             512
      BatchNorm2d-96          [-1, 256, 16, 16]             512
             ReLU-97          [-1, 256, 16, 16]               0
             ReLU-98          [-1, 256, 16, 16]               0
           Conv2d-99          [-1, 256, 16, 16]         589,824
          Conv2d-100          [-1, 256, 16, 16]         589,824
     BatchNorm2d-101          [-1, 256, 16, 16]             512
     BatchNorm2d-102          [-1, 256, 16, 16]             512
            ReLU-103          [-1, 256, 16, 16]               0
            ReLU-104          [-1, 256, 16, 16]               0
      BasicBlock-105          [-1, 256, 16, 16]               0
      BasicBlock-106          [-1, 256, 16, 16]               0
          Conv2d-107            [-1, 512, 8, 8]       1,179,648
          Conv2d-108            [-1, 512, 8, 8]       1,179,648
     BatchNorm2d-109            [-1, 512, 8, 8]           1,024
     BatchNorm2d-110            [-1, 512, 8, 8]           1,024
            ReLU-111            [-1, 512, 8, 8]               0
            ReLU-112            [-1, 512, 8, 8]               0
          Conv2d-113            [-1, 512, 8, 8]       2,359,296
          Conv2d-114            [-1, 512, 8, 8]       2,359,296
     BatchNorm2d-115            [-1, 512, 8, 8]           1,024
     BatchNorm2d-116            [-1, 512, 8, 8]           1,024
          Conv2d-117            [-1, 512, 8, 8]         131,072
          Conv2d-118            [-1, 512, 8, 8]         131,072
     BatchNorm2d-119            [-1, 512, 8, 8]           1,024
     BatchNorm2d-120            [-1, 512, 8, 8]           1,024
            ReLU-121            [-1, 512, 8, 8]               0
            ReLU-122            [-1, 512, 8, 8]               0
      BasicBlock-123            [-1, 512, 8, 8]               0
      BasicBlock-124            [-1, 512, 8, 8]               0
          Conv2d-125            [-1, 512, 8, 8]       2,359,296
          Conv2d-126            [-1, 512, 8, 8]       2,359,296
     BatchNorm2d-127            [-1, 512, 8, 8]           1,024
     BatchNorm2d-128            [-1, 512, 8, 8]           1,024
            ReLU-129            [-1, 512, 8, 8]               0
            ReLU-130            [-1, 512, 8, 8]               0
          Conv2d-131            [-1, 512, 8, 8]       2,359,296
          Conv2d-132            [-1, 512, 8, 8]       2,359,296
     BatchNorm2d-133            [-1, 512, 8, 8]           1,024
     BatchNorm2d-134            [-1, 512, 8, 8]           1,024
            ReLU-135            [-1, 512, 8, 8]               0
            ReLU-136            [-1, 512, 8, 8]               0
      BasicBlock-137            [-1, 512, 8, 8]               0
      BasicBlock-138            [-1, 512, 8, 8]               0
 ConvTranspose2d-139          [-1, 512, 16, 16]       1,049,088
     BatchNorm2d-140          [-1, 512, 16, 16]           1,024
            ReLU-141          [-1, 512, 16, 16]               0
          Conv2d-142          [-1, 256, 16, 16]       1,769,728
     BatchNorm2d-143          [-1, 256, 16, 16]             512
            ReLU-144          [-1, 256, 16, 16]               0
          Conv2d-145          [-1, 256, 16, 16]         590,080
     BatchNorm2d-146          [-1, 256, 16, 16]             512
            ReLU-147          [-1, 256, 16, 16]               0
 ConvTranspose2d-148          [-1, 256, 32, 32]         262,400
     BatchNorm2d-149          [-1, 256, 32, 32]             512
            ReLU-150          [-1, 256, 32, 32]               0
          Conv2d-151          [-1, 128, 32, 32]         442,496
     BatchNorm2d-152          [-1, 128, 32, 32]             256
            ReLU-153          [-1, 128, 32, 32]               0
          Conv2d-154          [-1, 128, 32, 32]         147,584
     BatchNorm2d-155          [-1, 128, 32, 32]             256
            ReLU-156          [-1, 128, 32, 32]               0
 ConvTranspose2d-157          [-1, 128, 64, 64]          65,664
     BatchNorm2d-158          [-1, 128, 64, 64]             256
            ReLU-159          [-1, 128, 64, 64]               0
          Conv2d-160           [-1, 64, 64, 64]         110,656
     BatchNorm2d-161           [-1, 64, 64, 64]             128
            ReLU-162           [-1, 64, 64, 64]               0
          Conv2d-163           [-1, 64, 64, 64]          36,928
     BatchNorm2d-164           [-1, 64, 64, 64]             128
            ReLU-165           [-1, 64, 64, 64]               0
 ConvTranspose2d-166         [-1, 64, 128, 128]          16,448
     BatchNorm2d-167         [-1, 64, 128, 128]             128
            ReLU-168         [-1, 64, 128, 128]               0
          Conv2d-169         [-1, 32, 128, 128]          36,896
     BatchNorm2d-170         [-1, 32, 128, 128]              64
            ReLU-171         [-1, 32, 128, 128]               0
          Conv2d-172         [-1, 32, 128, 128]           9,248
     BatchNorm2d-173         [-1, 32, 128, 128]              64
            ReLU-174         [-1, 32, 128, 128]               0
 ConvTranspose2d-175         [-1, 32, 256, 256]           4,128
     BatchNorm2d-176         [-1, 32, 256, 256]              64
            ReLU-177         [-1, 32, 256, 256]               0
          Conv2d-178         [-1, 16, 256, 256]           6,928
     BatchNorm2d-179         [-1, 16, 256, 256]              32
            ReLU-180         [-1, 16, 256, 256]               0
          Conv2d-181         [-1, 16, 256, 256]           2,320
     BatchNorm2d-182         [-1, 16, 256, 256]              32
            ReLU-183         [-1, 16, 256, 256]               0
          Conv2d-184         [-1, 10, 256, 256]             170
================================================================
Total params: 26,910,586
Trainable params: 26,910,586
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.75
Forward/backward pass size (MB): 403.00
Params size (MB): 102.66
Estimated Total Size (MB): 506.41
----------------------------------------------------------------

Similar to the VGGNet-16 example above, I can freeze the backbone parameters to reduce the number of trainable parameters in the model. This is accomplished by setting the requires_grad property for the backbone or encoder layers extracted from the ResNet architecture to False. Printing the summary, you can see that only a subset of the parameters is now trainable.

for l in model.base_layers:
  for param in l.parameters():
    param.requires_grad = False
summary(model, (3,256,256))
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 16, 256, 256]             448
       BatchNorm2d-2         [-1, 16, 256, 256]              32
              ReLU-3         [-1, 16, 256, 256]               0
            Conv2d-4         [-1, 16, 256, 256]           2,320
       BatchNorm2d-5         [-1, 16, 256, 256]              32
              ReLU-6         [-1, 16, 256, 256]               0
            Conv2d-7         [-1, 64, 128, 128]           9,408
            Conv2d-8         [-1, 64, 128, 128]           9,408
       BatchNorm2d-9         [-1, 64, 128, 128]             128
      BatchNorm2d-10         [-1, 64, 128, 128]             128
             ReLU-11         [-1, 64, 128, 128]               0
             ReLU-12         [-1, 64, 128, 128]               0
        MaxPool2d-13           [-1, 64, 64, 64]               0
        MaxPool2d-14           [-1, 64, 64, 64]               0
           Conv2d-15           [-1, 64, 64, 64]          36,864
           Conv2d-16           [-1, 64, 64, 64]          36,864
      BatchNorm2d-17           [-1, 64, 64, 64]             128
      BatchNorm2d-18           [-1, 64, 64, 64]             128
             ReLU-19           [-1, 64, 64, 64]               0
             ReLU-20           [-1, 64, 64, 64]               0
           Conv2d-21           [-1, 64, 64, 64]          36,864
           Conv2d-22           [-1, 64, 64, 64]          36,864
      BatchNorm2d-23           [-1, 64, 64, 64]             128
      BatchNorm2d-24           [-1, 64, 64, 64]             128
             ReLU-25           [-1, 64, 64, 64]               0
             ReLU-26           [-1, 64, 64, 64]               0
       BasicBlock-27           [-1, 64, 64, 64]               0
       BasicBlock-28           [-1, 64, 64, 64]               0
           Conv2d-29           [-1, 64, 64, 64]          36,864
           Conv2d-30           [-1, 64, 64, 64]          36,864
      BatchNorm2d-31           [-1, 64, 64, 64]             128
      BatchNorm2d-32           [-1, 64, 64, 64]             128
             ReLU-33           [-1, 64, 64, 64]               0
             ReLU-34           [-1, 64, 64, 64]               0
           Conv2d-35           [-1, 64, 64, 64]          36,864
           Conv2d-36           [-1, 64, 64, 64]          36,864
      BatchNorm2d-37           [-1, 64, 64, 64]             128
      BatchNorm2d-38           [-1, 64, 64, 64]             128
             ReLU-39           [-1, 64, 64, 64]               0
             ReLU-40           [-1, 64, 64, 64]               0
       BasicBlock-41           [-1, 64, 64, 64]               0
       BasicBlock-42           [-1, 64, 64, 64]               0
           Conv2d-43          [-1, 128, 32, 32]          73,728
           Conv2d-44          [-1, 128, 32, 32]          73,728
      BatchNorm2d-45          [-1, 128, 32, 32]             256
      BatchNorm2d-46          [-1, 128, 32, 32]             256
             ReLU-47          [-1, 128, 32, 32]               0
             ReLU-48          [-1, 128, 32, 32]               0
           Conv2d-49          [-1, 128, 32, 32]         147,456
           Conv2d-50          [-1, 128, 32, 32]         147,456
      BatchNorm2d-51          [-1, 128, 32, 32]             256
      BatchNorm2d-52          [-1, 128, 32, 32]             256
           Conv2d-53          [-1, 128, 32, 32]           8,192
           Conv2d-54          [-1, 128, 32, 32]           8,192
      BatchNorm2d-55          [-1, 128, 32, 32]             256
      BatchNorm2d-56          [-1, 128, 32, 32]             256
             ReLU-57          [-1, 128, 32, 32]               0
             ReLU-58          [-1, 128, 32, 32]               0
       BasicBlock-59          [-1, 128, 32, 32]               0
       BasicBlock-60          [-1, 128, 32, 32]               0
           Conv2d-61          [-1, 128, 32, 32]         147,456
           Conv2d-62          [-1, 128, 32, 32]         147,456
      BatchNorm2d-63          [-1, 128, 32, 32]             256
      BatchNorm2d-64          [-1, 128, 32, 32]             256
             ReLU-65          [-1, 128, 32, 32]               0
             ReLU-66          [-1, 128, 32, 32]               0
           Conv2d-67          [-1, 128, 32, 32]         147,456
           Conv2d-68          [-1, 128, 32, 32]         147,456
      BatchNorm2d-69          [-1, 128, 32, 32]             256
      BatchNorm2d-70          [-1, 128, 32, 32]             256
             ReLU-71          [-1, 128, 32, 32]               0
             ReLU-72          [-1, 128, 32, 32]               0
       BasicBlock-73          [-1, 128, 32, 32]               0
       BasicBlock-74          [-1, 128, 32, 32]               0
           Conv2d-75          [-1, 256, 16, 16]         294,912
           Conv2d-76          [-1, 256, 16, 16]         294,912
      BatchNorm2d-77          [-1, 256, 16, 16]             512
      BatchNorm2d-78          [-1, 256, 16, 16]             512
             ReLU-79          [-1, 256, 16, 16]               0
             ReLU-80          [-1, 256, 16, 16]               0
           Conv2d-81          [-1, 256, 16, 16]         589,824
           Conv2d-82          [-1, 256, 16, 16]         589,824
      BatchNorm2d-83          [-1, 256, 16, 16]             512
      BatchNorm2d-84          [-1, 256, 16, 16]             512
           Conv2d-85          [-1, 256, 16, 16]          32,768
           Conv2d-86          [-1, 256, 16, 16]          32,768
      BatchNorm2d-87          [-1, 256, 16, 16]             512
      BatchNorm2d-88          [-1, 256, 16, 16]             512
             ReLU-89          [-1, 256, 16, 16]               0
             ReLU-90          [-1, 256, 16, 16]               0
       BasicBlock-91          [-1, 256, 16, 16]               0
       BasicBlock-92          [-1, 256, 16, 16]               0
           Conv2d-93          [-1, 256, 16, 16]         589,824
           Conv2d-94          [-1, 256, 16, 16]         589,824
      BatchNorm2d-95          [-1, 256, 16, 16]             512
      BatchNorm2d-96          [-1, 256, 16, 16]             512
             ReLU-97          [-1, 256, 16, 16]               0
             ReLU-98          [-1, 256, 16, 16]               0
           Conv2d-99          [-1, 256, 16, 16]         589,824
          Conv2d-100          [-1, 256, 16, 16]         589,824
     BatchNorm2d-101          [-1, 256, 16, 16]             512
     BatchNorm2d-102          [-1, 256, 16, 16]             512
            ReLU-103          [-1, 256, 16, 16]               0
            ReLU-104          [-1, 256, 16, 16]               0
      BasicBlock-105          [-1, 256, 16, 16]               0
      BasicBlock-106          [-1, 256, 16, 16]               0
          Conv2d-107            [-1, 512, 8, 8]       1,179,648
          Conv2d-108            [-1, 512, 8, 8]       1,179,648
     BatchNorm2d-109            [-1, 512, 8, 8]           1,024
     BatchNorm2d-110            [-1, 512, 8, 8]           1,024
            ReLU-111            [-1, 512, 8, 8]               0
            ReLU-112            [-1, 512, 8, 8]               0
          Conv2d-113            [-1, 512, 8, 8]       2,359,296
          Conv2d-114            [-1, 512, 8, 8]       2,359,296
     BatchNorm2d-115            [-1, 512, 8, 8]           1,024
     BatchNorm2d-116            [-1, 512, 8, 8]           1,024
          Conv2d-117            [-1, 512, 8, 8]         131,072
          Conv2d-118            [-1, 512, 8, 8]         131,072
     BatchNorm2d-119            [-1, 512, 8, 8]           1,024
     BatchNorm2d-120            [-1, 512, 8, 8]           1,024
            ReLU-121            [-1, 512, 8, 8]               0
            ReLU-122            [-1, 512, 8, 8]               0
      BasicBlock-123            [-1, 512, 8, 8]               0
      BasicBlock-124            [-1, 512, 8, 8]               0
          Conv2d-125            [-1, 512, 8, 8]       2,359,296
          Conv2d-126            [-1, 512, 8, 8]       2,359,296
     BatchNorm2d-127            [-1, 512, 8, 8]           1,024
     BatchNorm2d-128            [-1, 512, 8, 8]           1,024
            ReLU-129            [-1, 512, 8, 8]               0
            ReLU-130            [-1, 512, 8, 8]               0
          Conv2d-131            [-1, 512, 8, 8]       2,359,296
          Conv2d-132            [-1, 512, 8, 8]       2,359,296
     BatchNorm2d-133            [-1, 512, 8, 8]           1,024
     BatchNorm2d-134            [-1, 512, 8, 8]           1,024
            ReLU-135            [-1, 512, 8, 8]               0
            ReLU-136            [-1, 512, 8, 8]               0
      BasicBlock-137            [-1, 512, 8, 8]               0
      BasicBlock-138            [-1, 512, 8, 8]               0
 ConvTranspose2d-139          [-1, 512, 16, 16]       1,049,088
     BatchNorm2d-140          [-1, 512, 16, 16]           1,024
            ReLU-141          [-1, 512, 16, 16]               0
          Conv2d-142          [-1, 256, 16, 16]       1,769,728
     BatchNorm2d-143          [-1, 256, 16, 16]             512
            ReLU-144          [-1, 256, 16, 16]               0
          Conv2d-145          [-1, 256, 16, 16]         590,080
     BatchNorm2d-146          [-1, 256, 16, 16]             512
            ReLU-147          [-1, 256, 16, 16]               0
 ConvTranspose2d-148          [-1, 256, 32, 32]         262,400
     BatchNorm2d-149          [-1, 256, 32, 32]             512
            ReLU-150          [-1, 256, 32, 32]               0
          Conv2d-151          [-1, 128, 32, 32]         442,496
     BatchNorm2d-152          [-1, 128, 32, 32]             256
            ReLU-153          [-1, 128, 32, 32]               0
          Conv2d-154          [-1, 128, 32, 32]         147,584
     BatchNorm2d-155          [-1, 128, 32, 32]             256
            ReLU-156          [-1, 128, 32, 32]               0
 ConvTranspose2d-157          [-1, 128, 64, 64]          65,664
     BatchNorm2d-158          [-1, 128, 64, 64]             256
            ReLU-159          [-1, 128, 64, 64]               0
          Conv2d-160           [-1, 64, 64, 64]         110,656
     BatchNorm2d-161           [-1, 64, 64, 64]             128
            ReLU-162           [-1, 64, 64, 64]               0
          Conv2d-163           [-1, 64, 64, 64]          36,928
     BatchNorm2d-164           [-1, 64, 64, 64]             128
            ReLU-165           [-1, 64, 64, 64]               0
 ConvTranspose2d-166         [-1, 64, 128, 128]          16,448
     BatchNorm2d-167         [-1, 64, 128, 128]             128
            ReLU-168         [-1, 64, 128, 128]               0
          Conv2d-169         [-1, 32, 128, 128]          36,896
     BatchNorm2d-170         [-1, 32, 128, 128]              64
            ReLU-171         [-1, 32, 128, 128]               0
          Conv2d-172         [-1, 32, 128, 128]           9,248
     BatchNorm2d-173         [-1, 32, 128, 128]              64
            ReLU-174         [-1, 32, 128, 128]               0
 ConvTranspose2d-175         [-1, 32, 256, 256]           4,128
     BatchNorm2d-176         [-1, 32, 256, 256]              64
            ReLU-177         [-1, 32, 256, 256]               0
          Conv2d-178         [-1, 16, 256, 256]           6,928
     BatchNorm2d-179         [-1, 16, 256, 256]              32
            ReLU-180         [-1, 16, 256, 256]               0
          Conv2d-181         [-1, 16, 256, 256]           2,320
     BatchNorm2d-182         [-1, 16, 256, 256]              32
            ReLU-183         [-1, 16, 256, 256]               0
          Conv2d-184         [-1, 10, 256, 256]             170
================================================================
Total params: 26,910,586
Trainable params: 4,557,562
Non-trainable params: 22,353,024
----------------------------------------------------------------
Input size (MB): 0.75
Forward/backward pass size (MB): 403.00
Params size (MB): 102.66
Estimated Total Size (MB): 506.41
----------------------------------------------------------------

Concluding Remarks

You can now define a basic UNet architecture, train a UNet model, and define UNet architectures that make use of common CNN architectures as the backbone or encoder and can accept pre-trained weights. However, there are other semantic segmentation architectures that are more complex and difficult to build from scratch. Also, you may want to be able to use a wide variety of backbones in a variety of different semantic segmentation architectures. In the next module, we will explore the Segmentation Models package, which builds on PyTorch and allows for using many different semantic segmentation architectures, backbone encoders, and pre-trained weights without having to build them or define the model or components on your own.