Build A Convolutional Neural Network

Build a Convolutional Neural Network

Introduction

Now that you have an understanding of how CNNs work and the different layers that are used to construct them, you are ready to build a CNN architecture using PyTorch and by subclassing nn.Module. In this section, we will build a CNN. In the next section, we will prepare image data then train and assess a CNN. The last two sections focused with CNNs will explore some famous CNN architectures and how to implement transfer learning. In this module, you should focus on understanding how the CNN architecture is defined using PyTorch and not worry about the other components of training and implementing them.

I first import the required packages. Next, I define the device. If a GPU is available, it is used. If one is not available, the CPU is used. Note that training a CNN using a CPU can be very slow. So, if you plan to explore CNNs further, you need to have access to a CUDA-enabled GPU.

import torch
import torchsummary
import torch.nn as nn
import torch.nn.functional as F
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
cuda:0

Define Architecture

My goal is to build a CNN that allows the user to define the number of classes, number of input channels, number of feature maps in each convolution layer, and the number of nodes in each fully connected layer. The CNN will have the following general architecture.

  1. A total of 4 2D convolutional layers.

  2. After each convolutional layer, batch normalization will be applied followed by a rectified linear unit (ReLU) activation.

  3. After the convolutional component of the model, the resulting array will be flattened.

  4. The flatten array, or 1D vector, will be passed through 3 fully connected layers.

  5. After each of the first 2 fully connected layers, batch normalization will be applied followed by a ReLU activation.

  6. After the final fully connected layer, no batch normalization or activation will be applied, and raw logits will be delivered.

I will need to know the height and width of the array after the final max pooling operation is applied. So, I write a simple helper function that allows me to calculate the height and width after the final max pooling operation if provided with the original input size and the number of convolution blocks. The function simply divides the original size by two in a loop in which the number of loops is equivalent to the number of convolution blocks and associated max pooling operations.

I test this simple function using an original input size of 128x128 and 4 convolution blocks. Given these inputs, the final height and width of the tensor after all 4 max pooling operations are applied will be 8x8 (128x128 –> 64x64 –> 32x32 –> 16x16 –> 8x8).

def getDim(inDim, numBlks):
  x = inDim
  for i in range(1,numBlks+1,1):
    x = x/2
  return int(x)
getDim(128, 4)
8

I first define my CNN without using nn.Sequential(). Later, I will replicate this process using nn.Sequential(). This is accomplished by subclassing nn.Module. Here is a walkthrough of the process.

  1. Begin by subclassing nn.Module(). I use super().__init__() inside of the __init__() constructor method and also define all input parameters.

  2. I then define all of the required layers within the __init__() constructor method. A total of 4 nn.Conv2d() operations are define. Note that the first one will accept the number of input channels and output the first provided output channel size. Each subsequent nn.Conv2d() layer requires that the in_channels parameter be equal to the out_channels parameter of the prior convolution operation. Within all nn.Conv2d() blocks, I use a kernel size of (3x3) and a padding of 1. This will result in an output height and width that is the same as the input height and width for that specific nn.Conv2d() layer.

  3. I then define associated nn.BatchNorm2d() layers for each nn.Conv2d() layer. The number of features parameter is equivalent to the output channel size of the associated nn.Conv2d() layer.

  4. I define an nn.ReLU() and an nn.MaxPool2d() layer. These only need defined once since they have the same settings each time they are called. For max pooling, a kernel size of 2x2 is used with a stride of 2. This will result in decreasing the array size by half in the spatial dimensions.

  5. I then define the fully connected layers. The first fully connected layer will have an input size equal to the number of output channels from the last convolution layer multiplied by the height and width of the tensor following the last max pooling application. All subsequent layers require that the number of input features be equivalent to the number of output features from the prior fully connected layer. The last fully connected layer must have an output size equal to the number of classes being differentiated.

  6. I define batch normalization layers to accompany the first and second fully connected layers. Since these are fully connected layers, I used nn.BatchNorm1d() as opposed to nn.BatchNorm2d(). The number of features parameter must be equivalent to the number of output features from the associated fully connected layer.

  7. In the forward method, how the input data will be passed through the network is defined. Within each convolution block the sequence is 2D Convolution –> Batch Normalization –> ReLU Activation –> Max Pooling.

  8. Once the data have been passed through the convolutional component of the architecture, they are flattened to a 1D tensor using view().

  9. The flatten tensor is then passed through the fully connected layers. Note that no batch normalization or activation function is applied after the last fully connected layer.

  10. I am not applying a softmax activation. As a result, the logits will be returned as opposed to probabilities.

class myCNN(nn.Module):
  def __init__(self, nCls, inChn, outChn, fcChn, lastDim):
    super().__init__()
    self.nCls = nCls
    self.inChn = inChn
    self.outChn = outChn
    self.fcChn = fcChn
    self.lastDim = lastDim

    self.conv1 = nn.Conv2d(in_channels=inChn, out_channels=outChn[0], kernel_size=3, padding=1)
    self.conv2 = nn.Conv2d(in_channels=outChn[0], out_channels=outChn[1], kernel_size=(3,3), padding=1)
    self.conv3 = nn.Conv2d(in_channels=outChn[1], out_channels=outChn[2], kernel_size=(3,3), padding=1)
    self.conv4 = nn.Conv2d(in_channels=outChn[2], out_channels=outChn[3], kernel_size=(3,3), padding=1)

    self.bn1 = nn.BatchNorm2d(outChn[0])
    self.bn2 = nn.BatchNorm2d(outChn[1])
    self.bn3 = nn.BatchNorm2d(outChn[2])
    self.bn4 = nn.BatchNorm2d(outChn[3])

    self.relu = nn.ReLU()
    self.pool = nn.MaxPool2d(kernel_size=(2,2),stride=2)

    self.fc1 = nn.Linear(lastDim*lastDim*outChn[3], fcChn[0])
    self.fc2 = nn.Linear(fcChn[0],  fcChn[1])
    self.fc3 = nn.Linear(fcChn[1], nCls)

    self.bnfc1 = nn.BatchNorm1d(fcChn[0])
    self.bnfc2 = nn.BatchNorm1d(fcChn[1])
  
  def forward(self, x):
    x = self.pool(self.relu(self.bn1(self.conv1(x))))
    x = self.pool(self.relu(self.bn2(self.conv2(x))))
    x = self.pool(self.relu(self.bn3(self.conv3(x))))
    x = self.pool(self.relu(self.bn4(self.conv4(x))))
    x = x.view(-1, self.lastDim*self.lastDim*self.outChn[3])
    x = self.relu(self.bnfc1(self.fc1(x)))
    x = self.relu(self.bnfc2(self.fc2(x)))
    x = self.fc3(x)
    return x    

Once the CNN architecture is defined, I instantiate an instance of the subclass. Here, I must define the number of classes to differentiate, the number of input channels, a list of number of kernels to learn in each convolution block that must be the sample length as the number of convolution blocks, the number of features for the first and second fully connected layers as a list, and the height and width of the tensor following the last max pooling operation. This is required to calculate the length of the flattened array prior to feeding it through the fully connected layers. This is calculated using the helper function I defined above.

The model is then moved to the device using the to() method.

model = myCNN(nCls=10,
              inChn=3, 
              outChn=[10,20,30,40],
              fcChn=[268,512], 
              lastDim=getDim(128,4)).to(device)

The torchsummary package allows for summarizing a neural network. The summary() function accepts a model and an associated input tensor shape. It provides a summary including the layers being used, the shape of the output of each layer, and the number of parameters. It also lists the number of total parameters, the number of trainable parameters, and the number of non-trainable parameters. The size of the model in regards to memory consumption is also estimated.

In the example below, I have performed the summarization using an array size of (3,128,128). This represents images with 3 bands and a height and width of 128 pixels. Note that the spatial dimension of the array is halved after every max pooling operation. Most of the trainable parameters are associated with the kernel weights and fully connected layers. However, there are also some trainable parameters associated with the batch normalization layers. The ReLU and max pooling layers have no trainable parameters.

The first dimension, -1 throughout the network, represents the batch dimension. A value of -1 indicates that the batch size is not defined.

Lastly, note that the there are no non-trainable parameters. This is because none of the layers have been frozen. We will discuss freezing components of an architecture in later modules.

torchsummary.summary(model, (3,128,128))
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 10, 128, 128]             280
       BatchNorm2d-2         [-1, 10, 128, 128]              20
              ReLU-3         [-1, 10, 128, 128]               0
         MaxPool2d-4           [-1, 10, 64, 64]               0
            Conv2d-5           [-1, 20, 64, 64]           1,820
       BatchNorm2d-6           [-1, 20, 64, 64]              40
              ReLU-7           [-1, 20, 64, 64]               0
         MaxPool2d-8           [-1, 20, 32, 32]               0
            Conv2d-9           [-1, 30, 32, 32]           5,430
      BatchNorm2d-10           [-1, 30, 32, 32]              60
             ReLU-11           [-1, 30, 32, 32]               0
        MaxPool2d-12           [-1, 30, 16, 16]               0
           Conv2d-13           [-1, 40, 16, 16]          10,840
      BatchNorm2d-14           [-1, 40, 16, 16]              80
             ReLU-15           [-1, 40, 16, 16]               0
        MaxPool2d-16             [-1, 40, 8, 8]               0
           Linear-17                  [-1, 268]         686,348
      BatchNorm1d-18                  [-1, 268]             536
             ReLU-19                  [-1, 268]               0
           Linear-20                  [-1, 512]         137,728
      BatchNorm1d-21                  [-1, 512]           1,024
             ReLU-22                  [-1, 512]               0
           Linear-23                   [-1, 10]           5,130
================================================================
Total params: 849,336
Trainable params: 849,336
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.19
Forward/backward pass size (MB): 7.13
Params size (MB): 3.24
Estimated Total Size (MB): 10.55
----------------------------------------------------------------

Define Architecture using nn.Sequential()

In the code block below, I have defined the same network as above but now making use of nn.Sequential(). After instantiating an instance of the model and using the summary function from torchsummary, you can see that the resulting model is the same as the one defined above. So, either method is valid for defining the model architecture. Whether or not you want to use nn.Sequential() comes down to personal preference; however, it can simplify the code in many cases. I generally prefer to use nn.Sequential() as I find the code more readable.

class myCNNSeq(nn.Module):
  def __init__(self, nCls, inChn, outChn, fcChn, lastDim):
    super().__init__()
    self.nCls = nCls
    self.inChn = inChn
    self.outChn = outChn
    self.fcChn = fcChn
    self.lastDim = lastDim

    self.cnnLyrs = nn.Sequential(
        nn.Conv2d(in_channels=inChn, out_channels=outChn[0], kernel_size=(3,3), padding=1),
        nn.BatchNorm2d(outChn[0]),
        nn.ReLU(inplace=True),
        nn.MaxPool2d(2,2),
        nn.Conv2d(in_channels=outChn[0], out_channels=outChn[1], kernel_size=(3,3), padding=1),
        nn.BatchNorm2d(outChn[1]),
        nn.ReLU(inplace=True),
        nn.MaxPool2d(2,2),
        nn.Conv2d(in_channels=outChn[1], out_channels=outChn[2], kernel_size=(3,3), padding=1),
        nn.BatchNorm2d(outChn[2]),
        nn.ReLU(inplace=True),
        nn.MaxPool2d(2,2),
        nn.Conv2d(in_channels=outChn[2], out_channels=outChn[3], kernel_size=(3,3), padding=1),
        nn.BatchNorm2d(outChn[3]),
        nn.ReLU(inplace=True),
        nn.MaxPool2d(2,2)   
    )

    self.fcLyrs = nn.Sequential(
        nn.Linear(lastDim*lastDim*outChn[3], fcChn[0]),
        nn.BatchNorm1d(fcChn[0]),
        nn.ReLU(inplace=True),
        nn.Linear(fcChn[0], fcChn[1]),
        nn.BatchNorm1d(fcChn[1]),
        nn.ReLU(inplace=True),
        nn.Linear(fcChn[1], nCls)
    )   
  
  def forward(self,x):
    x = self.cnnLyrs(x)
    x = x.view(-1, self.lastDim*self.lastDim*self.outChn[3])
    x = self.fcLyrs(x)
    return x                    
model = myCNNSeq(nCls=10,
            inChn=3, 
            outChn=[10,20,30,40],
            fcChn=[268,512],  
            lastDim=getDim(128,4)).to(device)
torchsummary.summary(model, (3,128,128))
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 10, 128, 128]             280
       BatchNorm2d-2         [-1, 10, 128, 128]              20
              ReLU-3         [-1, 10, 128, 128]               0
         MaxPool2d-4           [-1, 10, 64, 64]               0
            Conv2d-5           [-1, 20, 64, 64]           1,820
       BatchNorm2d-6           [-1, 20, 64, 64]              40
              ReLU-7           [-1, 20, 64, 64]               0
         MaxPool2d-8           [-1, 20, 32, 32]               0
            Conv2d-9           [-1, 30, 32, 32]           5,430
      BatchNorm2d-10           [-1, 30, 32, 32]              60
             ReLU-11           [-1, 30, 32, 32]               0
        MaxPool2d-12           [-1, 30, 16, 16]               0
           Conv2d-13           [-1, 40, 16, 16]          10,840
      BatchNorm2d-14           [-1, 40, 16, 16]              80
             ReLU-15           [-1, 40, 16, 16]               0
        MaxPool2d-16             [-1, 40, 8, 8]               0
           Linear-17                  [-1, 268]         686,348
      BatchNorm1d-18                  [-1, 268]             536
             ReLU-19                  [-1, 268]               0
           Linear-20                  [-1, 512]         137,728
      BatchNorm1d-21                  [-1, 512]           1,024
             ReLU-22                  [-1, 512]               0
           Linear-23                   [-1, 10]           5,130
================================================================
Total params: 849,336
Trainable params: 849,336
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.19
Forward/backward pass size (MB): 7.13
Params size (MB): 3.24
Estimated Total Size (MB): 10.55
----------------------------------------------------------------

Concluding Remarks

You now know how to build a CNN architecture using PyTorch and by subclassing the nn.Module class. This architecture was constructed using a small set of layer types: 2D convolution, 2D batch normalization, ReLU activation function, max pooling, fully connected layers, and 1D batch normalization. Further, the model was generalized so that the user could provide inputs with varying shapes and define the number of classes that will be differentiated, output number of learned kernels for each convolution layer, and intermediate sizes of the fully connected layers.

Now, we are ready to train a CNN.