from import LearningWithoutForgetting
from import SupervisedPlugin

[docs]class LwFPlugin(SupervisedPlugin): """Learning without Forgetting plugin. LwF uses distillation to regularize the current loss with soft targets taken from a previous version of the model. When used with multi-headed models, all heads are distilled. """
[docs] def __init__(self, alpha=1, temperature=2): """ :param alpha: distillation hyperparameter. It can be either a float number or a list containing alpha for each experience. :param temperature: softmax temperature for distillation """ super().__init__() self.lwf = LearningWithoutForgetting(alpha, temperature)
def before_backward(self, strategy, **kwargs): """ Add distillation loss """ strategy.loss += self.lwf( strategy.mb_x, strategy.mb_output, strategy.model ) def after_training_exp(self, strategy, **kwargs): """ Save a copy of the model after each experience and update self.prev_classes to include the newly learned classes. """ self.lwf.update(strategy.experience, strategy.model)