# Copyright (c) 2021 ContinualAI. #
# Copyrights licensed under the MIT License. #
# See the accompanying LICENSE file for terms. #
# #
# Date: 29-03-2022 #
# Author(s): Rudy Semola #
# E-mail: contact@continualai.org #
# Website: www.continualai.org #
from typing import List, Union, Dict
import torch
from torch import Tensor
from torchmetrics.functional import accuracy
from avalanche.evaluation import Metric, PluginMetric, GenericPluginMetric
from avalanche.evaluation.metrics.mean import Mean
from avalanche.evaluation.metric_utils import phase_and_task
from collections import defaultdict
[docs]class TopkAccuracy(Metric[float]):
The Top-k Accuracy metric. This is a standalone metric.
It is defined using torchmetrics.functional accuracy with top_k
[docs] def __init__(self, top_k):
Creates an instance of the standalone Top-k Accuracy metric.
By default this metric in its initial state will return a value of 0.
The metric can be updated by using the `update` method while
the running top-k accuracy can be retrieved using the `result` method.
:param top_k: integer number to define the value of k.
self._topk_acc_dict = defaultdict(Mean)
self.top_k = top_k
def update(
predicted_y: Tensor,
true_y: Tensor,
task_labels: Union[float, Tensor],
) -> None:
Update the running top-k accuracy given the true and predicted labels.
Parameter `task_labels` is used to decide how to update the inner
dictionary: if Float, only the dictionary value related to that task
is updated. If Tensor, all the dictionary elements belonging to the
task labels will be updated.
:param predicted_y: The model prediction. Both labels and logit vectors
are supported.
:param true_y: The ground truth. Both labels and one-hot vectors
are supported.
:param task_labels: the int task label associated to the current
experience or the task labels vector showing the task label
for each pattern.
:return: None.
if len(true_y) != len(predicted_y):
raise ValueError("Size mismatch for true_y and predicted_y tensors")
if isinstance(task_labels, Tensor) and len(task_labels) != len(true_y):
raise ValueError("Size mismatch for true_y and task_labels tensors")
true_y = torch.as_tensor(true_y)
predicted_y = torch.as_tensor(predicted_y)
if isinstance(task_labels, int):
total_patterns = len(true_y)
accuracy(predicted_y, true_y, top_k=self.top_k), total_patterns
elif isinstance(task_labels, Tensor):
for pred, true, t in zip(predicted_y, true_y, task_labels):
accuracy(pred, true, top_k=self.top_k), 1
raise ValueError(
f"Task label type: {type(task_labels)}, "
f"expected int/float or Tensor"
def result(self, task_label=None) -> Dict[int, float]:
Retrieves the running top-k accuracy.
Calling this method will not change the internal state of the metric.
:param task_label: if None, return the entire dictionary of accuracies
for each task. Otherwise return the dictionary
`{task_label: topk_accuracy}`.
:return: A dict of running accuracies for each task label,
where each value is a float value between 0 and 1.
assert task_label is None or isinstance(task_label, int)
if task_label is None:
return {k: v.result() for k, v in self._topk_acc_dict.items()}
return {task_label: self._topk_acc_dict[task_label].result()}
def reset(self, task_label=None) -> None:
Resets the metric.
:param task_label: if None, reset the entire dictionary.
Otherwise, reset the value associated to `task_label`.
:return: None.
assert task_label is None or isinstance(task_label, int)
if task_label is None:
self._topk_acc_dict = defaultdict(Mean)
class TopkAccuracyPluginMetric(GenericPluginMetric[float]):
Base class for all top-k accuracies plugin metrics
def __init__(self, reset_at, emit_at, mode, top_k):
self._topk_acc = TopkAccuracy(top_k=top_k)
super(TopkAccuracyPluginMetric, self).__init__(
self._topk_acc, reset_at=reset_at, emit_at=emit_at, mode=mode
def reset(self, strategy=None) -> None:
if self._reset_at == "stream" or strategy is None:
def result(self, strategy=None) -> float:
if self._emit_at == "stream" or strategy is None:
return self._metric.result()
return self._metric.result(phase_and_task(strategy)[1])
def update(self, strategy):
# task labels defined for each experience
task_labels = strategy.experience.task_labels
if len(task_labels) > 1:
# task labels defined for each pattern
task_labels = strategy.mb_task_id
task_labels = task_labels[0]
self._topk_acc.update(strategy.mb_output, strategy.mb_y, task_labels)
[docs]class MinibatchTopkAccuracy(TopkAccuracyPluginMetric):
The minibatch plugin top-k accuracy metric.
This metric only works at training time.
This metric computes the average top-k accuracy over patterns
from a single minibatch.
It reports the result after each iteration.
[docs] def __init__(self, top_k):
Creates an instance of the MinibatchTopkAccuracy metric.
super(MinibatchTopkAccuracy, self).__init__(
reset_at="iteration", emit_at="iteration", mode="train", top_k=top_k
self.top_k = top_k
def __str__(self):
return "Topk_" + str(self.top_k) + "_Acc_MB"
[docs]class EpochTopkAccuracy(TopkAccuracyPluginMetric):
The average top-k accuracy over a single training epoch.
This plugin metric only works at training time.
The top-k accuracy will be logged after each training epoch by computing
the number of correctly predicted patterns during the epoch divided by
the overall number of patterns encountered in that epoch.
[docs] def __init__(self, top_k):
Creates an instance of the EpochTopkAccuracy metric.
super(EpochTopkAccuracy, self).__init__(
reset_at="epoch", emit_at="epoch", mode="train", top_k=top_k
self.top_k = top_k
def __str__(self):
return "Topk_" + str(self.top_k) + "_Acc_Epoch"
[docs]class RunningEpochTopkAccuracy(TopkAccuracyPluginMetric):
The average top-k accuracy across all minibatches up to the current
epoch iteration.
This plugin metric only works at training time.
At each iteration, this metric logs the top-k accuracy averaged over all
patterns seen so far in the current epoch.
The metric resets its state after each training epoch.
[docs] def __init__(self, top_k):
Creates an instance of the RunningEpochTopkAccuracy metric.
super(RunningEpochTopkAccuracy, self).__init__(
reset_at="epoch", emit_at="iteration", mode="train", top_k=top_k
self.top_k = top_k
def __str__(self):
return "Topk_" + str(self.top_k) + "_Acc_Epoch"
[docs]class ExperienceTopkAccuracy(TopkAccuracyPluginMetric):
At the end of each experience, this plugin metric reports
the average top-k accuracy over all patterns seen in that experience.
This metric only works at eval time.
[docs] def __init__(self, top_k):
Creates an instance of the ExperienceTopkAccuracy metric.
super(ExperienceTopkAccuracy, self).__init__(
self.top_k = top_k
def __str__(self):
return "Topk_" + str(self.top_k) + "_Acc_Exp"
[docs]class TrainedExperienceTopkAccuracy(TopkAccuracyPluginMetric):
At the end of each experience, this plugin metric reports the average
top-k accuracy for only the experiences
that the model has been trained on so far.
This metric only works at eval time.
[docs] def __init__(self, top_k):
Creates an instance of the TrainedExperienceTopkAccuracy metric.
super(TrainedExperienceTopkAccuracy, self).__init__(
reset_at="stream", emit_at="stream", mode="eval", top_k=top_k
self._current_experience = 0
self.top_k = top_k
def after_training_exp(self, strategy) -> None:
self._current_experience = strategy.experience.current_experience
# Reset average after learning from a new experience
TopkAccuracyPluginMetric.reset(self, strategy)
return TopkAccuracyPluginMetric.after_training_exp(self, strategy)
def update(self, strategy):
Only update the top-k accuracy with results from experiences
that have been trained on
if strategy.experience.current_experience <= self._current_experience:
TopkAccuracyPluginMetric.update(self, strategy)
def __str__(self):
return "Topk_" + str(self.top_k) + "_Acc_On_Trained_Experiences"
[docs]class StreamTopkAccuracy(TopkAccuracyPluginMetric):
At the end of the entire stream of experiences, this plugin metric
reports the average top-k accuracy over all patterns
seen in all experiences. This metric only works at eval time.
[docs] def __init__(self, top_k):
Creates an instance of StreamTopkAccuracy metric
super(StreamTopkAccuracy, self).__init__(
reset_at="stream", emit_at="stream", mode="eval", top_k=top_k
self.top_k = top_k
def __str__(self):
return "Topk_" + str(self.top_k) + "_Acc_Stream"
[docs]def topk_acc_metrics(
) -> List[PluginMetric]:
Helper method that can be used to obtain the desired set of
plugin metrics.
:param minibatch: If True, will return a metric able to log
the minibatch top-k accuracy at training time.
:param epoch: If True, will return a metric able to log
the epoch top-k accuracy at training time.
:param epoch_running: If True, will return a metric able to log
the running epoch top-k accuracy at training time.
:param experience: If True, will return a metric able to log
the top-k accuracy on each evaluation experience.
:param trained_experience: If True, will return a metric able to log
the average evaluation top-k accuracy only for experiences that the
model has been trained on
:param stream: If True, will return a metric able to log the top-k accuracy
averaged over the entire evaluation stream of experiences.
:return: A list of plugin metrics.
metrics = []
if minibatch:
if epoch:
if epoch_running:
if experience:
if trained_experience:
if stream:
return metrics
__all__ = [
if __name__ == "__main__":
metric = topk_acc_metrics(trained_experience=True, top_k=5)