Source code for avalanche.evaluation.metric_definitions

################################################################################
# Copyright (c) 2021 ContinualAI.                                              #
# Copyrights licensed under the MIT License.                                   #
# See the accompanying LICENSE file for terms.                                 #
#                                                                              #
# Date: 30-12-2020                                                             #
# Author(s): Lorenzo Pellegrini, Antonio Carta, Andrea Cossu                   #
# E-mail: contact@continualai.org                                              #
# Website: www.continualai.org                                                 #
################################################################################

from abc import ABC, abstractmethod
from typing import (
    Generic,
    TypeVar,
    Optional,
    TYPE_CHECKING,
    List,
    Union,
    overload,
    Literal,
    Protocol,
)
from .metric_results import MetricValue, MetricType, AlternativeValues
from .metric_utils import (
    get_metric_name,
    generic_get_metric_name,
    default_metric_name_template,
)

if TYPE_CHECKING:
    from .metric_results import MetricResult
    from ..training.templates import SupervisedTemplate

TResult_co = TypeVar("TResult_co", covariant=True)
TMetric = TypeVar("TMetric", bound="Metric")


[docs]class Metric(Protocol[TResult_co]): """Standalone metric. A standalone metric exposes methods to reset its internal state and to emit a result. Emitting a result does not automatically cause a reset in the internal state. The specific metric implementation exposes ways to update the internal state. Usually, standalone metrics like :class:`Sum`, :class:`Mean`, :class:`Accuracy`, ... expose an `update` method. The `Metric` class can be used as a standalone metric by directly calling its methods. In order to automatically integrate the metric with the training and evaluation flows, you can use :class:`PluginMetric` class. The class receives events directly from the :class:`EvaluationPlugin` and can emits values on each callback. Usually, an instance of `Metric` is created within `PluginMetric`, which is then responsible for its update and results. See :class:`PluginMetric` for more details. """ def result(self) -> Optional[TResult_co]: """ Obtains the value of the metric. :return: The value of the metric. """ pass def reset(self) -> None: """ Resets the metric internal state. :return: None. """ pass
[docs]class PluginMetric(Metric[TResult_co], ABC): """A metric that can be used together with :class:`EvaluationPlugin`. This class leaves the implementation of the `result` and `reset` methods to child classes while providing an empty implementation of the callbacks invoked by the :class:`EvaluationPlugin`. Subclasses should implement the `result`, `reset` and the desired callbacks to compute the specific metric. Remember to call the `super()` method when overriding `after_train_iteration` or `after_eval_iteration`. An instance of this class usually leverages a `Metric` instance to update, reset and emit metric results at appropriate times (during specific callbacks). """
[docs] def __init__(self): """ Creates an instance of a plugin metric. Child classes can safely invoke this (super) constructor as the first experience. """ pass
@abstractmethod def result(self) -> Optional[TResult_co]: pass @abstractmethod def reset(self) -> None: pass def before_training(self, strategy: "SupervisedTemplate") -> "MetricResult": pass def before_training_exp(self, strategy: "SupervisedTemplate") -> "MetricResult": pass def before_train_dataset_adaptation( self, strategy: "SupervisedTemplate" ) -> "MetricResult": pass def after_train_dataset_adaptation( self, strategy: "SupervisedTemplate" ) -> "MetricResult": pass def before_training_epoch(self, strategy: "SupervisedTemplate") -> "MetricResult": pass def before_training_iteration( self, strategy: "SupervisedTemplate" ) -> "MetricResult": pass def before_forward(self, strategy: "SupervisedTemplate") -> "MetricResult": pass def after_forward(self, strategy: "SupervisedTemplate") -> "MetricResult": pass def before_backward(self, strategy: "SupervisedTemplate") -> "MetricResult": pass def after_backward(self, strategy: "SupervisedTemplate") -> "MetricResult": pass def after_training_iteration( self, strategy: "SupervisedTemplate" ) -> "MetricResult": pass def before_update(self, strategy: "SupervisedTemplate") -> "MetricResult": pass def after_update(self, strategy: "SupervisedTemplate") -> "MetricResult": pass def after_training_epoch(self, strategy: "SupervisedTemplate") -> "MetricResult": pass def after_training_exp(self, strategy: "SupervisedTemplate") -> "MetricResult": pass def after_training(self, strategy: "SupervisedTemplate") -> "MetricResult": pass def before_eval(self, strategy: "SupervisedTemplate") -> "MetricResult": pass def before_eval_dataset_adaptation( self, strategy: "SupervisedTemplate" ) -> "MetricResult": pass def after_eval_dataset_adaptation( self, strategy: "SupervisedTemplate" ) -> "MetricResult": pass def before_eval_exp(self, strategy: "SupervisedTemplate") -> "MetricResult": pass def after_eval_exp(self, strategy: "SupervisedTemplate") -> "MetricResult": pass def after_eval(self, strategy: "SupervisedTemplate") -> "MetricResult": pass def before_eval_iteration(self, strategy: "SupervisedTemplate") -> "MetricResult": pass def before_eval_forward(self, strategy: "SupervisedTemplate") -> "MetricResult": pass def after_eval_forward(self, strategy: "SupervisedTemplate") -> "MetricResult": pass def after_eval_iteration(self, strategy: "SupervisedTemplate") -> "MetricResult": pass
[docs]class GenericPluginMetric(PluginMetric[TResult_co], Generic[TResult_co, TMetric]): """ This class provides a generic implementation of a Plugin Metric. The user can subclass this class to easily implement custom plugin metrics. """ @overload def __init__( self, metric: TMetric, reset_at: Literal[ "iteration", "epoch", "experience", "stream", "never" ] = "experience", emit_at: Literal["iteration", "epoch", "experience", "stream"] = "experience", mode: Literal["train"] = "train", ): ... @overload def __init__( self, metric: TMetric, reset_at: Literal["iteration", "experience", "stream", "never"] = "experience", emit_at: Literal["iteration", "experience", "stream"] = "experience", mode: Literal["eval"] = "eval", ): ...
[docs] def __init__( self, metric: TMetric, reset_at="experience", emit_at="experience", mode="eval" ): super(GenericPluginMetric, self).__init__() assert mode in {"train", "eval"} if mode == "train": assert reset_at in { "iteration", "epoch", "experience", "stream", "never", } assert emit_at in {"iteration", "epoch", "experience", "stream"} else: assert reset_at in {"iteration", "experience", "stream", "never"} assert emit_at in {"iteration", "experience", "stream"} self._metric: TMetric = metric self._reset_at = reset_at self._emit_at = emit_at self._mode = mode
def reset(self) -> None: self._metric.reset() def result(self): return self._metric.result() def update(self, strategy: "SupervisedTemplate"): pass def _package_result(self, strategy: "SupervisedTemplate") -> "MetricResult": metric_value = self.result() add_exp = self._emit_at == "experience" plot_x_position = strategy.clock.train_iterations if isinstance(metric_value, dict): metrics = [] for k, v in metric_value.items(): metric_name = get_metric_name( self, strategy, add_experience=add_exp, add_task=k ) metrics.append(MetricValue(self, metric_name, v, plot_x_position)) return metrics else: metric_name = get_metric_name( self, strategy, add_experience=add_exp, add_task=True ) return [MetricValue(self, metric_name, metric_value, plot_x_position)] def before_training(self, strategy: "SupervisedTemplate"): super().before_training(strategy) if self._reset_at == "stream" and self._mode == "train": self.reset() def before_training_exp(self, strategy: "SupervisedTemplate"): super().before_training_exp(strategy) if self._reset_at == "experience" and self._mode == "train": self.reset() def before_training_epoch(self, strategy: "SupervisedTemplate"): super().before_training_epoch(strategy) if self._reset_at == "epoch" and self._mode == "train": self.reset() def before_training_iteration(self, strategy: "SupervisedTemplate"): super().before_training_iteration(strategy) if self._reset_at == "iteration" and self._mode == "train": self.reset() def after_training_iteration(self, strategy: "SupervisedTemplate"): super().after_training_iteration(strategy) if self._mode == "train": self.update(strategy) if self._emit_at == "iteration" and self._mode == "train": return self._package_result(strategy) def after_training_epoch(self, strategy: "SupervisedTemplate"): super().after_training_epoch(strategy) if self._emit_at == "epoch" and self._mode == "train": return self._package_result(strategy) def after_training_exp(self, strategy: "SupervisedTemplate"): super().after_training_exp(strategy) if self._emit_at == "experience" and self._mode == "train": return self._package_result(strategy) def after_training(self, strategy: "SupervisedTemplate"): super().after_training(strategy) if self._emit_at == "stream" and self._mode == "train": return self._package_result(strategy) def before_eval(self, strategy: "SupervisedTemplate"): super().before_eval(strategy) if self._reset_at == "stream" and self._mode == "eval": self.reset() def before_eval_exp(self, strategy: "SupervisedTemplate"): super().before_eval_exp(strategy) if self._reset_at == "experience" and self._mode == "eval": self.reset() def after_eval_exp(self, strategy: "SupervisedTemplate"): super().after_eval_exp(strategy) if self._emit_at == "experience" and self._mode == "eval": return self._package_result(strategy) def after_eval(self, strategy: "SupervisedTemplate"): super().after_eval(strategy) if self._emit_at == "stream" and self._mode == "eval": return self._package_result(strategy) def after_eval_iteration(self, strategy: "SupervisedTemplate"): super().after_eval_iteration(strategy) if self._mode == "eval": self.update(strategy) if self._emit_at == "iteration" and self._mode == "eval": return self._package_result(strategy) def before_eval_iteration(self, strategy: "SupervisedTemplate"): super().before_eval_iteration(strategy) if self._reset_at == "iteration" and self._mode == "eval": self.reset()
class _ExtendedPluginMetricValue: """ A data structure used to describe a metric value. Mainly used to compose the final "name" or "path" of a metric. For the moment, this class should be considered an internal utility. Use it at your own risk! """ def __init__( self, *, metric_name: str, metric_value: Union[MetricType, AlternativeValues], phase_name: str, stream_name: Optional[str], experience_id: Optional[int], task_label: Optional[int], plot_position: Optional[int] = None, **other_info ): super().__init__() self.metric_name = metric_name """ The name of metric, as a string (cannot be None). """ self.metric_value = metric_value """ The metric value name (cannot be None). """ self.phase_name = phase_name """ The phase name, as a str (cannot be None). """ self.stream_name = stream_name """ The stream name, as a str (can be None if stream-agnostic). """ self.experience_id = experience_id """ The experience id, as an int (can be None if experience-agnostic). """ self.task_label = task_label """ The task label, as an int (can be None if task-agnostic). """ self.plot_position = plot_position """ The x position of the value, as an int (cannot be None). """ self.other_info = other_info """ Additional info for this value, as a dictionary (may be empty). """ class _ExtendedGenericPluginMetric( GenericPluginMetric[List[_ExtendedPluginMetricValue], TMetric] ): """ A generified version of :class:`GenericPluginMetric` which supports emitting multiple metrics from a single metric instance. Child classes need to emit metrics via `result()` as a list of :class:`ExtendedPluginMetricValue`. This is in contrast with :class:`GenericPluginMetric`, that expects a simpler dictionary "task_label -> value". The resulting metric name will be given by the implementation of the :meth:`metric_value_name` method. For the moment, this class should be considered an internal utility. Use it at your own risk! """ def __init__(self, *args, **kwargs): """ Creates an instance of an extended :class:`GenericPluginMetric`. :param args: The positional arguments to be passed to the :class:`GenericPluginMetric` constructor. :param kwargs: The named arguments to be passed to the :class:`GenericPluginMetric` constructor. """ super().__init__(*args, **kwargs) def _package_result(self, strategy: "SupervisedTemplate") -> "MetricResult": emitted_values = self.result() default_plot_x_position = strategy.clock.train_iterations metrics = [] for m_value in emitted_values: if not isinstance(m_value, _ExtendedPluginMetricValue): raise RuntimeError( "Emitted a value that is not of type " "ExtendedPluginMetricValue" ) m_name = self.metric_value_name(m_value) x_pos = m_value.plot_position if x_pos is None: x_pos = default_plot_x_position metrics.append(MetricValue(self, m_name, m_value.metric_value, x_pos)) return metrics def metric_value_name(self, m_value: _ExtendedPluginMetricValue) -> str: return generic_get_metric_name(default_metric_name_template, vars(m_value)) __all__ = [ "Metric", "PluginMetric", "GenericPluginMetric", "_ExtendedPluginMetricValue", "_ExtendedGenericPluginMetric", ]