Source code for avalanche.training.plugins.ewc

from collections import defaultdict
from typing import Dict, Tuple, Union
import warnings
import itertools

import torch
from torch.utils.data import DataLoader

from avalanche.models.utils import avalanche_forward
from avalanche.training.plugins.strategy_plugin import SupervisedPlugin
from avalanche.training.utils import copy_params_dict, zerolike_params_dict, \
    ParamData


[docs]class EWCPlugin(SupervisedPlugin): """ Elastic Weight Consolidation (EWC) plugin. EWC computes importance of each weight at the end of training on current experience. During training on each minibatch, the loss is augmented with a penalty which keeps the value of the current weights close to the value they had on previous experiences in proportion to their importance on that experience. Importances are computed with an additional pass on the training set. This plugin does not use task identities. """
[docs] def __init__( self, ewc_lambda, mode="separate", decay_factor=None, keep_importance_data=False, ): """ :param ewc_lambda: hyperparameter to weigh the penalty inside the total loss. The larger the lambda, the larger the regularization. :param mode: `separate` to keep a separate penalty for each previous experience. `online` to keep a single penalty summed with a decay factor over all previous tasks. :param decay_factor: used only if mode is `online`. It specifies the decay term of the importance matrix. :param keep_importance_data: if True, keep in memory both parameter values and importances for all previous task, for all modes. If False, keep only last parameter values and importances. If mode is `separate`, the value of `keep_importance_data` is set to be True. """ super().__init__() assert (decay_factor is None) or ( mode == "online" ), "You need to set `online` mode to use `decay_factor`." assert (decay_factor is not None) or ( mode != "online" ), "You need to set `decay_factor` to use the `online` mode." assert ( mode == "separate" or mode == "online" ), "Mode must be separate or online." self.ewc_lambda = ewc_lambda self.mode = mode self.decay_factor = decay_factor if self.mode == "separate": self.keep_importance_data = True else: self.keep_importance_data = keep_importance_data self.saved_params = defaultdict(dict) self.importances = defaultdict(dict)
def before_backward(self, strategy, **kwargs): """ Compute EWC penalty and add it to the loss. """ exp_counter = strategy.clock.train_exp_counter if exp_counter == 0: return penalty = torch.tensor(0).float().to(strategy.device) if self.mode == "separate": for experience in range(exp_counter): for k, cur_param in strategy.model.named_parameters(): # new parameters do not count if k not in self.saved_params[experience]: continue saved_param = self.saved_params[experience][k] imp = self.importances[experience][k] new_shape = cur_param.shape penalty += (imp.expand(new_shape) * (cur_param - saved_param.expand(new_shape)) .pow(2)).sum() elif self.mode == "online": # may need importance and param expansion prev_exp = exp_counter - 1 for k, cur_param in strategy.model.named_parameters(): # new parameters do not count if k not in self.saved_params[prev_exp]: continue saved_param = self.saved_params[prev_exp][k] imp = self.importances[prev_exp][k] new_shape = cur_param.shape penalty += (imp.expand(new_shape) * (cur_param - saved_param.expand(new_shape)) .pow(2)).sum() else: raise ValueError("Wrong EWC mode.") strategy.loss += self.ewc_lambda * penalty def after_training_exp(self, strategy, **kwargs): """ Compute importances of parameters after each experience. """ exp_counter = strategy.clock.train_exp_counter importances = self.compute_importances( strategy.model, strategy._criterion, strategy.optimizer, strategy.experience.dataset, strategy.device, strategy.train_mb_size, ) self.update_importances(importances, exp_counter) self.saved_params[exp_counter] = copy_params_dict(strategy.model) # clear previous parameter values if exp_counter > 0 and (not self.keep_importance_data): del self.saved_params[exp_counter - 1] def compute_importances( self, model, criterion, optimizer, dataset, device, batch_size ): """ Compute EWC importance matrix for each parameter """ model.eval() # Set RNN-like modules on GPU to training mode to avoid CUDA error if device == "cuda": for module in model.modules(): if isinstance(module, torch.nn.RNNBase): warnings.warn( "RNN-like modules do not support " "backward calls while in `eval` mode on CUDA " "devices. Setting all `RNNBase` modules to " "`train` mode. May produce inconsistent " "output if such modules have `dropout` > 0." ) module.train() # list of list importances = zerolike_params_dict(model) collate_fn = ( dataset.collate_fn if hasattr(dataset, "collate_fn") else None ) dataloader = DataLoader( dataset, batch_size=batch_size, collate_fn=collate_fn ) for i, batch in enumerate(dataloader): # get only input, target and task_id from the batch x, y, task_labels = batch[0], batch[1], batch[-1] x, y = x.to(device), y.to(device) optimizer.zero_grad() out = avalanche_forward(model, x, task_labels) loss = criterion(out, y) loss.backward() for (k1, p), (k2, imp) in zip( model.named_parameters(), importances.items() ): assert k1 == k2 if p.grad is not None: imp.data += p.grad.data.clone().pow(2) # average over mini batch length for _, imp in importances.items(): imp.data /= float(len(dataloader)) return importances @torch.no_grad() def update_importances(self, importances, t): """ Update importance for each parameter based on the currently computed importances. """ if self.mode == "separate" or t == 0: self.importances[t] = importances elif self.mode == "online": for (k1, old_imp), (k2, curr_imp) in itertools.zip_longest( self.importances[t-1].items(), importances.items(), fillvalue=(None, None), ): # Add new module importances to the importances value (New head) if k1 is None: self.importances[t][k2] = curr_imp continue assert k1 == k2, "Error in importance computation." # manage expansion of existing layers self.importances[t][k1] = ParamData( f'imp_{k1}', curr_imp.shape, init_tensor=self.decay_factor * old_imp.expand( curr_imp.shape) + curr_imp.data, device=curr_imp.device) # clear previous parameter importances if t > 0 and (not self.keep_importance_data): del self.importances[t - 1] else: raise ValueError("Wrong EWC mode.")
ParamDict = Dict[str, Union[ParamData]] EwcDataType = Tuple[ParamDict, ParamDict]