"""
This module contains Protocols for some of the main components of Avalanche,
such as strategy plugins and the agent state.
Most of these protocols are checked dynamically at runtime, so it is often not
necessary to inherit explicit from them or implement all the methods.
"""
from abc import ABC
from typing import Any, TypeVar, Generic, Protocol, runtime_checkable
from typing import TYPE_CHECKING
from avalanche.benchmarks import CLExperience
if TYPE_CHECKING:
from avalanche.training.templates.base import BaseTemplate
Template = TypeVar("Template", bound="BaseTemplate")
class Agent:
"""Avalanche Continual Learning Agent.
The agent stores the state needed by continual learning training methods,
such as optimizers, models, regularization losses.
You can add any objects as attributes dynamically:
.. code-block::
agent = Agent()
agent.replay = ReservoirSamplingBuffer(max_size=200)
agent.loss = MaskedCrossEntropy()
agent.reg_loss = LearningWithoutForgetting(alpha=1, temperature=2)
agent.model = my_model
agent.opt = SGD(agent.model.parameters(), lr=0.001)
agent.scheduler = ExponentialLR(agent.opt, gamma=0.999)
Many CL objects will need to perform some operation before or
after training on each experience. This is supported via the `Adaptable`
Protocol, which requires the `pre_adapt` and `post_adapt` methods.
To call the pre/post adaptation you can implement your training loop
like in the following example:
.. code-block::
def train(agent, exp):
agent.pre_adapt(exp)
# do training here
agent.post_adapt(exp)
Objects that implement the `Adaptable` Protocol will be called by the Agent.
You can also add additional functionality to the adaptation phases with
hooks. For example:
.. code-block::
agent.add_pre_hooks(lambda a, e: update_optimizer(a.opt, new_params={}, optimized_params=dict(a.model.named_parameters())))
# we update the lr scheduler after each experience (not every epoch!)
agent.add_post_hooks(lambda a, e: a.scheduler.step())
"""
def __init__(self, verbose=False):
"""Init.
:param verbose: If True, print every time an adaptable object or hook
is called during the adaptation. Useful for debugging.
"""
self._updatable_objects = []
self.verbose = verbose
self._pre_hooks = []
self._post_hooks = []
def __setattr__(self, name, value):
super().__setattr__(name, value)
if hasattr(value, "pre_adapt") or hasattr(value, "post_adapt"):
self._updatable_objects.append(value)
if self.verbose:
print("Added updatable object ", value)
def pre_adapt(self, exp):
"""Pre-adaptation.
Remember to call this before training on a new experience.
:param exp: current experience
"""
for uo in self._updatable_objects:
if hasattr(uo, "pre_adapt"):
uo.pre_adapt(self, exp)
if self.verbose:
print("pre_adapt ", uo)
for foo in self._pre_hooks:
if self.verbose:
print("pre_adapt hook ", foo)
foo(self, exp)
def post_adapt(self, exp):
"""Post-adaptation.
Remember to call this after training on a new experience.
:param exp: current experience
"""
for uo in self._updatable_objects:
if hasattr(uo, "post_adapt"):
uo.post_adapt(self, exp)
if self.verbose:
print("post_adapt ", uo)
for foo in self._post_hooks:
if self.verbose:
print("post_adapt hook ", foo)
foo(self, exp)
def add_pre_hooks(self, foo):
"""Add a pre-adaptation hooks
Hooks take two arguments: `<agent, experience>`.
:param foo: the hook function
"""
self._pre_hooks.append(foo)
def add_post_hooks(self, foo):
"""Add a post-adaptation hooks
Hooks take two arguments: `<agent, experience>`.
:param foo: the hook function
"""
self._post_hooks.append(foo)
class Adaptable(Protocol):
"""Adaptable objects Protocol.
These class documents the Adaptable objects API but it is not necessary
for an object to inherit from it since the `Agent` will search for the methods
dynamically.
Adaptable objects are objects that require to run their `pre_adapt` and
`post_adapt` methods before (and after, respectively) training on each
experience.
Adaptable objects can implement only the method that they need since the
`Agent` will look for the methods dynamically and call it only if it is
implemented.
"""
def pre_adapt(self, agent: Agent, exp: CLExperience):
pass
def post_adapt(self, agent: Agent, exp: CLExperience):
pass
[docs]class BasePlugin(Generic[Template], ABC):
"""ABC for BaseTemplate plugins.
A plugin is simply an object implementing some strategy callbacks.
Plugins are called automatically during the strategy execution.
Callbacks provide access before/after each phase of the execution.
In general, for each method of the training and evaluation loops,
`StrategyCallbacks`
provide two functions `before_{method}` and `after_{method}`, called
before and after the method, respectively.
Therefore plugins can "inject" additional code by implementing callbacks.
Each callback has a `strategy` argument that gives access to the state.
In Avalanche, callbacks are used to implement continual strategies, metrics
and loggers.
"""
supports_distributed: bool = False
"""
A flag describing whether this plugin supports distributed training.
"""
[docs] def __init__(self):
"""
Inizializes an instance of a supervised plugin.
"""
super().__init__()
def before_training(self, strategy: Template, *args, **kwargs) -> Any:
"""Called before `train` by the `BaseTemplate`."""
pass
def before_training_exp(self, strategy: Template, *args, **kwargs) -> Any:
"""Called before `train_exp` by the `BaseTemplate`."""
pass
def after_training_exp(self, strategy: Template, *args, **kwargs) -> Any:
"""Called after `train_exp` by the `BaseTemplate`."""
pass
def after_training(self, strategy: Template, *args, **kwargs) -> Any:
"""Called after `train` by the `BaseTemplate`."""
pass
def before_eval(self, strategy: Template, *args, **kwargs) -> Any:
"""Called before `eval` by the `BaseTemplate`."""
pass
def before_eval_exp(self, strategy: Template, *args, **kwargs) -> Any:
"""Called before `eval_exp` by the `BaseTemplate`."""
pass
def after_eval_exp(self, strategy: Template, *args, **kwargs) -> Any:
"""Called after `eval_exp` by the `BaseTemplate`."""
pass
def after_eval(self, strategy: Template, *args, **kwargs) -> Any:
"""Called after `eval` by the `BaseTemplate`."""
pass
def __init_subclass__(cls, supports_distributed: bool = False, **kwargs) -> None:
cls.supports_distributed = supports_distributed
return super().__init_subclass__(**kwargs)
[docs]class BaseSGDPlugin(BasePlugin[Template], ABC):
"""ABC for BaseSGDTemplate plugins.
See `BaseSGDTemplate` for complete description of the train/eval loop.
"""
[docs] def __init__(self):
"""
Inizializes an instance of a base SGD plugin.
"""
super().__init__()
def before_training_epoch(self, strategy: Template, *args, **kwargs) -> Any:
"""Called before `train_epoch` by the `BaseTemplate`."""
pass
def before_training_iteration(self, strategy: Template, *args, **kwargs) -> Any:
"""Called before the start of a training iteration by the
`BaseTemplate`."""
pass
def before_forward(self, strategy: Template, *args, **kwargs) -> Any:
"""Called before `model.forward()` by the `BaseTemplate`."""
pass
def after_forward(self, strategy: Template, *args, **kwargs) -> Any:
"""Called after `model.forward()` by the `BaseTemplate`."""
pass
def before_backward(self, strategy: Template, *args, **kwargs) -> Any:
"""Called before `criterion.backward()` by the `BaseTemplate`."""
pass
def after_backward(self, strategy: Template, *args, **kwargs) -> Any:
"""Called after `criterion.backward()` by the `BaseTemplate`."""
pass
def after_training_iteration(self, strategy: Template, *args, **kwargs) -> Any:
"""Called after the end of a training iteration by the
`BaseTemplate`."""
pass
def before_update(self, strategy: Template, *args, **kwargs) -> Any:
"""Called before `optimizer.update()` by the `BaseTemplate`."""
pass
def after_update(self, strategy: Template, *args, **kwargs) -> Any:
"""Called after `optimizer.update()` by the `BaseTemplate`."""
pass
def after_training_epoch(self, strategy: Template, *args, **kwargs) -> Any:
"""Called after `train_epoch` by the `BaseTemplate`."""
pass
def before_eval_iteration(self, strategy: Template, *args, **kwargs) -> Any:
"""Called before the start of a training iteration by the
`BaseTemplate`."""
pass
def before_eval_forward(self, strategy: Template, *args, **kwargs) -> Any:
"""Called before `model.forward()` by the `BaseTemplate`."""
pass
def after_eval_forward(self, strategy: Template, *args, **kwargs) -> Any:
"""Called after `model.forward()` by the `BaseTemplate`."""
pass
def after_eval_iteration(self, strategy: Template, *args, **kwargs) -> Any:
"""Called after the end of an iteration by the
`BaseTemplate`."""
pass
def before_train_dataset_adaptation(
self, strategy: Template, *args, **kwargs
) -> Any:
"""Called before `train_dataset_adapatation` by the `BaseTemplate`."""
pass
def after_train_dataset_adaptation(
self, strategy: Template, *args, **kwargs
) -> Any:
"""Called after `train_dataset_adapatation` by the `BaseTemplate`."""
pass
def before_eval_dataset_adaptation(
self, strategy: Template, *args, **kwargs
) -> Any:
"""Called before `eval_dataset_adaptation` by the `BaseTemplate`."""
pass
def after_eval_dataset_adaptation(self, strategy: Template, *args, **kwargs) -> Any:
"""Called after `eval_dataset_adaptation` by the `BaseTemplate`."""
pass
[docs]class SupervisedPlugin(BaseSGDPlugin[Template], ABC):
"""ABC for SupervisedTemplate plugins.
See `BaseTemplate` for complete description of the train/eval loop.
"""
[docs] def __init__(self):
"""
Inizializes an instance of a supervised plugin.
"""
super().__init__()
class SupervisedMetaLearningPlugin(SupervisedPlugin[Template], ABC):
"""ABC for SupervisedMetaLearningTemplate plugins.
See `BaseTemplate` for complete description of the train/eval loop.
"""
def before_inner_updates(self, strategy: Template, *args, **kwargs) -> Any:
"""Called before `_inner_updates` by the `BaseTemplate`."""
pass
def after_inner_updates(self, strategy: Template, *args, **kwargs) -> Any:
"""Called before `_outer_updates` by the `BaseTemplate`."""
pass
def before_outer_update(self, strategy: Template, *args, **kwargs) -> Any:
"""Called before `_outer_updates` by the `BaseTemplate`."""
pass
def after_outer_update(self, strategy: Template, *args, **kwargs) -> Any:
"""Called before `_outer_updates` by the `BaseTemplate`."""
pass