################################################################################
# Copyright (c) 2021 ContinualAI. #
# Copyrights licensed under the MIT License. #
# See the accompanying LICENSE file for terms. #
# #
# Date: 1-05-2020 #
# Author(s): Vincenzo Lomonaco, Antonio Carta #
# E-mail: contact@continualai.org #
# Website: avalanche.continualai.org #
################################################################################
import torch.nn as nn
from avalanche.models.dynamic_modules import MultiTaskModule, \
MultiHeadClassifier
from avalanche.models.base_model import BaseModel
[docs]class SimpleMLP(nn.Module, BaseModel):
"""
Multi-Layer Perceptron with custom parameters.
It can be configured to have multiple layers and dropout.
"""
[docs] def __init__(self, num_classes=10, input_size=28 * 28,
hidden_size=512, hidden_layers=1, drop_rate=0.5):
"""
:param num_classes: output size
:param input_size: input size
:param hidden_size: hidden layer size
:param hidden_layers: number of hidden layers
:param drop_rate: dropout rate. 0 to disable
"""
super().__init__()
layers = nn.Sequential(*(nn.Linear(input_size, hidden_size),
nn.ReLU(inplace=True),
nn.Dropout(p=drop_rate)))
for layer_idx in range(hidden_layers - 1):
layers.add_module(
f"fc{layer_idx + 1}", nn.Sequential(
*(nn.Linear(hidden_size, hidden_size),
nn.ReLU(inplace=True),
nn.Dropout())))
self.features = nn.Sequential(*layers)
self.classifier = nn.Linear(hidden_size, num_classes)
self._input_size = input_size
def forward(self, x):
x = x.contiguous()
x = x.view(x.size(0), self._input_size)
x = self.features(x)
x = self.classifier(x)
return x
def get_features(self, x):
x = x.contiguous()
x = x.view(x.size(0), self._input_size)
x = self.features(x)
return x
[docs]class MTSimpleMLP(MultiTaskModule):
"""Multi-layer perceptron with multi-head classifier"""
[docs] def __init__(self, input_size=28 * 28, hidden_size=512):
super().__init__()
self.features = nn.Sequential(
nn.Linear(input_size, hidden_size),
nn.ReLU(inplace=True),
nn.Dropout(),
)
self.classifier = MultiHeadClassifier(hidden_size)
self._input_size = input_size
def forward(self, x, task_labels):
x = x.contiguous()
x = x.view(x.size(0), self._input_size)
x = self.features(x)
x = self.classifier(x, task_labels)
return x
__all__ = [
'SimpleMLP',
'MTSimpleMLP'
]