Code for Network In Network

Implementation in 100 lines of Python ยท Network In Network
View on GitHub โ†’
Abstract (original paper)

We propose a novel deep network structure called "Network In Network" (NIN) to enhance model discriminability for local patches within the receptive field. The conventional convolutional layer uses linear filters followed by a nonlinear activation function to scan the input. Instead, we build micro neural networks with more complex structures to abstract the data within the receptive field. We instantiate the micro neural network with a multilayer perceptron, which is a potent function approximator. The feature maps are obtained by sliding the micro networks over the input in a similar manner as CNN; they are then fed into the next layer. Deep NIN can be implemented by stacking mutiple of the above described structure. With enhanced local modeling via the micro network, we are able to utilize global average pooling over feature maps in the classification layer, which is easier to interpret and less prone to overfitting than traditional fully connected layers. We demonstrated the state-of-the-art classification performances with NIN on CIFAR-10 and CIFAR-100, and reasonable performances on SVHN and MNIST datasets.

Source: Network In Network (2013-12-13). See: paper link.

Code

Network In Network in 100 lines (Python)

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from matplotlib import pyplot as plt
from keras.datasets.mnist import load_data
import seaborn as sns
sns.set_theme()

# load (and normalize) mnist dataset
(trainX, trainy), (testX, testy) = load_data()
trainX = np.float32(trainX) / 127.5 - 1.
testX = np.float32(testX) / 127.5 - 1.


class NiN(nn.Module):

    def __init__(self):
        super(NiN, self).__init__()

        conv1 = nn.Conv2d(1, 96, 5, padding=2)
        nn.init.normal_(conv1.weight, mean=0.0, std=0.05)
        cccp1 = nn.Conv2d(96, 64, 1)
        nn.init.normal_(cccp1.weight, mean=0.0, std=0.05)
        cccp2 = nn.Conv2d(64, 48, 1)
        nn.init.normal_(cccp2.weight, mean=0.0, std=0.05)
        conv2 = nn.Conv2d(48, 128, 5, padding=2)
        nn.init.normal_(conv2.weight, mean=0.0, std=0.05)
        cccp3 = nn.Conv2d(128, 96, 1)
        nn.init.normal_(cccp3.weight, mean=0.0, std=0.05)
        cccp4 = nn.Conv2d(96, 48, 1)
        nn.init.normal_(cccp4.weight, mean=0.0, std=0.05)
        conv3 = nn.Conv2d(48, 128, 5, padding=2)
        nn.init.normal_(conv3.weight, mean=0.0, std=0.05)
        cccp5 = nn.Conv2d(128, 96, 1)
        nn.init.normal_(cccp5.weight, mean=0.0, std=0.05)
        cccp6 = nn.Conv2d(96, 10, 1)
        nn.init.normal_(cccp6.weight, mean=0.0, std=0.05)

        self.model = nn.Sequential(conv1, nn.ReLU(),
                                   cccp1, nn.ReLU(),
                                   cccp2, nn.ReLU(),
                                   nn.MaxPool2d(3, stride=2, padding=1),
                                   nn.Dropout(p=0.5),
                                   conv2, nn.ReLU(),
                                   cccp3, nn.ReLU(),
                                   cccp4, nn.ReLU(),
                                   nn.MaxPool2d(3, stride=2, padding=1),
                                   nn.Dropout(p=0.5),
                                   conv3, nn.ReLU(),
                                   cccp5, nn.ReLU(),
                                   cccp6,
                                   torch.nn.AvgPool2d(7, stride=1, padding=0))
        self.logsoftmax = nn.LogSoftmax(dim=-1)

    def forward(self, x):
        log_prob = self.logsoftmax(self.model(x).squeeze(-1).squeeze(-1))
        return log_prob


def train(model, optimizer, scheduler, loss_fct, batch_size, trainX, trainy, 
          testX, testy, device):

    testing_accuracy = []
    initial_lr = optimizer.param_groups[0]['lr']
    while (optimizer.param_groups[0]['lr'] > (0.01 * initial_lr + 1e-15)):

        # Sample batch
        idx = torch.randperm(trainX.shape[0])

        train_batch_accuracy = 0.
        for indices in idx.chunk(int(np.ceil(trainX.shape[0] / batch_size))):
            x = trainX[indices]
            y = trainy[indices]

            log_prob = model(torch.from_numpy(x).unsqueeze(1).to(device))
            loss = loss_fct(log_prob, torch.from_numpy(y).to(device))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_batch_accuracy += ((log_prob.argmax(-1) == torch.from_numpy(
                y).to(device)).sum().item() / y.shape[0])
        scheduler.step(train_batch_accuracy)

        # Testing
        model.eval()
        log_prob = model(torch.from_numpy(testX).unsqueeze(1).to(device))
        testing_accuracy.append((log_prob.argmax(-1) == torch.from_numpy(
            testy).to(device)).sum().item() / testy.shape[0])
        model.train()

    return testing_accuracy


if __name__ == "__main__":
    device = 'cuda'
    model = NiN().to(device)
    loss_fct = torch.nn.NLLLoss()
    optimizer = optim.SGD(model.parameters(), lr=1e-2, weight_decay=1e-5,
                          momentum=0.9)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max')
    testing_accuracy = train(model, optimizer, scheduler, loss_fct, 128,
                             trainX, trainy, testX, testy, device)

    plt.plot((1 - np.array(testing_accuracy)) * 100)
    plt.gca().set_ylim([0.47 - 0.15, 0.47 + 0.45])
    plt.xlabel('Epoch', fontsize=13)
    plt.ylabel('Test Error (%)', fontsize=13)
    plt.savefig('Imgs/nin.png', bbox_inches='tight')
    plt.close()

python implementation Network In Network in 100 lines