# Copyright (c) 2021 ContinualAI.                                              #
# Copyrights licensed under the MIT License.                                   #
# See the accompanying LICENSE file for terms.                                 #
#                                                                              #
# Date: 20-11-2020                                                             #
# Author(s): Vincenzo Lomonaco                                                 #
# E-mail:                                              #
# Website:                                           #

from typing import Optional, Sequence, TYPE_CHECKING, Union

from torch.nn import Module
from torch.optim import Optimizer

from avalanche.benchmarks.scenarios import ClassificationExperience
from avalanche.benchmarks.utils import concat_classification_datasets
from avalanche.benchmarks.utils.utils import concat_datasets
from import default_evaluator
from import SupervisedTemplate
from avalanche.models import DynamicModule

    from import SupervisedPlugin

class AlreadyTrainedError(Exception):

[docs]class JointTraining(SupervisedTemplate): """Joint training on the entire stream. JointTraining performs joint training (also called offline training) on the entire stream of data. This means that it is not a continual learning strategy but it can be used as an "offline" upper bound for them. .. warnings also:: Currently :py:class:`JointTraining` adapts its own dataset. Please check that the plugins you are using do not implement :py:meth:`adapt_trainin_dataset`. Otherwise, they are incompatible with :py:class:`JointTraining`. """
[docs] def __init__( self, model: Module, optimizer: Optimizer, criterion, train_mb_size: int = 1, train_epochs: int = 1, eval_mb_size: int = 1, device="cpu", plugins: Optional[Sequence["SupervisedPlugin"]] = None, evaluator=default_evaluator(), eval_every=-1, ): """Init. :param model: PyTorch model. :param optimizer: PyTorch optimizer. :param criterion: loss function. :param train_mb_size: mini-batch size for training. :param train_epochs: number of training epochs. :param eval_mb_size: mini-batch size for eval. :param device: PyTorch device to run the model. :param plugins: (optional) list of StrategyPlugins. :param evaluator: (optional) instance of EvaluationPlugin for logging and metric computations. None to remove logging. :param eval_every: the frequency of the calls to `eval` inside the training loop. -1 disables the evaluation. 0 means `eval` is called only at the end of the learning experience. Values >0 mean that `eval` is called every `eval_every` epochs and at the end of the learning experience.""" super().__init__( model=model, optimizer=optimizer, criterion=criterion, train_mb_size=train_mb_size, train_epochs=train_epochs, eval_mb_size=eval_mb_size, device=device, plugins=plugins, evaluator=evaluator, eval_every=eval_every, ) # JointTraining can be trained only once. self._is_fitted = False
def train( self, experiences: Union[ ClassificationExperience, Sequence[ClassificationExperience] ], eval_streams: Optional[ Sequence[ Union[ ClassificationExperience, Sequence[ClassificationExperience] ] ] ] = None, **kwargs ): """Training loop. JointTraining concatenates all the experiences together and trains on all of them at the same time (a.k.a. offline training). :param experiences: single Experience or sequence. :param eval_streams: list of streams for evaluation. If None: use training experiences for evaluation. Use [] if you do not want to evaluate during training. :return: dictionary containing last recorded value for each metric name. """ self.is_training = True self.model.train() if self._is_fitted: raise AlreadyTrainedError( "JointTraining can be trained only once. " "Please call the train method once on the entire stream." ) # Normalize training and eval data. if isinstance(experiences, ClassificationExperience): experiences = [experiences] if eval_streams is None: eval_streams = [experiences] for i, exp in enumerate(eval_streams): if isinstance(exp, ClassificationExperience): eval_streams[i] = [exp] self._eval_streams = eval_streams self._experiences = experiences self._before_training(**kwargs) for self.experience in experiences: self._before_training_exp(**kwargs) self._train_exp(self.experience, eval_streams, **kwargs) self._after_training_exp(**kwargs) # Joint training only needs a single step because # it concatenates all the data at once. break self._after_training(**kwargs) res = self.evaluator.get_last_metrics() self._is_fitted = True return res def train_dataset_adaptation(self, **kwargs): """Concatenates all the datastream.""" self.adapted_dataset = self._experiences[0].dataset if len(self._experiences) > 1: for exp in self._experiences[1:]: cat_data = concat_datasets([self.adapted_dataset, exp.dataset]) self.adapted_dataset = cat_data self.adapted_dataset = self.adapted_dataset.train() def model_adaptation(self, model=None): """Adapts strategy's model for all experiences.""" if model is None: model = self.model for experience in self._experiences: for module in model.modules(): if isinstance(module, DynamicModule): module.adaptation(experience) model = return model
__all__ = ["JointTraining"]