Source code for avalanche.models.pnn

import torch
import torch.nn.functional as F
from torch import nn

from avalanche.benchmarks.utils import AvalancheDataset
from avalanche.benchmarks.utils.dataset_utils import ConstantSequence
from avalanche.models import MultiTaskModule, DynamicModule
from avalanche.models import MultiHeadClassifier


class LinearAdapter(nn.Module):
    """
    Linear adapter for Progressive Neural Networks.
    """
    def __init__(self, in_features, out_features_per_column, num_prev_modules):
        """
        :param in_features: size of each input sample
        :param out_features_per_column: size of each output sample
        :param num_prev_modules: number of previous modules
        """
        super().__init__()
        # Eq. 1 - lateral connections
        # one layer for each previous column. Empty for the first task.
        self.lat_layers = nn.ModuleList([])
        for _ in range(num_prev_modules):
            m = nn.Linear(in_features, out_features_per_column)
            self.lat_layers.append(m)

    def forward(self, x):
        assert len(x) == self.num_prev_modules
        hs = []
        for ii, lat in enumerate(self.lat_layers):
            hs.append(lat(x[ii]))
        return sum(hs)


class MLPAdapter(nn.Module):
    """
     MLP adapter for Progressive Neural Networks.
    """
    def __init__(self, in_features, out_features_per_column, num_prev_modules,
                 activation=F.relu):
        """
        :param in_features: size of each input sample
        :param out_features_per_column: size of each output sample
        :param num_prev_modules: number of previous modules
        :param activation: activation function (default=ReLU)
        """
        super().__init__()
        self.num_prev_modules = num_prev_modules
        self.activation = activation

        if num_prev_modules == 0:
            return  # first adapter is empty

        # Eq. 2 - MLP adapter. Not needed for the first task.
        self.V = nn.Linear(in_features * num_prev_modules,
                           out_features_per_column)
        self.alphas = nn.Parameter(torch.randn(num_prev_modules))
        self.U = nn.Linear(out_features_per_column, out_features_per_column)

    def forward(self, x):
        if self.num_prev_modules == 0:
            return 0  # first adapter is empty

        assert len(x) == self.num_prev_modules
        assert len(x[0].shape) == 2, \
            "Inputs to MLPAdapter should have two dimensions: " \
            "<batch_size, num_features>."
        for i, el in enumerate(x):
            x[i] = self.alphas[i] * el
        x = torch.cat(x, dim=1)
        x = self.U(self.activation(self.V(x)))
        return x


class PNNColumn(nn.Module):
    """
    Progressive Neural Network column.
    """
    def __init__(self, in_features, out_features_per_column, num_prev_modules,
                 adapter='mlp'):
        """
        :param in_features: size of each input sample
        :param out_features_per_column:
            size of each output sample (single column)
        :param num_prev_modules: number of previous columns
        :param adapter: adapter type. One of {'linear', 'mlp'} (default='mlp')
        """
        super().__init__()
        self.in_features = in_features
        self.out_features_per_column = out_features_per_column
        self.num_prev_modules = num_prev_modules

        self.itoh = nn.Linear(in_features, out_features_per_column)
        if adapter == 'linear':
            self.adapter = LinearAdapter(in_features, out_features_per_column,
                                         num_prev_modules)
        elif adapter == 'mlp':
            self.adapter = MLPAdapter(in_features, out_features_per_column,
                                      num_prev_modules)
        else:
            raise ValueError("`adapter` must be one of: {'mlp', `linear'}.")

    def freeze(self):
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, x):
        prev_xs, last_x = x[:-1], x[-1]
        hs = self.adapter(prev_xs)
        hs += self.itoh(last_x)
        return hs


class PNNLayer(MultiTaskModule):
    """ Progressive Neural Network layer.

        The adaptation phase assumes that each experience is a separate task.
        Multiple experiences with the same task label or multiple task labels
        within the same experience will result in a runtime error.
        """
    def __init__(self, in_features, out_features_per_column, adapter='mlp'):
        """
        :param in_features: size of each input sample
        :param out_features_per_column:
            size of each output sample (single column)
        :param adapter: adapter type. One of {'linear', 'mlp'} (default='mlp')
        """
        super().__init__()
        self.in_features = in_features
        self.out_features_per_column = out_features_per_column
        self.adapter = adapter

        # convert from task label to module list order
        self.task_to_module_idx = {}
        first_col = PNNColumn(in_features, out_features_per_column,
                              0, adapter=adapter)
        self.columns = nn.ModuleList([first_col])

    @property
    def num_columns(self):
        return len(self.columns)

    def train_adaptation(self, dataset: AvalancheDataset):
        """ Training adaptation for PNN layer.

        Adds an additional column to the layer.

        :param dataset:
        :return:
        """
        super().train_adaptation(dataset)
        task_labels = dataset.targets_task_labels
        if isinstance(task_labels, ConstantSequence):
            # task label is unique. Don't check duplicates.
            task_labels = [task_labels[0]]
        else:
            task_labels = set(task_labels)
        assert len(task_labels) == 1, \
            "PNN assumes a single task for each experience. Please use a " \
            "compatible benchmark."
        # extract task label from set
        task_label = next(iter(task_labels))
        assert task_label not in self.task_to_module_idx, \
            "A new experience is using a previously seen task label. This is " \
            "not compatible with PNN, which assumes different task labels for" \
            " each training experience."

        if len(self.task_to_module_idx) == 0:
            # we have already initialized the first column.
            # No need to call add_column here.
            self.task_to_module_idx[task_label] = 0
        else:
            self.task_to_module_idx[task_label] = self.num_columns
            self._add_column()

    def _add_column(self):
        """ Add a new column. """
        # Freeze old parameters
        for param in self.parameters():
            param.requires_grad = False
        self.columns.append(PNNColumn(self.in_features,
                                      self.out_features_per_column,
                                      self.num_columns,
                                      adapter=self.adapter))

    def forward_single_task(self, x, task_label):
        """ Forward.

        :param x: list of inputs.
        :param task_label:
        :return:
        """
        col_idx = self.task_to_module_idx[task_label]
        hs = []
        for ii in range(col_idx + 1):
            hs.append(self.columns[ii](x[:ii+1]))
        return hs


[docs]class PNN(MultiTaskModule): """ Progressive Neural Network. The model assumes that each experience is a separate task. Multiple experiences with the same task label or multiple task labels within the same experience will result in a runtime error. """
[docs] def __init__(self, num_layers=1, in_features=784, hidden_features_per_column=100, adapter='mlp'): """ :param num_layers: number of layers (default=1) :param in_features: size of each input sample :param hidden_features_per_column: number of hidden units for each column :param adapter: adapter type. One of {'linear', 'mlp'} (default='mlp') """ super().__init__() assert num_layers >= 1 self.num_layers = num_layers self.in_features = in_features self.out_features_per_columns = hidden_features_per_column self.layers = nn.ModuleList() self.layers.append(PNNLayer(in_features, hidden_features_per_column)) for _ in range(num_layers - 1): lay = PNNLayer(hidden_features_per_column, hidden_features_per_column, adapter=adapter) self.layers.append(lay) self.classifier = MultiHeadClassifier(hidden_features_per_column)
def forward_single_task(self, x, task_label): """ Forward. :param x: :param task_label: :return: """ x = x.contiguous() x = x.view(x.size(0), self.in_features) num_columns = self.layers[0].num_columns col_idx = self.layers[-1].task_to_module_idx[task_label] x = [x for _ in range(num_columns)] for lay in self.layers: x = [F.relu(el) for el in lay(x, task_label)] return self.classifier(x[col_idx], task_label)