import copy
import itertools
from typing import TYPE_CHECKING, Optional, List
import torch
from torch.optim import Optimizer

from avalanche.benchmarks.utils import AvalancheConcatDataset, \
    AvalancheTensorDataset, AvalancheSubset
from math import ceil

from avalanche.models import TrainEvalModel, NCMClassifier
from import EvaluationPlugin
from import default_logger
from import ICaRLLossPlugin
from import StrategyPlugin
from torch.nn import Module
from import DataLoader
from import BaseStrategy

[docs]class ICaRL(BaseStrategy): """ iCaRL Strategy. This strategy does not use task identities. """
[docs] def __init__(self, feature_extractor: Module, classifier: Module, optimizer: Optimizer, memory_size, buffer_transform, fixed_memory, criterion=ICaRLLossPlugin(), train_mb_size: int = 1, train_epochs: int = 1, eval_mb_size: int = None, device=None, plugins: Optional[List[StrategyPlugin]] = None, evaluator: EvaluationPlugin = default_logger, eval_every=-1): """Init. :param feature_extractor: The feature extractor. :param classifier: The differentiable classifier that takes as input the output of the feature extractor. :param optimizer: The optimizer to use. :param memory_size: The nuber of patterns saved in the memory. :param buffer_transform: transform applied on buffer elements already modified by test_transform (if specified) before being used for replay :param fixed_memory: If True a memory of size memory_size is allocated and partitioned between samples from the observed experiences. If False every time a new class is observed memory_size samples of that class are added to the memory. :param train_mb_size: The train minibatch size. Defaults to 1. :param train_epochs: The number of training epochs. Defaults to 1. :param eval_mb_size: The eval minibatch size. Defaults to 1. :param device: The device to use. Defaults to None (cpu). :param plugins: Plugins to be added. Defaults to None. :param evaluator: (optional) instance of EvaluationPlugin for logging and metric computations. :param eval_every: the frequency of the calls to `eval` inside the training loop. -1 disables the evaluation. 0 means `eval` is called only at the end of the learning experience. Values >0 mean that `eval` is called every `eval_every` epochs and at the end of the learning experience. """ model = TrainEvalModel(feature_extractor, train_classifier=classifier, eval_classifier=NCMClassifier()) icarl = _ICaRLPlugin(memory_size, buffer_transform, fixed_memory) if plugins is None: plugins = [icarl] else: plugins += [icarl] if isinstance(criterion, StrategyPlugin): plugins += [criterion] super().__init__( model, optimizer, criterion=criterion, train_mb_size=train_mb_size, train_epochs=train_epochs, eval_mb_size=eval_mb_size, device=device, plugins=plugins, evaluator=evaluator, eval_every=eval_every)
class _ICaRLPlugin(StrategyPlugin): """ iCaRL Plugin. iCaRL uses nearest class exemplar classification to prevent forgetting to occur at the classification layer. The feature extractor is continually learned using replay and distillation. The exemplars used for replay and classification are selected through herding. This plugin does not use task identities. """ def __init__(self, memory_size, buffer_transform=None, fixed_memory=True): """ :param memory_size: amount of patterns saved in the memory. :param buffer_transform: transform applied on buffer elements already modified by test_transform (if specified) before being used for replay :param fixed_memory: If True a memory of size memory_size is allocated and partitioned between samples from the observed experiences. If False every time a new class is observed memory_size samples of that class are added to the memory. """ super().__init__() self.memory_size = memory_size self.buffer_transform = buffer_transform self.fixed_memory = fixed_memory self.x_memory = [] self.y_memory = [] self.order = [] self.old_model = None self.observed_classes = [] self.class_means = None self.embedding_size = None self.output_size = None self.input_size = None def after_train_dataset_adaptation(self, strategy: 'BaseStrategy', **kwargs): if strategy.clock.train_exp_counter != 0: memory = AvalancheTensorDataset(, list(itertools.chain.from_iterable(self.y_memory)), transform=self.buffer_transform, target_transform=None) strategy.adapted_dataset = \ AvalancheConcatDataset((strategy.adapted_dataset, memory)) def before_training_exp(self, strategy: 'BaseStrategy', **kwargs): tid = strategy.clock.train_exp_counter benchmark = strategy.experience.benchmark nb_cl = benchmark.n_classes_per_exp[tid] previous_seen_classes = sum(benchmark.n_classes_per_exp[:tid]) self.observed_classes.extend( benchmark.classes_order[previous_seen_classes: previous_seen_classes + nb_cl]) def before_forward(self, strategy: 'BaseStrategy', **kwargs): if self.input_size is None: with torch.no_grad(): self.input_size = strategy.mb_x.shape[1:] self.output_size = strategy.model(strategy.mb_x).shape[1] self.embedding_size = strategy.model.feature_extractor( strategy.mb_x).shape[1] def after_training_exp(self, strategy: 'BaseStrategy', **kwargs): strategy.model.eval() self.construct_exemplar_set(strategy) self.reduce_exemplar_set(strategy) self.compute_class_means(strategy) def compute_class_means(self, strategy): if self.class_means is None: n_classes = sum(strategy.experience.benchmark.n_classes_per_exp) self.class_means = torch.zeros( (self.embedding_size, n_classes)).to(strategy.device) for i, class_samples in enumerate(self.x_memory): label = self.y_memory[i][0] class_samples = with torch.no_grad(): mapped_prototypes = strategy.model.feature_extractor( class_samples).detach() D = mapped_prototypes.T D = D / torch.norm(D, dim=0) if len(class_samples.shape) == 4: class_samples = torch.flip(class_samples, [3]) with torch.no_grad(): mapped_prototypes2 = strategy.model.feature_extractor( class_samples).detach() D2 = mapped_prototypes2.T D2 = D2 / torch.norm(D2, dim=0) div = torch.ones(class_samples.shape[0], device=strategy.device) div = div / class_samples.shape[0] m1 =, div.unsqueeze(1)).squeeze(1) m2 =, div.unsqueeze(1)).squeeze(1) self.class_means[:, label] = (m1 + m2) / 2 self.class_means[:, label] /= torch.norm(self.class_means[:, label]) strategy.model.eval_classifier.class_means = self.class_means def construct_exemplar_set(self, strategy: BaseStrategy): tid = strategy.clock.train_exp_counter benchmark = strategy.experience.benchmark nb_cl = benchmark.n_classes_per_exp[tid] previous_seen_classes = sum(benchmark.n_classes_per_exp[:tid]) if self.fixed_memory: nb_protos_cl = int(ceil( self.memory_size / len(self.observed_classes))) else: nb_protos_cl = self.memory_size new_classes = self.observed_classes[previous_seen_classes: previous_seen_classes + nb_cl] dataset = strategy.experience.dataset targets = torch.tensor(dataset.targets) for iter_dico in range(nb_cl): cd = AvalancheSubset(dataset, torch.where(targets == new_classes[iter_dico]) [0]) class_patterns, _, _ = next(iter( DataLoader(cd.eval(), batch_size=len(cd)))) class_patterns = with torch.no_grad(): mapped_prototypes = strategy.model.feature_extractor( class_patterns).detach() D = mapped_prototypes.T D = D / torch.norm(D, dim=0) mu = torch.mean(D, dim=1) order = torch.zeros(class_patterns.shape[0]) w_t = mu i, added, selected = 0, 0, [] while not added == nb_protos_cl and i < 1000: tmp_t =, D) ind_max = torch.argmax(tmp_t) if ind_max not in selected: order[ind_max] = 1 + added added += 1 selected.append(ind_max.item()) w_t = w_t + mu - D[:, ind_max] i += 1 pick = (order > 0) * (order < nb_protos_cl + 1) * 1. self.x_memory.append(class_patterns[torch.where(pick == 1)[0]]) self.y_memory.append( [new_classes[iter_dico]] * len(torch.where(pick == 1)[0])) self.order.append(order[torch.where(pick == 1)[0]]) def reduce_exemplar_set(self, strategy: BaseStrategy): tid = strategy.clock.train_exp_counter nb_cl = strategy.experience.benchmark.n_classes_per_exp if self.fixed_memory: nb_protos_cl = int(ceil( self.memory_size / len(self.observed_classes))) else: nb_protos_cl = self.memory_size for i in range(len(self.x_memory) - nb_cl[tid]): pick = (self.order[i] < nb_protos_cl + 1) * 1. self.x_memory[i] = self.x_memory[i][torch.where(pick == 1)[0]] self.y_memory[i] = self.y_memory[i][:len(torch.where(pick == 1)[0])] self.order[i] = self.order[i][torch.where(pick == 1)[0]]