Source code for avalanche.training.plugins.update_ncm

#!/usr/bin/env python3
import collections
import copy
from typing import Dict

import numpy as np
import torch
import torch.nn.functional as F
import tqdm
from torch import Tensor, nn

from avalanche.benchmarks.utils import concat_datasets
from avalanche.training.plugins import SupervisedPlugin
from avalanche.training.storage_policy import ClassBalancedBuffer
from avalanche.training.templates import SupervisedTemplate
from avalanche.models import NCMClassifier


@torch.no_grad()
def compute_class_means(model, dataset, batch_size, normalize, device, **kwargs):
    class_means_dict = collections.defaultdict(list)
    class_counts = collections.defaultdict(lambda: 0)
    num_workers = kwargs["num_workers"] if "num_workers" in kwargs else 0
    loader = torch.utils.data.DataLoader(
        dataset.eval(), batch_size=batch_size, shuffle=False, num_workers=num_workers
    )

    model.eval()

    for x, y, t in loader:
        x = x.to(device)
        for class_idx in torch.unique(y):
            mask = y == class_idx
            out = model.feature_extractor(x[mask])
            class_means_dict[int(class_idx)].append(out)
            class_counts[int(class_idx)] += len(x[mask])

    for k, v in class_means_dict.items():
        v = torch.cat(v)
        if normalize:
            class_means_dict[k] = (
                torch.sum(v / torch.norm(v, dim=1, keepdim=True), dim=0)
                / class_counts[k]
            )
        else:
            class_means_dict[k] = torch.sum(v, dim=0) / class_counts[k]

        if normalize:
            class_means_dict[k] = class_means_dict[k] / class_means_dict[k].norm()

    model.train()

    return class_means_dict


def _check_has_ncm(model):
    assert hasattr(model, "eval_classifier")
    assert isinstance(model.eval_classifier, NCMClassifier)


[docs]class CurrentDataNCMUpdate(SupervisedPlugin): """ Updates the NCM prototypes using the current task data """
[docs] def __init__(self): super().__init__()
# Maybe change with before_eval @torch.no_grad() def after_training_exp(self, strategy, **kwargs): _check_has_ncm(strategy.model) class_means_dict = compute_class_means( strategy.model, strategy.experience.dataset, strategy.train_mb_size, normalize=strategy.model.eval_classifier.normalize, device=strategy.device, ) strategy.model.eval_classifier.update_class_means_dict(class_means_dict)
[docs]class MemoryNCMUpdate(SupervisedPlugin): """ Updates NCM prototypes using the data contained inside a memory buffer (as is is done in ICaRL) """
[docs] def __init__(self, mem_size=2000, storage_policy=None): super().__init__() if storage_policy is None: self.storage_policy = ClassBalancedBuffer(max_size=mem_size) else: self.storage_policy = storage_policy
def after_training_exp(self, strategy, **kwargs): _check_has_ncm(strategy.model) self.storage_policy.update(strategy) class_means_dict = compute_class_means( strategy.model, self.storage_policy.buffer.eval(), batch_size=strategy.train_mb_size, normalize=strategy.model.eval_classifier.normalize, device=strategy.device, ) strategy.model.eval_classifier.replace_class_means_dict(class_means_dict)
[docs]class NCMOracle(SupervisedPlugin): """ Updates NCM prototypes using all the data seen so far WARNING: This is an oracle, and thus breaks assumptions usually made in continual learning algorithms i (storage of full dataset) This is meant to be used as an upper bound for NCM based methods (i.e when trying to estimate prototype drift) """
[docs] def __init__(self): super().__init__() self.all_datasets = []
@torch.no_grad() def after_training_exp(self, strategy, **kwargs): _check_has_ncm(strategy.model) self.all_datasets.append(strategy.experience.dataset) accumulated_dataset = concat_datasets(self.all_datasets) class_means_dict = compute_class_means( strategy.model, accumulated_dataset, strategy.train_mb_size, normalize=strategy.model.eval_classifier.normalize, device=strategy.device, ) strategy.model.eval_classifier.replace_class_means_dict(class_means_dict)
__all__ = ["CurrentDataNCMUpdate", "MemoryNCMUpdate", "NCMOracle"]