Source code for avalanche.training.supervised.der

from collections import defaultdict
from typing import (
    Callable,
    Dict,
    List,
    Optional,
    Sequence,
    Set,
    SupportsInt,
    Union,
)

import torch
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss, Module
from torch.optim import Optimizer

from avalanche.benchmarks.utils import make_avalanche_dataset
from avalanche.benchmarks.utils.data import AvalancheDataset
from avalanche.benchmarks.utils.data_attribute import TensorDataAttribute
from avalanche.benchmarks.utils.flat_data import FlatData
from avalanche.training.templates.strategy_mixin_protocol import CriterionType
from avalanche.training.utils import cycle
from avalanche.core import SupervisedPlugin
from avalanche.training.plugins.evaluation import (
    EvaluationPlugin,
    default_evaluator,
)
from avalanche.training.storage_policy import (
    BalancedExemplarsBuffer,
    ReservoirSamplingBuffer,
)
from avalanche.training.templates import SupervisedTemplate


@torch.no_grad()
def compute_dataset_logits(dataset, model, batch_size, device, num_workers=0):
    was_training = model.training
    model.eval()

    logits = []
    loader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
    )
    for x, _, _ in loader:
        x = x.to(device)
        out = model(x)
        out = out.detach().cpu()

        for row in out:
            logits.append(torch.clone(row))

    if was_training:
        model.train()

    return logits


class ClassBalancedBufferWithLogits(BalancedExemplarsBuffer):
    """
    ClassBalancedBuffer that also stores the logits
    """

    def __init__(
        self,
        max_size: int,
        adaptive_size: bool = True,
        total_num_classes: Optional[int] = None,
    ):
        """Init.

        :param max_size: The max capacity of the replay memory.
        :param adaptive_size: True if mem_size is divided equally over all
                            observed experiences (keys in replay_mem).
        :param total_num_classes: If adaptive size is False, the fixed number
                                  of classes to divide capacity over.
        :param transforms: transformation to be applied to the buffer
        """
        if not adaptive_size:
            assert (
                total_num_classes is not None and total_num_classes > 0
            ), "When fixed exp mem size, total_num_classes should be > 0."

        super().__init__(max_size, adaptive_size, total_num_classes)
        self.adaptive_size = adaptive_size
        self.total_num_classes = total_num_classes
        self.seen_classes: Set[int] = set()

    def update(self, strategy: "SupervisedTemplate", **kwargs):
        assert strategy.experience is not None
        new_data: AvalancheDataset = strategy.experience.dataset

        logits = compute_dataset_logits(
            new_data.eval(),
            strategy.model,
            strategy.train_mb_size,
            strategy.device,
            num_workers=kwargs.get("num_workers", 0),
        )
        new_data_with_logits = make_avalanche_dataset(
            new_data,
            data_attributes=[
                TensorDataAttribute(
                    FlatData([logits], discard_elements_not_in_indices=True),
                    name="logits",
                    use_in_getitem=True,
                )
            ],
        )
        # Get sample idxs per class
        cl_idxs: Dict[int, List[int]] = defaultdict(list)
        targets: Sequence[SupportsInt] = getattr(new_data, "targets")
        for idx, target in enumerate(targets):
            # Conversion to int may fix issues when target
            # is a single-element torch.tensor
            target = int(target)
            cl_idxs[target].append(idx)

        # Make AvalancheSubset per class
        cl_datasets = {}
        for c, c_idxs in cl_idxs.items():
            subset = new_data_with_logits.subset(c_idxs)
            cl_datasets[c] = subset
        # Update seen classes
        self.seen_classes.update(cl_datasets.keys())

        # associate lengths to classes
        lens = self.get_group_lengths(len(self.seen_classes))
        class_to_len = {}
        for class_id, ll in zip(self.seen_classes, lens):
            class_to_len[class_id] = ll

        # update buffers with new data
        for class_id, new_data_c in cl_datasets.items():
            ll = class_to_len[class_id]
            if class_id in self.buffer_groups:
                old_buffer_c = self.buffer_groups[class_id]
                # Here it uses underlying dataset
                old_buffer_c.update_from_dataset(new_data_c)
                old_buffer_c.resize(strategy, ll)
            else:
                new_buffer = ReservoirSamplingBuffer(ll)
                new_buffer.update_from_dataset(new_data_c)
                self.buffer_groups[class_id] = new_buffer

        # resize buffers
        for class_id, class_buf in self.buffer_groups.items():
            self.buffer_groups[class_id].resize(strategy, class_to_len[class_id])


[docs]class DER(SupervisedTemplate): """ Implements the DER and the DER++ Strategy, from the "Dark Experience For General Continual Learning" paper, Buzzega et. al, https://arxiv.org/abs/2004.07211 """
[docs] def __init__( self, *, model: Module, optimizer: Optimizer, criterion: CriterionType = CrossEntropyLoss(), mem_size: int = 200, batch_size_mem: Optional[int] = None, alpha: float = 0.1, beta: float = 0.5, train_mb_size: int = 1, train_epochs: int = 1, eval_mb_size: Optional[int] = 1, device: Union[str, torch.device] = "cpu", plugins: Optional[List[SupervisedPlugin]] = None, evaluator: Union[ EvaluationPlugin, Callable[[], EvaluationPlugin] ] = default_evaluator, eval_every=-1, peval_mode="epoch", **kwargs ): """ :param model: PyTorch model. :param optimizer: PyTorch optimizer. :param criterion: loss function. :param mem_size: int : Fixed memory size :param batch_size_mem: int : Size of the batch sampled from the buffer :param alpha: float : Hyperparameter weighting the MSE loss :param beta: float : Hyperparameter weighting the CE loss, when more than 0, DER++ is used instead of DER :param transforms: Callable: Transformations to use for both the dataset and the buffer data, on top of already existing test transformations. If any supplementary transformations are applied to the input data, it will be overwritten by this argument :param train_mb_size: mini-batch size for training. :param train_passes: number of training passes. :param eval_mb_size: mini-batch size for eval. :param device: PyTorch device where the model will be allocated. :param plugins: (optional) list of StrategyPlugins. :param evaluator: (optional) instance of EvaluationPlugin for logging and metric computations. None to remove logging. :param eval_every: the frequency of the calls to `eval` inside the training loop. -1 disables the evaluation. 0 means `eval` is called only at the end of the learning experience. Values >0 mean that `eval` is called every `eval_every` experiences and at the end of the learning experience. :param peval_mode: one of {'experience', 'iteration'}. Decides whether the periodic evaluation during training should execute every `eval_every` experience or iterations (Default='experience'). """ super().__init__( model=model, optimizer=optimizer, criterion=criterion, train_mb_size=train_mb_size, train_epochs=train_epochs, eval_mb_size=eval_mb_size, device=device, plugins=plugins, evaluator=evaluator, eval_every=eval_every, peval_mode=peval_mode, **kwargs ) if batch_size_mem is None: self.batch_size_mem = train_mb_size else: self.batch_size_mem = batch_size_mem self.mem_size = mem_size self.storage_policy = ClassBalancedBufferWithLogits( self.mem_size, adaptive_size=True ) self.replay_loader = None self.alpha = alpha self.beta = beta
def _before_training_exp(self, **kwargs): buffer = self.storage_policy.buffer if len(buffer) >= self.batch_size_mem: self.replay_loader = cycle( torch.utils.data.DataLoader( buffer, batch_size=self.batch_size_mem, shuffle=True, drop_last=True, num_workers=kwargs.get("num_workers", 0), ) ) else: self.replay_loader = None super()._before_training_exp(**kwargs) def _after_training_exp(self, **kwargs): self.replay_loader = None # Allow DER to be checkpointed self.storage_policy.update(self, **kwargs) super()._after_training_exp(**kwargs) def _before_forward(self, **kwargs): super()._before_forward(**kwargs) if self.replay_loader is None: return None batch_x, batch_y, batch_tid, batch_logits = next(self.replay_loader) batch_x, batch_y, batch_tid, batch_logits = ( batch_x.to(self.device), batch_y.to(self.device), batch_tid.to(self.device), batch_logits.to(self.device), ) self.mbatch[0] = torch.cat((batch_x, self.mbatch[0])) self.mbatch[1] = torch.cat((batch_y, self.mbatch[1])) self.mbatch[2] = torch.cat((batch_tid, self.mbatch[2])) self.batch_logits = batch_logits def training_epoch(self, **kwargs): """Training epoch. :param kwargs: :return: """ for self.mbatch in self.dataloader: if self._stop_training: break self._unpack_minibatch() self._before_training_iteration(**kwargs) self.optimizer.zero_grad() self.loss = self._make_empty_loss() # Forward self._before_forward(**kwargs) self.mb_output = self.forward() self._after_forward(**kwargs) if self.replay_loader is not None: # DER Loss computation self.loss += F.cross_entropy( self.mb_output[self.batch_size_mem :], self.mb_y[self.batch_size_mem :], ) self.loss += self.alpha * F.mse_loss( self.mb_output[: self.batch_size_mem], self.batch_logits, ) self.loss += self.beta * F.cross_entropy( self.mb_output[: self.batch_size_mem], self.mb_y[: self.batch_size_mem], ) # They are a few difference compared to the autors impl: # - Joint forward pass vs. 3 forward passes # - One replay batch vs two replay batches # - Logits are stored from the non-transformed sample # after training on task vs instantly on transformed sample else: self.loss += self.criterion() self._before_backward(**kwargs) self.backward() self._after_backward(**kwargs) # Optimization step self._before_update(**kwargs) self.optimizer_step() self._after_update(**kwargs) self._after_training_iteration(**kwargs)
__all__ = ["DER"]