Source code for avalanche.training.plugins.feature_distillation

#!/usr/bin/env python3
import copy
from typing import Dict

import numpy as np
import torch
import torch.nn.functional as F
import tqdm
from torch import Tensor, nn

from avalanche.models.utils import avalanche_forward
from avalanche.training.plugins import SupervisedPlugin
from avalanche.training.storage_policy import ClassBalancedBuffer
from avalanche.training.utils import _at_task_boundary, cycle


[docs]class FeatureDistillationPlugin(SupervisedPlugin):
[docs] def __init__(self, alpha=1, mode="cosine"): """ Adds a Distillation loss term on the features of the model, trying to maximize the cosine similarity between current and old features :param alpha: distillation hyperparameter. It can be either a float number or a list containing alpha for each experience. """ super().__init__() self.alpha = alpha self.prev_model = None assert mode in ["mse", "cosine"] self.mode = mode
def before_backward(self, strategy, **kwargs): """ Add distillation loss """ if self.prev_model is None: return with torch.no_grad(): avalanche_forward(self.prev_model, strategy.mb_x, strategy.mb_task_id) old_features = self.prev_model.features new_features = strategy.model.features if self.mode == "cosine": strategy.loss += self.alpha * ( 1 - F.cosine_similarity(new_features, old_features, dim=1).mean() ) elif self.mode == "mse": strategy.loss += self.alpha * F.mse_loss(new_features, old_features, dim=1) def after_training_exp(self, strategy, **kwargs): """ Save a copy of the model after each experience and update self.prev_classes to include the newly learned classes. """ if _at_task_boundary(strategy.experience, before=False): strategy.model.features = None self.prev_model = copy.deepcopy(strategy.model)
__all__ = ["FeatureDistillationPlugin"]