Source code for avalanche.training.plugins.early_stopping

import operator
import warnings
from copy import deepcopy

from avalanche.training.plugins import SupervisedPlugin


[docs]class EarlyStoppingPlugin(SupervisedPlugin): """Early stopping and model checkpoint plugin. The plugin checks a metric and stops the training loop when the accuracy on the metric stopped progressing for `patience` epochs. After training, the best model's checkpoint is loaded. .. warning:: The plugin checks the metric value, which is updated by the strategy during the evaluation. This means that you must ensure that the evaluation is called frequently enough during the training loop. For example, if you set `patience=1`, you must also set `eval_every=1` in the `BaseTemplate`, otherwise the metric won't be updated after every epoch/iteration. Similarly, `peval_mode` must have the same value. """
[docs] def __init__( self, patience: int, val_stream_name: str, metric_name: str = "Top1_Acc_Stream", mode: str = "max", peval_mode: str = "epoch", margin: float = 0.0, verbose=False, ): """Init. :param patience: Number of epochs to wait before stopping the training. :param val_stream_name: Name of the validation stream to search in the metrics. The corresponding stream will be used to keep track of the evolution of the performance of a model. :param metric_name: The name of the metric to watch as it will be reported in the evaluator. :param mode: Must be "max" or "min". max (resp. min) means that the given metric should me maximized (resp. minimized). :param peval_mode: one of {'epoch', 'iteration'}. Decides whether the early stopping should happen after `patience` epochs or iterations (Default='epoch'). :param margin: a minimal margin of improvements required to be considered best than a previous one. It should be an float, the default value is 0. That means that any improvement is considered better. :param verbose: If True, prints a message for each update (default: False). """ super().__init__() self.val_stream_name = val_stream_name self.patience = patience self.verbose = verbose assert peval_mode in {"epoch", "iteration"} self.peval_mode = peval_mode assert type(margin) == float self.margin = margin self.metric_name = metric_name self.metric_key = ( f"{self.metric_name}/eval_phase/" f"{self.val_stream_name}" ) if mode not in ("max", "min"): raise ValueError(f'Mode must be "max" or "min", got {mode}.') self.operator = operator.gt if mode == "max" else operator.lt self.best_state = None # Contains the best parameters self.best_val = None self.best_step = None
def before_training(self, strategy, **kwargs): self.best_state = None self.best_val = None self.best_step = None self.best_step = None def before_training_iteration(self, strategy, **kwargs): if self.peval_mode == "iteration": ub = self._update_best(strategy) if ub is None or self.best_step is None: return curr_step = self._get_strategy_counter(strategy) if curr_step - self.best_step >= self.patience: strategy.model.load_state_dict(self.best_state) strategy.stop_training() def before_training_epoch(self, strategy, **kwargs): if self.peval_mode == "epoch": ub = self._update_best(strategy) if ub is None or self.best_step is None: return curr_step = self._get_strategy_counter(strategy) if curr_step - self.best_step >= self.patience: strategy.model.load_state_dict(self.best_state) strategy.stop_training() def _update_best(self, strategy): res = strategy.evaluator.get_last_metrics() names = [k for k in res.keys() if k.startswith(self.metric_key)] if len(names) == 0: return None full_name = names[-1] val_acc = res.get(full_name) if self.best_val is None: warnings.warn( f"Metric {self.metric_name} used by the EarlyStopping plugin " f"is not computed yet. EarlyStopping will not be triggered." ) if self.best_val is None or self.operator(val_acc, self.best_val): self.best_state = deepcopy(strategy.model.state_dict()) if self.best_val is None: self.best_val = val_acc self.best_step = 0 return None if self.operator(float(val_acc - self.best_val), self.margin): self.best_step = self._get_strategy_counter(strategy) self.best_val = val_acc if self.verbose: print("EarlyStopping: new best value:", val_acc) return self.best_val def _get_strategy_counter(self, strategy): if self.peval_mode == "epoch": return strategy.clock.train_exp_epochs elif self.peval_mode == "iteration": return strategy.clock.train_exp_iterations else: raise ValueError("Invalid `peval_mode`:", self.peval_mode)