import copy
from typing import TYPE_CHECKING
import torch
from avalanche.benchmarks.utils import concat_datasets
from avalanche.models.utils import avalanche_forward
from import SupervisedPlugin
from import ClassBalancedBuffer

    from import SupervisedTemplate

def cycle(loader):
    while True:
        for batch in loader:
            yield batch

def update_temp(model, grad, lr):
    model_copy = copy.deepcopy(model)
    for g, p in zip(grad, model_copy.parameters()):
        if g is not None:
   = - lr * g
    return model_copy

[docs]class MIRPlugin(SupervisedPlugin): """ Maximally Interfered Retrieval plugin, Implements the strategy defined in "Online Continual Learning with Maximally Interfered Retrieval" This strategy has been designed and tested in the Online Setting (OnlineCLScenario). However, it can also be used in non-online scenarios """
[docs] def __init__( self, batch_size_mem: int, mem_size: int = 200, subsample: int = 200, ): """ mem_size: int : Fixed memory size subsample: int : Size of the sample from which to look for highest interfering exemplars batch_size_mem: int : Size of the batch sampled from the bigger subsample batch """ super().__init__() self.mem_size = mem_size self.subsample = subsample self.batch_size_mem = batch_size_mem self.storage_policy = ClassBalancedBuffer( max_size=self.mem_size, adaptive_size=True ) self.replay_loader = None
def before_backward(self, strategy, **kwargs): if self.replay_loader is None: return samples_x, samples_y, samples_tid = next(self.replay_loader) samples_x, samples_y, samples_tid = (,,, ) # Perform the temporary update with current data grad = torch.autograd.grad( strategy.loss, strategy.model.parameters(), retain_graph=True, allow_unused=True, ) model_updated = update_temp( strategy.model, grad, strategy.optimizer.param_groups[0]["lr"] ) # Selection of the most interfering samples, no grad required # plus we put the model in eval mode so that the additional # forward pass don't influence the batch norm statistics # strategy.model.eval() # model_updated.eval() with torch.no_grad(): _old_red_strategy = strategy._criterion.reduction strategy._criterion.reduction = "none" old_output = avalanche_forward(strategy.model, samples_x, samples_tid) old_loss = strategy._criterion(old_output, samples_y) new_output = avalanche_forward(model_updated, samples_x, samples_tid) new_loss = strategy._criterion(new_output, samples_y) loss_diff = new_loss - old_loss chosen_samples_indexes = torch.argsort(loss_diff)[ len(samples_x) - self.batch_size_mem : ] strategy._criterion.reduction = _old_red_strategy # strategy.model.train() # Choose the samples and add their loss to the current loss chosen_samples_x, chosen_samples_y, chosen_samples_tid = ( samples_x[chosen_samples_indexes], samples_y[chosen_samples_indexes], samples_tid[chosen_samples_indexes], ) replay_output = avalanche_forward( strategy.model, chosen_samples_x, chosen_samples_tid ) replay_loss = strategy._criterion(replay_output, chosen_samples_y) strategy.loss += replay_loss def after_training_exp(self, strategy: "SupervisedTemplate", **kwargs): self.storage_policy.update(strategy, **kwargs) # Exclude classes that were in the last batch buffer = concat_datasets( [ self.storage_policy.buffer_groups[key].buffer for key, _ in self.storage_policy.buffer_groups.items() if int(key) not in torch.unique(strategy.mb_y).cpu() ] ) if len(buffer) > self.batch_size_mem: self.replay_loader = cycle( buffer, batch_size=self.subsample, shuffle=True, ) ) else: self.replay_loader = None
__all__ = ["MIRPlugin"]