import operator
from typing import Optional
import warnings
from copy import deepcopy
from avalanche.training.plugins import SupervisedPlugin
[docs]class EarlyStoppingPlugin(SupervisedPlugin):
"""Early stopping and model checkpoint plugin.
The plugin checks a metric and stops the training loop when the accuracy
on the metric stopped progressing for `patience` epochs.
After training, the best model's checkpoint is loaded.
.. warning::
The plugin checks the metric value, which is updated by the strategy
during the evaluation. This means that you must ensure that the
evaluation is called frequently enough during the training loop.
For example, if you set `patience=1`, you must also set `eval_every=1`
in the `BaseTemplate`, otherwise the metric won't be updated after
every epoch/iteration. Similarly, `peval_mode` must have the same
value.
"""
[docs] def __init__(
self,
patience: int,
val_stream_name: str,
metric_name: str = "Top1_Acc_Stream",
mode: str = "max",
peval_mode: str = "epoch",
margin: float = 0.0,
verbose=False,
):
"""Init.
:param patience: Number of epochs to wait before stopping the training.
:param val_stream_name: Name of the validation stream to search in the
metrics. The corresponding stream will be used to keep track of the
evolution of the performance of a model.
:param metric_name: The name of the metric to watch as it will be
reported in the evaluator.
:param mode: Must be "max" or "min". max (resp. min) means that the
given metric should me maximized (resp. minimized).
:param peval_mode: one of {'epoch', 'iteration'}. Decides whether the
early stopping should happen after `patience`
epochs or iterations (Default='epoch').
:param margin: a minimal margin of improvements required to be
considered best than a previous one. It should be an float, the
default value is 0. That means that any improvement is considered
better.
:param verbose: If True, prints a message for each update
(default: False).
"""
super().__init__()
self.val_stream_name = val_stream_name
self.patience = patience
self.verbose = verbose
assert peval_mode in {"epoch", "iteration"}
self.peval_mode = peval_mode
assert type(margin) == float
self.margin = margin
self.metric_name = metric_name
self.metric_key = f"{self.metric_name}/eval_phase/" f"{self.val_stream_name}"
if mode not in ("max", "min"):
raise ValueError(f'Mode must be "max" or "min", got {mode}.')
self.operator = operator.gt if mode == "max" else operator.lt
self.best_state = None # Contains the best parameters
self.best_val = None
self.best_step: Optional[int] = None
def before_training_exp(self, strategy, **kwargs):
self.best_state = None
self.best_val = None
self.best_step = None
def before_training_iteration(self, strategy, **kwargs):
if self.peval_mode == "iteration":
ub = self._update_best(strategy)
if ub is None or self.best_step is None:
return
curr_step = self._get_strategy_counter(strategy)
if curr_step - self.best_step >= self.patience:
strategy.model.load_state_dict(self.best_state)
strategy.stop_training()
def before_training_epoch(self, strategy, **kwargs):
if self.peval_mode == "epoch":
ub = self._update_best(strategy)
if ub is None or self.best_step is None:
return
curr_step = self._get_strategy_counter(strategy)
if curr_step - self.best_step >= self.patience:
strategy.model.load_state_dict(self.best_state)
strategy.stop_training()
def _update_best(self, strategy):
res = strategy.evaluator.get_last_metrics()
names = [k for k in res.keys() if k.startswith(self.metric_key)]
if len(names) == 0:
return None
full_name = names[-1]
val_acc = res.get(full_name)
if self.best_val is None:
warnings.warn(
f"Metric {self.metric_name} used by the EarlyStopping plugin "
f"is not computed yet. EarlyStopping will not be triggered."
)
if self.best_val is None:
self.best_state = deepcopy(strategy.model.state_dict())
self.best_val = val_acc
self.best_step = self._get_strategy_counter(strategy)
return None
delta_val = float(val_acc - self.best_val)
if self.operator(delta_val, 0) and abs(delta_val) >= self.margin:
self.best_state = deepcopy(strategy.model.state_dict())
self.best_val = val_acc
self.best_step = self._get_strategy_counter(strategy)
if self.verbose:
print("EarlyStopping: new best value:", val_acc)
return self.best_val
def _get_strategy_counter(self, strategy):
if self.peval_mode == "epoch":
return strategy.clock.train_exp_epochs
elif self.peval_mode == "iteration":
return strategy.clock.train_exp_iterations
else:
raise ValueError("Invalid `peval_mode`:", self.peval_mode)