Source code for avalanche.logging.interactive_logging

################################################################################
# Copyright (c) 2021 ContinualAI.                                              #
# Copyrights licensed under the MIT License.                                   #
# See the accompanying LICENSE file for terms.                                 #
#                                                                              #
# Date: 2020-01-25                                                             #
# Author(s): Antonio Carta, Lorenzo Pellegrini                                 #
# E-mail: contact@continualai.org                                              #
# Website: avalanche.continualai.org                                           #
################################################################################
import sys
from typing import List, TYPE_CHECKING

from avalanche.core import SupervisedPlugin
from avalanche.evaluation.metric_results import MetricValue
from avalanche.logging import TextLogger
from avalanche.benchmarks.scenarios import OnlineCLExperience

from tqdm import tqdm

if TYPE_CHECKING:
    from avalanche.training.templates import SupervisedTemplate


[docs]class InteractiveLogger(TextLogger, SupervisedPlugin): """ The `InteractiveLogger` class provides logging facilities for the console standard output. The logger shows a progress bar during training and evaluation flows and interactively display metric results as soon as they become available. The logger writes metric results after each training epoch, evaluation experience and at the end of the entire evaluation stream. .. note:: To avoid an excessive amount of printed lines, this logger will **not** print results after each iteration. If the user is monitoring metrics which emit results after each minibatch (e.g., `MinibatchAccuracy`), only the last recorded value of such metrics will be reported at the end of the epoch. .. note:: Since this logger works on the standard output, metrics producing images or more complex visualizations will be converted to a textual format suitable for console printing. You may want to add more loggers to your `EvaluationPlugin` to better support different formats. """
[docs] def __init__(self): super().__init__(file=sys.stdout) self._pbar = None
def before_training_epoch( self, strategy: "SupervisedTemplate", metric_values: List["MetricValue"], **kwargs ): if isinstance(strategy.experience, OnlineCLExperience): return super().before_training_epoch(strategy, metric_values, **kwargs) self._progress.total = len(strategy.dataloader) def after_training_epoch( self, strategy: "SupervisedTemplate", metric_values: List["MetricValue"], **kwargs ): if isinstance(strategy.experience, OnlineCLExperience): return self._end_progress() super().after_training_epoch(strategy, metric_values, **kwargs) def before_training_exp( self, strategy: "SupervisedTemplate", metric_values: List["MetricValue"], **kwargs ): if isinstance(strategy.experience, OnlineCLExperience): experience = strategy.experience.logging() if experience.is_first_subexp: super().before_training_exp(strategy, metric_values, **kwargs) self._progress.total = experience.sub_stream_length def after_training_exp( self, strategy: "SupervisedTemplate", metric_values: List["MetricValue"], **kwargs ): if isinstance(strategy.experience, OnlineCLExperience): experience = strategy.experience.logging() if experience.is_last_subexp: self._end_progress() super().after_training_exp(strategy, metric_values, **kwargs) else: self._progress.update() self._progress.refresh() def before_eval_exp( self, strategy: "SupervisedTemplate", metric_values: List["MetricValue"], **kwargs ): super().before_eval_exp(strategy, metric_values, **kwargs) self._progress.total = len(strategy.dataloader) def after_eval_exp( self, strategy: "SupervisedTemplate", metric_values: List["MetricValue"], **kwargs ): self._end_progress() super().after_eval_exp(strategy, metric_values, **kwargs) def after_training_iteration( self, strategy: "SupervisedTemplate", metric_values: List["MetricValue"], **kwargs ): if isinstance(strategy.experience, OnlineCLExperience): return self._progress.update() self._progress.refresh() super().after_training_iteration(strategy, metric_values, **kwargs) def after_eval_iteration( self, strategy: "SupervisedTemplate", metric_values: List["MetricValue"], **kwargs ): self._progress.update() self._progress.refresh() super().after_eval_iteration(strategy, metric_values, **kwargs) @property def _progress(self): if self._pbar is None: self._pbar = tqdm(leave=True, position=0, file=sys.stdout) return self._pbar def _end_progress(self): if self._pbar is not None: self._pbar.close() self._pbar = None def __getstate__(self): out = super().__getstate__() del out['_pbar'] return out def __setstate__(self, state): state['_pbar'] = None super().__setstate__(state)