Source code for avalanche.evaluation.metrics.amca

################################################################################
# Copyright (c) 2021 ContinualAI.                                              #
# Copyrights licensed under the MIT License.                                   #
# See the accompanying LICENSE file for terms.                                 #
#                                                                              #
# Date: 26-05-2022                                                             #
# Author(s): Eli Verwimp, Lorenzo Pellegrini                                   #
# E-mail: contact@continualai.org                                              #
# Website: www.continualai.org                                                 #
################################################################################
try:
    from statistics import fmean
except ImportError:
    from statistics import mean as fmean
from typing import Dict, List, Union, TYPE_CHECKING, Optional, Sequence, Set
from collections import defaultdict, OrderedDict

import torch
from torch import Tensor
from avalanche.evaluation import (
    Metric,
    PluginMetric,
    _ExtendedGenericPluginMetric,
    _ExtendedPluginMetricValue,
)
from avalanche.evaluation.metric_utils import generic_get_metric_name
from avalanche.evaluation.metrics.class_accuracy import (
    ClassAccuracy,
    TrackedClassesType,
)

if TYPE_CHECKING:
    from avalanche.training.templates import SupervisedTemplate


[docs]class AverageMeanClassAccuracy(Metric[Dict[int, float]]): """ The Average Mean Class Accuracy (AMCA) metric. This is a standalone metric used to compute more specific ones. Instances of this metric keeps the running average accuracy over multiple <prediction, target> pairs of Tensors, provided incrementally. Beware that this class does not provide mechanisms to separate scores based on the originating data stream. For this, please refer to :class:`MultiStreamAMCA`. The "prediction" and "target" tensors may contain plain labels or one-hot/logit vectors. Each time `result` is called, this metric emits the average mean accuracy as the average accuracy of all previous experiences (also considering the accuracy in the current experience). The metric expects that the :meth:`next_experience` method will be called after each experience. This is needed to consolidate the current mean accuracy. After calling :meth:`next_experience`, a new experience with accuracy 0.0 is immediately started. If you need to obtain the AMCA up to experience `t-1`, obtain the :meth:`result` before calling :meth:`next_experience`. The set of classes to be tracked can be reduced (please refer to the constructor parameters). The reset method will bring the metric to its initial state (tracked classes will be kept). By default, this metric in its initial state will return a `{task_id -> amca}` dictionary in which all AMCAs are set to 0 (that is, the `reset` method will hardly be useful when using this metric). """
[docs] def __init__(self, classes: Optional[TrackedClassesType] = None): """ Creates an instance of the standalone AMCA metric. By default, this metric in its initial state will return an empty dictionary. The metric can be updated by using the `update` method while the running AMCA can be retrieved using the `result` method. By using the `classes` parameter, one can restrict the list of classes to be tracked and in addition will initialize the accuracy for that class to 0.0. Setting the `classes` parameter is very important, as the mean class accuracy may vary based on this! If the test set is fixed and contains at least a sample for each class, then it is safe to leave `classes` to None. :param classes: The classes to keep track of. If None (default), all classes seen are tracked. Otherwise, it can be a dict of classes to be tracked (as "task-id" -> "list of class ids") or, if running a task-free benchmark (with only task 0), a simple list of class ids. By passing this parameter, the list of classes to be considered is created immediately. This will ensure that the mean class accuracy is correctly computed. In addition, this can be used to restrict the classes that should be considered when computing the mean class accuracy. """ self._class_accuracies = ClassAccuracy(classes=classes) """ A dictionary "task_id -> {class_id -> Mean}". """ # Here a Mean metric could be used as well. However, that could make it # difficult to compute the running AMCA... self._prev_exps_accuracies: Dict[int, List[float]] = defaultdict(list) """ The mean class accuracy of previous experiences as a dictionary `{task_id -> [accuracies]}`. """ self._updated_once = False
@torch.no_grad() def update( self, predicted_y: Tensor, true_y: Tensor, task_labels: Union[int, Tensor], ) -> None: """ Update the running accuracy given the true and predicted labels for each class. :param predicted_y: The model prediction. Both labels and logit vectors are supported. :param true_y: The ground truth. Both labels and one-hot vectors are supported. :param task_labels: the int task label associated to the current experience or the task labels vector showing the task label for each pattern. :return: None. """ self._updated_once = True self._class_accuracies.update(predicted_y, true_y, task_labels) def result(self) -> Dict[int, float]: """ Retrieves the running AMCA for each task. Calling this method will not change the internal state of the metric. :return: A dictionary `{task_id -> amca}`. The running AMCA of each task is a float value between 0 and 1. """ curr_task_acc = self._get_curr_task_acc() all_task_ids = set(self._prev_exps_accuracies.keys()) all_task_ids = all_task_ids.union(curr_task_acc.keys()) mean_accs = OrderedDict() for task_id in sorted(all_task_ids): prev_accs = self._prev_exps_accuracies.get(task_id, list()) curr_acc = curr_task_acc.get(task_id, 0) mean_accs[task_id] = fmean(prev_accs + [curr_acc]) return mean_accs def next_experience(self): """ Moves to the next experience. This will consolidate the class accuracies for the current experience. This method can also be safely called before even calling the `update` method for the first time. In that case, this call will be ignored. """ if not self._updated_once: return for task_id, mean_class_acc in self._get_curr_task_acc().items(): self._prev_exps_accuracies[task_id].append(mean_class_acc) self._class_accuracies.reset() def reset(self) -> None: """ Resets the metric. :return: None. """ self._updated_once = False self._class_accuracies.reset() self._prev_exps_accuracies.clear() def _get_curr_task_acc(self): task_acc = dict() class_acc = self._class_accuracies.result() for task_id, task_classes in class_acc.items(): class_accuracies = list(task_classes.values()) mean_class_acc = fmean(class_accuracies) task_acc[task_id] = mean_class_acc return task_acc
[docs]class MultiStreamAMCA(Metric[Dict[str, Dict[int, float]]]): """ An extension of the Average Mean Class Accuracy (AMCA) metric (class:`AverageMeanClassAccuracy`) able to separate the computation of the AMCA based on the current stream. """
[docs] def __init__(self, classes=None, streams=None): """ Creates an instance of a MultiStream AMCA. :param classes: The list of classes to track. This has the same semantic of the `classes` parameter of class :class:`AverageMeanClassAccuracy`. :param streams: The list of streams to track. Defaults to None, which means that all stream will be tracked. This is not recommended, as you usually will want to track the "test" stream only. """ self._limit_streams = streams if self._limit_streams is not None: self._limit_streams = set(self._limit_streams) self._limit_classes = classes self._amcas: Dict[str, AverageMeanClassAccuracy] = dict() self._current_stream: Optional[str] = None self._streams_in_this_phase: Set[str] = set()
@torch.no_grad() def update( self, predicted_y: Tensor, true_y: Tensor, task_labels: Union[int, Tensor], ) -> None: """ Update the running accuracy given the true and predicted labels for each class. This will update the accuracies for the "current stream" (the one set through `next_experience`). If `next_experience` has not been called, then an error will be raised. :param predicted_y: The model prediction. Both labels and logit vectors are supported. :param true_y: The ground truth. Both labels and one-hot vectors are supported. :param task_labels: the int task label associated to the current experience or the task labels vector showing the task label for each pattern. :return: None. """ if self._current_stream is None: raise RuntimeError( "No current stream set. " 'Call "set_stream" to set the current stream.' ) if self._is_stream_tracked(self._current_stream): self._amcas[self._current_stream].update( predicted_y, true_y, task_labels ) def result(self) -> Dict[str, Dict[int, float]]: """ Retrieves the running AMCA for each stream. Calling this method will not change the internal state of the metric. :return: A dictionary `{stream_name -> {task_id -> amca}}`. The running AMCA of each task is a float value between 0 and 1. """ all_streams_dict = OrderedDict() for stream_name in sorted(self._amcas.keys()): stream_metric = self._amcas[stream_name] stream_result = stream_metric.result() all_streams_dict[stream_name] = stream_result return all_streams_dict def set_stream(self, stream_name: str): """ Switches to a specific stream. :param stream_name: The name of the stream. """ self._current_stream = stream_name if not self._is_stream_tracked(stream_name): return if self._current_stream not in self._amcas: self._amcas[stream_name] = AverageMeanClassAccuracy( classes=self._limit_classes ) self._streams_in_this_phase.add(stream_name) def finish_phase(self): """ Moves to the next phase. This will consolidate the class accuracies recorded so far. """ for stream_name in self._streams_in_this_phase: self._amcas[stream_name].next_experience() self._streams_in_this_phase.clear() def reset(self) -> None: """ Resets the metric. :return: None. """ for metric in self._amcas.values(): metric.reset() self._current_stream = None self._streams_in_this_phase.clear() def _is_stream_tracked(self, stream_name): return self._limit_streams is None or stream_name in self._limit_streams
class AMCAPluginMetric(_ExtendedGenericPluginMetric): """ Plugin metric for the Average Mean Class Accuracy (AMCA). The AMCA is tracked for the classes and streams defined in the constructor. In addition, by default, the results obtained through the periodic evaluation (mid-training validation) mechanism are ignored. """ VALUE_NAME = "{metric_name}/{stream_name}_stream/Task{task_label:03}" def __init__(self, classes=None, streams=None, ignore_validation=True): """ Instantiates the AMCA plugin metric. :param classes: The classes to track. Refer to :class:`MultiStreamAMCA` for more details. :param streams: The streams to track. Defaults to None, which means that all streams will be considered. Beware that, when creating instances of this class using the :func:`amca_metrics` helper, the resulting metric will only track the "test" stream by default. :param ignore_validation: Defaults to True, which means that periodic evaluations will be ignored (recommended). """ self._ms_amca = MultiStreamAMCA(classes=classes, streams=streams) self._ignore_validation = ignore_validation self._is_training = False super().__init__( self._ms_amca, reset_at="never", emit_at="stream", mode="eval" ) def update(self, strategy: "SupervisedTemplate"): if self._is_training and self._ignore_validation: # Running a validation (eval phase inside a train phase), ignore it return self._ms_amca.update( strategy.mb_output, strategy.mb_y, strategy.mb_task_id ) def before_training(self, strategy: "SupervisedTemplate"): self._is_training = True return super().before_training(strategy) def after_training(self, strategy: "SupervisedTemplate"): self._is_training = False return super().after_training(strategy) def before_eval(self, strategy: "SupervisedTemplate"): # In the first eval phase, calling finish_phase will do nothing # (as expected) if not (self._is_training and self._ignore_validation): # If not running a validation self._ms_amca.finish_phase() return super().before_eval(strategy) def before_eval_exp(self, strategy: "SupervisedTemplate"): if not (self._is_training and self._ignore_validation): # If not running a validation self._ms_amca.set_stream(strategy.experience.origin_stream.name) return super().before_eval_exp(strategy) def result(self, strategy) -> List[_ExtendedPluginMetricValue]: if self._is_training and self._ignore_validation: # Running a validation, ignore it return [] metric_values = [] stream_amca = self._ms_amca.result() for stream_name, stream_accs in stream_amca.items(): for task_id, task_amca in stream_accs.items(): metric_values.append( _ExtendedPluginMetricValue( metric_name=str(self), metric_value=task_amca, phase_name="eval", stream_name=stream_name, task_label=task_id, experience_id=None, ) ) return metric_values def metric_value_name(self, m_value: _ExtendedPluginMetricValue) -> str: return generic_get_metric_name( AMCAPluginMetric.VALUE_NAME, vars(m_value) ) def __str__(self): return "Top1_AMCA_Stream"
[docs]def amca_metrics(streams: Sequence[str] = ("test",)) -> PluginMetric: """ Helper method that can be used to obtain the desired set of plugin metrics. The returned metric will not compute the AMCA when the :class:`PeriodicEval` plugin is used. To change this behavior, you can instantiate a :class:`AMCAPluginMetric` by setting `ignore_validation` to False. :param streams: The list of streams to track. Defaults to "test" only. :return: The AMCA plugin metric. """ return AMCAPluginMetric(streams=streams, ignore_validation=True)
__all__ = [ "AverageMeanClassAccuracy", "MultiStreamAMCA", "AMCAPluginMetric", "amca_metrics", ]