Source code for avalanche.training.plugins.il2m

from typing import Optional

from packaging.version import parse
import torch
import numpy as np

from avalanche.training.templates import SupervisedTemplate
from avalanche.training.plugins.strategy_plugin import SupervisedPlugin
from avalanche.training.storage_policy import ExemplarsBuffer, ExperienceBalancedBuffer
from avalanche.benchmarks.utils.data_loader import ReplayDataLoader


[docs]class IL2MPlugin(SupervisedPlugin): """ Class Incremental Learning With Dual Memory (IL2M) plugin. Technique introduced in: Belouadah, E. and Popescu, A. "IL2M: Class Incremental Learning With Dual Memory." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2019. Implementation based on FACIL, as in: https://github.com/mmasana/FACIL/blob/master/src/approach/il2m.py """
[docs] def __init__( self, mem_size: int = 2000, batch_size: Optional[int] = None, batch_size_mem: Optional[int] = None, storage_policy: Optional[ExemplarsBuffer] = None, ): """ :param mem_size: replay buffer size. :param batch_size: the size of the data batch. If set to `None`, it will be set equal to the strategy's batch size. :param batch_size_mem: the size of the memory batch. If its value is set to `None` (the default value), it will be automatically set equal to the data batch size. :param storage_policy: The policy that controls how to add new exemplars in memory. """ super().__init__() self.mem_size = mem_size self.batch_size = batch_size self.batch_size_mem = batch_size_mem if storage_policy is not None: # Use other storage policy self.storage_policy = storage_policy assert storage_policy.max_size == self.mem_size else: # Default self.storage_policy = ExperienceBalancedBuffer( max_size=self.mem_size, adaptive_size=True ) # to store statistics for the classes as learned in the current incremental state self.current_classes_means = [] # to store statistics for past classes as learned in the incremental state in which they were first seen self.init_classes_means = [] # to store statistics for model confidence in different states (i.e. avg top-1 pred scores) self.models_confidence = [] # to store the mapping between classes and the incremental state in which they were first seen self.classes2exp = [] # total number of classes that will be seen self.n_classes = 0
def before_training_exp( self, strategy: SupervisedTemplate, num_workers: int = 0, shuffle: bool = True, drop_last: bool = False, **kwargs ): if len(self.init_classes_means) == 0: self.n_classes = len(strategy.experience.classes_seen_so_far) + len( strategy.experience.future_classes ) self.init_classes_means = [0 for _ in range(self.n_classes)] self.classes2exp = [-1 for _ in range(self.n_classes)] if len(self.storage_policy.buffer) == 0: # first experience. We don't use the buffer, no need to change # the dataloader. return batch_size = self.batch_size if batch_size is None: batch_size = strategy.train_mb_size batch_size_mem = self.batch_size_mem if batch_size_mem is None: batch_size_mem = strategy.train_mb_size assert strategy.adapted_dataset is not None other_dataloader_args = dict() if "ffcv_args" in kwargs: other_dataloader_args["ffcv_args"] = kwargs["ffcv_args"] if "persistent_workers" in kwargs: if parse(torch.__version__) >= parse("1.7.0"): other_dataloader_args["persistent_workers"] = kwargs[ "persistent_workers" ] strategy.dataloader = ReplayDataLoader( strategy.adapted_dataset, self.storage_policy.buffer, oversample_small_tasks=True, batch_size=batch_size, batch_size_mem=batch_size_mem, num_workers=num_workers, shuffle=shuffle, drop_last=drop_last, **other_dataloader_args ) def after_training_exp(self, strategy: SupervisedTemplate, **kwargs): experience = strategy.experience self.current_classes_means = [0 for _ in range(self.n_classes)] classes_counts = [0 for _ in range(self.n_classes)] self.models_confidence.append(0) models_counts = 0 # compute the mean prediction scores that will be used to rectify scores in subsequent incremental states with torch.no_grad(): strategy.model.eval() for inputs, targets, _ in strategy.dataloader: inputs, targets = inputs.to(strategy.device), targets.to( strategy.device ) outputs = strategy.model(inputs.to(strategy.device)) scores = outputs.data.cpu().numpy() for i in range(len(targets)): target = targets[i].item() classes_counts[target] += 1 if target in experience.previous_classes: # compute the mean prediction scores for past classes of the current state self.current_classes_means[target] += scores[i, target] else: # compute the mean prediction scores for the new classes of the current state self.init_classes_means[target] += scores[i, target] # compute the mean top scores for the new classes of the current state self.models_confidence[-1] += np.max(scores[i,]) models_counts += 1 # normalize by corresponding number of samples for cls in experience.previous_classes: self.current_classes_means[cls] /= classes_counts[cls] for cls in experience.classes_in_this_experience: self.init_classes_means[cls] /= classes_counts[cls] self.models_confidence[-1] /= models_counts # store the mapping between classes and the incremental state in which they are first seen for cls in experience.classes_in_this_experience: self.classes2exp[cls] = experience.current_experience # update the buffer of exemplars self.storage_policy.post_adapt(strategy, strategy.experience) def after_eval_forward(self, strategy: SupervisedTemplate, **kwargs): old_classes = strategy.experience.previous_classes new_classes = strategy.experience.classes_in_this_experience if not old_classes: return outputs = strategy.mb_output targets = strategy.mbatch[1] # rectify predicted scores (Eq. 1 in the paper) for i in range(len(targets)): # if the top-1 class predicted by the network is a new one, rectify the score if outputs[i].argmax().item() in new_classes: for cls in old_classes: o_exp = self.classes2exp[cls] if ( self.current_classes_means[cls] == 0 ): # when evaluation is done before training continue outputs[i, cls] *= ( self.init_classes_means[cls] / self.current_classes_means[cls] ) * (self.models_confidence[-1] / self.models_confidence[o_exp])
# otherwise, rectification is not done because an old class is directly predicted __all__ = [ "IL2MPlugin", ]