from typing import Union, Sequence, Callable
import torch
from torch.nn import (
Module,
Sequential,
BatchNorm2d,
Conv2d,
ReLU,
ConstantPad3d,
Identity,
AdaptiveAvgPool2d,
Linear,
)
from torch import Tensor
from torch.nn.init import zeros_, kaiming_normal_
from torch.nn.modules.flatten import Flatten
import torch.nn.functional as F
class IdentityShortcut(Module):
def __init__(self, transform_function: Callable[[Tensor], Tensor]):
super(IdentityShortcut, self).__init__()
self.transform_function = transform_function
def forward(self, x: Tensor) -> Tensor:
return self.transform_function(x)
def conv3x3(
in_planes: int, out_planes: int, stride: Union[int, Sequence[int]] = 1
):
return Conv2d(
in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=1,
bias=False,
)
def batch_norm(num_channels: int) -> BatchNorm2d:
return BatchNorm2d(num_channels)
class ResidualBlock(Module):
def __init__(
self,
input_num_filters: int,
increase_dim: bool = False,
projection: bool = False,
last: bool = False,
):
super().__init__()
self.last: bool = last
if increase_dim:
first_stride = (2, 2)
out_num_filters = input_num_filters * 2
else:
first_stride = (1, 1)
out_num_filters = input_num_filters
self.direct = Sequential(
conv3x3(input_num_filters, out_num_filters, stride=first_stride),
batch_norm(out_num_filters),
ReLU(True),
conv3x3(out_num_filters, out_num_filters, stride=(1, 1)),
batch_norm(out_num_filters),
)
self.shortcut: Module
# add shortcut connections
if increase_dim:
if projection:
# projection shortcut, as option B in paper
self.shortcut = Sequential(
Conv2d(
input_num_filters,
out_num_filters,
kernel_size=(1, 1),
stride=(2, 2),
bias=False,
),
batch_norm(out_num_filters),
)
else:
# identity shortcut, as option A in paper
self.shortcut = Sequential(
IdentityShortcut(lambda x: x[:, :, ::2, ::2]),
ConstantPad3d(
(
0,
0,
0,
0,
out_num_filters // 4,
out_num_filters // 4,
),
0.0,
),
)
else:
self.shortcut = Identity()
def forward(self, x):
if self.last:
return self.direct(x) + self.shortcut(x)
else:
return torch.relu(self.direct(x) + self.shortcut(x))
[docs]class IcarlNet(Module):
[docs] def __init__(self, num_classes: int, n=5, c=3):
super().__init__()
self.is_train = True
input_dims = c
output_dims = 16
first_conv = Sequential(
conv3x3(input_dims, output_dims, stride=(1, 1)),
batch_norm(16),
ReLU(True),
)
input_dims = output_dims
output_dims = 16
# first stack of residual blocks, output is 16 x 32 x 32
layers_list = []
for _ in range(n):
layers_list.append(ResidualBlock(input_dims))
first_block = Sequential(*layers_list)
input_dims = output_dims
output_dims = 32
# second stack of residual blocks, output is 32 x 16 x 16
layers_list = [ResidualBlock(input_dims, increase_dim=True)]
for _ in range(1, n):
layers_list.append(ResidualBlock(output_dims))
second_block = Sequential(*layers_list)
input_dims = output_dims
output_dims = 64
# third stack of residual blocks, output is 64 x 8 x 8
layers_list = [ResidualBlock(input_dims, increase_dim=True)]
for _ in range(1, n - 1):
layers_list.append(ResidualBlock(output_dims))
layers_list.append(ResidualBlock(output_dims, last=True))
third_block = Sequential(*layers_list)
final_pool = AdaptiveAvgPool2d(output_size=(1, 1))
self.feature_extractor = Sequential(
first_conv,
first_block,
second_block,
third_block,
final_pool,
Flatten(),
)
input_dims = output_dims
output_dims = num_classes
self.classifier = Linear(input_dims, output_dims)
def forward(self, x):
x = self.feature_extractor(x) # Already flattened
x = self.classifier(x)
return x
[docs]def make_icarl_net(num_classes: int, n=5, c=3) -> IcarlNet:
"""Create :py:class:`IcarlNet` network, the ResNet used in
ICarl.
:param num_classes: number of classes, network output size
:param n: depth of each residual blocks stack
:param c: number of input channels
"""
return IcarlNet(num_classes, n=n, c=c)
[docs]def initialize_icarl_net(m: Module):
"""Initialize the input network based on `kaiming_normal`
with `mode=fan_in` for `Conv2d` and `Linear` blocks.
Biases are initialized to zero.
:param m: input network (should be IcarlNet).
"""
if isinstance(m, Conv2d):
kaiming_normal_(m.weight.data, mode="fan_in", nonlinearity="relu")
if m.bias is not None:
zeros_(m.bias.data)
elif isinstance(m, Linear):
kaiming_normal_(m.weight.data, mode="fan_in", nonlinearity="sigmoid")
if m.bias is not None:
zeros_(m.bias.data)
__all__ = ["initialize_icarl_net", "make_icarl_net", "IcarlNet"]