Source code for avalanche.training.plugins.update_fecam

#!/usr/bin/env python3
import copy
from typing import Dict

import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor, nn

from avalanche.benchmarks.utils import concat_datasets
from avalanche.models import FeCAMClassifier
from avalanche.models.fecam import compute_covariance, compute_means
from avalanche.training.plugins import SupervisedPlugin
from avalanche.training.storage_policy import ClassBalancedBuffer
from avalanche.training.templates import SupervisedTemplate


def _gather_means_and_cov(model, dataset, batch_size, device, **kwargs):
    num_workers = kwargs["num_workers"] if "num_workers" in kwargs else 0
    loader = torch.utils.data.DataLoader(
        dataset.eval(),
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
    )

    features = []
    labels = []

    was_training = model.training
    model.eval()

    for x, y, t in loader:
        x = x.to(device)
        y = y.to(device)

        with torch.no_grad():
            out = model.feature_extractor(x)

        features.append(out)
        labels.append(y)

    if was_training:
        model.train()

    features = torch.cat(features)
    labels = torch.cat(labels)

    # Transform
    features = model.eval_classifier.apply_transforms(features)
    class_means = compute_means(features, labels)
    class_cov = compute_covariance(features, labels)
    class_cov = model.eval_classifier.apply_cov_transforms(class_cov)

    return class_means, class_cov


def _check_has_fecam(model):
    assert hasattr(model, "eval_classifier")
    assert isinstance(model.eval_classifier, FeCAMClassifier)


[docs]class CurrentDataFeCAMUpdate(SupervisedPlugin): """ Updates FeCAM cov and prototypes using the current task data (at the end of each task) """
[docs] def __init__(self): super().__init__()
def after_training_exp(self, strategy, **kwargs): _check_has_fecam(strategy.model) class_means, class_cov = _gather_means_and_cov( strategy.model, strategy.experience.dataset, strategy.train_mb_size, strategy.device, **kwargs ) strategy.model.eval_classifier.update_class_means_dict(class_means) strategy.model.eval_classifier.update_class_cov_dict(class_cov)
[docs]class MemoryFeCAMUpdate(SupervisedPlugin): """ Updates FeCAM cov and prototypes using the data contained inside a memory buffer """
[docs] def __init__(self, mem_size=2000, storage_policy=None): super().__init__() if storage_policy is None: self.storage_policy = ClassBalancedBuffer(max_size=mem_size) else: self.storage_policy = storage_policy
def after_training_exp(self, strategy, **kwargs): _check_has_fecam(strategy.model) self.storage_policy.update(strategy) class_means, class_cov = _gather_means_and_cov( strategy.model, self.storage_policy.buffer.eval(), strategy.train_mb_size, strategy.device, **kwargs ) strategy.model.eval_classifier.update_class_means_dict(class_means) strategy.model.eval_classifier.update_class_cov_dict(class_cov)
[docs]class FeCAMOracle(SupervisedPlugin): """ Updates FeCAM cov and prototypes using all the data seen so far WARNING: This is an oracle, and thus breaks assumptions usually made in continual learning algorithms i (storage of full dataset) This is meant to be used as an upper bound for FeCAM based methods (i.e when trying to estimate prototype and covariance drift) """
[docs] def __init__(self): super().__init__() self.all_datasets = []
def after_training_exp(self, strategy, **kwargs): _check_has_fecam(strategy.model) self.all_datasets.append(strategy.experience.dataset) full_dataset = concat_datasets(self.all_datasets) class_means, class_cov = _gather_means_and_cov( strategy.model, full_dataset, strategy.train_mb_size, strategy.device, **kwargs ) strategy.model.eval_classifier.update_class_means_dict(class_means) strategy.model.eval_classifier.update_class_cov_dict(class_cov)
__all__ = ["CurrentDataFeCAMUpdate", "MemoryFeCAMUpdate", "FeCAMOracle"]