from typing import Union, Sequence, Callable
import torch
from torch.nn import Module, Sequential, BatchNorm2d, Conv2d, ReLU,\
ConstantPad3d, Identity, AdaptiveAvgPool2d, Linear
from torch import Tensor
from torch.nn.init import zeros_, kaiming_normal_
from torch.nn.modules.flatten import Flatten
import torch.nn.functional as F
class IdentityShortcut(Module):
def __init__(self, transform_function: Callable[[Tensor], Tensor]):
super(IdentityShortcut, self).__init__()
self.transform_function = transform_function
def forward(self, x: Tensor) -> Tensor:
return self.transform_function(x)
def conv3x3(in_planes: int, out_planes: int,
stride: Union[int, Sequence[int]] = 1):
return Conv2d(in_planes, out_planes,
kernel_size=3, stride=stride, padding=1, bias=False)
def batch_norm(num_channels: int) -> BatchNorm2d:
return BatchNorm2d(num_channels)
class ResidualBlock(Module):
def __init__(self, input_num_filters: int,
increase_dim: bool = False,
projection: bool = False,
last: bool = False):
super().__init__()
self.last: bool = last
if increase_dim:
first_stride = (2, 2)
out_num_filters = input_num_filters * 2
else:
first_stride = (1, 1)
out_num_filters = input_num_filters
self.direct = Sequential(
conv3x3(input_num_filters, out_num_filters, stride=first_stride),
batch_norm(out_num_filters),
ReLU(True),
conv3x3(out_num_filters, out_num_filters, stride=(1, 1)),
batch_norm(out_num_filters),
)
self.shortcut: Module
# add shortcut connections
if increase_dim:
if projection:
# projection shortcut, as option B in paper
self.shortcut = Sequential(
Conv2d(input_num_filters, out_num_filters,
kernel_size=(1, 1), stride=(2, 2), bias=False),
batch_norm(out_num_filters)
)
else:
# identity shortcut, as option A in paper
self.shortcut = Sequential(
IdentityShortcut(lambda x: x[:, :, ::2, ::2]),
ConstantPad3d((0, 0, 0, 0,
out_num_filters // 4,
out_num_filters // 4), 0.0)
)
else:
self.shortcut = Identity()
def forward(self, x):
if self.last:
return self.direct(x) + self.shortcut(x)
else:
return torch.relu(self.direct(x) + self.shortcut(x))
class IcarlNet(Module):
def __init__(self, num_classes: int, n=5, c=3):
super().__init__()
self.is_train = True
input_dims = c
output_dims = 16
first_conv = Sequential(
conv3x3(input_dims, output_dims, stride=(1, 1)),
batch_norm(16),
ReLU(True)
)
input_dims = output_dims
output_dims = 16
# first stack of residual blocks, output is 16 x 32 x 32
layers_list = []
for _ in range(n):
layers_list.append(ResidualBlock(input_dims))
first_block = Sequential(*layers_list)
input_dims = output_dims
output_dims = 32
# second stack of residual blocks, output is 32 x 16 x 16
layers_list = [ResidualBlock(input_dims, increase_dim=True)]
for _ in range(1, n):
layers_list.append(ResidualBlock(output_dims))
second_block = Sequential(*layers_list)
input_dims = output_dims
output_dims = 64
# third stack of residual blocks, output is 64 x 8 x 8
layers_list = [ResidualBlock(input_dims, increase_dim=True)]
for _ in range(1, n-1):
layers_list.append(ResidualBlock(output_dims))
layers_list.append(ResidualBlock(output_dims, last=True))
third_block = Sequential(*layers_list)
final_pool = AdaptiveAvgPool2d(output_size=(1, 1))
self.feature_extractor = Sequential(
first_conv, first_block,
second_block, third_block,
final_pool, Flatten())
input_dims = output_dims
output_dims = num_classes
self.classifier = Linear(input_dims, output_dims)
def forward(self, x):
x = self.feature_extractor(x) # Already flattened
x = self.classifier(x)
return x
[docs]def make_icarl_net(num_classes: int, n=5, c=3) -> IcarlNet:
"""Create :py:class:`IcarlNet` network, the ResNet used in
ICarl.
:param num_classes: number of classes, network output size
:param n: depth of each residual blocks stack
:param c: number of input channels
"""
return IcarlNet(num_classes, n=n, c=c)
[docs]def initialize_icarl_net(m: Module):
"""Initialize the input network based on `kaiming_normal`
with `mode=fan_in` for `Conv2d` and `Linear` blocks.
Biases are initialized to zero.
:param m: input network (should be IcarlNet).
"""
if isinstance(m, Conv2d):
kaiming_normal_(m.weight.data, mode='fan_in', nonlinearity='relu')
if m.bias is not None:
zeros_(m.bias.data)
elif isinstance(m, Linear):
kaiming_normal_(m.weight.data, mode='fan_in', nonlinearity='sigmoid')
if m.bias is not None:
zeros_(m.bias.data)
__all__ = [
'initialize_icarl_net',
'make_icarl_net',
'IcarlNet'
]