Source code for avalanche.models.ncm_classifier

from typing import Dict

import torch
from torch import Tensor, nn

from avalanche.models import DynamicModule


[docs]class NCMClassifier(DynamicModule): """ NCM Classifier. NCMClassifier performs nearest class mean classification measuring the distance between the input tensor and the ones stored in 'self.class_means'. Before being used for inference, NCMClassifier needs to be updated with a mean vector per class, by calling `update_class_means_dict`. This class registers a `class_means` buffer that stores the class means in a single tensor of shape [max_class_id_seen, feature_size]. Classes with ID smaller than `max_class_id_seen` are associated with a 0-vector. """
[docs] def __init__(self, normalize: bool = True): """ :param normalize: whether to normalize the input with 2-norm = 1 before computing the distance. """ super().__init__() # vectorized version of class means self.register_buffer("class_means", None) self.class_means_dict = {} self.normalize = normalize self.max_class = -1
def load_state_dict(self, state_dict, strict: bool = True): self.class_means = state_dict["class_means"] super().load_state_dict(state_dict, strict) # fill dictionary if self.class_means is not None: for i in range(self.class_means.shape[0]): if (self.class_means[i] != 0).any(): self.class_means_dict[i] = self.class_means[i].clone() self.max_class = max(self.class_means_dict.keys()) def _vectorize_means_dict(self): """ Transform the dictionary of {class_id: mean vector} into a tensor of shape (num_classes, feature_size), where `feature_size` is the size of the mean vector. The `num_classes` parameter is the maximum class_id found in the dictionary. Missing class means are treated as zero vectors. """ if self.class_means_dict == {}: return max_class = max(self.class_means_dict.keys()) self.max_class = max(max_class, self.max_class) first_mean = list(self.class_means_dict.values())[0] feature_size = first_mean.size(0) device = first_mean.device self.class_means = torch.zeros(self.max_class + 1, feature_size).to(device) for k, v in self.class_means_dict.items(): self.class_means[k] = self.class_means_dict[k].clone() @torch.no_grad() def forward(self, x): """ :param x: (batch_size, feature_size) Returns a tensor of size (batch_size, num_classes) with negative distance of each element in the mini-batch with respect to each class. """ if self.class_means_dict == {}: self.init_missing_classes(range(self.max_class + 1), x.shape[1], x.device) assert self.class_means_dict != {}, "no class means available." if self.normalize: # normalize across feature_size x = (x.T / torch.norm(x.T, dim=0)).T # (num_classes, batch_size) sqd = torch.cdist(self.class_means.to(x.device), x) # (batch_size, num_classes) return (-sqd).T def update_class_means_dict( self, class_means_dict: Dict[int, Tensor], momentum: float = 0.5 ): """ Update dictionary of class means. If class already exists, the average of the two mean vectors will be computed. :param class_means_dict: a dictionary mapping class id to class mean tensor. :param momentum: Weighting of the new means vs old means in the update. 1 = replace, 0 = don't update """ assert momentum <= 1 and momentum >= 0 assert isinstance(class_means_dict, dict), ( "class_means_dict must be a dictionary mapping class_id " "to mean vector" ) for k, v in class_means_dict.items(): if k not in self.class_means_dict or (self.class_means_dict[k] == 0).all(): self.class_means_dict[k] = class_means_dict[k].clone() else: device = self.class_means_dict[k].device self.class_means_dict[k] = ( momentum * class_means_dict[k].to(device) + (1 - momentum) * self.class_means_dict[k] ) self._vectorize_means_dict() def replace_class_means_dict(self, class_means_dict: Dict[int, Tensor]): """ Replace existing dictionary of means with a given dictionary. """ assert isinstance(class_means_dict, dict), ( "class_means_dict must be a dictionary mapping class_id " "to mean vector" ) self.class_means_dict = {k: v.clone() for k, v in class_means_dict.items()} self._vectorize_means_dict() def init_missing_classes(self, classes, class_size, device): for k in classes: if k not in self.class_means_dict: self.class_means_dict[k] = torch.zeros(class_size).to(device) self._vectorize_means_dict() def adaptation(self, experience): super().adaptation(experience) if not self.training: classes = experience.classes_in_this_experience for k in classes: self.max_class = max(k, self.max_class) if self.class_means is not None: self.init_missing_classes( classes, self.class_means.shape[1], self.class_means.device )
__all__ = ["NCMClassifier"]