Source code for

import warnings
import random
from typing import Any, Iterator, List, Optional
import torch
from torch import Tensor

from avalanche.benchmarks.utils import _make_taskaware_classification_dataset
from import AvalancheDataset
from avalanche.benchmarks.utils.data_loader import (
from avalanche.models import avalanche_forward
from import SupervisedPlugin

[docs]class AGEMPlugin(SupervisedPlugin): """Average Gradient Episodic Memory Plugin. AGEM projects the gradient on the current minibatch by using an external episodic memory of patterns from previous experiences. If the dot product between the current gradient and the (average) gradient of a randomly sampled set of memory examples is negative, the gradient is projected. This plugin does not use task identities. """
[docs] def __init__(self, patterns_per_experience: int, sample_size: int): """ :param patterns_per_experience: number of patterns per experience in the memory. :param sample_size: number of patterns in memory sample when computing reference gradient. """ super().__init__() self.patterns_per_experience = int(patterns_per_experience) self.sample_size = int(sample_size) # One AvalancheDataset for each experience self.buffers: List[AvalancheDataset] = [] self.buffer_dataloader: Optional[GroupBalancedInfiniteDataLoader] = None # Placeholder iterator to avoid typing issues self.buffer_dliter: Iterator[Any] = iter([]) # Placeholder Tensor to avoid typing issues self.reference_gradients: Tensor = torch.empty(0)
def before_training_iteration(self, strategy, **kwargs): """ Compute reference gradient on memory sample. """ if len(self.buffers) > 0: strategy.model.train() strategy.optimizer.zero_grad() mb = self.sample_from_memory() xref, yref, tid = mb[0], mb[1], mb[-1] xref, yref =, out = avalanche_forward(strategy.model, xref, tid) loss = strategy._criterion(out, yref) loss.backward() # gradient can be None for some head on multi-headed models reference_gradients_list = [ ( p.grad.view(-1) if p.grad is not None else torch.zeros(p.numel(), device=strategy.device) ) for n, p in strategy.model.named_parameters() ] self.reference_gradients = strategy.optimizer.zero_grad() @torch.no_grad() def after_backward(self, strategy, **kwargs): """ Project gradient based on reference gradients """ if len(self.buffers) > 0: current_gradients_list = [ ( p.grad.view(-1) if p.grad is not None else torch.zeros(p.numel(), device=strategy.device) ) for n, p in strategy.model.named_parameters() ] current_gradients = assert ( current_gradients.shape == self.reference_gradients.shape ), "Different model parameters in AGEM projection" dotg =, self.reference_gradients) if dotg < 0: alpha2 = dotg / self.reference_gradients, self.reference_gradients ) grad_proj = current_gradients - self.reference_gradients * alpha2 count = 0 for n, p in strategy.model.named_parameters(): n_param = p.numel() if p.grad is not None: p.grad.copy_(grad_proj[count : count + n_param].view_as(p)) count += n_param def after_training_exp(self, strategy, **kwargs): """Update replay memory with patterns from current experience.""" self.update_memory(strategy.experience.dataset, **kwargs) def sample_from_memory(self): """ Sample a minibatch from memory. Return a tuple of patterns (tensor), targets (tensor). """ return next(self.buffer_dliter) @torch.no_grad() def update_memory(self, dataset, num_workers=0, **kwargs): """ Update replay memory with patterns from current experience. """ if num_workers > 0: warnings.warn( "Num workers > 0 is known to cause heavy" "slowdowns in AGEM." ) removed_els = len(dataset) - self.patterns_per_experience if removed_els > 0: indices = list(range(len(dataset))) random.shuffle(indices) dataset = dataset.subset(indices[: self.patterns_per_experience]) self.buffers.append(dataset) persistent_workers = num_workers > 0 self.buffer_dataloader = GroupBalancedInfiniteDataLoader( self.buffers, batch_size=(self.sample_size // len(self.buffers)), num_workers=num_workers, pin_memory=False, persistent_workers=persistent_workers, ) self.buffer_dliter = iter(self.buffer_dataloader)