Source code for avalanche.training.plugins.gdumb

import copy
from typing import TYPE_CHECKING, Optional

from avalanche.training.plugins.strategy_plugin import SupervisedPlugin
from avalanche.training.storage_policy import ClassBalancedBuffer

if TYPE_CHECKING:
    from avalanche.training.templates import SupervisedTemplate


[docs]class GDumbPlugin(SupervisedPlugin, supports_distributed=True): """GDumb plugin. At each experience the model is trained from scratch using a buffer of samples collected from all the previous learning experiences. The buffer is updated at the start of each experience to add new classes or new examples of already encountered classes. In multitask scenarios, mem_size is the memory size for each task. This plugin can be combined with a Naive strategy to obtain the standard GDumb strategy. https://www.robots.ox.ac.uk/~tvg/publications/2020/gdumb.pdf """
[docs] def __init__(self, mem_size: int = 200): super().__init__() self.mem_size = mem_size # model initialization # self.buffer = {} # TODO: remove buffer self.storage_policy = ClassBalancedBuffer( max_size=self.mem_size, adaptive_size=True ) self.init_model = None
def before_train_dataset_adaptation(self, strategy: "SupervisedTemplate", **kwargs): """Reset model.""" if self.init_model is None: self.init_model = copy.deepcopy(strategy.model) else: strategy.model = copy.deepcopy(self.init_model) strategy.model_adaptation(self.init_model) def before_eval_dataset_adaptation(self, strategy: "SupervisedTemplate", **kwargs): strategy.model_adaptation(self.init_model) def after_train_dataset_adaptation(self, strategy: "SupervisedTemplate", **kwargs): self.storage_policy.update(strategy, **kwargs) strategy.adapted_dataset = self.storage_policy.buffer
__all__ = ["GDumbPlugin"]