Source code for avalanche.models.simple_cnn

################################################################################
# Copyright (c) 2021 ContinualAI.                                              #
# Copyrights licensed under the MIT License.                                   #
# See the accompanying LICENSE file for terms.                                 #
#                                                                              #
# Date: 1-05-2020                                                              #
# Author(s): Vincenzo Lomonaco, Antonio Carta                                  #
# E-mail: contact@continualai.org                                              #
# Website: avalanche.continualai.org                                           #
################################################################################

import torch.nn as nn

from avalanche.models.dynamic_modules import MultiTaskModule, \
    MultiHeadClassifier


[docs]class SimpleCNN(nn.Module): """ Convolutional Neural Network """
[docs] def __init__(self, num_classes=10): super(SimpleCNN, self).__init__() self.features = nn.Sequential( nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True), nn.Conv2d(32, 32, kernel_size=3, padding=0), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=2, stride=2), nn.Dropout(p=0.25), nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(64, 64, kernel_size=3, padding=0), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=2, stride=2), nn.Dropout(p=0.25), nn.Conv2d(64, 64, kernel_size=1, padding=0), nn.ReLU(inplace=True), nn.AdaptiveMaxPool2d(1), nn.Dropout(p=0.25) ) self.classifier = nn.Sequential( nn.Linear(64, num_classes) )
def forward(self, x): x = self.features(x) x = x.view(x.size(0), -1) x = self.classifier(x) return x
[docs]class MTSimpleCNN(SimpleCNN, MultiTaskModule): """Convolutional Neural Network with multi-head classifier"""
[docs] def __init__(self): super().__init__() self.classifier = MultiHeadClassifier(64)
def forward(self, x, task_labels): x = self.features(x) x = x.squeeze() x = self.classifier(x, task_labels) return x
__all__ = [ 'SimpleCNN', 'MTSimpleCNN' ]