from abc import ABC
from typing import TypeVar, Generic
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from avalanche.training.templates.base import BaseTemplate
CallbackResult = TypeVar("CallbackResult")
Template = TypeVar("Template", bound="BaseTemplate")
[docs]class BasePlugin(Generic[Template], ABC):
"""ABC for BaseTemplate plugins.
A plugin is simply an object implementing some strategy callbacks.
Plugins are called automatically during the strategy execution.
Callbacks provide access before/after each phase of the execution.
In general, for each method of the training and evaluation loops,
`StrategyCallbacks`
provide two functions `before_{method}` and `after_{method}`, called
before and after the method, respectively.
Therefore plugins can "inject" additional code by implementing callbacks.
Each callback has a `strategy` argument that gives access to the state.
In Avalanche, callbacks are used to implement continual strategies, metrics
and loggers.
"""
[docs] def __init__(self):
pass
def before_training(self, strategy: Template, *args, **kwargs):
"""Called before `train` by the `BaseTemplate`."""
pass
def before_training_exp(self, strategy: Template, *args, **kwargs):
"""Called before `train_exp` by the `BaseTemplate`."""
pass
def after_training_exp(self, strategy: Template, *args, **kwargs):
"""Called after `train_exp` by the `BaseTemplate`."""
pass
def after_training(self, strategy: Template, *args, **kwargs):
"""Called after `train` by the `BaseTemplate`."""
pass
def before_eval(
self, strategy: Template, *args, **kwargs
) -> CallbackResult:
"""Called before `eval` by the `BaseTemplate`."""
pass
def before_eval_exp(
self, strategy: Template, *args, **kwargs
) -> CallbackResult:
"""Called before `eval_exp` by the `BaseTemplate`."""
pass
def after_eval_exp(
self, strategy: Template, *args, **kwargs
) -> CallbackResult:
"""Called after `eval_exp` by the `BaseTemplate`."""
pass
def after_eval(self, strategy: Template, *args, **kwargs) -> CallbackResult:
"""Called after `eval` by the `BaseTemplate`."""
pass
[docs]class BaseSGDPlugin(BasePlugin[Template], ABC):
"""ABC for BaseSGDTemplate plugins.
See `BaseSGDTemplate` for complete description of the train/eval loop.
"""
def before_training_epoch(
self, strategy: Template, *args, **kwargs
) -> CallbackResult:
"""Called before `train_epoch` by the `BaseTemplate`."""
pass
def before_training_iteration(
self, strategy: Template, *args, **kwargs
) -> CallbackResult:
"""Called before the start of a training iteration by the
`BaseTemplate`."""
pass
def before_forward(
self, strategy: Template, *args, **kwargs
) -> CallbackResult:
"""Called before `model.forward()` by the `BaseTemplate`."""
pass
def after_forward(
self, strategy: Template, *args, **kwargs
) -> CallbackResult:
"""Called after `model.forward()` by the `BaseTemplate`."""
pass
def before_backward(
self, strategy: Template, *args, **kwargs
) -> CallbackResult:
"""Called before `criterion.backward()` by the `BaseTemplate`."""
pass
def after_backward(
self, strategy: Template, *args, **kwargs
) -> CallbackResult:
"""Called after `criterion.backward()` by the `BaseTemplate`."""
pass
def after_training_iteration(
self, strategy: Template, *args, **kwargs
) -> CallbackResult:
"""Called after the end of a training iteration by the
`BaseTemplate`."""
pass
def before_update(
self, strategy: Template, *args, **kwargs
) -> CallbackResult:
"""Called before `optimizer.update()` by the `BaseTemplate`."""
pass
def after_update(
self, strategy: Template, *args, **kwargs
) -> CallbackResult:
"""Called after `optimizer.update()` by the `BaseTemplate`."""
pass
def after_training_epoch(
self, strategy: Template, *args, **kwargs
) -> CallbackResult:
"""Called after `train_epoch` by the `BaseTemplate`."""
pass
def before_eval_iteration(
self, strategy: Template, *args, **kwargs
) -> CallbackResult:
"""Called before the start of a training iteration by the
`BaseTemplate`."""
pass
def before_eval_forward(
self, strategy: Template, *args, **kwargs
) -> CallbackResult:
"""Called before `model.forward()` by the `BaseTemplate`."""
pass
def after_eval_forward(
self, strategy: Template, *args, **kwargs
) -> CallbackResult:
"""Called after `model.forward()` by the `BaseTemplate`."""
pass
def after_eval_iteration(
self, strategy: Template, *args, **kwargs
) -> CallbackResult:
"""Called after the end of an iteration by the
`BaseTemplate`."""
pass
[docs]class SupervisedPlugin(BaseSGDPlugin[Template], ABC):
"""ABC for SupervisedTemplate plugins.
See `BaseTemplate` for complete description of the train/eval loop.
"""
def before_train_dataset_adaptation(
self, strategy: Template, *args, **kwargs
) -> CallbackResult:
"""Called before `train_dataset_adapatation` by the `BaseTemplate`."""
pass
def after_train_dataset_adaptation(
self, strategy: Template, *args, **kwargs
) -> CallbackResult:
"""Called after `train_dataset_adapatation` by the `BaseTemplate`."""
pass
def before_eval_dataset_adaptation(
self, strategy: Template, *args, **kwargs
) -> CallbackResult:
"""Called before `eval_dataset_adaptation` by the `BaseTemplate`."""
pass
def after_eval_dataset_adaptation(
self, strategy: Template, *args, **kwargs
) -> CallbackResult:
"""Called after `eval_dataset_adaptation` by the `BaseTemplate`."""
pass
class SupervisedMetaLearningPlugin(SupervisedPlugin[Template], ABC):
"""ABC for SupervisedMetaLearningTemplate plugins.
See `BaseTemplate` for complete description of the train/eval loop.
"""
def before_inner_updates(
self, strategy: Template, *args, **kwargs
) -> CallbackResult:
"""Called before `_inner_updates` by the `BaseTemplate`."""
pass
def after_inner_updates(
self, strategy: Template, *args, **kwargs
) -> CallbackResult:
"""Called before `_outer_updates` by the `BaseTemplate`."""
pass
def before_outer_update(
self, strategy: Template, *args, **kwargs
) -> CallbackResult:
"""Called before `_outer_updates` by the `BaseTemplate`."""
pass
def after_outer_update(
self, strategy: Template, *args, **kwargs
) -> CallbackResult:
"""Called before `_outer_updates` by the `BaseTemplate`."""
pass