import warnings
from typing import Callable, Optional, Sequence, Union
import os
import torch
from torch.nn import Module
from avalanche.training.plugins import SupervisedPlugin
from avalanche.training.templates import SupervisedTemplate
from avalanche.training.plugins.evaluation import (
EvaluationPlugin,
default_evaluator,
)
from avalanche.models.dynamic_modules import MultiTaskModule
from avalanche.models import FeatureExtractorBackbone
from avalanche.training.templates.strategy_mixin_protocol import CriterionType
[docs]class StreamingLDA(SupervisedTemplate):
"""Deep Streaming Linear Discriminant Analysis.
This strategy does not use backpropagation.
Minibatches are first passed to the pretrained feature extractor.
The result is processed one element at a time to fit the LDA.
Original paper:
"Hayes et. al., Lifelong Machine Learning with Deep Streaming Linear
Discriminant Analysis, CVPR Workshop, 2020"
https://openaccess.thecvf.com/content_CVPRW_2020/papers/w15/Hayes_Lifelong_Machine_Learning_With_Deep_Streaming_Linear_Discriminant_Analysis_CVPRW_2020_paper.pdf
"""
[docs] def __init__(
self,
*,
slda_model: Module,
criterion: CriterionType,
input_size: int,
num_classes: int,
output_layer_name: Optional[str] = None,
shrinkage_param=1e-4,
streaming_update_sigma=True,
train_epochs: int = 1,
train_mb_size: int = 1,
eval_mb_size: int = 1,
device: Union[str, torch.device] = "cpu",
plugins: Optional[Sequence["SupervisedPlugin"]] = None,
evaluator: Union[
EvaluationPlugin, Callable[[], EvaluationPlugin]
] = default_evaluator,
eval_every=-1,
**kwargs,
):
"""Init function for the SLDA model.
:param model: a PyTorch model
:param criterion: loss function
:param output_layer_name: if not None, wrap model to retrieve
only the `output_layer_name` output. If None, the strategy
assumes that the model already produces a valid output.
You can use `FeatureExtractorBackbone` class to create your custom
SLDA-compatible model.
:param input_size: feature dimension
:param num_classes: number of total classes in stream
:param train_mb_size: batch size for feature extractor during
training. Fit will be called on a single pattern at a time.
:param eval_mb_size: batch size for inference
:param shrinkage_param: value of the shrinkage parameter
:param streaming_update_sigma: True if sigma is plastic else False
feature extraction in `self.feature_extraction_wrapper`.
:param plugins: list of StrategyPlugins
:param evaluator: Evaluation Plugin instance
:param eval_every: run eval every `eval_every` epochs.
See `BaseTemplate` for details.
"""
if plugins is None:
plugins = []
slda_model = slda_model.eval()
if output_layer_name is not None:
slda_model = FeatureExtractorBackbone(
slda_model.to(device), output_layer_name
).eval()
super(StreamingLDA, self).__init__(
model=slda_model,
optimizer=None, # type: ignore
criterion=criterion,
train_mb_size=train_mb_size,
train_epochs=train_epochs,
eval_mb_size=eval_mb_size,
device=device,
plugins=plugins,
evaluator=evaluator,
eval_every=eval_every,
**kwargs,
)
# SLDA parameters
self.input_size = input_size
self.shrinkage_param = shrinkage_param
self.streaming_update_sigma = streaming_update_sigma
# setup weights for SLDA
self.muK = torch.zeros((num_classes, input_size)).to(self.device)
self.cK = torch.zeros(num_classes).to(self.device)
self.Sigma = torch.ones((input_size, input_size)).to(self.device)
self.num_updates = 0
self.Lambda = torch.zeros_like(self.Sigma).to(self.device)
self.prev_num_updates = -1
def forward(self, return_features=False):
"""Compute the model's output given the current mini-batch."""
self.model.eval()
if isinstance(self.model, MultiTaskModule):
feat = self.model(self.mb_x, self.mb_task_id)
else: # no task labels
feat = self.model(self.mb_x)
out = self.predict(feat)
if return_features:
return out, feat
else:
return out
def training_epoch(self, **kwargs):
"""
Training epoch.
:param kwargs:
:return:
"""
for _, self.mbatch in enumerate(self.dataloader):
self._unpack_minibatch()
self._before_training_iteration(**kwargs)
self.loss = self._make_empty_loss()
# Forward
self._before_forward(**kwargs)
# compute output on entire minibatch
self.mb_output, feats = self.forward(return_features=True)
self._after_forward(**kwargs)
# Loss & Backward
self.loss += self.criterion()
# Optimization step
self._before_update(**kwargs)
# process one element at a time
for f, y in zip(feats, self.mb_y):
self.fit(f.unsqueeze(0), y.unsqueeze(0))
self._after_update(**kwargs)
self._after_training_iteration(**kwargs)
def make_optimizer(self, **kwargs):
"""Empty function.
Deep SLDA does not need a Pytorch optimizer."""
pass
@torch.no_grad()
def fit(self, x, y):
"""
Fit the SLDA model to a new sample (x,y).
:param x: a torch tensor of the input data (must be a vector)
:param y: a torch tensor of the input label
:return: None
"""
# covariance updates
if self.streaming_update_sigma:
x_minus_mu = x - self.muK[y]
mult = torch.matmul(x_minus_mu.transpose(1, 0), x_minus_mu)
delta = mult * self.num_updates / (self.num_updates + 1)
self.Sigma = (self.num_updates * self.Sigma + delta) / (
self.num_updates + 1
)
# update class means
self.muK[y, :] += (x - self.muK[y, :]) / (self.cK[y] + 1).unsqueeze(1)
self.cK[y] += 1
self.num_updates += 1
@torch.no_grad()
def predict(self, X):
"""
Make predictions on test data X.
:param X: a torch tensor that contains N data samples (N x d)
:param return_probas: True if the user would like probabilities instead
of predictions returned
:return: the test predictions or probabilities
"""
# compute/load Lambda matrix
if self.prev_num_updates != self.num_updates:
# there have been updates to the model, compute Lambda
self.Lambda = torch.pinverse(
(1 - self.shrinkage_param) * self.Sigma
+ self.shrinkage_param * torch.eye(self.input_size, device=self.device)
)
self.prev_num_updates = self.num_updates
# parameters for predictions
M = self.muK.transpose(1, 0)
W = torch.matmul(self.Lambda, M)
c = 0.5 * torch.sum(M * W, dim=0)
scores = torch.matmul(X, W) - c
# return predictions or probabilities
return scores
def fit_base(self, X, y):
"""
Fit the SLDA model to the base data.
:param X: an Nxd torch tensor of base initialization data
:param y: an Nx1-dimensional torch tensor of the associated labels for X
:return: None
"""
print("\nFitting Base...")
# update class means
for k in torch.unique(y):
self.muK[k] = X[y == k].mean(0)
self.cK[k] = X[y == k].shape[0]
self.num_updates = X.shape[0]
print("\nEstimating initial covariance matrix...")
from sklearn.covariance import OAS
cov_estimator = OAS(assume_centered=True)
cov_estimator.fit((X - self.muK[y]).cpu().numpy())
self.Sigma = torch.from_numpy(cov_estimator.covariance_).float().to(self.device)
def save_model(self, save_path, save_name):
"""
Save the model parameters to a torch file.
:param save_path: the path where the model will be saved
:param save_name: the name for the saved file
:return:
"""
# grab parameters for saving
d = dict()
d["muK"] = self.muK.cpu()
d["cK"] = self.cK.cpu()
d["Sigma"] = self.Sigma.cpu()
d["num_updates"] = self.num_updates
# save model out
torch.save(d, os.path.join(save_path, save_name + ".pth"))
def load_model(self, save_path, save_name):
"""
Load the model parameters into StreamingLDA object.
:param save_path: the path where the model is saved
:param save_name: the name of the saved file
:return:
"""
# load parameters
d = torch.load(os.path.join(save_path, save_name + ".pth"))
self.muK = d["muK"].to(self.device)
self.cK = d["cK"].to(self.device)
self.Sigma = d["Sigma"].to(self.device)
self.num_updates = d["num_updates"]
def _check_plugin_compatibility(self):
"""Check that the list of plugins is compatible with the template.
This means checking that each plugin impements a subset of the
supported callbacks.
"""
# TODO: ideally we would like to check the argument's type to check
# that it's a supertype of the template.
# I don't know if it's possible to do it in Python.
ps = self.plugins
def get_plugins_from_object(obj):
def is_callback(x):
return x.startswith("before") or x.startswith("after")
return filter(is_callback, dir(obj))
cb_supported = set(get_plugins_from_object(self.PLUGIN_CLASS))
cb_supported.remove("before_backward")
cb_supported.remove("after_backward")
for p in ps:
cb_p = set(get_plugins_from_object(p))
if not cb_p.issubset(cb_supported):
warnings.warn(
f"Plugin {p} implements incompatible callbacks for template"
f" {self}. This may result in errors. Incompatible "
f"callbacks: {cb_p - cb_supported}",
)
return
__all__ = ["StreamingLDA"]