Source code for avalanche.evaluation.metric_results

################################################################################
# Copyright (c) 2021 ContinualAI.                                              #
# Copyrights licensed under the MIT License.                                   #
# See the accompanying LICENSE file for terms.                                 #
#                                                                              #
# Date: 30-12-2020                                                             #
# Author(s): Lorenzo Pellegrini                                                #
# E-mail: contact@continualai.org                                              #
# Website: www.continualai.org                                                 #
################################################################################
from dataclasses import dataclass
from typing import List, Optional, TYPE_CHECKING, Tuple, Union

from PIL.Image import Image
from matplotlib.figure import Figure
from torch import Tensor
from enum import Enum

if TYPE_CHECKING:
    from .metric_definitions import Metric

MetricResult = Optional[List["MetricValue"]]


[docs]class LoggingType(Enum): """A type for MetricValues. It can be used by MetricValues to choose how they want to be visualize. For example, a 2D tensor could be a line plot or be used to create a histogram. """ ANY = 1 # generic type. The logger will use the value type to decide how # to serialize it. IMAGE = 2 FIGURE = 3 # Matplotlib figure. HISTOGRAM = 4
# you can add others here. All Tensorboard metrics are good candidates: # https://pytorch.org/docs/stable/tensorboard.html # just remember to add explicit support to the loggers once you add them. # If a metric is already printed correctly by the loggers (e.g. scalars) # there is no need to add it here. @dataclass class TensorImage: image: Tensor def __array__(self): return self.image.numpy() MetricType = Union[float, int, str, Tensor, Image, TensorImage, Figure] class AlternativeValues: """ A container for alternative representations of the same metric value. """ def __init__(self, *alternatives: MetricType): self.alternatives: Tuple[MetricType] = alternatives def best_supported_value( self, *supported_types: type ) -> Optional[MetricType]: """ Retrieves a supported representation for this metric value. :param supported_types: A list of supported value types. :return: The best supported representation. Returns None if no supported representation is found. """ for alternative in self.alternatives: if isinstance(alternative, supported_types): return alternative return None
[docs]class MetricValue(object): """ The result of a Metric. A result has a name, a value and a "x" position in which the metric value should be plotted. The "value" field can also be an instance of "AlternativeValues", in which case it means that alternative representations exist for this value. For instance, the Confusion Matrix can be represented both as a Tensor and as an Image. It's up to the Logger, according to its capabilities, decide which representation to use. """
[docs] def __init__( self, origin: "Metric", name: str, value: Union[MetricType, AlternativeValues], x_plot: int, logging_type: LoggingType = LoggingType.ANY, ): """ Creates an instance of MetricValue. :param origin: The originating Metric instance. :param name: The display name of this value. This value roughly corresponds to the name of the plot in which the value should be logged. :param value: The value of the metric. Can be a scalar value, a PIL Image, or a Tensor. If more than a possible representation of the same value exist, an instance of :class:`AlternativeValues` can be passed. For instance, the Confusion Matrix can be represented both as an Image and a Tensor, in which case an instance of :class:`AlternativeValues` carrying both the Tensor and the Image is more appropriate. The Logger instance will then select the most appropriate way to log the metric according to its capabilities. :param x_plot: The position of the value. This value roughly corresponds to the x-axis position of the value in a plot. When logging a singleton value, pass 0 as a value for this parameter. :param logging_type: determines how the metric should be logged. """ self.origin: "Metric" = origin self.name: str = name self.value: Union[MetricType, AlternativeValues] = value self.x_plot: int = x_plot self.logging_type = logging_type
__all__ = [ "MetricType", "MetricResult", "AlternativeValues", "MetricValue", "TensorImage", ]