################################################################################
# Copyright (c) 2021 ContinualAI. #
# Copyrights licensed under the MIT License. #
# See the accompanying LICENSE file for terms. #
# #
# Date: 30-12-2020 #
# Author(s): Lorenzo Pellegrini, Antonio Carta, Andrea Cossu #
# E-mail: contact@continualai.org #
# Website: www.continualai.org #
################################################################################
from abc import ABC, abstractmethod
from typing import TypeVar, Optional, TYPE_CHECKING
from typing_extensions import Protocol
from .metric_results import MetricValue
from .metric_utils import get_metric_name, phase_and_task
from ..core import StrategyCallbacks
if TYPE_CHECKING:
from .metric_results import MetricResult
from ..training.strategies import BaseStrategy
TResult = TypeVar('TResult')
TAggregated = TypeVar('TAggregated', bound='PluginMetric')
[docs]class Metric(Protocol[TResult]):
"""
Definition of a standalone metric.
A standalone metric exposes methods to reset its internal state and
to emit a result. Emitting a result does not automatically cause
a reset in the internal state.
The specific metric implementation exposes ways to update the internal
state. Usually, standalone metrics like :class:`Sum`, :class:`Mean`,
:class:`Accuracy`, ... expose an `update` method.
The `Metric` class can be used as a standalone metric by directly calling
its methods.
In order to automatically integrate the metric with the training and
evaluation flows, you can use :class:`PluginMetric` class. The class
receives events directly from the :class:`EvaluationPlugin` and can
emits values on each callback. Usually, an instance of `Metric` is
created within `PluginMetric`, which is then responsible for its
update and results. See :class:`PluginMetric` for more details.
"""
def result(self, **kwargs) -> Optional[TResult]:
"""
Obtains the value of the metric.
:return: The value of the metric.
"""
pass
def reset(self, **kwargs) -> None:
"""
Resets the metric internal state.
:return: None.
"""
pass
[docs]class PluginMetric(Metric[TResult], StrategyCallbacks['MetricResult'], ABC):
"""
A metric that can be used together with :class:`EvaluationPlugin`.
This class leaves the implementation of the `result` and `reset` methods
to child classes while providing an empty implementation of the callbacks
invoked by the :class:`EvaluationPlugin`. Subclasses should implement
the `result`, `reset` and the desired callbacks to compute the specific
metric.
Remember to call the `super()` method when overriding
`after_train_iteration` or `after_eval_iteration`.
An instance of this class usually leverages a `Metric` instance to update,
reset and emit metric results at appropriate times
(during specific callbacks).
"""
[docs] def __init__(self):
"""
Creates an instance of a plugin metric.
Child classes can safely invoke this (super) constructor as the first
experience.
"""
pass
@abstractmethod
def result(self, **kwargs) -> Optional[TResult]:
pass
@abstractmethod
def reset(self, **kwargs) -> None:
pass
def before_training(self, strategy: 'BaseStrategy') -> 'MetricResult':
pass
def before_training_exp(self, strategy: 'BaseStrategy') \
-> 'MetricResult':
pass
def before_train_dataset_adaptation(self, strategy: 'BaseStrategy') \
-> 'MetricResult':
pass
def after_train_dataset_adaptation(self, strategy: 'BaseStrategy') \
-> 'MetricResult':
pass
def before_training_epoch(self, strategy: 'BaseStrategy') \
-> 'MetricResult':
pass
def before_training_iteration(self, strategy: 'BaseStrategy') \
-> 'MetricResult':
pass
def before_forward(self, strategy: 'BaseStrategy') -> 'MetricResult':
pass
def after_forward(self, strategy: 'BaseStrategy') -> 'MetricResult':
pass
def before_backward(self, strategy: 'BaseStrategy') -> 'MetricResult':
pass
def after_backward(self, strategy: 'BaseStrategy') -> 'MetricResult':
pass
def after_training_iteration(self, strategy: 'BaseStrategy') \
-> 'MetricResult':
pass
def before_update(self, strategy: 'BaseStrategy') -> 'MetricResult':
pass
def after_update(self, strategy: 'BaseStrategy') -> 'MetricResult':
pass
def after_training_epoch(self, strategy: 'BaseStrategy') \
-> 'MetricResult':
pass
def after_training_exp(self, strategy: 'BaseStrategy') \
-> 'MetricResult':
pass
def after_training(self, strategy: 'BaseStrategy') -> 'MetricResult':
pass
def before_eval(self, strategy: 'BaseStrategy') -> 'MetricResult':
pass
def before_eval_dataset_adaptation(self, strategy: 'BaseStrategy') \
-> 'MetricResult':
pass
def after_eval_dataset_adaptation(self, strategy: 'BaseStrategy') \
-> 'MetricResult':
pass
def before_eval_exp(self, strategy: 'BaseStrategy') -> 'MetricResult':
pass
def after_eval_exp(self, strategy: 'BaseStrategy') -> 'MetricResult':
pass
def after_eval(self, strategy: 'BaseStrategy') -> 'MetricResult':
pass
def before_eval_iteration(self, strategy: 'BaseStrategy') \
-> 'MetricResult':
pass
def before_eval_forward(self, strategy: 'BaseStrategy') \
-> 'MetricResult':
pass
def after_eval_forward(self, strategy: 'BaseStrategy') \
-> 'MetricResult':
pass
def after_eval_iteration(self, strategy: 'BaseStrategy') \
-> 'MetricResult':
pass
[docs]class GenericPluginMetric(PluginMetric[TResult]):
"""
This class provides a generic implementation of a Plugin Metric.
The user can subclass this class to easily implement custom plugin
metrics.
"""
[docs] def __init__(self, metric, reset_at='experience', emit_at='experience',
mode='eval'):
super(GenericPluginMetric, self).__init__()
assert mode in {'train', 'eval'}
if mode == 'train':
assert reset_at in {'iteration', 'epoch', 'experience', 'stream'}
assert emit_at in {'iteration', 'epoch', 'experience', 'stream'}
else:
assert reset_at in {'iteration', 'experience', 'stream'}
assert emit_at in {'iteration', 'experience', 'stream'}
self._metric = metric
self._reset_at = reset_at
self._emit_at = emit_at
self._mode = mode
def reset(self, strategy) -> None:
self._metric.reset()
def result(self, strategy):
return self._metric.result()
def update(self, strategy):
pass
def _package_result(self, strategy: 'BaseStrategy') -> 'MetricResult':
metric_value = self.result(strategy)
add_exp = self._emit_at == 'experience'
plot_x_position = strategy.clock.train_iterations
if isinstance(metric_value, dict):
metrics = []
for k, v in metric_value.items():
metric_name = get_metric_name(
self, strategy, add_experience=add_exp, add_task=k)
metrics.append(MetricValue(self, metric_name, v,
plot_x_position))
return metrics
else:
metric_name = get_metric_name(self, strategy,
add_experience=add_exp,
add_task=True)
return [MetricValue(self, metric_name, metric_value,
plot_x_position)]
def before_training(self, strategy: 'BaseStrategy'):
super().before_training(strategy)
if self._reset_at == 'stream' and self._mode == 'train':
self.reset()
def before_training_exp(self, strategy: 'BaseStrategy'):
super().before_training_exp(strategy)
if self._reset_at == 'experience' and self._mode == 'train':
self.reset(strategy)
def before_training_epoch(self, strategy: 'BaseStrategy'):
super().before_training_epoch(strategy)
if self._reset_at == 'epoch' and self._mode == 'train':
self.reset(strategy)
def before_training_iteration(self, strategy: 'BaseStrategy'):
super().before_training_iteration(strategy)
if self._reset_at == 'iteration' and self._mode == 'train':
self.reset(strategy)
def after_training_iteration(self, strategy: 'BaseStrategy') -> None:
super().after_training_iteration(strategy)
if self._mode == 'train':
self.update(strategy)
if self._emit_at == 'iteration' and self._mode == 'train':
return self._package_result(strategy)
def after_training_epoch(self, strategy: 'BaseStrategy'):
super().after_training_epoch(strategy)
if self._emit_at == 'epoch' and self._mode == 'train':
return self._package_result(strategy)
def after_training_exp(self, strategy: 'BaseStrategy'):
super().after_training_exp(strategy)
if self._emit_at == 'experience' and self._mode == 'train':
return self._package_result(strategy)
def after_training(self, strategy: 'BaseStrategy'):
super().after_training(strategy)
if self._emit_at == 'stream' and self._mode == 'train':
return self._package_result(strategy)
def before_eval(self, strategy: 'BaseStrategy'):
super().before_eval(strategy)
if self._reset_at == 'stream' and self._mode == 'eval':
self.reset(strategy)
def before_eval_exp(self, strategy: 'BaseStrategy'):
super().before_eval_exp(strategy)
if self._reset_at == 'experience' and self._mode == 'eval':
self.reset(strategy)
def after_eval_exp(self, strategy: 'BaseStrategy'):
super().after_eval_exp(strategy)
if self._emit_at == 'experience' and self._mode == 'eval':
return self._package_result(strategy)
def after_eval(self, strategy: 'BaseStrategy'):
super().after_eval(strategy)
if self._emit_at == 'stream' and self._mode == 'eval':
return self._package_result(strategy)
def after_eval_iteration(self, strategy: 'BaseStrategy'):
super().after_eval_iteration(strategy)
if self._mode == 'eval':
self.update(strategy)
if self._emit_at == 'iteration' and self._mode == 'eval':
return self._package_result(strategy)
def before_eval_iteration(self, strategy: 'BaseStrategy'):
super().before_eval_iteration(strategy)
if self._reset_at == 'iteration' and self._mode == 'eval':
self.reset(strategy)
__all__ = ['Metric', 'PluginMetric', 'GenericPluginMetric']