from typing import Any, TYPE_CHECKING
from avalanche.core import StrategyCallbacks
if TYPE_CHECKING:
from avalanche.training import BaseStrategy
[docs]class StrategyPlugin(StrategyCallbacks[Any]):
"""
Base class for strategy plugins. Implements all the callbacks required
by the BaseStrategy with an empty function. Subclasses should override
the callbacks.
"""
[docs] def __init__(self):
super().__init__()
pass
def before_training(self, strategy: 'BaseStrategy', **kwargs):
pass
def before_training_exp(self, strategy: 'BaseStrategy', **kwargs):
pass
def before_train_dataset_adaptation(self, strategy: 'BaseStrategy',
**kwargs):
pass
def after_train_dataset_adaptation(self, strategy: 'BaseStrategy',
**kwargs):
pass
def before_training_epoch(self, strategy: 'BaseStrategy', **kwargs):
pass
def before_training_iteration(self, strategy: 'BaseStrategy', **kwargs):
pass
def before_forward(self, strategy: 'BaseStrategy', **kwargs):
pass
def after_forward(self, strategy: 'BaseStrategy', **kwargs):
pass
def before_backward(self, strategy: 'BaseStrategy', **kwargs):
pass
def after_backward(self, strategy: 'BaseStrategy', **kwargs):
pass
def after_training_iteration(self, strategy: 'BaseStrategy', **kwargs):
pass
def before_update(self, strategy: 'BaseStrategy', **kwargs):
pass
def after_update(self, strategy: 'BaseStrategy', **kwargs):
pass
def after_training_epoch(self, strategy: 'BaseStrategy', **kwargs):
pass
def after_training_exp(self, strategy: 'BaseStrategy', **kwargs):
pass
def after_training(self, strategy: 'BaseStrategy', **kwargs):
pass
def before_eval(self, strategy: 'BaseStrategy', **kwargs):
pass
def before_eval_dataset_adaptation(self, strategy: 'BaseStrategy',
**kwargs):
pass
def after_eval_dataset_adaptation(self, strategy: 'BaseStrategy', **kwargs):
pass
def before_eval_exp(self, strategy: 'BaseStrategy', **kwargs):
pass
def after_eval_exp(self, strategy: 'BaseStrategy', **kwargs):
pass
def after_eval(self, strategy: 'BaseStrategy', **kwargs):
pass
def before_eval_iteration(self, strategy: 'BaseStrategy', **kwargs):
pass
def before_eval_forward(self, strategy: 'BaseStrategy', **kwargs):
pass
def after_eval_forward(self, strategy: 'BaseStrategy', **kwargs):
pass
def after_eval_iteration(self, strategy: 'BaseStrategy', **kwargs):
pass