Source code for

import warnings
from collections import defaultdict
from typing import Iterable, Sequence, Optional, Union, List

import torch
from torch.nn import Module

from avalanche.benchmarks import CLExperience, CLStream
from avalanche.core import BasePlugin
from import trigger_plugins

ExpSequence = Iterable[CLExperience]

[docs]class BaseTemplate: """Base class for continual learning skeletons. **Training loop** The training loop is organized as follows:: train train_exp # for each experience **Evaluation loop** The evaluation loop is organized as follows:: eval eval_exp # for each experience """ # we need this only for type checking PLUGIN_CLASS = BasePlugin
[docs] def __init__( self, model: Module, device="cpu", plugins: Optional[List[BasePlugin]] = None, ): """Init.""" self.model: Module = model """ PyTorch model. """ if device is None: device = 'cpu' self.device = torch.device(device) """ PyTorch device where the model will be allocated. """ self.plugins = [] if plugins is None else plugins """ List of `SupervisedPlugin`s. """ # check plugin compatibility self._check_plugin_compatibility() ################################################################### # State variables. These are updated during the train/eval loops. # ################################################################### self.experience: Optional[CLExperience] = None """ Current experience. """ self.is_training: bool = False """ True if the strategy is in training mode. """ self.current_eval_stream: Optional[ExpSequence] = None """ Current evaluation stream. """
@property def is_eval(self): """True if the strategy is in evaluation mode.""" return not self.is_training def train( self, experiences: Union[CLExperience, ExpSequence], eval_streams: Optional[ Sequence[Union[CLExperience, ExpSequence]] ] = None, **kwargs, ): """Training loop. If experiences is a single element trains on it. If it is a sequence, trains the model on each experience in order. This is different from joint training on the entire stream. It returns a dictionary with last recorded value for each metric. :param experiences: single Experience or sequence. :param eval_streams: sequence of streams for evaluation. If None: use training experiences for evaluation. Use [] if you do not want to evaluate during training. Experiences in `eval_streams` are grouped by stream name when calling `eval`. If you use multiple streams, they must have different names. """ self.is_training = True self._stop_training = False self.model.train() # Normalize training and eval data. if not isinstance(experiences, Iterable): experiences = [experiences] if eval_streams is None: eval_streams = [experiences] self._eval_streams = _group_experiences_by_stream(eval_streams) self._before_training(**kwargs) for self.experience in experiences: self._before_training_exp(**kwargs) self._train_exp(self.experience, eval_streams, **kwargs) self._after_training_exp(**kwargs) self._after_training(**kwargs) def _train_exp(self, experience: CLExperience, eval_streams, **kwargs): raise NotImplementedError() @torch.no_grad() def eval( self, exp_list: Union[CLExperience, CLStream], **kwargs, ): """ Evaluate the current model on a series of experiences and returns the last recorded value for each metric. :param exp_list: CL experience information. :param kwargs: custom arguments. :return: dictionary containing last recorded value for each metric name """ # eval can be called inside the train method. # Save the shared state here to restore before returning. prev_train_state = self._save_train_state() self.is_training = False self.model.eval() if not isinstance(exp_list, Iterable): exp_list = [exp_list] self.current_eval_stream = exp_list self._before_eval(**kwargs) for self.experience in exp_list: self._before_eval_exp(**kwargs) self._eval_exp(**kwargs) self._after_eval_exp(**kwargs) self._after_eval(**kwargs) # restore previous shared state. self._load_train_state(prev_train_state) def _eval_exp(self, **kwargs): raise NotImplementedError() def _save_train_state(self): """Save the training state, which may be modified by the eval loop. TODO: we probably need a better way to do this. """ # save each layer's training mode, to restore it later _prev_model_training_modes = {} for name, layer in self.model.named_modules(): _prev_model_training_modes[name] = _prev_state = { "experience": self.experience, "is_training": self.is_training, "model_training_mode": _prev_model_training_modes, } return _prev_state def _load_train_state(self, prev_state): # restore train-state variables and training mode. self.experience = prev_state["experience"] self.is_training = prev_state["is_training"] # restore each layer's training mode to original prev_training_modes = prev_state["model_training_mode"] for name, layer in self.model.named_modules(): try: prev_mode = prev_training_modes[name] layer.train(mode=prev_mode) except KeyError: # Unknown parameter, probably added during the eval # model's adaptation. We set it to train mode. layer.train() def _check_plugin_compatibility(self): """Check that the list of plugins is compatible with the template. This means checking that each plugin impements a subset of the supported callbacks. """ # TODO: ideally we would like to check the argument's type to check # that it's a supertype of the template. # I don't know if it's possible to do it in Python. ps = self.plugins def get_plugins_from_object(obj): def is_callback(x): return x.startswith("before") or x.startswith("after") return filter(is_callback, dir(obj)) cb_supported = set(get_plugins_from_object(self.PLUGIN_CLASS)) for p in ps: cb_p = set(get_plugins_from_object(p)) if not cb_p.issubset(cb_supported): warnings.warn( f"Plugin {p} implements incompatible callbacks for template" f" {self}. This may result in errors. Incompatible " f"callbacks: {cb_p - cb_supported}", ) return ######################################################### # Plugin Triggers # ######################################################### def _before_training_exp(self, **kwargs): trigger_plugins(self, "before_training_exp", **kwargs) def _after_training_exp(self, **kwargs): trigger_plugins(self, "after_training_exp", **kwargs) def _before_training(self, **kwargs): trigger_plugins(self, "before_training", **kwargs) def _after_training(self, **kwargs): trigger_plugins(self, "after_training", **kwargs) def _before_eval(self, **kwargs): trigger_plugins(self, "before_eval", **kwargs) def _after_eval(self, **kwargs): trigger_plugins(self, "after_eval", **kwargs) def _before_eval_exp(self, **kwargs): trigger_plugins(self, "before_eval_exp", **kwargs) def _after_eval_exp(self, **kwargs): trigger_plugins(self, "after_eval_exp", **kwargs)
def _group_experiences_by_stream(eval_streams): if len(eval_streams) == 1: return eval_streams exps = [] # First, we unpack the list of experiences. for exp in eval_streams: if isinstance(exp, Iterable): exps.extend(exp) else: exps.append(exp) # Then, we group them by stream. exps_by_stream = defaultdict(list) for exp in exps: sname = exps_by_stream[sname].append(exp) # Finally, we return a list of lists. return list(exps_by_stream.values())