Source code for avalanche.training.plugins.mir

import copy
from typing import TYPE_CHECKING
import torch
from avalanche.benchmarks.utils import concat_datasets
from avalanche.models.utils import avalanche_forward
from avalanche.training.plugins.strategy_plugin import SupervisedPlugin
from avalanche.training.storage_policy import ClassBalancedBuffer

if TYPE_CHECKING:
    from avalanche.training.templates import SupervisedTemplate


def cycle(loader):
    while True:
        for batch in loader:
            yield batch


def update_temp(model, grad, lr):
    model_copy = copy.deepcopy(model)
    for g, p in zip(grad, model_copy.parameters()):
        if g is not None:
            p.data = p.data - lr * g
    return model_copy


[docs]class MIRPlugin(SupervisedPlugin): """ Maximally Interfered Retrieval plugin, Implements the strategy defined in "Online Continual Learning with Maximally Interfered Retrieval" https://arxiv.org/abs/1908.04742 This strategy has been designed and tested in the Online Setting (OnlineCLScenario). However, it can also be used in non-online scenarios """
[docs] def __init__( self, batch_size_mem: int, mem_size: int = 200, subsample: int = 200, ): """ mem_size: int : Fixed memory size subsample: int : Size of the sample from which to look for highest interfering exemplars batch_size_mem: int : Size of the batch sampled from the bigger subsample batch """ super().__init__() self.mem_size = mem_size self.subsample = subsample self.batch_size_mem = batch_size_mem self.storage_policy = ClassBalancedBuffer( max_size=self.mem_size, adaptive_size=True ) self.replay_loader = None
def before_backward(self, strategy, **kwargs): if self.replay_loader is None: return samples_x, samples_y, samples_tid = next(self.replay_loader) samples_x, samples_y, samples_tid = ( samples_x.to(strategy.device), samples_y.to(strategy.device), samples_tid.to(strategy.device), ) # Perform the temporary update with current data grad = torch.autograd.grad( strategy.loss, strategy.model.parameters(), retain_graph=True, allow_unused=True, ) model_updated = update_temp( strategy.model, grad, strategy.optimizer.param_groups[0]["lr"] ) # Selection of the most interfering samples, no grad required # plus we put the model in eval mode so that the additional # forward pass don't influence the batch norm statistics # strategy.model.eval() # model_updated.eval() with torch.no_grad(): _old_red_strategy = strategy._criterion.reduction strategy._criterion.reduction = "none" old_output = avalanche_forward(strategy.model, samples_x, samples_tid) old_loss = strategy._criterion(old_output, samples_y) new_output = avalanche_forward(model_updated, samples_x, samples_tid) new_loss = strategy._criterion(new_output, samples_y) loss_diff = new_loss - old_loss chosen_samples_indexes = torch.argsort(loss_diff)[ len(samples_x) - self.batch_size_mem : ] strategy._criterion.reduction = _old_red_strategy # strategy.model.train() # Choose the samples and add their loss to the current loss chosen_samples_x, chosen_samples_y, chosen_samples_tid = ( samples_x[chosen_samples_indexes], samples_y[chosen_samples_indexes], samples_tid[chosen_samples_indexes], ) replay_output = avalanche_forward( strategy.model, chosen_samples_x, chosen_samples_tid ) replay_loss = strategy._criterion(replay_output, chosen_samples_y) strategy.loss += replay_loss def after_training_exp(self, strategy: "SupervisedTemplate", **kwargs): self.storage_policy.update(strategy, **kwargs) # Exclude classes that were in the last batch buffer = concat_datasets( [ self.storage_policy.buffer_groups[key].buffer for key, _ in self.storage_policy.buffer_groups.items() if int(key) not in torch.unique(strategy.mb_y).cpu() ] ) if len(buffer) > self.batch_size_mem: self.replay_loader = cycle( torch.utils.data.DataLoader( buffer, batch_size=self.subsample, shuffle=True, ) ) else: self.replay_loader = None
__all__ = ["MIRPlugin"]