Source code for avalanche.evaluation.metrics.labels_repartition

from collections import defaultdict
from typing import (
    Callable,
    Dict,
    Sequence,
    TYPE_CHECKING,
    Union,
    Optional,
    List,
    Counter,
)

from matplotlib.figure import Figure

from avalanche.evaluation import GenericPluginMetric, Metric, PluginMetric
from avalanche.evaluation.metric_results import MetricValue, AlternativeValues
from avalanche.evaluation.metric_utils import (
    stream_type,
    default_history_repartition_image_creator,
)

try:
    from typing import Literal
except ImportError:
    from typing_extensions import Literal


if TYPE_CHECKING:
    from avalanche.training.templates.supervised import SupervisedTemplate
    from avalanche.evaluation.metric_results import MetricResult


[docs]class LabelsRepartition(Metric): """ Metric used to monitor the labels repartition. """
[docs] def __init__(self): self.task2label2count: Dict[int, Dict[int, int]] = {} self.class_order = None self.reset()
def reset(self, **kargs) -> None: self.task2label2count = defaultdict(Counter) def update( self, tasks: Sequence[int], labels: Sequence[Union[str, int]], class_order: Optional[List[int]], ): self.class_order = class_order for task, label in zip(tasks, labels): self.task2label2count[task][label] += 1 def update_order(self, class_order: Optional[List[int]]): self.class_order = class_order def result(self) -> Dict[int, Dict[int, int]]: if self.class_order is None: return self.task2label2count return { task: { label: label2count[label] for label in self.class_order if label in label2count } for task, label2count in self.task2label2count.items() }
LabelsRepartitionImageCreator = Callable[ [Dict[int, List[int]], List[int]], Figure ] class LabelsRepartitionPlugin(GenericPluginMetric[Figure]): """ A plugin to monitor the labels repartition. :param image_creator: The function to use to create an image from the history of the labels repartition. It will receive a dictionary of the form {label_id: [count_at_step_0, count_at_step_1, ...], ...} and the list of the corresponding steps [step_0, step_1, ...]. If set to None, only the raw data is emitted. :param mode: Indicates if this plugin should run on train or eval. :param emit_reset_at: The refreshment rate of the plugin. :return: The list of corresponding plugins. """ def __init__( self, *, image_creator: Optional[ LabelsRepartitionImageCreator ] = default_history_repartition_image_creator, mode: Literal["train", "eval"] = "train", emit_reset_at: Literal["stream", "experience", "epoch"] = "epoch", ): self.labels_repartition = LabelsRepartition() super().__init__( metric=self.labels_repartition, emit_at=emit_reset_at, reset_at=emit_reset_at, mode=mode, ) self.emit_reset_at = emit_reset_at self.mode = mode self.image_creator = image_creator self.steps = [0] self.task2label2counts: Dict[int, Dict[int, List[int]]] = defaultdict( dict ) def reset(self, strategy) -> None: self.steps.append(strategy.clock.train_iterations) return super().reset(strategy) def update(self, strategy: "SupervisedTemplate"): if strategy.clock.train_exp_epochs and self.emit_reset_at != "epoch": return self.labels_repartition.update( strategy.mb_task_id.tolist(), strategy.mb_y.tolist(), class_order=getattr( strategy.experience.benchmark, "classes_order", None ), ) def _package_result(self, strategy: "SupervisedTemplate") -> "MetricResult": self.steps.append(strategy.clock.train_iterations) task2label2count = self.labels_repartition.result() for task, label2count in task2label2count.items(): for label, count in label2count.items(): self.task2label2counts[task].setdefault( label, [0] * (len(self.steps) - 2) ).extend((count, count)) for task, label2counts in self.task2label2counts.items(): for label, counts in label2counts.items(): counts.extend([0] * (len(self.steps) - len(counts))) return [ MetricValue( self, name=f"Repartition" f"/{self._mode}_phase" f"/{stream_type(strategy.experience)}_stream" f"/Task_{task:03}", value=AlternativeValues( self.image_creator(label2counts, self.steps), label2counts, ) if self.image_creator is not None else label2counts, x_plot=strategy.clock.train_iterations, ) for task, label2counts in self.task2label2counts.items() ] def __str__(self): return "Repartition"
[docs]def labels_repartition_metrics( *, on_train: bool = True, emit_train_at: Literal["stream", "experience", "epoch"] = "epoch", on_eval: bool = False, emit_eval_at: Literal["stream", "experience"] = "stream", image_creator: Optional[ LabelsRepartitionImageCreator ] = default_history_repartition_image_creator, ) -> List[PluginMetric]: """ Create plugins to monitor the labels repartition. :param on_train: If True, emit the metrics during training. :param emit_train_at: (only if on_train is True) when to emit the training metrics. :param on_eval: If True, emit the metrics during evaluation. :param emit_eval_at: (only if on_eval is True) when to emit the evaluation metrics. :param image_creator: The function to use to create an image from the history of the labels repartition. It will receive a dictionary of the form {label_id: [count_at_step_0, count_at_step_1, ...], ...} and the list of the corresponding steps [step_0, step_1, ...]. If set to None, only the raw data is emitted. :return: The list of corresponding plugins. """ plugins = [] if on_eval: plugins.append( LabelsRepartitionPlugin( image_creator=image_creator, mode="eval", emit_reset_at=emit_eval_at, ) ) if on_train: plugins.append( LabelsRepartitionPlugin( image_creator=image_creator, mode="train", emit_reset_at=emit_train_at, ) ) return plugins
__all__ = [ "LabelsRepartitionPlugin", "LabelsRepartition", "labels_repartition_metrics", ]