Source code for avalanche.models.icarl_resnet

from typing import Union, Sequence, Callable

import torch
from torch.nn import Module, Sequential, BatchNorm2d, Conv2d, ReLU,\
    ConstantPad3d, Identity, AdaptiveAvgPool2d, Linear
from torch import Tensor
from torch.nn.init import zeros_, kaiming_normal_
from torch.nn.modules.flatten import Flatten
import torch.nn.functional as F


class IdentityShortcut(Module):
    def __init__(self, transform_function: Callable[[Tensor], Tensor]):
        super(IdentityShortcut, self).__init__()
        self.transform_function = transform_function

    def forward(self, x: Tensor) -> Tensor:
        return self.transform_function(x)


def conv3x3(in_planes: int, out_planes: int,
            stride: Union[int, Sequence[int]] = 1):
    return Conv2d(in_planes, out_planes,
                  kernel_size=3, stride=stride, padding=1, bias=False)


def batch_norm(num_channels: int) -> BatchNorm2d:
    return BatchNorm2d(num_channels)


class ResidualBlock(Module):

    def __init__(self, input_num_filters: int,
                 increase_dim: bool = False,
                 projection: bool = False,
                 last: bool = False):
        super().__init__()
        self.last: bool = last

        if increase_dim:
            first_stride = (2, 2)
            out_num_filters = input_num_filters * 2
        else:
            first_stride = (1, 1)
            out_num_filters = input_num_filters

        self.direct = Sequential(
            conv3x3(input_num_filters, out_num_filters, stride=first_stride),
            batch_norm(out_num_filters),
            ReLU(True),
            conv3x3(out_num_filters, out_num_filters, stride=(1, 1)),
            batch_norm(out_num_filters),
        )

        self.shortcut: Module

        # add shortcut connections
        if increase_dim:
            if projection:
                # projection shortcut, as option B in paper
                self.shortcut = Sequential(
                    Conv2d(input_num_filters, out_num_filters,
                           kernel_size=(1, 1), stride=(2, 2), bias=False),
                    batch_norm(out_num_filters)
                )
            else:
                # identity shortcut, as option A in paper
                self.shortcut = Sequential(
                    IdentityShortcut(lambda x: x[:, :, ::2, ::2]),
                    ConstantPad3d((0, 0, 0, 0,
                                   out_num_filters // 4,
                                   out_num_filters // 4), 0.0)
                )
        else:
            self.shortcut = Identity()

    def forward(self, x):
        if self.last:
            return self.direct(x) + self.shortcut(x)
        else:
            return torch.relu(self.direct(x) + self.shortcut(x))


class IcarlNet(Module):
    def __init__(self, num_classes: int, n=5, c=3):
        super().__init__()

        self.is_train = True
        input_dims = c
        output_dims = 16

        first_conv = Sequential(
            conv3x3(input_dims, output_dims, stride=(1, 1)),
            batch_norm(16),
            ReLU(True)
        )

        input_dims = output_dims
        output_dims = 16

        # first stack of residual blocks, output is 16 x 32 x 32
        layers_list = []
        for _ in range(n):
            layers_list.append(ResidualBlock(input_dims))
        first_block = Sequential(*layers_list)

        input_dims = output_dims
        output_dims = 32

        # second stack of residual blocks, output is 32 x 16 x 16
        layers_list = [ResidualBlock(input_dims, increase_dim=True)]
        for _ in range(1, n):
            layers_list.append(ResidualBlock(output_dims))
        second_block = Sequential(*layers_list)

        input_dims = output_dims
        output_dims = 64

        # third stack of residual blocks, output is 64 x 8 x 8
        layers_list = [ResidualBlock(input_dims, increase_dim=True)]
        for _ in range(1, n-1):
            layers_list.append(ResidualBlock(output_dims))
        layers_list.append(ResidualBlock(output_dims, last=True))
        third_block = Sequential(*layers_list)
        final_pool = AdaptiveAvgPool2d(output_size=(1, 1))

        self.feature_extractor = Sequential(
            first_conv, first_block,
            second_block, third_block,
            final_pool, Flatten())

        input_dims = output_dims
        output_dims = num_classes

        self.classifier = Linear(input_dims, output_dims)

    def forward(self, x):
        x = self.feature_extractor(x)  # Already flattened
        x = self.classifier(x)
        return x


[docs]def make_icarl_net(num_classes: int, n=5, c=3) -> IcarlNet: """Create :py:class:`IcarlNet` network, the ResNet used in ICarl. :param num_classes: number of classes, network output size :param n: depth of each residual blocks stack :param c: number of input channels """ return IcarlNet(num_classes, n=n, c=c)
[docs]def initialize_icarl_net(m: Module): """Initialize the input network based on `kaiming_normal` with `mode=fan_in` for `Conv2d` and `Linear` blocks. Biases are initialized to zero. :param m: input network (should be IcarlNet). """ if isinstance(m, Conv2d): kaiming_normal_(m.weight.data, mode='fan_in', nonlinearity='relu') if m.bias is not None: zeros_(m.bias.data) elif isinstance(m, Linear): kaiming_normal_(m.weight.data, mode='fan_in', nonlinearity='sigmoid') if m.bias is not None: zeros_(m.bias.data)
__all__ = [ 'initialize_icarl_net', 'make_icarl_net', 'IcarlNet' ]