Source code for

from typing import Dict

import torch
from torch import Tensor
from torch.nn.functional import normalize
from torch.nn.modules import Module

from import get_last_fc_layer, swap_last_fc_layer
from import SupervisedPlugin
from import ClassBalancedBuffer
from avalanche.benchmarks.utils.data_loader import ReplayDataLoader

[docs]class CoPEPlugin(SupervisedPlugin): """Continual Prototype Evolution plugin. Each class has a prototype for nearest-neighbor classification. The prototypes are updated continually with an exponentially moving average, using class-balanced replay to keep the prototypes up-to-date. The embedding space is optimized using the PseudoPrototypicalProxy-loss, exploiting both prototypes and batch information. This plugin doesn't use task identities in training or eval (data incremental) and is designed for online learning (1 epoch per task). """
[docs] def __init__( self, mem_size=200, n_classes=10, p_size=100, alpha=0.99, T=0.1, max_it_cnt=1, ): """ :param mem_size: max number of input samples in the replay memory. :param n_classes: total number of classes that will be encountered. This is used to output predictions for all classes, with zero probability for unseen classes. :param p_size: The prototype size, which equals the feature size of the last layer. :param alpha: The momentum for the exponentially moving average of the prototypes. :param T: The softmax temperature, used as a concentration parameter. :param max_it_cnt: How many processing iterations per batch (experience) """ super().__init__() self.n_classes = n_classes self.it_cnt = 0 self.max_it_cnt = max_it_cnt # Operational memory: replay memory self.mem_size = mem_size # replay memory size self.storage_policy = ClassBalancedBuffer( max_size=self.mem_size, adaptive_size=True ) # Operational memory: Prototypical memory # Scales with nb classes * feature size self.p_mem: Dict[int, Tensor] = {} self.p_size = p_size # Prototype size determined on runtime self.tmp_p_mem = {} # Intermediate to process batch for multiple times self.alpha = alpha self.p_init_adaptive = False # Only create proto when class seen # PPP-loss self.T = T self.ppp_loss = PPPloss(self.p_mem, T=self.T) self.initialized = False
def before_training(self, strategy, **kwargs): """Enforce using the PPP-loss and add a NN-classifier.""" if not self.initialized: strategy._criterion = self.ppp_loss print("Using the Pseudo-Prototypical-Proxy loss for CoPE.") # Normalize representation of last layer swap_last_fc_layer( strategy.model, torch.nn.Sequential( get_last_fc_layer(strategy.model)[1], L2Normalization() ), ) # Static prototype init # Create prototypes for all classes at once if not self.p_init_adaptive and len(self.p_mem) == 0: self._init_new_prototypes( torch.arange(0, self.n_classes).to(strategy.device) ) self.initialized = True def before_training_exp(self, strategy, num_workers=0, shuffle=True, **kwargs): """ Random retrieval from a class-balanced memory. Dataloader builds batches containing examples from both memories and the training dataset. This implementation requires the use of early stopping, otherwise the entire memory will be iterated. """ if len(self.storage_policy.buffer) == 0: return self.it_cnt = 0 strategy.dataloader = ReplayDataLoader( strategy.adapted_dataset, self.storage_policy.buffer, oversample_small_tasks=False, num_workers=num_workers, batch_size=strategy.train_mb_size, batch_size_mem=strategy.train_mb_size, shuffle=shuffle, ) def after_training_iteration(self, strategy, **kwargs): """ Implements early stopping, determining how many subsequent times a batch can be used for updates. The dataloader contains only data for the current experience (batch) and the entire memory. Multiple iterations will hence result in the original batch with new exemplars sampled from the memory for each iteration. """ self.it_cnt += 1 if self.it_cnt == self.max_it_cnt: strategy.stop_training() # Stop processing the new-data batch def after_forward(self, strategy, **kwargs): """ After the forward we can use the representations to update our running avg of the prototypes. This is in case we do multiple iterations of processing on the same batch. New prototypes are initialized for previously unseen classes. """ if self.p_init_adaptive: # Init prototypes for unseen classes in batch self._init_new_prototypes(strategy.mb_y) # Update batch info (when multiple iterations on same batch) self._update_running_prototypes(strategy) @torch.no_grad() def _init_new_prototypes(self, targets: Tensor): """Initialize prototypes for previously unseen classes. :param targets: The targets Tensor to make prototypes for. """ y_unique: Tensor = torch.unique(targets).squeeze().view(-1) for idx in range(y_unique.size(0)): c: int = y_unique[idx].item() if c not in self.p_mem: # Init new prototype self.p_mem[c] = ( normalize( torch.empty((1, self.p_size)).uniform_(-1, 1), p=2, dim=1, ) .detach() .to(targets.device) ) @torch.no_grad() def _update_running_prototypes(self, strategy): """Accumulate seen outputs of the network and keep counts.""" y_unique = torch.unique(strategy.mb_y).squeeze().view(-1) for idx in range(y_unique.size(0)): c = y_unique[idx].item() idxs = torch.nonzero(strategy.mb_y == c).squeeze(1) p_tmp_batch = ( strategy.mb_output[idxs].sum(dim=0).unsqueeze(0).to(strategy.device) ) p_init, cnt_init = self.tmp_p_mem[c] if c in self.tmp_p_mem else (0, 0) self.tmp_p_mem[c] = (p_init + p_tmp_batch, cnt_init + len(idxs)) def after_training_exp(self, strategy, **kwargs): """After the current experience (batch), update prototypes and store observed samples for replay. """ self._update_prototypes() # Update prototypes self.storage_policy.update(strategy) # Update memory @torch.no_grad() def _update_prototypes(self): """Update the prototypes based on the running averages.""" for c, (p_sum, p_cnt) in self.tmp_p_mem.items(): incr_p = normalize(p_sum / p_cnt, p=2, dim=1) # L2 normalized old_p = self.p_mem[c].clone() new_p_momentum = ( self.alpha * old_p + (1 - self.alpha) * incr_p ) # Momentum update self.p_mem[c] = normalize(new_p_momentum, p=2, dim=1).detach() self.tmp_p_mem = {} def after_eval_iteration(self, strategy, **kwargs): """Convert output scores to probabilities for other metrics like accuracy and forgetting. We only do it at this point because before this,we still need the embedding outputs to obtain the PPP-loss.""" strategy.mb_output = self._get_nearest_neigbor_distr(strategy.mb_output) def _get_nearest_neigbor_distr(self, x: Tensor) -> Tensor: """ Find closest prototype for output samples in batch x. :param x: Batch of network logits. :return: one-hot representation of the predicted class. """ ns = x.size(0) nd = x.view(ns, -1).shape[-1] # Get prototypes seen_c = len(self.p_mem.keys()) if seen_c == 0: # no prototypes yet, output uniform distr. all classes return ( torch.Tensor(ns, self.n_classes) .fill_(1.0 / self.n_classes) .to(x.device) ) means = torch.ones(seen_c, nd).to(x.device) * float("inf") for c, c_proto in self.p_mem.items(): means[c] = c_proto # Class idx gets allocated its prototype # Predict nearest mean classpred = torch.LongTensor(ns) for s_idx in range(ns): # Per sample dist =, x[s_idx].unsqueeze(-1)) # Dot product _, ii = dist.min(0) # Min dist (no proto = inf) ii = ii.squeeze() classpred[s_idx] = ii.item() # Allocate class idx # Convert to 1-hot out = torch.zeros(ns, self.n_classes).to(x.device) for s_idx in range(ns): out[s_idx, classpred[s_idx]] = 1 return out # return 1-of-C code, ns x nc
class L2Normalization(Module): """Module to L2-normalize the input. Typically used in last layer to normalize the embedding.""" def __init__(self): super().__init__() def forward(self, x: Tensor) -> Tensor: return torch.nn.functional.normalize(x, p=2, dim=1) class PPPloss(object): """Pseudo-Prototypical Proxy loss (PPP-loss). This is a contrastive loss using prototypes and representations of the samples in the batch to optimize the embedding space. """ def __init__(self, p_mem: Dict, T=0.1): """ :param p_mem: dictionary with keys the prototype identifier and values the prototype tensors. :param T: temperature of the softmax, serving as concentration density parameter. """ self.T = T self.p_mem = p_mem def __call__(self, x, y): """ The loss is calculated with one-vs-rest batches Bc and Bk, split into the attractor and repellor loss terms. We iterate over the possible batches while accumulating the losses per class c vs other-classes k. """ loss = None bs = x.size(0) x = x.view(bs, -1) # Batch x feature size y_unique = torch.unique(y).squeeze().view(-1) include_repellor = len(y_unique.size()) <= 1 # When at least 2 classes # All prototypes p_y = torch.tensor([c for c in self.p_mem.keys()]).to(x.device).detach() p_x =[self.p_mem[c.item()] for c in p_y]).to(x.device).detach() for label_idx in range(y_unique.size(0)): # Per-class operation c = y_unique[label_idx] # Make all-vs-rest batches per class (Bc=attractor, Bk=repellor set) Bc = x.index_select(0, torch.nonzero(y == c).squeeze(dim=1)) Bk = x.index_select(0, torch.nonzero(y != c).squeeze(dim=1)) p_idx = torch.nonzero(p_y == c).squeeze(dim=1) # Prototypes pc = p_x[p_idx] # Class proto pk =[p_x[:p_idx], p_x[p_idx + 1 :]]).clone().detach() # Accumulate loss for instances of class c sum_logLc = self.attractor(pc, pk, Bc) sum_logLk = self.repellor(pc, pk, Bc, Bk) if include_repellor else 0 Loss_c = -sum_logLc - sum_logLk # attractor + repellor for class c loss = Loss_c if loss is None else loss + Loss_c # Update loss return loss / bs # Make independent batch size def attractor(self, pc, pk, Bc): """ Get the attractor loss terms for all instances in xc. :param pc: Prototype of the same class c. :param pk: Prototoypes of the other classes. :param Bc: Batch of instances of the same class c. :return: Sum_{i, the part of same class c} log P(c|x_i^c) """ m =[Bc.clone(), pc, pk]).detach() # Incl other-class proto pk_idx = m.shape[0] - pk.shape[0] # from when starts p_k # Resulting distance columns are per-instance loss terms D =, Bc.t()).div_(self.T).exp_() # Distance matrix exp terms mask = torch.eye(*D.shape).bool().to(Bc.device) # Exclude self-product Dm = D.masked_fill(mask, 0) # Masked out products with self Lc_n, Lk_d = Dm[:pk_idx], Dm[pk_idx:].sum(dim=0) # Num/denominator Pci = Lc_n / (Lc_n + Lk_d) # Get probabilities per instance E_Pc = Pci.sum(0) / Bc.shape[0] # Expectation over pseudo-prototypes return E_Pc.log_().sum() # sum over all instances (sum i) def repellor(self, pc, pk, Bc, Bk): """ Get the repellor loss terms for all pseudo-prototype instances in Bc. :param pc: Actual prototype of the same class c. :param pk: Prototoypes of the other classes (k). :param Bc: Batch of instances of the same class c. Acting as pseudo-prototypes. :param Bk: Batch of instances of other-than-c classes (k). :return: Sum_{i, part of same class c} Sum_{x_j^k} log 1 - P(c|x_j^k) """ union_ck =[Bc.clone(), pc, pk]).detach() pk_idx = union_ck.shape[0] - pk.shape[0] # Distance other-class-k to prototypes (pc/pk) and pseudo-prototype (xc) D =, Bk.t()).div_(self.T).exp_() Lk_d = D[pk_idx:].sum(dim=0).unsqueeze(0) # Numerator/denominator terms Lc_n = D[:pk_idx] Pki = Lc_n / (Lc_n + Lk_d) # probability E_Pk = (Pki[:-1] + Pki[-1].unsqueeze(0)) / 2 # Exp. pseudo/prototype inv_E_Pk = E_Pk.mul_(-1).add_(1).log_() # log( (1 - Pk)) return inv_E_Pk.sum() # Sum over (pseudo-prototypes), and instances