Source code for avalanche.models.packnet

"""
Implements (Mallya & Lazebnik, 2018) PackNet algorithm for fixed-network
parameter isolation. PackNet is a task-incremental learning algorithm that
uses task identities during testing.

Mallya, A., & Lazebnik, S. (2018). PackNet: Adding Multiple Tasks to a
Single Network by Iterative Pruning. 2018 IEEE/CVF Conference on Computer
Vision and Pattern Recognition, 7765-7773.
https://doi.org/10.1109/CVPR.2018.00810
"""

import typing as t

import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import functional as F
from abc import ABC, abstractmethod
from enum import Enum
from avalanche.core import BaseSGDPlugin, Template
from avalanche.models.dynamic_modules import MultiTaskModule
from avalanche.models.simple_mlp import SimpleMLP
from typing import Union
from avalanche.training.templates.base_sgd import BaseSGDTemplate


class PackNetModule(ABC, nn.Module):
    """Defines the interface for implementing PackNet compatible PyTorch modules.

    The core idea of PackNet is to build a single network containing multiple
    task-specific subsets. Each subset builds on the previous subset and
    therefore shares parameters with the previous subset. But only the
    parameters not shared with the previous subset are mutable. This allows
    PackNet to isolate parameters for each task.

    Caution should be taken when optimizers with momentum are used, since they
    can cause parameters to be modified even when no gradient exists.

    PackNet has internal state that changes its behaviour and this class is
    responsible for ensuring that no invalid state transitions occur. When
    an invalid state transition occurs a `StateError` is thrown.
    """

    class State(Enum):
        """PackNet requires a procedure to be followed and we model this with
        the following states.
        """

        TRAINING = 0
        """The PackNet module is training all of the unfrozen capacity"""
        POST_PRUNE = 1
        """The PackNet module is training only on the unpruned parameters that
        will be frozen next"""
        EVAL = 2
        """Activate a task-specific subset and mask the remaining parameters.
        This state freezes all parameters."""

    class StateError(RuntimeError):
        """An invalid state transition occured"""

    def __init__(self) -> None:
        super().__init__()
        init_state = self.State.TRAINING
        self._state: Tensor
        """The current state of the PackNet"""
        self._active_task: Tensor
        """The id of the task that is currently active"""
        self._task_count: Tensor
        """The number of tasks that have been trained"""
        self.register_buffer("_state", torch.tensor(init_state.value, dtype=torch.int))
        self.register_buffer("_active_task", torch.tensor(0, dtype=torch.int))
        self.register_buffer("_task_count", torch.tensor(0, dtype=torch.int))

    def prune(self, prune_proportion: float):
        """Prune a proportion of the unfrozen parameters from the module.

        The pruned parameters will be reused for the next task, while the
        remaining will be fine-tuned further on the current task, in the
        post-pruning phase.

        Prune may only be called when PackNet is in the `State.TRAINING` state.
        Prune will move PackNet to the `State.POST_PRUNE` state.

        :param prune_proportion: A proportion of the prunable parameters to
            prune. Must be between 0 and 1.
        """
        if not 0 <= prune_proportion <= 1:
            raise ValueError(
                f"`prune_proportion` must be between 0 and 1, got "
                f"{prune_proportion}"
            )
        self._state_guard(
            self.prune.__name__, [self.State.TRAINING], self.State.POST_PRUNE
        )
        self._prune(prune_proportion)

    def freeze_pruned(self):
        """
        Freeze the pruned parameters, commiting them to become immutable.

        This prevents subsequent tasks from affecting any parameters associated
        with this task.

        This function can only be called when PackNet is in the
        `State.POST_PRUNE` state. It will then move PackNet to the `State.EVAL`
        state.
        """
        self._state_guard(
            self.freeze_pruned.__name__, [self.State.POST_PRUNE], self.State.EVAL
        )
        self._task_count += 1
        self._freeze_pruned()

    def activate_task(self, task_id: int):
        """Activates a task-specific subset of PackNet.

        When `task_id` is the active task, the active task can be trained using
        the remaining capacity. Otherwise, all parameters are frozen and
        the active task cannot be trained.

        This function can only be called when PackNet is in the `State.EVAL`,
        `State.TRAINING`, or `State.POST_PRUNE` state. Moving PackNet to the
        `State.EVAL` state if the `task_id` is not the active task. Otherwise,
        PackNet remains in the same state.

        :param task_id: The task to activate. Must be between 0 and the number
            of tasks seen so far.
        """
        if not (0 <= task_id <= self.task_count):
            raise ValueError(
                f"`task_id` must be between 0 and {self.task_count}, " f"got {task_id}"
            )
        if task_id != self.task_count:
            next_state = self.State.EVAL
        elif self.state == self.State.POST_PRUNE:
            next_state = self.State.POST_PRUNE
        else:
            next_state = self.State.TRAINING
        # Stop if the task is already active
        if task_id == self.active_task and self.state == next_state:
            return

        self._state_guard(
            self.activate_task.__name__,
            [self.State.EVAL, self.State.TRAINING],
            next_state,
        )
        self._activate_task(task_id)
        self._active_task.fill_(task_id)

    @abstractmethod
    def _prune(self, prune_proportion: float) -> None:
        """Implementation of `prune` once the state has been checked"""

    @abstractmethod
    def _freeze_pruned(self) -> None:
        """Implementation of `freeze_pruned` once the state has been checked"""

    @abstractmethod
    def _activate_task(self, task_id: int) -> None:
        """Implementation of `activate_task` once the state has been checked"""

    @property
    def active_task(self) -> int:
        """Returns the id of the task that is currently active.

        :return: The id of the task that is currently active.
        """
        return int(self._active_task.item())

    @property
    def task_count(self) -> int:
        """Counts the number of task-specific subsets in PackNet.

        :return: The number of task-specific subsets
        """
        return int(self._task_count.item())

    @property
    def state(self) -> State:
        return self.State(self._state.item())

    def _state_guard(
        self,
        func_name: str,
        previous: t.Sequence[State],
        next: State,
    ):
        """Ensure that the state is in the correct state and transition to the
        next correct state.
        """
        if self.state not in previous:
            previous_str = ", ".join([str(x) for x in previous])
            raise self.StateError(
                f"Calling `{func_name}` is only valid for `{previous_str}` "
                f"instead PackNet was in the `{self.state}` state"
            )
        self._state.fill_(next.value)


[docs]class WeightAndBiasPackNetModule(PackNetModule): """A PackNet module that has a weight and bias. This can be used to wrap many PyTorch modules such as `nn.Linear`, `nn.Conv2d` and `nn.ConvTranspose2d` """
[docs] def __init__(self, wrappee: nn.Module) -> None: super().__init__() self.wrappee: nn.Module = wrappee # The following attributes are used to check that the wrappee is # compatible if not hasattr(wrappee, "weight") or not isinstance( wrappee.weight, nn.Parameter ): raise ValueError(f"weight must be defined in {wrappee}") self.has_bias = hasattr(wrappee, "bias") and isinstance( wrappee.weight, nn.Parameter ) self.PRUNED_CODE: Tensor """Value used to code for a pruned weight, during the post-pruning phase""" self.register_buffer("PRUNED_CODE", torch.tensor(255, dtype=torch.int)) self.task_index: Tensor """Tracks which task each weight belongs to. Of the same shape as the weight tensor.""" self.register_buffer( "task_index", torch.ones_like(self.wrappee.weight).byte() * self._task_count, ) self.visible_mask: Tensor """Mask of weights that are visible. Can be computed from `task_index`""" self.register_buffer( "visible_mask", torch.ones_like(self.task_index, dtype=torch.bool) ) self.unfrozen_mask: Tensor """Mask of weights that are mutable. Can be computed from `task_index`""" self.register_buffer( "unfrozen_mask", torch.ones_like(self.task_index, dtype=torch.bool) ) wrappee.weight.register_hook(self._remove_gradient_hook)
@property def pruned_mask(self) -> Tensor: """Return a mask of weights that have been pruned""" return self.task_index.eq(self.PRUNED_CODE) def _prune(self, prune_proportion: float): ranked = self._rank_prunable() prune_count = int(len(ranked) * prune_proportion) self._prune_weights(ranked[:prune_count]) self.unfrozen_mask = self.task_index.eq(self._task_count) def available_weights(self) -> Tensor: return self.visible_mask * self.wrappee.weight def _activate_task(self, task_id: int): self.visible_mask.zero_() self.unfrozen_mask.zero_() self.visible_mask = self.task_index.less_equal(task_id) if task_id == self.task_count: self.unfrozen_mask = self.task_index.eq(task_id) def _freeze_pruned(self): self.task_index[self.pruned_mask] = self.task_count self.unfrozen_mask.zero_() if self.has_bias: self.wrappee.bias.requires_grad = False @property def device(self) -> torch.device: return self.weight.device def _remove_gradient_hook(self, grad: Tensor) -> Tensor: """Gradients that are frozen are zeroed out. Preventing them from being modifed after they have been frozen.""" return grad * self.unfrozen_mask def _rank_prunable(self) -> Tensor: """ Returns a 1D tensor of the weights ranked based on their absolute value. Sorted to be in ascending order. """ # "We use the simple heuristic to quantify the importance of the # weights using their absolute value." (Han et al., 2017) # Han, S., Pool, J., Narang, S., Mao, H., Gong, E., Tang, S., Elsen, E., # Vajda, P., Paluri, M., Tran, J., Catanzaro, B., & Dally, W. J. (2017). # DSD: Dense-Sparse-Dense Training for Deep Neural Networks. # ArXiv:1607.04381 [Cs]. http://arxiv.org/abs/1607.04381 importance = self.wrappee.weight.abs() un_prunable = ~self.unfrozen_mask # Mark un-prunable weights using -1.0 so they can be cutout after sort importance[un_prunable] = -1.0 rank = torch.argsort(importance.flatten()) # Cut out un-prunable weights return rank[un_prunable.count_nonzero() :] def _prune_weights(self, indices: Tensor): """Given a list of indices, prune the weights at those indices. Pruning simply marks the weight as pruned in the `task_index` and makes the weight invisible in the `visible_mask`. :param indices: A 1D tensor of indices to prune """ self.task_index.flatten()[indices] = self.PRUNED_CODE.item() self.visible_mask.flatten()[indices] = False
class _PnLinear(WeightAndBiasPackNetModule): """A decorator for `nn.Linear` module making it PackNet compatible.""" def __init__(self, wrappee: nn.Linear) -> None: self.wrappee: nn.Linear super().__init__(wrappee) def forward(self, input: Tensor) -> Tensor: return F.linear(input, self.available_weights(), self.wrappee.bias) class _PnConvNd(WeightAndBiasPackNetModule): """A decorator for `nn.Linear` module making it PackNet compatible.""" def __init__(self, wrappee: Union[nn.Conv1d, nn.Conv2d, nn.Conv3d]) -> None: super().__init__(wrappee) def forward(self, input: Tensor) -> Tensor: return self.wrappee._conv_forward( input, self.available_weights(), self.wrappee.bias ) class _PnConvTransposedNd(WeightAndBiasPackNetModule): def __init__( self, wrappee: Union[nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d] ) -> None: super().__init__(wrappee) def forward( self, input: Tensor, output_size: t.Optional[t.List[int]] = None ) -> Tensor: w = self.wrappee if w.padding_mode != "zeros": raise ValueError( "Only `zeros` padding mode is supported for ConvTranspose2d" ) assert isinstance(w.padding, tuple) # One cannot replace List by Tuple or Sequence in "_output_padding" # because torch.Script does not support `Sequence[T]` or # `Tuple[T, ...]`. output_padding = w._output_padding( input, output_size, w.stride, w.padding, w.kernel_size, w.dilation ) # type: ignore[arg-type] return F.conv_transpose2d( input, self.available_weights(), w.bias, w.stride, w.padding, output_padding, w.groups, w.dilation, )
[docs]class PackNetModel(PackNetModule, MultiTaskModule): """ PackNet implements the PackNet algorithm for parameter isolation. It is designed to automatically upgrade most models to support PackNet. But because of the nature of the strategy, it is not possible to use it with every model or PyTorch module. Furthermore, PackNet not everything has been implemented yet. Here are some basic guidelines: - Stateless modules like :class:`torch.nn.ReLU`, :class:`torch.nn.Flatten`, or `torch.nn.Dropout` should work fine. - Many normalization layers currently do not work. - Supports: :class:`nn.Linear`, :class:`nn.Conv1d`, :class:`nn.Conv2d`, :class:`nn.Conv3d`, :class:`nn.ConvTranspose1d`, :class:`nn.ConvTranspose2d`, :class:`nn.ConvTranspose3d` - If you want to use a custom module with state or parameters, ensure it implements :class:`PackNetModule`. Mallya, A., & Lazebnik, S. (2018). PackNet: Adding Multiple Tasks to a Single Network by Iterative Pruning. 2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition, 7765-7773. https://doi.org/10.1109/CVPR.2018.00810 """ @staticmethod def wrap(wrappee: nn.Module): """Upgrade a PyTorch module and all of its submodules to be PackNet compatible. This is a recursive function that will wrap all submodules in a PackNet compatible module. :param wrappee: The module to wrap :raises ValueError: If the module is not supported :return: A PackNet compatible module """ # Weight norm is not supported if hasattr(wrappee, "weight_g") and hasattr(wrappee, "weight_v"): raise ValueError("PackNet does not support weight norm") # Other norms are not supported if hasattr(wrappee, "running_mean") or hasattr(wrappee, "running_var"): raise ValueError( "The PackNet implementation does not yet support norms " f"{wrappee.__class__.__name__}" ) # Recursive cases if isinstance(wrappee, PackNetModule): return wrappee elif isinstance(wrappee, nn.Linear): return _PnLinear(wrappee) elif isinstance(wrappee, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): return _PnConvNd(wrappee) elif isinstance( wrappee, (nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d) ): return _PnConvTransposedNd(wrappee) elif isinstance(wrappee, nn.Sequential): # Wrap each submodule for i, x in enumerate(wrappee): wrappee[i] = PackNetModel.wrap(x) return wrappee # If the module has parameters and has not been wrapped yet, then it is # not supported if len(list(wrappee.parameters(recurse=False))) != 0: raise ValueError( f"PackNet does not support the module {wrappee.__class__.__name__}" ) for submodule_name, submodule in wrappee.named_children(): setattr(wrappee, submodule_name, PackNetModel.wrap(submodule)) return wrappee
[docs] def __init__(self, wrappee: nn.Module) -> None: """Wrap a PyTorch module to make it PackNet compatible. :param wrappee: The module to wrap """ super().__init__() self.wrappee: nn.Module = PackNetModel.wrap(wrappee)
def _pn_apply(self, func: t.Callable[["PackNetModel"], None]): """Apply a function to all child PackNetModules :param func: The function to apply """ @torch.no_grad() def __pn_apply(module): # Apply function to all child PackNetModule but not other # parent PackNet modules if isinstance(module, PackNetModule) and not isinstance( module, PackNetModel ): func(module) self.apply(__pn_apply) def _prune(self, to_prune_proportion: float): """Call `prune` on all child PackNetModules :param to_prune_proportion: The proportion of parameters to prune in each child PackNetModule """ self._pn_apply(lambda x: x.prune(to_prune_proportion)) def _freeze_pruned(self): """Call `freeze_pruned` on all child PackNetModules""" self._pn_apply(lambda x: x.freeze_pruned()) def _activate_task(self, task_id: int): """Call `activate_task` on all child PackNetModules :param task_id: The task to activate """ self._pn_apply(lambda x: x.activate_task(task_id)) def forward(self, input: Tensor, task_id: Tensor) -> Tensor: task_id_ = task_id[0] assert task_id.eq(task_id_).all(), "All task ids must be the same" self.activate_task(min(task_id_, self.task_count)) return self.wrappee.forward(input)
class PackNetPlugin(BaseSGDPlugin): """A plugin calling PackNet's pruning and freezing procedures at the appropriate times. This plugin can only be used with `PackNet` models. """ def __init__( self, post_prune_epochs: int, prune_proportion: t.Union[float, t.Callable[[int], float], t.List[float]] = 0.5, ): """The PackNetPlugin calls PackNet's pruning and freezing procedures at the appropriate times. :param post_prune_epochs: The number of epochs to finetune the model after pruning the parameters. Must be less than the number of training epochs. :param prune_proportion: The proportion of parameters to prune during each task. Can be a float, a list of floats, or a function that takes the task id and returns a float. Each value must be between 0 and 1. """ super().__init__() self.post_prune_epochs = post_prune_epochs self.total_epochs: Union[int, None] = None self.prune_proportion: t.Callable[[int], float] = prune_proportion if isinstance(prune_proportion, float): assert 0 <= self.prune_proportion <= 1, ( f"`prune_proportion` must be between 0 and 1, got " f"{self.prune_proportion}" ) self.prune_proportion = lambda _: prune_proportion elif isinstance(prune_proportion, list): assert all(0 <= x <= 1 for x in prune_proportion), ( "all values in `prune_proportion` must be between 0 and 1," f" got {prune_proportion}" ) self.prune_proportion = lambda i: prune_proportion[i] else: self.prune_proportion = prune_proportion def before_training(self, strategy: "BaseSGDTemplate", *args, **kwargs): assert isinstance( strategy, BaseSGDTemplate ), "Strategy must be a `BaseSGDTemplate` or derived class." if not hasattr(strategy, "train_epochs"): raise ValueError( "`PackNetPlugin` can only be used with a `BaseStrategy` that " "has a `train_epochs` attribute." ) # Check the scenario has enough epochs for the post-pruning phase self.total_epochs = strategy.train_epochs if self.post_prune_epochs >= self.total_epochs: raise ValueError( f"`PackNetPlugin` can only be used with a `BaseStrategy`" "that has a `train_epochs` attribute greater than " f"{self.post_prune_epochs}. " f"Strategy has only {self.total_epochs} training epochs." ) def before_training_exp(self, strategy: "BaseSGDTemplate", *args, **kwargs): # Reset the optimizer to prevent momentum from affecting the pruned # parameters strategy.optimizer = strategy.optimizer.__class__( strategy.model.parameters(), **strategy.optimizer.defaults ) def before_training_epoch(self, strategy: "BaseSGDTemplate", *args, **kwargs): """When the initial training phase is over, prune the model and transition to the post-pruning phase. """ epoch = strategy.clock.train_exp_epochs model = self._get_model(strategy) if epoch == (self.total_epochs - self.post_prune_epochs): model.prune(self.prune_proportion(strategy.clock.train_exp_counter)) def after_training_exp(self, strategy: "Template", *args, **kwargs): """After each experience, commit the model so that the next experience does not interfere with the previous one. """ model = self._get_model(strategy) model.freeze_pruned() def _get_model(self, strategy: "BaseSGDTemplate") -> PackNetModule: """Get the model from the strategy.""" model = strategy.model if not isinstance(strategy.model, PackNetModule): raise ValueError( f"`PackNetPlugin` can only be used with a `PackNet` model, " f"got {type(strategy.model)}. Try wrapping your model with " "`PackNet` before using this plugin." ) return model
[docs]def packnet_simple_mlp( num_classes=10, input_size=28 * 28, hidden_size=512, hidden_layers=1, drop_rate=0.5, ) -> PackNetModel: """ Convenience function for creating a PackNet compatible :class:`SimpleMLP` model. :param num_classes: output size :param input_size: input size :param hidden_size: hidden layer size :param hidden_layers: number of hidden layers :param drop_rate: dropout rate. 0 to disable :return: A PackNet compatible model """ return PackNetModel( SimpleMLP(num_classes, input_size, hidden_size, hidden_layers, drop_rate) )