Source code for avalanche.training.plugins.generative_replay

################################################################################
# Copyright (c) 2021 ContinualAI.                                              #
# Copyrights licensed under the MIT License.                                   #
# See the accompanying LICENSE file for terms.                                 #
#                                                                              #
# Date: 05-03-2022                                                             #
# Author: Florian Mies                                                         #
# Website: https://github.com/travela                                          #
################################################################################

"""

All plugins related to Generative Replay.

"""

from copy import deepcopy
from typing import Optional, Any
from avalanche.core import SupervisedPlugin, Template
import torch


[docs]class GenerativeReplayPlugin(SupervisedPlugin): """ Experience generative replay plugin. Updates the current mbatch of a strategy before training an experience by sampling a generator model and concatenating the replay data to the current batch. In this version of the plugin the number of replay samples is increased with each new experience. Another way to implempent the algorithm is by weighting the loss function and give more importance to the replayed data as the number of experiences increases. This will be implemented as an option for the user soon. :param generator_strategy: In case the plugin is applied to a non-generative model (e.g. a simple classifier), this should contain an Avalanche strategy for a model that implements a 'generate' method (see avalanche.models.generator.Generator). Defaults to None. :param untrained_solver: if True we assume this is the beginning of a continual learning task and add replay data only from the second experience onwards, otherwise we sample and add generative replay data before training the first experience. Default to True. :param replay_size: The user can specify the batch size of replays that should be added to each data batch. By default each data batch will be matched with replays of the same number. :param increasing_replay_size: If set to True, each experience this will double the amount of replay data added to each data batch. The effect will be that the older experiences will gradually increase in importance to the final loss. :param is_weighted_replay: If set to True, the loss function will be weighted and more importance will be given to the replay data as the number of experiences increases. :param weight_replay_loss_factor: If is_weighted_replay is set to True, the user can specify a factor the weight will be multiplied by in each iteration, the default is 1.0 :param weight_replay_loss: The user can specify the initial weight of the loss for the replay data. The default is 0.0001 """
[docs] def __init__( self, generator_strategy=None, untrained_solver: bool = True, replay_size: Optional[int] = None, increasing_replay_size: bool = False, is_weighted_replay: bool = False, weight_replay_loss_factor: float = 1.0, weight_replay_loss: float = 0.0001, ): """ Init. """ super().__init__() self.generator_strategy = generator_strategy if self.generator_strategy: self.generator = generator_strategy.model else: self.generator = None self.untrained_solver = untrained_solver self.model_is_generator = False self.replay_size = replay_size self.increasing_replay_size = increasing_replay_size self.is_weighted_replay = is_weighted_replay self.weight_replay_loss_factor = weight_replay_loss_factor self.weight_replay_loss = weight_replay_loss
def before_training(self, strategy, *args, **kwargs): """Checks whether we are using a user defined external generator or we use the strategy's model as the generator. If the generator is None after initialization we assume that strategy.model is the generator. (e.g. this would be the case when training a VAE with generative replay)""" if not self.generator_strategy: self.generator_strategy = strategy self.generator = strategy.model self.model_is_generator = True def before_training_exp( self, strategy, num_workers: int = 0, shuffle: bool = True, **kwargs ): """ Make deep copies of generator and solver before training new experience. """ if self.untrained_solver: # The solver needs to be trained before labelling generated data and # the generator needs to be trained before we can sample. return self.old_generator = deepcopy(self.generator) self.old_generator.eval() if not self.model_is_generator: self.old_model = deepcopy(strategy.model) self.old_model.eval() def after_training_exp( self, strategy, num_workers: int = 0, shuffle: bool = True, **kwargs ): """ Set untrained_solver boolean to False after (the first) experience, in order to start training with replay data from the second experience. """ self.untrained_solver = False def before_backward(self, strategy: Template, *args, **kwargs) -> Any: """ Generate replay data and calculate the loss on the replay data. Add weighted loss to the total loss if the user has set the weight_replay_loss """ super().before_backward(strategy, *args, **kwargs) if not self.is_weighted_replay: # If we are not using weighted loss, ignore this method return if self.untrained_solver: # do not generate on the first experience return # determine how many replay data points to generate if self.replay_size: number_replays_to_generate = self.replay_size else: if self.increasing_replay_size: number_replays_to_generate = len(strategy.mbatch[0]) * ( strategy.experience.current_experience ) else: number_replays_to_generate = len(strategy.mbatch[0]) replay_data = self.old_generator.generate(number_replays_to_generate).to( strategy.device ) # get labels for replay data if not self.model_is_generator: with torch.no_grad(): replay_output = self.old_model(replay_data).argmax(dim=-1) else: # Mock labels: replay_output = torch.zeros(replay_data.shape[0]) # make copy of mbatch mbatch = deepcopy(strategy.mbatch) # replace mbatch with replay data, calculate loss and add to strategy.loss strategy.mbatch = [replay_data, replay_output, strategy.mbatch[-1]] strategy.forward() strategy.loss += self.weight_replay_loss * strategy.criterion() self.weight_replay_loss *= self.weight_replay_loss_factor # restore mbatch strategy.mbatch = mbatch def before_training_iteration(self, strategy, **kwargs): """ Generating and appending replay data to current minibatch before each training iteration. """ if self.is_weighted_replay: # When using weighted loss, do not add replay data to the current minibatch return if self.untrained_solver: # The solver needs to be trained before labelling generated data and # the generator needs to be trained before we can sample. return # determine how many replay data points to generate if self.replay_size: number_replays_to_generate = self.replay_size else: if self.increasing_replay_size: number_replays_to_generate = len(strategy.mbatch[0]) * ( strategy.experience.current_experience ) else: number_replays_to_generate = len(strategy.mbatch[0]) # extend X with replay data replay = self.old_generator.generate(number_replays_to_generate).to( strategy.device ) strategy.mbatch[0] = torch.cat([strategy.mbatch[0], replay], dim=0) # extend y with predicted labels (or mock labels if model==generator) if not self.model_is_generator: with torch.no_grad(): replay_output = self.old_model(replay).argmax(dim=-1) else: # Mock labels: replay_output = torch.zeros(replay.shape[0]) strategy.mbatch[1] = torch.cat( [strategy.mbatch[1], replay_output.to(strategy.device)], dim=0 ) # extend task id batch (we implicitley assume a task-free case) strategy.mbatch[-1] = torch.cat( [ strategy.mbatch[-1], torch.ones(replay.shape[0]).to(strategy.device) * strategy.mbatch[-1][0], ], dim=0, )
[docs]class TrainGeneratorAfterExpPlugin(SupervisedPlugin): """ TrainGeneratorAfterExpPlugin makes sure that after each experience of training the solver of a scholar model, we also train the generator on the data of the current experience. """ def after_training_exp(self, strategy, **kwargs): """ The training method expects an Experience object with a 'dataset' parameter. """ for plugin in strategy.plugins: if type(plugin) is GenerativeReplayPlugin: plugin.generator_strategy.train(strategy.experience)
__all__ = ["GenerativeReplayPlugin", "TrainGeneratorAfterExpPlugin"]