Source code for avalanche.models.icarl_resnet

from typing import Union, Sequence, Callable

import torch
from torch.nn import (
    Module,
    Sequential,
    BatchNorm2d,
    Conv2d,
    ReLU,
    ConstantPad3d,
    Identity,
    AdaptiveAvgPool2d,
    Linear,
)
from torch import Tensor
from torch.nn.init import zeros_, kaiming_normal_
from torch.nn.modules.flatten import Flatten
import torch.nn.functional as F


class IdentityShortcut(Module):
    def __init__(self, transform_function: Callable[[Tensor], Tensor]):
        super(IdentityShortcut, self).__init__()
        self.transform_function = transform_function

    def forward(self, x: Tensor) -> Tensor:
        return self.transform_function(x)


def conv3x3(in_planes: int, out_planes: int, stride: Union[int, Sequence[int]] = 1):
    return Conv2d(
        in_planes,
        out_planes,
        kernel_size=3,
        stride=stride,
        padding=1,
        bias=False,
    )


def batch_norm(num_channels: int) -> BatchNorm2d:
    return BatchNorm2d(num_channels)


class ResidualBlock(Module):
    def __init__(
        self,
        input_num_filters: int,
        increase_dim: bool = False,
        projection: bool = False,
        last: bool = False,
    ):
        super().__init__()
        self.last: bool = last

        if increase_dim:
            first_stride = (2, 2)
            out_num_filters = input_num_filters * 2
        else:
            first_stride = (1, 1)
            out_num_filters = input_num_filters

        self.direct = Sequential(
            conv3x3(input_num_filters, out_num_filters, stride=first_stride),
            batch_norm(out_num_filters),
            ReLU(True),
            conv3x3(out_num_filters, out_num_filters, stride=(1, 1)),
            batch_norm(out_num_filters),
        )

        self.shortcut: Module

        # add shortcut connections
        if increase_dim:
            if projection:
                # projection shortcut, as option B in paper
                self.shortcut = Sequential(
                    Conv2d(
                        input_num_filters,
                        out_num_filters,
                        kernel_size=(1, 1),
                        stride=(2, 2),
                        bias=False,
                    ),
                    batch_norm(out_num_filters),
                )
            else:
                # identity shortcut, as option A in paper
                self.shortcut = Sequential(
                    IdentityShortcut(lambda x: x[:, :, ::2, ::2]),
                    ConstantPad3d(
                        (
                            0,
                            0,
                            0,
                            0,
                            out_num_filters // 4,
                            out_num_filters // 4,
                        ),
                        0.0,
                    ),
                )
        else:
            self.shortcut = Identity()

    def forward(self, x):
        if self.last:
            return self.direct(x) + self.shortcut(x)
        else:
            return torch.relu(self.direct(x) + self.shortcut(x))


[docs]class IcarlNet(Module):
[docs] def __init__(self, num_classes: int, n=5, c=3): super().__init__() self.is_train = True input_dims = c output_dims = 16 first_conv = Sequential( conv3x3(input_dims, output_dims, stride=(1, 1)), batch_norm(16), ReLU(True), ) input_dims = output_dims output_dims = 16 # first stack of residual blocks, output is 16 x 32 x 32 layers_list = [] for _ in range(n): layers_list.append(ResidualBlock(input_dims)) first_block = Sequential(*layers_list) input_dims = output_dims output_dims = 32 # second stack of residual blocks, output is 32 x 16 x 16 layers_list = [ResidualBlock(input_dims, increase_dim=True)] for _ in range(1, n): layers_list.append(ResidualBlock(output_dims)) second_block = Sequential(*layers_list) input_dims = output_dims output_dims = 64 # third stack of residual blocks, output is 64 x 8 x 8 layers_list = [ResidualBlock(input_dims, increase_dim=True)] for _ in range(1, n - 1): layers_list.append(ResidualBlock(output_dims)) layers_list.append(ResidualBlock(output_dims, last=True)) third_block = Sequential(*layers_list) final_pool = AdaptiveAvgPool2d(output_size=(1, 1)) self.feature_extractor = Sequential( first_conv, first_block, second_block, third_block, final_pool, Flatten(), ) input_dims = output_dims output_dims = num_classes self.classifier = Linear(input_dims, output_dims)
def forward(self, x): x = self.feature_extractor(x) # Already flattened x = self.classifier(x) return x
[docs]def make_icarl_net(num_classes: int, n=5, c=3) -> IcarlNet: """Create :py:class:`IcarlNet` network, the ResNet used in ICarl. :param num_classes: number of classes, network output size :param n: depth of each residual blocks stack :param c: number of input channels """ return IcarlNet(num_classes, n=n, c=c)
[docs]def initialize_icarl_net(m: Module): """Initialize the input network based on `kaiming_normal` with `mode=fan_in` for `Conv2d` and `Linear` blocks. Biases are initialized to zero. :param m: input network (should be IcarlNet). """ if isinstance(m, Conv2d): kaiming_normal_(m.weight.data, mode="fan_in", nonlinearity="relu") if m.bias is not None: zeros_(m.bias.data) elif isinstance(m, Linear): kaiming_normal_(m.weight.data, mode="fan_in", nonlinearity="sigmoid") if m.bias is not None: zeros_(m.bias.data)
__all__ = ["initialize_icarl_net", "make_icarl_net", "IcarlNet"]