Source code for

import numpy as np
import quadprog
import torch
from import DataLoader

from avalanche.models import avalanche_forward
from import SupervisedPlugin

[docs]class GEMPlugin(SupervisedPlugin): """ Gradient Episodic Memory Plugin. GEM projects the gradient on the current minibatch by using an external episodic memory of patterns from previous experiences. The gradient on the current minibatch is projected so that the dot product with all the reference gradients of previous tasks remains positive. This plugin does not use task identities. """
[docs] def __init__(self, patterns_per_experience: int, memory_strength: float): """ :param patterns_per_experience: number of patterns per experience in the memory. :param memory_strength: offset to add to the projection direction in order to favour backward transfer (gamma in original paper). """ super().__init__() self.patterns_per_experience = int(patterns_per_experience) self.memory_strength = memory_strength self.memory_x, self.memory_y, self.memory_tid = {}, {}, {} self.G = None
def before_training_iteration(self, strategy, **kwargs): """ Compute gradient constraints on previous memory samples from all experiences. """ if strategy.clock.train_exp_counter > 0: G = [] strategy.model.train() for t in range(strategy.clock.train_exp_counter): strategy.model.train() strategy.optimizer.zero_grad() xref = self.memory_x[t].to(strategy.device) yref = self.memory_y[t].to(strategy.device) out = avalanche_forward( strategy.model, xref, self.memory_tid[t] ) loss = strategy._criterion(out, yref) loss.backward() G.append( [ p.grad.flatten() if p.grad is not None else torch.zeros(p.numel(), device=strategy.device) for p in strategy.model.parameters() ], dim=0, ) ) self.G = torch.stack(G) # (experiences, parameters) @torch.no_grad() def after_backward(self, strategy, **kwargs): """ Project gradient based on reference gradients """ if strategy.clock.train_exp_counter > 0: g = [ p.grad.flatten() if p.grad is not None else torch.zeros(p.numel(), device=strategy.device) for p in strategy.model.parameters() ], dim=0, ) to_project = (, g) < 0).any() else: to_project = False if to_project: v_star = self.solve_quadprog(g).to(strategy.device) num_pars = 0 # reshape v_star into the parameter matrices for p in strategy.model.parameters(): curr_pars = p.numel() if p.grad is not None: p.grad.copy_( v_star[num_pars : num_pars + curr_pars].view(p.size()) ) num_pars += curr_pars assert num_pars == v_star.numel(), "Error in projecting gradient" def after_training_exp(self, strategy, **kwargs): """ Save a copy of the model after each experience """ self.update_memory( strategy.experience.dataset, strategy.clock.train_exp_counter, strategy.train_mb_size, ) @torch.no_grad() def update_memory(self, dataset, t, batch_size): """ Update replay memory with patterns from current experience. """ dataloader = DataLoader(dataset, batch_size=batch_size) tot = 0 for mbatch in dataloader: x, y, tid = mbatch[0], mbatch[1], mbatch[-1] if tot + x.size(0) <= self.patterns_per_experience: if t not in self.memory_x: self.memory_x[t] = x.clone() self.memory_y[t] = y.clone() self.memory_tid[t] = tid.clone() else: self.memory_x[t] =[t], x), dim=0) self.memory_y[t] =[t], y), dim=0) self.memory_tid[t] = (self.memory_tid[t], tid), dim=0 ) else: diff = self.patterns_per_experience - tot if t not in self.memory_x: self.memory_x[t] = x[:diff].clone() self.memory_y[t] = y[:diff].clone() self.memory_tid[t] = tid[:diff].clone() else: self.memory_x[t] = (self.memory_x[t], x[:diff]), dim=0 ) self.memory_y[t] = (self.memory_y[t], y[:diff]), dim=0 ) self.memory_tid[t] = (self.memory_tid[t], tid[:diff]), dim=0 ) break tot += x.size(0) def solve_quadprog(self, g): """ Solve quadratic programming with current gradient g and gradients matrix on previous tasks G. Taken from original code: """ memories_np = self.G.cpu().double().numpy() gradient_np = g.cpu().contiguous().view(-1).double().numpy() t = memories_np.shape[0] P =, memories_np.transpose()) P = 0.5 * (P + P.transpose()) + np.eye(t) * 1e-3 q =, gradient_np) * -1 G = np.eye(t) h = np.zeros(t) + self.memory_strength v = quadprog.solve_qp(P, q, G, h)[0] v_star =, memories_np) + gradient_np return torch.from_numpy(v_star).float()