Source code for avalanche.evaluation.metrics.class_accuracy

################################################################################
# Copyright (c) 2021 ContinualAI.                                              #
# Copyrights licensed under the MIT License.                                   #
# See the accompanying LICENSE file for terms.                                 #
#                                                                              #
# Date: 26-05-2022                                                             #
# Author(s): Eli Verwimp, Lorenzo Pellegrini                                   #
# E-mail: contact@continualai.org                                              #
# Website: www.continualai.org                                                 #
################################################################################

from typing import Dict, List, Set, Union, TYPE_CHECKING, Iterable, Optional
from collections import defaultdict, OrderedDict

import torch
from torch import Tensor
from avalanche.evaluation import (
    Metric,
    PluginMetric,
    _ExtendedGenericPluginMetric,
    _ExtendedPluginMetricValue,
)
from avalanche.evaluation.metric_utils import (
    default_metric_name_template,
    generic_get_metric_name,
)
from avalanche.evaluation.metrics import Mean

if TYPE_CHECKING:
    from avalanche.training.templates import SupervisedTemplate


TrackedClassesType = Union[Dict[int, Iterable[int]], Iterable[int]]


[docs]class ClassAccuracy(Metric[Dict[int, Dict[int, float]]]): """ The Class Accuracy metric. This is a standalone metric used to compute more specific ones. Instances of this metric keeps the running average accuracy over multiple <prediction, target> pairs of Tensors, provided incrementally. The "prediction" and "target" tensors may contain plain labels or one-hot/logit vectors. Each time `result` is called, this metric emits the average accuracy for all classes seen and across all predictions made since the last `reset`. The set of classes to be tracked can be reduced (please refer to the constructor parameters). The reset method will bring the metric to its initial state. By default, this metric in its initial state will return a `{task_id -> {class_id -> accuracy}}` dictionary in which all accuracies are set to 0. """
[docs] def __init__(self, classes: Optional[TrackedClassesType] = None): """ Creates an instance of the standalone Accuracy metric. By default, this metric in its initial state will return an empty dictionary. The metric can be updated by using the `update` method while the running accuracies can be retrieved using the `result` method. By using the `classes` parameter, one can restrict the list of classes to be tracked and in addition can immediately create plots for yet-to-be-seen classes. :param classes: The classes to keep track of. If None (default), all classes seen are tracked. Otherwise, it can be a dict of classes to be tracked (as "task-id" -> "list of class ids") or, if running a task-free benchmark (with only task 0), a simple list of class ids. By passing this parameter, the plot of each class is created immediately (with a default value of 0.0) and plots will be aligned across all classes. In addition, this can be used to restrict the classes for which the accuracy should be logged. """ self.classes: Dict[int, Set[int]] = defaultdict(set) """ The list of tracked classes. """ self.dynamic_classes = False """ If True, newly encountered classes will be tracked. """ self._class_accuracies: Dict[int, Dict[int, Mean]] = defaultdict( lambda: defaultdict(Mean) ) """ A dictionary "task_id -> {class_id -> Mean}". """ if classes is not None: if isinstance(classes, dict): # Task-id -> classes dict self.classes = { task_id: self._ensure_int_classes(class_list) for task_id, class_list in classes.items() } else: # Assume is a plain iterable self.classes = {0: self._ensure_int_classes(classes)} else: self.dynamic_classes = True self.__init_accs_for_known_classes()
@torch.no_grad() def update( self, predicted_y: Tensor, true_y: Tensor, task_labels: Union[int, Tensor], ) -> None: """ Update the running accuracy given the true and predicted labels for each class. :param predicted_y: The model prediction. Both labels and logit vectors are supported. :param true_y: The ground truth. Both labels and one-hot vectors are supported. :param task_labels: the int task label associated to the current experience or the task labels vector showing the task label for each pattern. :return: None. """ if len(true_y) != len(predicted_y): raise ValueError("Size mismatch for true_y and predicted_y tensors") if isinstance(task_labels, Tensor) and len(task_labels) != len(true_y): raise ValueError("Size mismatch for true_y and task_labels tensors") if not isinstance(task_labels, (int, Tensor)): raise ValueError( f"Task label type: {type(task_labels)}, " f"expected int or Tensor" ) if isinstance(task_labels, int): task_labels = [task_labels] * len(true_y) true_y = torch.as_tensor(true_y) predicted_y = torch.as_tensor(predicted_y) # Check if logits or labels if len(predicted_y.shape) > 1: # Logits -> transform to labels predicted_y = torch.max(predicted_y, 1)[1] if len(true_y.shape) > 1: # Logits -> transform to labels true_y = torch.max(true_y, 1)[1] for pred, true, t in zip(predicted_y, true_y, task_labels): t = int(t) if self.dynamic_classes: self.classes[t].add(int(true)) else: if t not in self.classes: continue if int(true) not in self.classes[t]: continue true_positives = (pred == true).float().item() # 1.0 or 0.0 self._class_accuracies[t][int(true)].update(true_positives, 1) def result(self) -> Dict[int, Dict[int, float]]: """ Retrieves the running accuracy for each class. Calling this method will not change the internal state of the metric. :return: A dictionary `{task_id -> {class_id -> running_accuracy}}`. The running accuracy of each class is a float value between 0 and 1. """ running_class_accuracies: Dict[int, Dict[int, float]] = OrderedDict() for task_label in sorted(self._class_accuracies.keys()): task_dict = self._class_accuracies[task_label] running_class_accuracies[task_label] = OrderedDict() for class_id in sorted(task_dict.keys()): running_class_accuracies[task_label][class_id] = task_dict[ class_id ].result() return running_class_accuracies def reset(self) -> None: """ Resets the metric. :return: None. """ self._class_accuracies = defaultdict(lambda: defaultdict(Mean)) self.__init_accs_for_known_classes() def __init_accs_for_known_classes(self): for task_id, task_classes in self.classes.items(): for c in task_classes: self._class_accuracies[task_id][c].reset() @staticmethod def _ensure_int_classes(classes_iterable: Iterable[int]): return set(int(c) for c in classes_iterable)
class ClassAccuracyPluginMetric(_ExtendedGenericPluginMetric[ClassAccuracy]): """ Base class for all class accuracy plugin metrics """ def __init__(self, reset_at, emit_at, mode, classes=None): super(ClassAccuracyPluginMetric, self).__init__( ClassAccuracy(classes=classes), reset_at=reset_at, emit_at=emit_at, mode=mode, ) self.phase_name = "train" self.stream_name = "train" self.experience_id = 0 def update(self, strategy: "SupervisedTemplate"): assert strategy.mb_output is not None assert strategy.experience is not None self._metric.update(strategy.mb_output, strategy.mb_y, strategy.mb_task_id) self.phase_name = "train" if strategy.is_training else "eval" self.stream_name = strategy.experience.origin_stream.name self.experience_id = strategy.experience.current_experience def result(self) -> List[_ExtendedPluginMetricValue]: metric_values = [] task_accuracies = self._metric.result() for task_id, task_classes in task_accuracies.items(): for class_id, class_accuracy in task_classes.items(): metric_values.append( _ExtendedPluginMetricValue( metric_name=str(self), metric_value=class_accuracy, phase_name=self.phase_name, stream_name=self.stream_name, task_label=task_id, experience_id=self.experience_id, class_id=class_id, ) ) return metric_values def metric_value_name(self, m_value: _ExtendedPluginMetricValue) -> str: m_value_values = vars(m_value) add_exp = self._emit_at == "experience" if not add_exp: del m_value_values["experience_id"] m_value_values["class_id"] = m_value.other_info["class_id"] return generic_get_metric_name( default_metric_name_template(m_value_values) + "/{class_id}", m_value_values, )
[docs]class MinibatchClassAccuracy(ClassAccuracyPluginMetric): """ The minibatch plugin class accuracy metric. This metric only works at training time. This metric computes the average accuracy over patterns from a single minibatch. It reports the result after each iteration. If a more coarse-grained logging is needed, consider using :class:`EpochClassAccuracy` instead. """
[docs] def __init__(self, classes=None): """ Creates an instance of the MinibatchClassAccuracy metric. """ super().__init__( reset_at="iteration", emit_at="iteration", mode="train", classes=classes, )
def __str__(self): return "Top1_ClassAcc_MB"
[docs]class EpochClassAccuracy(ClassAccuracyPluginMetric): """ The average class accuracy over a single training epoch. This plugin metric only works at training time. The accuracy will be logged after each training epoch by computing the number of correctly predicted patterns during the epoch divided by the overall number of patterns encountered in that epoch (separately for each class). """
[docs] def __init__(self, classes=None): """ Creates an instance of the EpochClassAccuracy metric. """ super().__init__( reset_at="epoch", emit_at="epoch", mode="train", classes=classes )
def __str__(self): return "Top1_ClassAcc_Epoch"
[docs]class RunningEpochClassAccuracy(ClassAccuracyPluginMetric): """ The average class accuracy across all minibatches up to the current epoch iteration. This plugin metric only works at training time. At each iteration, this metric logs the accuracy averaged over all patterns seen so far in the current epoch (separately for each class). The metric resets its state after each training epoch. """
[docs] def __init__(self, classes=None): """ Creates an instance of the RunningEpochClassAccuracy metric. """ super().__init__( reset_at="epoch", emit_at="iteration", mode="train", classes=classes )
def __str__(self): return "Top1_RunningClassAcc_Epoch"
[docs]class ExperienceClassAccuracy(ClassAccuracyPluginMetric): """ At the end of each experience, this plugin metric reports the average accuracy over all patterns seen in that experience (separately for each class). This metric only works at eval time. """
[docs] def __init__(self, classes=None): """ Creates an instance of ExperienceClassAccuracy metric """ super().__init__( reset_at="experience", emit_at="experience", mode="eval", classes=classes, )
def __str__(self): return "Top1_ClassAcc_Exp"
[docs]class StreamClassAccuracy(ClassAccuracyPluginMetric): """ At the end of the entire stream of experiences, this plugin metric reports the average accuracy over all patterns seen in all experiences (separately for each class). This metric only works at eval time. """
[docs] def __init__(self, classes=None): """ Creates an instance of StreamClassAccuracy metric """ super().__init__( reset_at="stream", emit_at="stream", mode="eval", classes=classes )
def __str__(self): return "Top1_ClassAcc_Stream"
[docs]def class_accuracy_metrics( *, minibatch=False, epoch=False, epoch_running=False, experience=False, stream=False, classes=None, ) -> List[ClassAccuracyPluginMetric]: """ Helper method that can be used to obtain the desired set of plugin metrics. :param minibatch: If True, will return a metric able to log the per-class minibatch accuracy at training time. :param epoch: If True, will return a metric able to log the per-class epoch accuracy at training time. :param epoch_running: If True, will return a metric able to log the per-class running epoch accuracy at training time. :param experience: If True, will return a metric able to log the per-class accuracy on each evaluation experience. :param stream: If True, will return a metric able to log the per-class accuracy averaged over the entire evaluation stream of experiences. :param classes: The list of classes to track. See the corresponding parameter of :class:`ClassAccuracy` for a precise explanation. :return: A list of plugin metrics. """ metrics: List[ClassAccuracyPluginMetric] = [] if minibatch: metrics.append(MinibatchClassAccuracy(classes=classes)) if epoch: metrics.append(EpochClassAccuracy(classes=classes)) if epoch_running: metrics.append(RunningEpochClassAccuracy(classes=classes)) if experience: metrics.append(ExperienceClassAccuracy(classes=classes)) if stream: metrics.append(StreamClassAccuracy(classes=classes)) return metrics
__all__ = [ "TrackedClassesType", "ClassAccuracy", "MinibatchClassAccuracy", "EpochClassAccuracy", "RunningEpochClassAccuracy", "ExperienceClassAccuracy", "StreamClassAccuracy", "class_accuracy_metrics", ]