import copy
from typing import Dict
import numpy as np
import torch
import torch.nn.functional as F
import tqdm
from torch import Tensor, nn
from avalanche.models import DynamicModule
[docs]class FeCAMClassifier(DynamicModule):
"""
FeCAMClassifier
Similar to NCM but uses malahanobis distance instead of l2 distance
This approach has been proposed for continual learning in
"FeCAM: Exploiting the Heterogeneity of Class Distributions
in Exemplar-Free Continual Learning" Goswami et. al.
(Neurips 2023)
This requires the storage of full per-class covariance matrices
"""
[docs] def __init__(
self,
tukey=True,
shrinkage=True,
shrink1: float = 1.0,
shrink2: float = 1.0,
tukey1: float = 0.5,
covnorm=True,
):
"""
:param tukey: whether to use the tukey transforms
(help get the distribution closer
to multivariate gaussian)
:param shrinkage: whether to shrink the covariance matrices
:param shrink1:
:param shrink2:
:param tukey1: power in tukey transforms
:param covnorm: whether to normalize the covariance matrix
"""
super().__init__()
self.class_means_dict = {}
self.class_cov_dict = {}
self.register_buffer("class_means", None)
self.register_buffer("class_covs", None)
self.tukey = tukey
self.shrinkage = shrinkage
self.covnorm = covnorm
self.shrink1 = shrink1
self.shrink2 = shrink2
self.tukey1 = tukey1
self.max_class = -1
@torch.no_grad()
def forward(self, x):
"""
:param x: (batch_size, feature_size)
Returns a tensor of size (batch_size, num_classes) with
negative distance of each element in the mini-batch
with respect to each class.
"""
if self.class_means_dict == {}:
self.init_missing_classes(range(self.max_class + 1), x.shape[1], x.device)
assert self.class_means_dict != {}, "no class means available."
if self.tukey:
x = self._tukey_transforms(x)
maha_dist = torch.ones(self.max_class + 1, x.shape[0]) * np.inf
for class_id, prototype in self.class_means_dict.items():
cov = self.class_cov_dict[class_id]
dist = self._mahalanobis(x, prototype, cov)
maha_dist[class_id] = dist
# n_classes, batch_size
maha_dis = maha_dist.T
# (batch_size, num_classes)
return -maha_dis
def _mahalanobis(self, vectors, class_means, cov):
x_minus_mu = F.normalize(vectors, p=2, dim=-1) - F.normalize(
class_means, p=2, dim=-1
)
inv_covmat = torch.linalg.pinv(cov).float().to(vectors.device)
left_term = torch.matmul(x_minus_mu, inv_covmat)
mahal = torch.matmul(left_term, x_minus_mu.T)
return torch.diagonal(mahal, 0)
def _tukey_transforms(self, x):
x = torch.tensor(x)
if self.tukey1 == 0:
return torch.log(x)
else:
return torch.pow(x, self.tukey1)
def _tukey_invert_transforms(self, x):
x = torch.tensor(x)
if self.tukey1 == 0:
return torch.exp(x)
else:
return torch.pow(x, 1 / self.tukey1)
def _shrink_cov(self, cov):
diag_mean = torch.mean(torch.diagonal(cov))
off_diag = cov.clone()
off_diag.fill_diagonal_(0.0)
mask = off_diag != 0.0
off_diag_mean = (off_diag * mask).sum() / mask.sum()
iden = torch.eye(cov.shape[0]).to(cov.device)
cov_ = (
cov
+ (self.shrink1 * diag_mean * iden)
+ (self.shrink2 * off_diag_mean * (1 - iden))
)
return cov_
def _vectorize_means_dict(self):
if self.class_means_dict == {}:
return
max_class = max(self.class_means_dict.keys())
self.max_class = max(max_class, self.max_class)
first_mean = list(self.class_means_dict.values())[0]
feature_size = first_mean.size(0)
device = first_mean.device
self.class_means = torch.zeros(self.max_class + 1, feature_size).to(device)
for k, v in self.class_means_dict.items():
self.class_means[k] = self.class_means_dict[k].clone()
def _vectorize_cov_dict(self):
if self.class_cov_dict == {}:
return
max_class = max(self.class_cov_dict.keys())
self.max_class = max(max_class, self.max_class)
first_mean = list(self.class_cov_dict.values())[0]
feature_size = first_mean.size(0)
device = first_mean.device
self.class_covs = torch.zeros(
self.max_class + 1, feature_size, feature_size
).to(device)
for k, v in self.class_cov_dict.items():
self.class_covs[k] = self.class_cov_dict[k].clone()
def _normalize_cov(self, cov_mat):
norm_cov_mat = {}
for key, cov in cov_mat.items():
sd = torch.sqrt(torch.diagonal(cov)) # standard deviations of the variables
cov = cov / (torch.matmul(sd.unsqueeze(1), sd.unsqueeze(0)))
norm_cov_mat[key] = cov
return norm_cov_mat
def update_class_means_dict(
self, class_means_dict: Dict[int, Tensor], momentum: float = 0.5
):
assert momentum <= 1 and momentum >= 0
assert isinstance(class_means_dict, dict), (
"class_means_dict must be a dictionary mapping class_id " "to mean vector"
)
for k, v in class_means_dict.items():
if k not in self.class_means_dict or (self.class_means_dict[k] == 0).all():
self.class_means_dict[k] = class_means_dict[k].clone()
else:
device = self.class_means_dict[k].device
self.class_means_dict[k] = (
momentum * class_means_dict[k].to(device)
+ (1 - momentum) * self.class_means_dict[k]
)
self._vectorize_means_dict()
def update_class_cov_dict(
self, class_cov_dict: Dict[int, Tensor], momentum: float = 0.5
):
assert momentum <= 1 and momentum >= 0
assert isinstance(class_cov_dict, dict), (
"class_cov_dict must be a dictionary mapping class_id " "to mean vector"
)
for k, v in class_cov_dict.items():
if k not in self.class_cov_dict or (self.class_cov_dict[k] == 0).all():
self.class_cov_dict[k] = class_cov_dict[k].clone()
else:
device = self.class_cov_dict[k].device
self.class_cov_dict[k] = (
momentum * class_cov_dict[k].to(device)
+ (1 - momentum) * self.class_cov_dict[k]
)
self._vectorize_cov_dict()
def replace_class_means_dict(
self,
class_means_dict: Dict[int, Tensor],
):
self.class_means_dict = class_means_dict
self._vectorize_means_dict()
def replace_class_cov_dict(
self,
class_cov_dict: Dict[int, Tensor],
):
self.class_cov_dict = class_cov_dict
self._vectorize_cov_dict()
def init_missing_classes(self, classes, class_size, device):
for k in classes:
if k not in self.class_means_dict:
self.class_means_dict[k] = torch.zeros(class_size).to(device)
self.class_cov_dict[k] = torch.eye(class_size).to(device)
# Vectorize
self._vectorize_means_dict()
self._vectorize_cov_dict()
def adaptation(self, experience):
super().adaptation(experience)
if not self.training:
classes = experience.classes_in_this_experience
for k in classes:
self.max_class = max(k, self.max_class)
if len(self.class_means_dict) > 0:
self.init_missing_classes(
classes,
list(self.class_means_dict.values())[0].shape[0],
list(self.class_means_dict.values())[0].device,
)
def apply_transforms(self, features):
if self.tukey:
features = self._tukey_transforms(features)
return features
def apply_invert_transforms(self, features):
if self.tukey:
features = self._tukey_invert_transforms(features)
return features
def apply_cov_transforms(self, class_cov):
if self.shrinkage:
for key, cov in class_cov.items():
class_cov[key] = self._shrink_cov(cov)
class_cov[key] = self._shrink_cov(class_cov[key])
if self.covnorm:
class_cov = self._normalize_cov(class_cov)
return class_cov
def load_state_dict(self, state_dict, strict: bool = True):
self.class_means = state_dict["class_means"]
self.class_covs = state_dict["class_covs"]
super().load_state_dict(state_dict, strict)
# fill dictionary
if self.class_means is not None:
for i in range(self.class_means.shape[0]):
if (self.class_means[i] != 0).any():
self.class_means_dict[i] = self.class_means[i].clone()
self.max_class = max(self.class_means_dict.keys())
if self.class_covs is not None:
for i in range(self.class_covs.shape[0]):
if (self.class_covs[i] != 0).any():
self.class_cov_dict[i] = self.class_covs[i].clone()
def compute_covariance(features, labels) -> Dict:
class_cov = {}
for class_id in list(torch.unique(labels).cpu().int().numpy()):
mask = labels == class_id
class_features = features[mask]
cov = torch.cov(class_features.T)
class_cov[class_id] = cov
return class_cov
def compute_means(features, labels) -> Dict:
class_means = {}
for class_id in list(torch.unique(labels).cpu().int().numpy()):
mask = labels == class_id
class_features = features[mask]
prototype = torch.mean(class_features, dim=0)
class_means[class_id] = prototype
return class_means