################################################################################
# Copyright (c) 2021 ContinualAI. #
# Copyrights licensed under the MIT License. #
# See the accompanying LICENSE file for terms. #
# #
# Date: 20-11-2020 #
# Author(s): Vincenzo Lomonaco #
# E-mail: contact@continualai.org #
# Website: avalanche.continualai.org #
################################################################################
from typing import Optional, Sequence, TYPE_CHECKING, Union
from torch.nn import Module
from torch.optim import Optimizer
from torch.utils.data import ConcatDataset
from avalanche.benchmarks.scenarios import Experience
from avalanche.benchmarks.utils import AvalancheConcatDataset
from avalanche.training.plugins.evaluation import default_logger
from avalanche.training.strategies import BaseStrategy
from avalanche.models import DynamicModule
if TYPE_CHECKING:
from avalanche.training.plugins import StrategyPlugin
class AlreadyTrainedError(Exception):
pass
[docs]class JointTraining(BaseStrategy):
""" Joint training on the entire stream.
JointTraining performs joint training (also called offline training) on
the entire stream of data. This means that it is not a continual
learning strategy but it can be used as an "offline" upper bound for
them.
.. warnings also::
Currently :py:class:`JointTraining` adapts its own dataset.
Please check that the plugins you are using do not implement
:py:meth:`adapt_trainin_dataset`. Otherwise, they are incompatible
with :py:class:`JointTraining`.
"""
[docs] def __init__(self, model: Module, optimizer: Optimizer, criterion,
train_mb_size: int = 1, train_epochs: int = 1,
eval_mb_size: int = 1, device='cpu',
plugins: Optional[Sequence['StrategyPlugin']] = None,
evaluator=default_logger, eval_every=-1):
"""Init.
:param model: PyTorch model.
:param optimizer: PyTorch optimizer.
:param criterion: loss function.
:param train_mb_size: mini-batch size for training.
:param train_epochs: number of training epochs.
:param eval_mb_size: mini-batch size for eval.
:param device: PyTorch device to run the model.
:param plugins: (optional) list of StrategyPlugins.
:param evaluator: (optional) instance of EvaluationPlugin for logging
and metric computations. None to remove logging.
:param eval_every: the frequency of the calls to `eval` inside the
training loop. -1 disables the evaluation. 0 means `eval` is called
only at the end of the learning experience. Values >0 mean that
`eval` is called every `eval_every` epochs and at the end of the
learning experience. """
super().__init__(model=model,
optimizer=optimizer,
criterion=criterion,
train_mb_size=train_mb_size,
train_epochs=train_epochs,
eval_mb_size=eval_mb_size,
device=device,
plugins=plugins,
evaluator=evaluator,
eval_every=eval_every
)
# JointTraining can be trained only once.
self._is_fitted = False
def train(self, experiences: Union[Experience, Sequence[Experience]],
eval_streams: Optional[Sequence[Union[Experience,
Sequence[
Experience]]]] = None,
**kwargs):
""" Training loop. if experiences is a single element trains on it.
If it is a sequence, trains the model on each experience in order.
This is different from joint training on the entire stream.
It returns a dictionary with last recorded value for each metric.
:param experiences: single Experience or sequence.
:param eval_streams: list of streams for evaluation.
If None: use training experiences for evaluation.
Use [] if you do not want to evaluate during training.
:return: dictionary containing last recorded value for
each metric name.
"""
self.is_training = True
self.model.train()
self.model.to(self.device)
if self._is_fitted:
raise AlreadyTrainedError(
"JointTraining can be trained only once. "
"Please call the train method once on the entire stream."
)
# Normalize training and eval data.
if isinstance(experiences, Experience):
experiences = [experiences]
if eval_streams is None:
eval_streams = [experiences]
for i, exp in enumerate(eval_streams):
if isinstance(exp, Experience):
eval_streams[i] = [exp]
self._experiences = experiences
self._before_training(**kwargs)
for exp in experiences:
self.train_exp(exp, eval_streams, **kwargs)
# Joint training only needs a single step because
# it concatenates all the data at once.
break
self._after_training(**kwargs)
res = self.evaluator.get_last_metrics()
self._is_fitted = True
return res
def train_dataset_adaptation(self, **kwargs):
""" Concatenates all the datastream. """
self.adapted_dataset = self._experiences[0].dataset
for exp in self._experiences[1:]:
cat_data = AvalancheConcatDataset([self.adapted_dataset,
exp.dataset])
self.adapted_dataset = cat_data
self.adapted_dataset = self.adapted_dataset.train()
def model_adaptation(self, model=None):
""" Adapts strategy's model for all experiences. """
if model is None:
model = self.model
for experience in self._experiences:
for module in model.modules():
if isinstance(module, DynamicModule):
module.adaptation(experience.dataset)
model = model.to(self.device)
return model
__all__ = ['JointTraining']