Convolutional neural network¶
Outline¶
Architecures
Train and test functions
CNN models
MNIST
CIFAR-10
Sources:
Deep learning - cs231n.stanford.edu
CNN - Stanford cs231n
Pytorch - WWW tutorials - github tutorials - github examples
MNIST and pytorch: - MNIST nextjournal.com/gkoehler/pytorch-mnist - MNIST github/pytorch/examples - MNIST kaggle
Architectures¶
Sources:
[zhenye-na.github.io(]https://zhenye-na.github.io/2018/12/01/cnn-deep-leearning-ai-week2.html)
LeNet¶
The first Convolutional Networks were developed by Yann LeCun in 1990’s.
AlexNet¶
(2012, Alex Krizhevsky, Ilya Sutskever and Geoff Hinton)
Deeper, bigger,
Featured Convolutional Layers stacked on top of each other (previously it was common to only have a single CONV layer always immediately followed by a POOL layer).
ReLu(Rectified Linear Unit) for the non-linear part, instead of a Tanh or Sigmoid.
The advantage of the ReLu over sigmoid is that it trains much faster than the latter because the derivative of sigmoid becomes very small in the saturating region and therefore the updates to the weights almost vanish. This is called vanishing gradient problem.
Dropout: reduces the over-fitting by using a Dropout layer after every FC layer. Dropout layer has a probability,(p), associated with it and is applied at every neuron of the response map separately. It randomly switches off the activation with the probability p.
Why does DropOut work?
The idea behind the dropout is similar to the model ensembles. Due to the dropout layer, different sets of neurons which are switched off, represent a different architecture and all these different architectures are trained in parallel with weight given to each subset and the summation of weights being one. For n neurons attached to DropOut, the number of subset architectures formed is 2^n. So it amounts to prediction being averaged over these ensembles of models. This provides a structured model regularization which helps in avoiding the over-fitting. Another view of DropOut being helpful is that since neurons are randomly chosen, they tend to avoid developing co-adaptations among themselves thereby enabling them to develop meaningful features, independent of others.
Data augmentation is carried out to reduce over-fitting. This Data augmentation includes mirroring and cropping the images to increase the variation in the training data-set.
GoogLeNet. (Szegedy et al. from Google 2014) was a Convolutional Network . Its main contribution was the development of an
Inception Module that dramatically reduced the number of parameters in the network (4M, compared to AlexNet with 60M).
There are also several followup versions to the GoogLeNet, most recently Inception-v4.
VGGNet. (Karen Simonyan and Andrew Zisserman 2014)
16 CONV/FC layers and, appealingly, features an extremely homogeneous architecture.
Only performs 3x3 convolutions and 2x2 pooling from the beginning to the end. Replace large kernel-sized filters(11 and 5 in the first and second convolutional layer, respectively) with multiple 3X3 kernel-sized filters one after another.
With a given receptive field(the effective area size of input image on which output depends), multiple stacked smaller size kernel is better than the one with a larger size kernel because multiple non-linear layers increases the depth of the network which enables it to learn more complex features, and that too at a lower cost. For example, three 3X3 filters on top of each other with stride 1 ha a receptive size of 7, but the number of parameters involved is 3*(9^2) in comparison to 49^2 parameters of kernels with a size of 7.
Lot more memory and parameters (140M)
ResNet. (Kaiming He et al. 2015)
Resnet block variants (Source):
Skip connections
Batch normalization.
State of the art CNN models and are the default choice (as of May 10, 2016). In particular, also see more
Recent developments that tweak the original architecture from Kaiming He et al. Identity Mappings in Deep Residual Networks (published March 2016).
Architecures general guidelines¶
ConvNets stack CONV,POOL,FC layers
Trend towards smaller filters and deeper architectures: stack 3x3, instead of 5x5
Trend towards getting rid of POOL/FC layers (just CONV)
Historically architectures looked like [(CONV-RELU) x N POOL?] x M (FC-RELU) x K, SOFTMAX where N is usually up to ~5, M is large, 0 <= K <= 2.
but recent advances such as ResNet/GoogLeNet have challenged this paradigm
Train function¶
%matplotlib inline
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision
import torchvision.transforms as transforms
from torchvision import models
#
from pathlib import Path
import matplotlib.pyplot as plt
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = 'cpu' # Force CPU
# %load train_val_model.py
import numpy as np
import torch
import time
import copy
def train_val_model(model, criterion, optimizer, dataloaders, num_epochs=25,
scheduler=None, log_interval=None):
since = time.time()
best_model_wts = copy.deepcopy(model.state_dict())
best_acc = 0.0
# Store losses and accuracies accross epochs
losses, accuracies = dict(train=[], val=[]), dict(train=[], val=[])
for epoch in range(num_epochs):
if log_interval is not None and epoch % log_interval == 0:
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
print('-' * 10)
# Each epoch has a training and validation phase
for phase in ['train', 'val']:
if phase == 'train':
model.train() # Set model to training mode
else:
model.eval() # Set model to evaluate mode
running_loss = 0.0
running_corrects = 0
# Iterate over data.
nsamples = 0
for inputs, labels in dataloaders[phase]:
inputs = inputs.to(device)
labels = labels.to(device)
nsamples += inputs.shape[0]
# zero the parameter gradients
optimizer.zero_grad()
# forward
# track history if only in train
with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
# backward + optimize only if in training phase
if phase == 'train':
loss.backward()
optimizer.step()
# statistics
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
if scheduler is not None and phase == 'train':
scheduler.step()
#nsamples = dataloaders[phase].dataset.data.shape[0]
epoch_loss = running_loss / nsamples
epoch_acc = running_corrects.double() / nsamples
losses[phase].append(epoch_loss)
accuracies[phase].append(epoch_acc)
if log_interval is not None and epoch % log_interval == 0:
print('{} Loss: {:.4f} Acc: {:.2f}%'.format(
phase, epoch_loss, 100 * epoch_acc))
# deep copy the model
if phase == 'val' and epoch_acc > best_acc:
best_acc = epoch_acc
best_model_wts = copy.deepcopy(model.state_dict())
if log_interval is not None and epoch % log_interval == 0:
print()
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(
time_elapsed // 60, time_elapsed % 60))
print('Best val Acc: {:.2f}%'.format(100 * best_acc))
# load best model weights
model.load_state_dict(best_model_wts)
return model, losses, accuracies
CNN models¶
LeNet-5¶
Here we implement LeNet-5 with relu activation. Sources: (1), (2).
import torch.nn as nn
import torch.nn.functional as F
class LeNet5(nn.Module):
"""
layers: (nb channels in input layer,
nb channels in 1rst conv,
nb channels in 2nd conv,
nb neurons for 1rst FC: TO BE TUNED,
nb neurons for 2nd FC,
nb neurons for 3rd FC,
nb neurons output FC TO BE TUNED)
"""
def __init__(self, layers = (1, 6, 16, 1024, 120, 84, 10), debug=False):
super(LeNet5, self).__init__()
self.layers = layers
self.debug = debug
self.conv1 = nn.Conv2d(layers[0], layers[1], 5, padding=2)
self.conv2 = nn.Conv2d(layers[1], layers[2], 5)
self.fc1 = nn.Linear(layers[3], layers[4])
self.fc2 = nn.Linear(layers[4], layers[5])
self.fc3 = nn.Linear(layers[5], layers[6])
def forward(self, x):
x = F.max_pool2d(F.relu(self.conv1(x)), 2) # same shape / 2
x = F.max_pool2d(F.relu(self.conv2(x)), 2) # -4 / 2
if self.debug:
print("### DEBUG: Shape of last convnet=", x.shape[1:], ". FC size=", np.prod(x.shape[1:]))
x = x.view(-1, self.layers[3])
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return F.log_softmax(x, dim=1)
VGGNet like: conv-relu blocks¶
# Defining the network (LeNet-5)
import torch.nn as nn
import torch.nn.functional as F
class MiniVGGNet(torch.nn.Module):
def __init__(self, layers=(1, 16, 32, 1024, 120, 84, 10), debug=False):
super(MiniVGGNet, self).__init__()
self.layers = layers
self.debug = debug
# Conv block 1
self.conv11 = nn.Conv2d(in_channels=layers[0], out_channels=layers[1], kernel_size=3,
stride=1, padding=0, bias=True)
self.conv12 = nn.Conv2d(in_channels=layers[1], out_channels=layers[1], kernel_size=3,
stride=1, padding=0, bias=True)
# Conv block 2
self.conv21 = nn.Conv2d(in_channels=layers[1], out_channels=layers[2], kernel_size=3,
stride=1, padding=0, bias=True)
self.conv22 = nn.Conv2d(in_channels=layers[2], out_channels=layers[2], kernel_size=3,
stride=1, padding=1, bias=True)
# Fully connected layer
self.fc1 = nn.Linear(layers[3], layers[4])
self.fc2 = nn.Linear(layers[4], layers[5])
self.fc3 = nn.Linear(layers[5], layers[6])
def forward(self, x):
x = F.relu(self.conv11(x))
x = F.relu(self.conv12(x))
x = F.max_pool2d(x, 2)
x = F.relu(self.conv21(x))
x = F.relu(self.conv22(x))
x = F.max_pool2d(x, 2)
if self.debug:
print("### DEBUG: Shape of last convnet=", x.shape[1:], ". FC size=", np.prod(x.shape[1:]))
x = x.view(-1, self.layers[3])
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return F.log_softmax(x, dim=1)
ResNet-like Model:¶
Stack multiple resnet blocks
# ---------------------------------------------------------------------------- #
# An implementation of https://arxiv.org/pdf/1512.03385.pdf #
# See section 4.2 for the model architecture on CIFAR-10 #
# Some part of the code was referenced from below #
# https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py #
# ---------------------------------------------------------------------------- #
import torch.nn as nn
# 3x3 convolution
def conv3x3(in_channels, out_channels, stride=1):
return nn.Conv2d(in_channels, out_channels, kernel_size=3,
stride=stride, padding=1, bias=False)
# Residual block
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1, downsample=None):
super(ResidualBlock, self).__init__()
self.conv1 = conv3x3(in_channels, out_channels, stride)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(out_channels, out_channels)
self.bn2 = nn.BatchNorm2d(out_channels)
self.downsample = downsample
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
# ResNet
class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=10):
super(ResNet, self).__init__()
self.in_channels = 16
self.conv = conv3x3(3, 16)
self.bn = nn.BatchNorm2d(16)
self.relu = nn.ReLU(inplace=True)
self.layer1 = self.make_layer(block, 16, layers[0])
self.layer2 = self.make_layer(block, 32, layers[1], 2)
self.layer3 = self.make_layer(block, 64, layers[2], 2)
self.avg_pool = nn.AvgPool2d(8)
self.fc = nn.Linear(64, num_classes)
def make_layer(self, block, out_channels, blocks, stride=1):
downsample = None
if (stride != 1) or (self.in_channels != out_channels):
downsample = nn.Sequential(
conv3x3(self.in_channels, out_channels, stride=stride),
nn.BatchNorm2d(out_channels))
layers = []
layers.append(block(self.in_channels, out_channels, stride, downsample))
self.in_channels = out_channels
for i in range(1, blocks):
layers.append(block(out_channels, out_channels))
return nn.Sequential(*layers)
def forward(self, x):
out = self.conv(x)
out = self.bn(out)
out = self.relu(out)
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.avg_pool(out)
out = out.view(out.size(0), -1)
out = self.fc(out)
return F.log_softmax(out, dim=1)
#return out
ResNet9
MNIST digit classification¶
from pathlib import Path
from torchvision import datasets, transforms
import os
WD = os.path.join(Path.home(), "data", "pystatml", "dl_mnist_pytorch")
os.makedirs(WD, exist_ok=True)
os.chdir(WD)
print("Working dir is:", os.getcwd())
os.makedirs("data", exist_ok=True)
os.makedirs("models", exist_ok=True)
def load_mnist(batch_size_train, batch_size_test):
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=batch_size_train, shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=batch_size_test, shuffle=True)
return train_loader, test_loader
train_loader, val_loader = load_mnist(64, 1000)
dataloaders = dict(train=train_loader, val=val_loader)
# Info about the dataset
data_shape = dataloaders["train"].dataset.data.shape[1:]
D_in = np.prod(data_shape)
D_out = len(dataloaders["train"].dataset.targets)
print("Datasets shape", {x: dataloaders[x].dataset.data.shape for x in ['train', 'val']})
print("N input features", D_in, "N output", D_out)
Working dir is: /home/ed203246/data/pystatml/dl_mnist_pytorch
Datasets shape {'train': torch.Size([60000, 28, 28]), 'val': torch.Size([10000, 28, 28])}
N input features 784 N output 60000
LeNet¶
Dry run in debug mode to get the shape of the last convnet layer.
model = LeNet5((1, 6, 16, 1, 120, 84, 10), debug=True)
batch_idx, (data_example, target_example) = next(enumerate(train_loader))
print(model)
_ = model(data_example)
LeNet5(
(conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
(fc1): Linear(in_features=1, out_features=120, bias=True)
(fc2): Linear(in_features=120, out_features=84, bias=True)
(fc3): Linear(in_features=84, out_features=10, bias=True)
)
### DEBUG: Shape of last convnet= torch.Size([16, 5, 5]) . FC size= 400
Set First FC layer to 400
model = LeNet5((1, 6, 16, 400, 120, 84, 10)).to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
criterion = nn.NLLLoss()
# Explore the model
for parameter in model.parameters():
print(parameter.shape)
print("Total number of parameters =", np.sum([np.prod(parameter.shape) for parameter in model.parameters()]))
model, losses, accuracies = train_val_model(model, criterion, optimizer, dataloaders,
num_epochs=5, log_interval=2)
_ = plt.plot(losses['train'], '-b', losses['val'], '--r')
torch.Size([6, 1, 5, 5])
torch.Size([6])
torch.Size([16, 6, 5, 5])
torch.Size([16])
torch.Size([120, 400])
torch.Size([120])
torch.Size([84, 120])
torch.Size([84])
torch.Size([10, 84])
torch.Size([10])
Total number of parameters = 61706
Epoch 0/4
----------
train Loss: 0.7807 Acc: 75.65%
val Loss: 0.1586 Acc: 94.96%
Epoch 2/4
----------
train Loss: 0.0875 Acc: 97.33%
val Loss: 0.0776 Acc: 97.47%
Epoch 4/4
----------
train Loss: 0.0592 Acc: 98.16%
val Loss: 0.0533 Acc: 98.30%
Training complete in 1m 29s
Best val Acc: 98.30%
MiniVGGNet¶
model = MiniVGGNet(layers=(1, 16, 32, 1, 120, 84, 10), debug=True)
print(model)
_ = model(data_example)
MiniVGGNet(
(conv11): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1))
(conv12): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1))
(conv21): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1))
(conv22): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(fc1): Linear(in_features=1, out_features=120, bias=True)
(fc2): Linear(in_features=120, out_features=84, bias=True)
(fc3): Linear(in_features=84, out_features=10, bias=True)
)
### DEBUG: Shape of last convnet= torch.Size([32, 5, 5]) . FC size= 800
Set First FC layer to 800
model = MiniVGGNet((1, 16, 32, 800, 120, 84, 10)).to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
criterion = nn.NLLLoss()
# Explore the model
for parameter in model.parameters():
print(parameter.shape)
print("Total number of parameters =", np.sum([np.prod(parameter.shape) for parameter in model.parameters()]))
model, losses, accuracies = train_val_model(model, criterion, optimizer, dataloaders,
num_epochs=5, log_interval=2)
_ = plt.plot(losses['train'], '-b', losses['val'], '--r')
torch.Size([16, 1, 3, 3])
torch.Size([16])
torch.Size([16, 16, 3, 3])
torch.Size([16])
torch.Size([32, 16, 3, 3])
torch.Size([32])
torch.Size([32, 32, 3, 3])
torch.Size([32])
torch.Size([120, 800])
torch.Size([120])
torch.Size([84, 120])
torch.Size([84])
torch.Size([10, 84])
torch.Size([10])
Total number of parameters = 123502
Epoch 0/4
----------
train Loss: 1.4180 Acc: 48.27%
val Loss: 0.2277 Acc: 92.68%
Epoch 2/4
----------
train Loss: 0.0838 Acc: 97.41%
val Loss: 0.0587 Acc: 98.14%
Epoch 4/4
----------
train Loss: 0.0495 Acc: 98.43%
val Loss: 0.0407 Acc: 98.63%
Training complete in 3m 10s
Best val Acc: 98.63%
Reduce the size of training dataset¶
Reduce the size of the training dataset by considering only 10
minibatche for size16
.
train_loader, val_loader = load_mnist(16, 1000)
train_size = 10 * 16
# Stratified sub-sampling
targets = train_loader.dataset.targets.numpy()
nclasses = len(set(targets))
indices = np.concatenate([np.random.choice(np.where(targets == lab)[0], int(train_size / nclasses),replace=False)
for lab in set(targets)])
np.random.shuffle(indices)
train_loader = torch.utils.data.DataLoader(train_loader.dataset, batch_size=16,
sampler=torch.utils.data.SubsetRandomSampler(indices))
# Check train subsampling
train_labels = np.concatenate([labels.numpy() for inputs, labels in train_loader])
print("Train size=", len(train_labels), " Train label count=", {lab:np.sum(train_labels == lab) for lab in set(train_labels)})
print("Batch sizes=", [inputs.size(0) for inputs, labels in train_loader])
# Put together train and val
dataloaders = dict(train=train_loader, val=val_loader)
# Info about the dataset
data_shape = dataloaders["train"].dataset.data.shape[1:]
D_in = np.prod(data_shape)
D_out = len(dataloaders["train"].dataset.targets.unique())
print("Datasets shape", {x: dataloaders[x].dataset.data.shape for x in ['train', 'val']})
print("N input features", D_in, "N output", D_out)
Train size= 160 Train label count= {0: 16, 1: 16, 2: 16, 3: 16, 4: 16, 5: 16, 6: 16, 7: 16, 8: 16, 9: 16}
Batch sizes= [16, 16, 16, 16, 16, 16, 16, 16, 16, 16]
Datasets shape {'train': torch.Size([60000, 28, 28]), 'val': torch.Size([10000, 28, 28])}
N input features 784 N output 10
LeNet5
model = LeNet5((1, 6, 16, 400, 120, 84, D_out)).to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
criterion = nn.NLLLoss()
model, losses, accuracies = train_val_model(model, criterion, optimizer, dataloaders,
num_epochs=100, log_interval=20)
_ = plt.plot(losses['train'], '-b', losses['val'], '--r')
Epoch 0/99
----------
train Loss: 2.3086 Acc: 11.88%
val Loss: 2.3068 Acc: 14.12%
Epoch 20/99
----------
train Loss: 0.8060 Acc: 76.25%
val Loss: 0.8522 Acc: 72.84%
Epoch 40/99
----------
train Loss: 0.0596 Acc: 99.38%
val Loss: 0.6188 Acc: 82.67%
Epoch 60/99
----------
train Loss: 0.0072 Acc: 100.00%
val Loss: 0.6888 Acc: 83.08%
Epoch 80/99
----------
train Loss: 0.0033 Acc: 100.00%
val Loss: 0.7546 Acc: 82.96%
Training complete in 3m 10s
Best val Acc: 83.46%
MiniVGGNet
model = MiniVGGNet((1, 16, 32, 800, 120, 84, 10)).to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
criterion = nn.NLLLoss()
model, losses, accuracies = train_val_model(model, criterion, optimizer, dataloaders,
num_epochs=100, log_interval=20)
_ = plt.plot(losses['train'], '-b', losses['val'], '--r')
Epoch 0/99
----------
train Loss: 2.3040 Acc: 10.00%
val Loss: 2.3025 Acc: 10.32%
Epoch 20/99
----------
train Loss: 2.2963 Acc: 10.00%
val Loss: 2.2969 Acc: 10.35%
Epoch 40/99
----------
train Loss: 2.1158 Acc: 37.50%
val Loss: 2.0764 Acc: 38.06%
Epoch 60/99
----------
train Loss: 0.0875 Acc: 97.50%
val Loss: 0.7315 Acc: 80.50%
Epoch 80/99
----------
train Loss: 0.0023 Acc: 100.00%
val Loss: 1.0397 Acc: 81.69%
Training complete in 5m 38s
Best val Acc: 82.02%
CIFAR-10 dataset¶
from pathlib import Path
WD = os.path.join(Path.home(), "data", "pystatml", "dl_cifar10_pytorch")
os.makedirs(WD, exist_ok=True)
os.chdir(WD)
print("Working dir is:", os.getcwd())
os.makedirs("data", exist_ok=True)
os.makedirs("models", exist_ok=True)
import numpy as np
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Hyper-parameters
num_epochs = 5
learning_rate = 0.001
# Image preprocessing modules
transform = transforms.Compose([
transforms.Pad(4),
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32),
transforms.ToTensor()])
# CIFAR-10 dataset
train_dataset = torchvision.datasets.CIFAR10(root='data/',
train=True,
transform=transform,
download=True)
val_dataset = torchvision.datasets.CIFAR10(root='data/',
train=False,
transform=transforms.ToTensor())
# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=100,
shuffle=True)
val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
batch_size=100,
shuffle=False)
# Put together train and val
dataloaders = dict(train=train_loader, val=val_loader)
# Info about the dataset
data_shape = dataloaders["train"].dataset.data.shape[1:]
D_in = np.prod(data_shape)
D_out = len(set(dataloaders["train"].dataset.targets))
print("Datasets shape:", {x: dataloaders[x].dataset.data.shape for x in ['train', 'val']})
print("N input features:", D_in, "N output:", D_out)
Working dir is: /home/ed203246/data/pystatml/dl_cifar10_pytorch
Files already downloaded and verified
Datasets shape: {'train': (50000, 32, 32, 3), 'val': (10000, 32, 32, 3)}
N input features: 3072 N output: 10
LeNet¶
model = LeNet5((3, 6, 16, 1, 120, 84, D_out), debug=True)
batch_idx, (data_example, target_example) = next(enumerate(train_loader))
print(model)
_ = model(data_example)
LeNet5(
(conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
(fc1): Linear(in_features=1, out_features=120, bias=True)
(fc2): Linear(in_features=120, out_features=84, bias=True)
(fc3): Linear(in_features=84, out_features=10, bias=True)
)
### DEBUG: Shape of last convnet= torch.Size([16, 6, 6]) . FC size= 576
Set 576 neurons to the first FC layer
SGD with momentum lr=0.001, momentum=0.5
model = LeNet5((3, 6, 16, 576, 120, 84, D_out)).to(device)
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.5)
criterion = nn.NLLLoss()
# Explore the model
for parameter in model.parameters():
print(parameter.shape)
print("Total number of parameters =", np.sum([np.prod(parameter.shape) for parameter in model.parameters()]))
model, losses, accuracies = train_val_model(model, criterion, optimizer, dataloaders,
num_epochs=25, log_interval=5)
_ = plt.plot(losses['train'], '-b', losses['val'], '--r')
torch.Size([6, 3, 5, 5])
torch.Size([6])
torch.Size([16, 6, 5, 5])
torch.Size([16])
torch.Size([120, 576])
torch.Size([120])
torch.Size([84, 120])
torch.Size([84])
torch.Size([10, 84])
torch.Size([10])
Total number of parameters = 83126
Epoch 0/24
----------
train Loss: 2.3041 Acc: 10.00%
val Loss: 2.3033 Acc: 10.00%
Epoch 5/24
----------
train Loss: 2.2991 Acc: 11.18%
val Loss: 2.2983 Acc: 11.00%
Epoch 10/24
----------
train Loss: 2.2860 Acc: 10.36%
val Loss: 2.2823 Acc: 10.60%
Epoch 15/24
----------
train Loss: 2.1759 Acc: 18.83%
val Loss: 2.1351 Acc: 20.74%
Epoch 20/24
----------
train Loss: 2.0159 Acc: 25.35%
val Loss: 1.9878 Acc: 26.90%
Training complete in 7m 26s
Best val Acc: 28.98%
Increase learning rate and momentum lr=0.01, momentum=0.9
model = LeNet5((3, 6, 16, 576, 120, 84, D_out)).to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
criterion = nn.NLLLoss()
model, losses, accuracies = train_val_model(model, criterion, optimizer, dataloaders,
num_epochs=25, log_interval=5)
_ = plt.plot(losses['train'], '-b', losses['val'], '--r')
Epoch 0/24
----------
train Loss: 2.0963 Acc: 21.65%
val Loss: 1.8211 Acc: 33.49%
Epoch 5/24
----------
train Loss: 1.3500 Acc: 51.34%
val Loss: 1.2278 Acc: 56.40%
Epoch 10/24
----------
train Loss: 1.1569 Acc: 58.79%
val Loss: 1.0933 Acc: 60.95%
Epoch 15/24
----------
train Loss: 1.0724 Acc: 62.12%
val Loss: 0.9863 Acc: 65.34%
Epoch 20/24
----------
train Loss: 1.0131 Acc: 64.41%
val Loss: 0.9720 Acc: 66.14%
Training complete in 7m 17s
Best val Acc: 67.87%
Adaptative learning rate: Adam
model = LeNet5((3, 6, 16, 576, 120, 84, D_out)).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.NLLLoss()
model, losses, accuracies = train_val_model(model, criterion, optimizer, dataloaders,
num_epochs=25, log_interval=5)
_ = plt.plot(losses['train'], '-b', losses['val'], '--r')
Epoch 0/24
----------
train Loss: 1.8411 Acc: 30.21%
val Loss: 1.5768 Acc: 41.22%
Epoch 5/24
----------
train Loss: 1.3185 Acc: 52.17%
val Loss: 1.2181 Acc: 55.71%
Epoch 10/24
----------
train Loss: 1.1724 Acc: 57.89%
val Loss: 1.1244 Acc: 59.17%
Epoch 15/24
----------
train Loss: 1.0987 Acc: 60.98%
val Loss: 1.0153 Acc: 63.82%
Epoch 20/24
----------
train Loss: 1.0355 Acc: 63.01%
val Loss: 0.9901 Acc: 64.90%
Training complete in 7m 30s
Best val Acc: 66.88%
MiniVGGNet¶
model = MiniVGGNet(layers=(3, 16, 32, 1, 120, 84, D_out), debug=True)
print(model)
_ = model(data_example)
MiniVGGNet(
(conv11): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1))
(conv12): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1))
(conv21): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1))
(conv22): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(fc1): Linear(in_features=1, out_features=120, bias=True)
(fc2): Linear(in_features=120, out_features=84, bias=True)
(fc3): Linear(in_features=84, out_features=10, bias=True)
)
### DEBUG: Shape of last convnet= torch.Size([32, 6, 6]) . FC size= 1152
Set 1152 neurons to the first FC layer
SGD with large momentum and learning rate
model = MiniVGGNet((3, 16, 32, 1152, 120, 84, D_out)).to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
criterion = nn.NLLLoss()
model, losses, accuracies = train_val_model(model, criterion, optimizer, dataloaders,
num_epochs=25, log_interval=5)
_ = plt.plot(losses['train'], '-b', losses['val'], '--r')
Epoch 0/24
----------
train Loss: 2.3027 Acc: 10.14%
val Loss: 2.3010 Acc: 10.00%
Epoch 5/24
----------
train Loss: 1.4829 Acc: 46.08%
val Loss: 1.3860 Acc: 50.39%
Epoch 10/24
----------
train Loss: 1.0899 Acc: 61.43%
val Loss: 1.0121 Acc: 64.59%
Epoch 15/24
----------
train Loss: 0.8825 Acc: 69.02%
val Loss: 0.7788 Acc: 72.73%
Epoch 20/24
----------
train Loss: 0.7805 Acc: 72.73%
val Loss: 0.7222 Acc: 74.72%
Training complete in 15m 19s
Best val Acc: 76.62%
Adam
model = MiniVGGNet((3, 16, 32, 1152, 120, 84, D_out)).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.NLLLoss()
model, losses, accuracies = train_val_model(model, criterion, optimizer, dataloaders,
num_epochs=25, log_interval=5)
_ = plt.plot(losses['train'], '-b', losses['val'], '--r')
Epoch 0/24
----------
train Loss: 1.8591 Acc: 30.74%
val Loss: 1.5424 Acc: 43.46%
Epoch 5/24
----------
train Loss: 1.1562 Acc: 58.46%
val Loss: 1.0811 Acc: 61.87%
Epoch 10/24
----------
train Loss: 0.9630 Acc: 65.69%
val Loss: 0.8669 Acc: 68.94%
Epoch 15/24
----------
train Loss: 0.8634 Acc: 69.38%
val Loss: 0.7933 Acc: 72.33%
Epoch 20/24
----------
train Loss: 0.8033 Acc: 71.75%
val Loss: 0.7737 Acc: 73.57%
Training complete in 15m 37s
Best val Acc: 74.86%
ResNet¶
model = ResNet(ResidualBlock, [2, 2, 2], num_classes=D_out).to(device) # 195738 parameters
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.NLLLoss()
model, losses, accuracies = train_val_model(model, criterion, optimizer, dataloaders,
num_epochs=25, log_interval=5)
_ = plt.plot(losses['train'], '-b', losses['val'], '--r')
Epoch 0/24
----------
train Loss: 1.4169 Acc: 48.11%
val Loss: 1.5213 Acc: 48.08%
Epoch 5/24
----------
train Loss: 0.6279 Acc: 78.09%
val Loss: 0.6652 Acc: 77.49%
Epoch 10/24
----------
train Loss: 0.4772 Acc: 83.57%
val Loss: 0.5314 Acc: 82.09%
Epoch 15/24
----------
train Loss: 0.4010 Acc: 86.09%
val Loss: 0.6457 Acc: 79.03%
Epoch 20/24
----------
train Loss: 0.3435 Acc: 88.07%
val Loss: 0.4887 Acc: 84.34%
Training complete in 103m 30s
Best val Acc: 85.66%