Source code for avalanche.evaluation.metrics.acc_matrix

import torch
from typing import List

from avalanche.evaluation import PluginMetric
from avalanche.evaluation.metric_results import MetricValue, MetricResult
from avalanche.evaluation.metrics.accuracy import Accuracy
from avalanche.training.templates import SupervisedTemplate


[docs]class AccuracyMatrixPluginMetric(PluginMetric[float]): """ Class for obtaining an Accuracy Matrix for the evaluation stream """
[docs] def __init__(self): """ Creates the Accuracy Matrix plugin """ self._accuracy = Accuracy() super(AccuracyMatrixPluginMetric, self).__init__() self.count = 0 self.matrix = torch.zeros((1, 1)) self.online = False self.num_training_steps = None
def reset(self, strategy=None) -> None: """Resets the metric. :param strategy: The strategy object associated with the stream. """ self.matrix = torch.zeros((1, 1)) self.count = 0 def result(self, strategy=None) -> float: """Returns the metric result. :param strategy: The strategy object associated with the stream. :return: The metric result as a torch tensor. """ return self.matrix def add_new_task(self, new_length): """Adds a new dimension to the accuracy matrix. :param new_length: The new dimension of the matrix. We assume a square matrix """ temp = self.matrix.clone() self.matrix = torch.zeros((new_length, new_length)) self.matrix[: temp.size(0), : temp.size(1)] = temp def update(self, num_training_steps, eval_exp_id): """Updates the matrix with the accuracy value for a given task pair. :param num_training_steps: The ID of the current training experience. :param eval_exp_id: The ID of the evaluation experience. """ if (max(num_training_steps, eval_exp_id) + 1) > self.matrix.size(0): self.add_new_task(max(num_training_steps, eval_exp_id) + 1) acc = self._accuracy.result() self.matrix[num_training_steps, eval_exp_id] = acc self._accuracy.reset() def _package_result(self, strategy: "SupervisedTemplate") -> "MetricResult": """Packages the metric result. :param strategy: The strategy object associated with the stream. :return: The metric result. As a MetricValue object. """ metric_value = self.result(strategy) plot_x_position = strategy.clock.train_iterations metric_name = "EvalStream/Acc_Matrix" return [MetricValue(self, metric_name, metric_value, plot_x_position)] def after_eval_iteration(self, strategy: "SupervisedTemplate"): """Performs actions after each evaluation iteration. :param strategy: The strategy object associated with the stream. """ super().after_eval_iteration(strategy) self._accuracy.update(strategy.mb_output, strategy.mb_y) def after_eval_exp(self, strategy: "SupervisedTemplate"): """Performs actions after evaluating an experience. :param strategy: The strategy object associated with the stream. """ super().after_eval_exp(strategy) curr_exp = strategy.experience.current_experience self.update(self.num_training_steps, curr_exp) def after_eval(self, strategy: "SupervisedTemplate"): """Performs actions after the evaluation phase. :param strategy: The strategy object associated with the stream. :return: The metric result. """ super().after_eval_exp(strategy) return self._package_result(strategy) def before_training(self, strategy: "SupervisedTemplate"): """Performs actions before the training phase. :param strategy: The strategy object associated with the metric. """ super().before_training(strategy) if self.num_training_steps is not None: self.num_training_steps += 1 else: self.num_training_steps = 0
def accuracy_matrix_metrics() -> List[PluginMetric]: """ Helper method that can be used to obtain the desired set of plugin metrics. :return: A list of plugin metrics. """ metrics = [] metrics.append(AccuracyMatrixPluginMetric()) return metrics __all__ = [ "accuracy_matrix_metrics", "AccuracyMatrixPluginMetric", ]