Transfer Learning and CNNs

Transfer Learning and CNNs

Introduction

In this last section focused on CNNs for scene classification or scene labeling, we will experiment with transfer learning for classifying the EuroSatAllBands dataset. Specifically, we will use a modified ResNet-32 architecture initialized using pre-trained weights from ImageNet. The majority of the code that you will see in this module is not new, so I will keep my explanations brief. As this module demonstrates, once you have learned the key components of implementing deep learning with PyTorch, it is possible to use your knowledge and adapt your or another analyst’s code to a new purpose or problem.

Preparation

This first section is identical to the first section of the Train a CNN module. Here is a quick review of the steps required to prepare the EuroSatAllBands data for input into a CNN architecture designed for scene labeling.

  1. Import the required libraries. This includes numpy, pandas, matplotlib, seaborn, os, torch, torch.nn, torch.utils.data.dataset, torch.utils.data, rasterio, torchmetrics, torchsummary, torchvision, and torchvision.transforms. I also import some specific assessment functions from scikit-learn.
  2. Set the device variable to the GPU if one is available.
  3. Set the folder path to the data to the folder variable.
  4. Read in the data tables of image chips using pandas. Here, I am also augmenting the file path since I am working on a different computer with a different directory structure.
  5. Define a DataSet subclass to obtain the band statistics.
  6. Instantiate a DataSet for the training set.
  7. Define a DataLoader for the DataSet.
  8. Define a function to calculate the pixel-level band means and standard deviations.
  9. Calculate the band means and standard deviations.
  10. Define a DataSet subclass for use in the training and validation process that accepts a DataFrame containing the information for each image chip, the band means and standard deviations, and transforms.
  11. Further prepare the band means and standard deviations to create tensors with the same shape as the input image chips (10, 64, 64).
  12. Define transforms for the training data. Here, I am using random horizontal and vertical flips.
  13. Instantiate an instance of the DataSet class for the training data that uses the defined transforms and an instance of the validation data that does not apply the transforms.
  14. Define the DataLoader for the training and validation data. I am using a batch size of 32 here as this worked on my hardware. You may need to change this depending on your system and GPU specifications.
  15. Perform checks to make sure the batch shapes and data types are correct. Also, check a single image chip.
  16. Define a function to display a batch of images and their associated labels.
  17. Visualize a batch of images and their associated labels.
import numpy as np
import pandas as pd
import matplotlib
from matplotlib import pyplot as plt 
import seaborn as sns

import os
from sklearn import preprocessing
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report


import torch
import torch.nn as nn
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader

import rasterio as rio

import torchmetrics as tm

from torchsummary import summary

import torchvision
import torchvision.transforms as transforms
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
cuda:0
folder = "C:/myFiles/work/dl/eurosat/EuroSATallBands/"
train = pd.read_csv(folder+"mytrain.csv")
test = pd.read_csv(folder+"mytest.csv")
val = pd.read_csv(folder+"myval.csv")
class EuroSat(Dataset):
    
    def __init__(self, df):
        super().__init__
        self.df = df
    
    def __getitem__(self, idx):
        image_name = self.df.iloc[idx, 1]
        label = self.df.iloc[idx, 3]
        label = np.array(label)
        source = rio.open(image_name)
        image = source.read()
        source.close()
        image = image.astype('float32')
        image = image[[1,2,3,4,5,6,7,8,11,12], :, :]
        image = torch.from_numpy(image)
        label = torch.from_numpy(label)
        label = label.long()
        return image, label 
        
    def __len__(self):
        return len(self.df)
trainDS = EuroSat(train)
print(len(trainDS))
16558
trainDL = torch.utils.data.DataLoader(trainDS, batch_size=32, shuffle=True, sampler=None,
num_workers=0, pin_memory=False, drop_last=False)
#https://www.binarystudy.com/2021/04/how-to-calculate-mean-standard-deviation-images-pytorch.html
def batch_mean_and_sd(loader, inChn):
    
    cnt = 0
    fst_moment = torch.empty(inChn)
    snd_moment = torch.empty(inChn)

    for images, _ in loader:
        b, c, h, w = images.shape
        nb_pixels = b * h * w
        sum_ = torch.sum(images, dim=[0, 2, 3])
        sum_of_square = torch.sum(images ** 2,
                                  dim=[0, 2, 3])
        fst_moment = (cnt * fst_moment + sum_) / (cnt + nb_pixels)
        snd_moment = (cnt * snd_moment + sum_of_square) / (cnt + nb_pixels)
        cnt += nb_pixels

    mean, std = fst_moment, torch.sqrt(snd_moment - fst_moment ** 2)        
    return mean,std
band_stats = batch_mean_and_sd(trainDL, 10)
class EuroSat(Dataset):
    
    def __init__(self, df, mnImg, sdImg, transform):
        self.df = df
        self.mnImg = mnImg
        self.sdImg = sdImg
        self.transform = transform
    
    def __getitem__(self, idx):
        image_name = self.df.iloc[idx, 1]
        label = self.df.iloc[idx, 3]
        label = np.array(label)
        source = rio.open(image_name)
        image = source.read()
        source.close()
        image = image[[1,2,3,4,5,6,7,8,11,12], :, :]
        image = np.subtract(image, self.mnImg)
        image = np.divide(image, self.sdImg)
        image = image.astype('float32')
        image = torch.from_numpy(image)
        label = torch.from_numpy(label)
        label = label.long()
        if self.transform is not None:
            image = self.transform(image)
        return image, label
        
    def __len__(self):
        return len(self.df)
bndMns = np.array(band_stats[0].tolist())
bndSDs = np.array(band_stats[1].tolist())
mnImg = np.repeat(bndMns[0], 64*64).reshape((64,64,1))
for b in range(1,len(bndMns)):
    mnImg2 = np.repeat(bndMns[b], 64*64).reshape((64,64,1))
    mnImg = np.dstack([mnImg, mnImg2])
mnImg = np.transpose(mnImg, (2,0,1))

sdImg = np.repeat(bndSDs[0], 64*64).reshape((64,64,1))
for b in range(1,len(bndSDs)):
    sdImg2 = np.repeat(bndSDs[b], 64*64).reshape((64,64,1))
    sdImg = np.dstack([sdImg, sdImg2])
sdImg = np.transpose(sdImg, (2,0,1))

print(mnImg.shape)
(10, 64, 64)
print(sdImg.shape)
(10, 64, 64)
myTransforms = transforms.Compose(
    [transforms.RandomHorizontalFlip(p=0.3), 
    transforms.RandomVerticalFlip(p=0.3),]
    )
trainDS = EuroSat(train, mnImg, sdImg, transform=myTransforms)
valDS = EuroSat(val, mnImg, sdImg, transform=None)
trainDL = torch.utils.data.DataLoader(trainDS, batch_size=32, shuffle=True, sampler=None,
num_workers=0, pin_memory=False, drop_last=True)

valDL = torch.utils.data.DataLoader(valDS, batch_size=32, shuffle=False, sampler=None,
num_workers=0, pin_memory=False, drop_last=True)
batch = next(iter(trainDL))
images, labels = batch
print(f'Batch Image Shape: {images.shape}, Batch Label Shape: {labels.shape}')
Batch Image Shape: torch.Size([32, 10, 64, 64]), Batch Label Shape: torch.Size([32])
print(f'Batch Image Data Type: {images.dtype}, Batch Label Data Type: {labels.dtype}')
Batch Image Data Type: torch.float32, Batch Label Data Type: torch.int64
print(f'Batch Image Band Means: {torch.mean(images, dim=(0,2,3))}')
Batch Image Band Means: tensor([-0.0640, -0.0715, -0.0596, -0.0516, -0.0974, -0.0890, -0.0774, -0.0627,
        -0.0407, -0.0720])
print(f'Batch Label Minimum: {torch.min(labels, dim=0)}, Batch Label Maximum: {torch.max(labels, dim=0)}')
Batch Label Minimum: torch.return_types.min(
values=tensor(0),
indices=tensor(1)), Batch Label Maximum: torch.return_types.max(
values=tensor(9),
indices=tensor(0))
testImg = images[1]
testMsk = labels[1]
print(f'Image Shape: {testImg.shape}, Label Shape: {testMsk.shape}')
Image Shape: torch.Size([10, 64, 64]), Label Shape: torch.Size([])
print(f'Image Data Type: {testImg.dtype}, Label Data Type: {testMsk.dtype}')
Image Data Type: torch.float32, Label Data Type: torch.int64
def img_display(img, mnImg, sdImg):
    img = np.multiply(img, sdImg)
    img = np.add(img, mnImg)
    image_vis = img[[2,1,0],:,:]
    image_vis = image_vis.permute(1,2,0)
    image_vis = (image_vis.numpy()/4000)*255
    image_vis = image_vis.astype('uint8')
    return image_vis

batch = next(iter(trainDL))
images, labels = batch

cover_types = {0: 'Annual Crop', 
1: 'Forest', 
2: 'Herb Veg', 
3: 'Highway', 
4: 'Industrial',
5: 'Pasture',
6: 'Perm Crop',
7: 'Residential',
8: 'River',
9: 'SeaLake'}
fig, axis = plt.subplots(4, 8, figsize=(15, 10))
for i, ax in enumerate(axis.flat):
    with torch.no_grad():
        image, label = images[i], labels[i]
        ax.imshow(img_display(image, mnImg, sdImg)) # add image
        ax.set(title = f"{cover_types[label.item()]}") # add label
        ax.axis('off')
<matplotlib.image.AxesImage object at 0x000001B88EAB5AE0>
[Text(0.5, 1.0, 'Pasture')]
(-0.5, 63.5, 63.5, -0.5)
<matplotlib.image.AxesImage object at 0x000001B88894FD00>
[Text(0.5, 1.0, 'Industrial')]
(-0.5, 63.5, 63.5, -0.5)
<matplotlib.image.AxesImage object at 0x000001B8889C45B0>
[Text(0.5, 1.0, 'Herb Veg')]
(-0.5, 63.5, 63.5, -0.5)
<matplotlib.image.AxesImage object at 0x000001B888A01300>
[Text(0.5, 1.0, 'Annual Crop')]
(-0.5, 63.5, 63.5, -0.5)
<matplotlib.image.AxesImage object at 0x000001B88E42A020>
[Text(0.5, 1.0, 'Pasture')]
(-0.5, 63.5, 63.5, -0.5)
<matplotlib.image.AxesImage object at 0x000001B88E452D40>
[Text(0.5, 1.0, 'Forest')]
(-0.5, 63.5, 63.5, -0.5)
<matplotlib.image.AxesImage object at 0x000001B88E483820>
[Text(0.5, 1.0, 'SeaLake')]
(-0.5, 63.5, 63.5, -0.5)
<matplotlib.image.AxesImage object at 0x000001B88E4E45B0>
[Text(0.5, 1.0, 'Perm Crop')]
(-0.5, 63.5, 63.5, -0.5)
<matplotlib.image.AxesImage object at 0x000001B88E519420>
[Text(0.5, 1.0, 'River')]
(-0.5, 63.5, 63.5, -0.5)
<matplotlib.image.AxesImage object at 0x000001B88E54A1D0>
[Text(0.5, 1.0, 'Herb Veg')]
(-0.5, 63.5, 63.5, -0.5)
<matplotlib.image.AxesImage object at 0x000001B88E5869E0>
[Text(0.5, 1.0, 'Herb Veg')]
(-0.5, 63.5, 63.5, -0.5)
<matplotlib.image.AxesImage object at 0x000001B88E5BB3A0>
[Text(0.5, 1.0, 'Herb Veg')]
(-0.5, 63.5, 63.5, -0.5)
<matplotlib.image.AxesImage object at 0x000001B88E618100>
[Text(0.5, 1.0, 'River')]
(-0.5, 63.5, 63.5, -0.5)
<matplotlib.image.AxesImage object at 0x000001B88E648E20>
[Text(0.5, 1.0, 'Residential')]
(-0.5, 63.5, 63.5, -0.5)
<matplotlib.image.AxesImage object at 0x000001B88E6817E0>
[Text(0.5, 1.0, 'SeaLake')]
(-0.5, 63.5, 63.5, -0.5)
<matplotlib.image.AxesImage object at 0x000001B88E6B2500>
[Text(0.5, 1.0, 'Herb Veg')]
(-0.5, 63.5, 63.5, -0.5)
<matplotlib.image.AxesImage object at 0x000001B88E6E3250>
[Text(0.5, 1.0, 'Residential')]
(-0.5, 63.5, 63.5, -0.5)
<matplotlib.image.AxesImage object at 0x000001B88E70FF70>
[Text(0.5, 1.0, 'Perm Crop')]
(-0.5, 63.5, 63.5, -0.5)
<matplotlib.image.AxesImage object at 0x000001B88E768BE0>
[Text(0.5, 1.0, 'Highway')]
(-0.5, 63.5, 63.5, -0.5)
<matplotlib.image.AxesImage object at 0x000001B88E7A5900>
[Text(0.5, 1.0, 'Herb Veg')]
(-0.5, 63.5, 63.5, -0.5)
<matplotlib.image.AxesImage object at 0x000001B88E7D51B0>
[Text(0.5, 1.0, 'Annual Crop')]
(-0.5, 63.5, 63.5, -0.5)
<matplotlib.image.AxesImage object at 0x000001B88E801930>
[Text(0.5, 1.0, 'Herb Veg')]
(-0.5, 63.5, 63.5, -0.5)
<matplotlib.image.AxesImage object at 0x000001B88E83B940>
[Text(0.5, 1.0, 'Perm Crop')]
(-0.5, 63.5, 63.5, -0.5)
<matplotlib.image.AxesImage object at 0x000001B88E89C6A0>
[Text(0.5, 1.0, 'Residential')]
(-0.5, 63.5, 63.5, -0.5)
<matplotlib.image.AxesImage object at 0x000001B88E8C9420>
[Text(0.5, 1.0, 'Industrial')]
(-0.5, 63.5, 63.5, -0.5)
<matplotlib.image.AxesImage object at 0x000001B88E8FE020>
[Text(0.5, 1.0, 'Highway')]
(-0.5, 63.5, 63.5, -0.5)
<matplotlib.image.AxesImage object at 0x000001B88E932D40>
[Text(0.5, 1.0, 'River')]
(-0.5, 63.5, 63.5, -0.5)
<matplotlib.image.AxesImage object at 0x000001B88E963B80>
[Text(0.5, 1.0, 'Pasture')]
(-0.5, 63.5, 63.5, -0.5)
<matplotlib.image.AxesImage object at 0x000001B88E9BC8E0>
[Text(0.5, 1.0, 'Residential')]
(-0.5, 63.5, 63.5, -0.5)
<matplotlib.image.AxesImage object at 0x000001B88E9F8B20>
[Text(0.5, 1.0, 'Forest')]
(-0.5, 63.5, 63.5, -0.5)
<matplotlib.image.AxesImage object at 0x000001B88EA2DC60>
[Text(0.5, 1.0, 'SeaLake')]
(-0.5, 63.5, 63.5, -0.5)
<matplotlib.image.AxesImage object at 0x000001B88EA62980>
[Text(0.5, 1.0, 'Perm Crop')]
(-0.5, 63.5, 63.5, -0.5)
plt.show()

Instantiate Pre-trained Model

The next set of code comes from the prior modules associated with CNN architectures. I first define a function to freeze weights. I next define a function to instantiate a ResNet model. This function allows the user to select a ResNet architecture (“18”, “34”, “50”, “101”, or “152”), the number of input channels, the number of classes being differentiated, whether or not to freeze the parameters/weights associated with the convolutional component of the model, and whether or not to use the pre-trained weights from ImageNet.

I instantiate a model instance using the defined function. I will use a ResNet-34 architecture that expects 10 input channels, differentiates 10 classes, does not have any weights frozen (i.e., all weights will be able to be updated), and that uses pre-trained weights from ImageNet. So, this model will be initialized using the ImageNet weights, other than the first convolutional layer and batch normalization layer and the fully connected layer at the end of the model. However, all weights will be trainable. This is because this problem is very different from the ImageNet use case. Thus, it is expected that the parameters/weights learned from ImageNet within the convolutional component of the model will need to be updated to allow the model to learn spatial patterns useful for this specific problem.

Once the model is initialized, I summarize it using torchmetrics and the tensor shape of the EuroSatAllBands dataset: (10,64,64). This model has > 21 million trainable parameters. Since no parameters/weights were frozen, there are no non-trainable parameters.

#https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html
def set_parameter_requires_grad(model, freeze=True):
    if freeze == True:
        for param in model.parameters():
            param.requires_grad = False
# https://stackoverflow.com/questions/62629114/how-to-modify-resnet-50-with-4-channels-as-input-using-pre-trained-weights-in-py
# https://discuss.pytorch.org/t/transfer-learning-usage-with-different-input-size/20744

def initialize_model(resNet, nChn, nCls, freeze=True, pretrained=True):
  if resNet == "18":
    model = torchvision.models.resnet18(pretrained=pretrained)

  elif resNet == "34":
    model = torchvision.models.resnet34(pretrained=pretrained)

  elif resNet == "50":
    model = torchvision.models.resnet50(pretrained=pretrained)

  elif resNet == "101":
    model = torchvision.models.resnet101(pretrained=pretrained)

  elif resNet == "152":
    model = torchvision.models.resnet152(pretrained=pretrained)

  else:
    model = torchvision.models.resnet34(pretrained=pretrained)
  
  if pretrained == True:
    set_parameter_requires_grad(model, freeze)
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, nCls)
  
  if nChn != 3:
      model.conv1 = nn.Conv2d(nChn, 64, kernel_size=7, stride=2, padding=3, bias=False)
      model.bn1 = nn.BatchNorm2d(64)
  
  return model
model = initialize_model(resNet="34", nChn=10, nCls=10, freeze=False, pretrained=True).to(device)
C:\Users\vidcg\ANACON~1\envs\torchENV\lib\site-packages\torchvision\models\_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
  warnings.warn(
C:\Users\vidcg\ANACON~1\envs\torchENV\lib\site-packages\torchvision\models\_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet34_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet34_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
summary(model, (10, 64, 64))
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 64, 32, 32]          31,360
       BatchNorm2d-2           [-1, 64, 32, 32]             128
              ReLU-3           [-1, 64, 32, 32]               0
         MaxPool2d-4           [-1, 64, 16, 16]               0
            Conv2d-5           [-1, 64, 16, 16]          36,864
       BatchNorm2d-6           [-1, 64, 16, 16]             128
              ReLU-7           [-1, 64, 16, 16]               0
            Conv2d-8           [-1, 64, 16, 16]          36,864
       BatchNorm2d-9           [-1, 64, 16, 16]             128
             ReLU-10           [-1, 64, 16, 16]               0
       BasicBlock-11           [-1, 64, 16, 16]               0
           Conv2d-12           [-1, 64, 16, 16]          36,864
      BatchNorm2d-13           [-1, 64, 16, 16]             128
             ReLU-14           [-1, 64, 16, 16]               0
           Conv2d-15           [-1, 64, 16, 16]          36,864
      BatchNorm2d-16           [-1, 64, 16, 16]             128
             ReLU-17           [-1, 64, 16, 16]               0
       BasicBlock-18           [-1, 64, 16, 16]               0
           Conv2d-19           [-1, 64, 16, 16]          36,864
      BatchNorm2d-20           [-1, 64, 16, 16]             128
             ReLU-21           [-1, 64, 16, 16]               0
           Conv2d-22           [-1, 64, 16, 16]          36,864
      BatchNorm2d-23           [-1, 64, 16, 16]             128
             ReLU-24           [-1, 64, 16, 16]               0
       BasicBlock-25           [-1, 64, 16, 16]               0
           Conv2d-26            [-1, 128, 8, 8]          73,728
      BatchNorm2d-27            [-1, 128, 8, 8]             256
             ReLU-28            [-1, 128, 8, 8]               0
           Conv2d-29            [-1, 128, 8, 8]         147,456
      BatchNorm2d-30            [-1, 128, 8, 8]             256
           Conv2d-31            [-1, 128, 8, 8]           8,192
      BatchNorm2d-32            [-1, 128, 8, 8]             256
             ReLU-33            [-1, 128, 8, 8]               0
       BasicBlock-34            [-1, 128, 8, 8]               0
           Conv2d-35            [-1, 128, 8, 8]         147,456
      BatchNorm2d-36            [-1, 128, 8, 8]             256
             ReLU-37            [-1, 128, 8, 8]               0
           Conv2d-38            [-1, 128, 8, 8]         147,456
      BatchNorm2d-39            [-1, 128, 8, 8]             256
             ReLU-40            [-1, 128, 8, 8]               0
       BasicBlock-41            [-1, 128, 8, 8]               0
           Conv2d-42            [-1, 128, 8, 8]         147,456
      BatchNorm2d-43            [-1, 128, 8, 8]             256
             ReLU-44            [-1, 128, 8, 8]               0
           Conv2d-45            [-1, 128, 8, 8]         147,456
      BatchNorm2d-46            [-1, 128, 8, 8]             256
             ReLU-47            [-1, 128, 8, 8]               0
       BasicBlock-48            [-1, 128, 8, 8]               0
           Conv2d-49            [-1, 128, 8, 8]         147,456
      BatchNorm2d-50            [-1, 128, 8, 8]             256
             ReLU-51            [-1, 128, 8, 8]               0
           Conv2d-52            [-1, 128, 8, 8]         147,456
      BatchNorm2d-53            [-1, 128, 8, 8]             256
             ReLU-54            [-1, 128, 8, 8]               0
       BasicBlock-55            [-1, 128, 8, 8]               0
           Conv2d-56            [-1, 256, 4, 4]         294,912
      BatchNorm2d-57            [-1, 256, 4, 4]             512
             ReLU-58            [-1, 256, 4, 4]               0
           Conv2d-59            [-1, 256, 4, 4]         589,824
      BatchNorm2d-60            [-1, 256, 4, 4]             512
           Conv2d-61            [-1, 256, 4, 4]          32,768
      BatchNorm2d-62            [-1, 256, 4, 4]             512
             ReLU-63            [-1, 256, 4, 4]               0
       BasicBlock-64            [-1, 256, 4, 4]               0
           Conv2d-65            [-1, 256, 4, 4]         589,824
      BatchNorm2d-66            [-1, 256, 4, 4]             512
             ReLU-67            [-1, 256, 4, 4]               0
           Conv2d-68            [-1, 256, 4, 4]         589,824
      BatchNorm2d-69            [-1, 256, 4, 4]             512
             ReLU-70            [-1, 256, 4, 4]               0
       BasicBlock-71            [-1, 256, 4, 4]               0
           Conv2d-72            [-1, 256, 4, 4]         589,824
      BatchNorm2d-73            [-1, 256, 4, 4]             512
             ReLU-74            [-1, 256, 4, 4]               0
           Conv2d-75            [-1, 256, 4, 4]         589,824
      BatchNorm2d-76            [-1, 256, 4, 4]             512
             ReLU-77            [-1, 256, 4, 4]               0
       BasicBlock-78            [-1, 256, 4, 4]               0
           Conv2d-79            [-1, 256, 4, 4]         589,824
      BatchNorm2d-80            [-1, 256, 4, 4]             512
             ReLU-81            [-1, 256, 4, 4]               0
           Conv2d-82            [-1, 256, 4, 4]         589,824
      BatchNorm2d-83            [-1, 256, 4, 4]             512
             ReLU-84            [-1, 256, 4, 4]               0
       BasicBlock-85            [-1, 256, 4, 4]               0
           Conv2d-86            [-1, 256, 4, 4]         589,824
      BatchNorm2d-87            [-1, 256, 4, 4]             512
             ReLU-88            [-1, 256, 4, 4]               0
           Conv2d-89            [-1, 256, 4, 4]         589,824
      BatchNorm2d-90            [-1, 256, 4, 4]             512
             ReLU-91            [-1, 256, 4, 4]               0
       BasicBlock-92            [-1, 256, 4, 4]               0
           Conv2d-93            [-1, 256, 4, 4]         589,824
      BatchNorm2d-94            [-1, 256, 4, 4]             512
             ReLU-95            [-1, 256, 4, 4]               0
           Conv2d-96            [-1, 256, 4, 4]         589,824
      BatchNorm2d-97            [-1, 256, 4, 4]             512
             ReLU-98            [-1, 256, 4, 4]               0
       BasicBlock-99            [-1, 256, 4, 4]               0
          Conv2d-100            [-1, 512, 2, 2]       1,179,648
     BatchNorm2d-101            [-1, 512, 2, 2]           1,024
            ReLU-102            [-1, 512, 2, 2]               0
          Conv2d-103            [-1, 512, 2, 2]       2,359,296
     BatchNorm2d-104            [-1, 512, 2, 2]           1,024
          Conv2d-105            [-1, 512, 2, 2]         131,072
     BatchNorm2d-106            [-1, 512, 2, 2]           1,024
            ReLU-107            [-1, 512, 2, 2]               0
      BasicBlock-108            [-1, 512, 2, 2]               0
          Conv2d-109            [-1, 512, 2, 2]       2,359,296
     BatchNorm2d-110            [-1, 512, 2, 2]           1,024
            ReLU-111            [-1, 512, 2, 2]               0
          Conv2d-112            [-1, 512, 2, 2]       2,359,296
     BatchNorm2d-113            [-1, 512, 2, 2]           1,024
            ReLU-114            [-1, 512, 2, 2]               0
      BasicBlock-115            [-1, 512, 2, 2]               0
          Conv2d-116            [-1, 512, 2, 2]       2,359,296
     BatchNorm2d-117            [-1, 512, 2, 2]           1,024
            ReLU-118            [-1, 512, 2, 2]               0
          Conv2d-119            [-1, 512, 2, 2]       2,359,296
     BatchNorm2d-120            [-1, 512, 2, 2]           1,024
            ReLU-121            [-1, 512, 2, 2]               0
      BasicBlock-122            [-1, 512, 2, 2]               0
AdaptiveAvgPool2d-123            [-1, 512, 1, 1]               0
          Linear-124                   [-1, 10]           5,130
================================================================
Total params: 21,311,754
Trainable params: 21,311,754
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.16
Forward/backward pass size (MB): 7.86
Params size (MB): 81.30
Estimated Total Size (MB): 89.32
----------------------------------------------------------------

I am now ready to train the model. I first instantiate an instance of the AdamW optimizer with the default learning rate. I also instantiate the loss metric (cross entropy loss) and the overall accuracy, class-aggregated F1-score, and Cohen’s Kappa assessment metrics provided by torchmetrics.

I will train the model for a total of 50 epochs. As I did in the Train a CNN module, I will only save a model to disk if the aggregated F1-score calculated for the validation data improves.

optimizer = torch.optim.AdamW(model.parameters())
criterion = nn.CrossEntropyLoss().to(device)
acc = tm.Accuracy(task="multiclass", num_classes=10).to(device)
f1 = tm.F1Score(task="multiclass", num_classes=10).to(device)
kappa = tm.CohenKappa(task="multiclass", num_classes=10).to(device)
epochs = 50
saveFolder = "C:/myFiles/work/dl/eurosat_resnet_models/"

The training loop is the same as the one used in the Train a CNN module. Here is a review of the key components.

  1. I must iterate over the epochs and the training batches.
  2. Backpropagation and an optimization step will be performed after each training batch is processed.
  3. The validation data will be predicted after each complete iteration over the training data (i.e., one training epoch).
  4. I am saving the training and validation losses and assessment metrics to list objects.
  5. Metrics are aggregated across batches using the compute() and reset() methods from torchmetrics.
  6. A model is only being saved to disk if the class-aggregated F1-score for the validation samples improves.

If you decide to run this model, it will likely take several hours to execute. On my machine it took ~3 hours to train for 50 epochs using one GPU and a batch size of 32. If you do not want to run the training loop, a model file has been provided.

eNum = []
t_loss = []
t_acc = []
t_f1 = []
t_kappa = []
v_loss = []
v_acc = []
v_f1 = []
v_kappa = []

f1VMax = 0.0

# Loop over epochs
for epoch in range(1, epochs+1):
    # Loop over training batches
    for batch_idx, (inputs, targets) in enumerate(trainDL):
        # Get data and move to device
        inputs, targets = inputs.to(device), targets.to(device)

        # Clear gradients
        optimizer.zero_grad()
        # Predict data
        outputs = model(inputs)
        # Calculate loss
        loss = criterion(outputs, targets)

        # Calculate metrics
        accT = acc(outputs, targets)
        f1T = f1(outputs, targets)
        kappaT = kappa(outputs, targets)
        
        # Backpropagate
        loss.backward()

        # Update parameters
        optimizer.step()

    # Accumulate metrics at end of training epoch
    accT = acc.compute()
    f1T = f1.compute()
    kappaT = kappa.compute()

    # Print Losses and metrics at end of each training epoch   
    print(f'Epoch: {epoch}, Training Loss: {loss.item():.4f}, Training Accuracy: {accT:.4f}, Training F1: {f1T:.4f}, Training Kappa: {kappaT:.4f}')

    # Append results
    eNum.append(epoch)
    t_loss.append(loss.item())
    t_acc.append(accT.detach().cpu().numpy())
    t_f1.append(f1T.detach().cpu().numpy())
    t_kappa.append(kappaT.detach().cpu().numpy())

    # Reset metrics
    acc.reset()
    f1.reset()
    kappa.reset()

    # loop over validation batches
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(valDL):
            # Get data and move to device
            inputs, targets = inputs.to(device), targets.to(device)

            # Predict data
            outputs = model(inputs)
            # Calculate validation loss
            loss_v = criterion(outputs, targets)

            # Calculate metrics
            accV = acc(outputs, targets)
            f1V = f1(outputs, targets)
            kappaV = kappa(outputs, targets)
            
    # Accumulate metrics at end of validation epoch
    accV = acc.compute()
    f1V = f1.compute()
    kappaV = kappa.compute()

    # Print validation loss and metrics
    print(f'Validation Loss: {loss_v.item():.4f}, Validation Accuracy: {accV:.4f}, Validation F1: {f1V:.4f}, Validation Kappa: {kappaV:.4f}')

    # Append results
    v_loss.append(loss_v.item())
    v_acc.append(accV.detach().cpu().numpy())
    v_f1.append(f1V.detach().cpu().numpy())
    v_kappa.append(kappaV.detach().cpu().numpy())

    # Reset metrics
    acc.reset()
    f1.reset()
    kappa.reset()

    # Save model if validation F1-score improves
    f1V2 = f1V.detach().cpu().numpy()
    if f1V2 > f1VMax:
        f1VMax = f1V2
        torch.save(model.state_dict(), saveFolder + 'eurosat_model.pt')
        print(f'Model saved for epoch {epoch}.')

Once the training loop executes, I next explore the training process by merging all of the saved losses and metrics into a single DataFrame. I then save this DataFrame to disk as a CSV file. I plot the training and validation losses along with the class-aggregated F1-score for the training and validation data.

Generally, these graphs suggest that the learning process progressed as expected. There is no evidence of overfitting to the training data.

SeNum = pd.Series(eNum, name="epoch")
St_loss = pd.Series(t_loss, name="training_loss")
St_acc = pd.Series(t_acc, name="training_accuracy")
St_f1 = pd.Series(t_f1, name="training_f1")
St_kappa = pd.Series(t_kappa, name="training_kappa")
Sv_loss = pd.Series(v_loss, name="val_loss")
Sv_acc = pd.Series(v_acc, name="val_accuracy")
Sv_f1 = pd.Series(v_f1, name="val_f1")
Sv_kappa = pd.Series(t_kappa, name="val_kappa")
resultsDF = pd.concat([SeNum, St_loss, St_acc, St_f1, St_kappa, Sv_loss, Sv_acc, Sv_f1, Sv_kappa], axis=1)
resultsDF.to_csv(saveFolder + "resultsCNN_ResNet.csv")
resultsDF = pd.read_csv(saveFolder + "resultsCNN_ResNet.csv")
plt.rcParams['figure.figsize'] = [10, 10]
firstPlot = resultsDF.plot(x='epoch', y="training_loss")
resultsDF.plot(x='epoch', y="val_loss", ax=firstPlot)
plt.show()

plt.rcParams['figure.figsize'] = [10, 10]
firstPlot = resultsDF.plot(x='epoch', y="training_f1")
resultsDF.plot(x='epoch', y="val_f1", ax=firstPlot)
plt.show()

Model Assessment

I next assess the model using the withheld validation data. In order to load the saved model weights, as opposed to using the weights after the 50 training epochs, I re-instantiate an instance of the model, read in the saved weights, and load them into the model’s state dictionary. I next define a DataSet subclass and DataLoader for the testing samples. The testing samples are normalized using the band means and standard deviations of the training data, and no data augmentations or transforms are applied.

I then predict the testing data batches in a loop. Again, it is important that the model be in evaluation mode so that predicting the testing data does not impact the computational graph and gradients. The metrics from torchmetrics are accumulated across batches using the compute() function.

I print the assessment metrics. The results look pretty good. I achieved an overall accuracy greater than 97% for predicting to new data.

model = initialize_model(resNet="34", nChn=10, nCls=10, freeze=True, pretrained=True).to(device)
C:\Users\vidcg\ANACON~1\envs\torchENV\lib\site-packages\torchvision\models\_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
  warnings.warn(
C:\Users\vidcg\ANACON~1\envs\torchENV\lib\site-packages\torchvision\models\_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet34_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet34_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
best_weights = torch.load(saveFolder+'eurosat_resnet_model.pt')
model.load_state_dict(best_weights)
<All keys matched successfully>
testDS = EuroSat(test, mnImg, sdImg, transform=None)
testDL = torch.utils.data.DataLoader(testDS, batch_size=32, shuffle=False, sampler=None,
num_workers=0, pin_memory=False, drop_last=True)
model.eval()
ResNet(
  (conv1): Conv2d(10, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (2): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer2): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (2): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (3): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer3): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (2): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (3): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (4): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (5): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (2): BasicBlock(
      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=512, out_features=10, bias=True)
)
with torch.no_grad():
    for batch_idx, (inputs, targets) in enumerate(testDL):
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = model(inputs)
        loss_v = criterion(outputs, targets)
        accV = acc(outputs, targets)
        f1V = f1(outputs, targets)
        kappaV = kappa(outputs, targets)
accV = acc.compute()
f1V = f1.compute()
kappaV = kappa.compute()
acc.reset()
f1.reset()
kappa.reset()
print(accV)
tensor(0.9702, device='cuda:0')
print(f1V)
tensor(0.9702, device='cuda:0')
print(kappaV)
tensor(0.9668, device='cuda:0')
print(loss_v.item())
3.403398659429513e-05

The class-level assessment metrics are obtained below.

cm = tm.ConfusionMatrix(task="multiclass", num_classes=10).to(device)
f1 = tm.F1Score(task="multiclass", num_classes=10, average="none").to(device)
recall = tm.Precision(task="multiclass", num_classes=10, average="none").to(device)
precision = tm.Recall(task="multiclass", num_classes=10, average="none").to(device)
model.eval()
ResNet(
  (conv1): Conv2d(10, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (2): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer2): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (2): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (3): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer3): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (2): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (3): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (4): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (5): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (2): BasicBlock(
      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=512, out_features=10, bias=True)
)
with torch.no_grad():
    for batch_idx, (inputs, targets) in enumerate(testDL):
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = model(inputs)
        cmV = acc(outputs, targets)
        f1V = f1(outputs, targets)
        pV = precision(outputs, targets)
        rV = recall(outputs, targets)
cmV =cm.compute()
C:\Users\vidcg\ANACON~1\envs\torchENV\lib\site-packages\torchmetrics\utilities\prints.py:36: UserWarning: The ``compute`` method of metric MulticlassConfusionMatrix was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.
  warnings.warn(*args, **kwargs)
f1V = f1.compute()
pV = precision.compute()
rV = recall.compute()
cm.reset()
f1.reset()
precision.reset()
recall.reset()
print(cm)
MulticlassConfusionMatrix()
print(f1V)
tensor([0.9713, 0.9892, 0.9507, 0.9503, 0.9783, 0.9484, 0.9488, 0.9728, 0.9763,
        0.9986], device='cuda:0')
print(pV)
tensor([0.9867, 0.9917, 0.9650, 0.9180, 0.9940, 0.9650, 0.9260, 0.9533, 0.9900,
        0.9972], device='cuda:0')
print(rV)
tensor([0.9564, 0.9867, 0.9369, 0.9850, 0.9632, 0.9324, 0.9727, 0.9931, 0.9630,
        1.0000], device='cuda:0')

Concluding Remarks

The goal of this module was to explore the use of transfer learning by initializing a model architecture using parameters/weights learned from a prior dataset, in this case ImageNet, as opposed to initializing the model parameters randomly. Since this was a very different problem, I did not freeze any of the model parameters/weights. Instead, all parameters were updated during the learning process, but starting from the pre-trained parameters/weights as opposed to a random initialization.

Transfer learning can be a very powerful technique, especially when your training dataset is not large. These methods are especially useful when you are using a famous or common architecture, such as VGGNet-16 or a ResNet architecture, that has already been trained using a large dataset. As you will see in the semantic segmentation sections, these architectures, pre-trained models, and transfer learning can be used in the encoder component of semantic segmentation models. So, we will continue to explore these techniques throughout the next set of modules relating to semantic segmentation.