from fnmatch import fnmatch
from typing import (

import numpy as np
import torch
from torch import Tensor
from torch.nn import Module
from torch.nn.modules.batchnorm import _NormBase

from .ewc import EwcDataType, ParamDict
from import SupervisedPlugin
from import get_layers_and_params, ParamData

    from ..templates import SupervisedTemplate

SynDataType = Dict[str, Dict[str, Union[ParamData, Tensor]]]

[docs]class SynapticIntelligencePlugin(SupervisedPlugin): """Synaptic Intelligence plugin. This is the Synaptic Intelligence PyTorch implementation of the algorithm described in the paper "Continuous Learning in Single-Incremental-Task Scenarios" ( The original implementation has been proposed in the paper "Continual Learning Through Synaptic Intelligence" ( This plugin can be attached to existing strategies to achieve a regularization effect. This plugin will require the strategy `loss` field to be set before the `before_backward` callback is invoked. The loss Tensor will be updated to achieve the S.I. regularization effect. """
[docs] def __init__( self, si_lambda: Union[float, Sequence[float]], eps: float = 0.0000001, excluded_parameters: Optional[Sequence[str]] = None, device: Any = "as_strategy", ): """Creates an instance of the Synaptic Intelligence plugin. :param si_lambda: Synaptic Intelligence lambda term. If list, one lambda for each experience. If the list has less elements than the number of experiences, last lambda will be used for the remaining experiences. :param eps: Synaptic Intelligence damping parameter. :param device: The device to use to run the S.I. experiences. Defaults to "as_strategy", which means that the `device` field of the strategy will be used. Using a different device may lead to a performance drop due to the required data transfer. """ super().__init__() if excluded_parameters is None: excluded_parameters = [] self.si_lambda = ( si_lambda if isinstance(si_lambda, (list, tuple)) else [si_lambda] ) self.eps: float = eps self.excluded_parameters: Set[str] = set(excluded_parameters) self.ewc_data: EwcDataType = (dict(), dict()) """ The first dictionary contains the params at loss minimum while the second one contains the parameter importance. """ self.syn_data: SynDataType = { "old_theta": dict(), "new_theta": dict(), "grad": dict(), "trajectory": dict(), "cum_trajectory": dict(), } self._device = device
def before_training_exp(self, strategy: "SupervisedTemplate", **kwargs): super().before_training_exp(strategy, **kwargs) SynapticIntelligencePlugin.create_syn_data( strategy.model, self.ewc_data, self.syn_data, self.excluded_parameters, ) SynapticIntelligencePlugin.init_batch( strategy.model, self.ewc_data, self.syn_data, self.excluded_parameters, ) def before_backward(self, strategy: "SupervisedTemplate", **kwargs): super().before_backward(strategy, **kwargs) exp_id = strategy.clock.train_exp_counter try: si_lamb = self.si_lambda[exp_id] except IndexError: # less than one lambda per experience, take last si_lamb = self.si_lambda[-1] syn_loss = SynapticIntelligencePlugin.compute_ewc_loss( strategy.model, self.ewc_data, self.excluded_parameters, lambd=si_lamb, device=self.device(strategy), ) if syn_loss is not None: strategy.loss += def before_training_iteration(self, strategy: "SupervisedTemplate", **kwargs): super().before_training_iteration(strategy, **kwargs) SynapticIntelligencePlugin.pre_update( strategy.model, self.syn_data, self.excluded_parameters ) def after_training_iteration(self, strategy: "SupervisedTemplate", **kwargs): super().after_training_iteration(strategy, **kwargs) SynapticIntelligencePlugin.post_update( strategy.model, self.syn_data, self.excluded_parameters ) def after_training_exp(self, strategy: "SupervisedTemplate", **kwargs): super().after_training_exp(strategy, **kwargs) SynapticIntelligencePlugin.update_ewc_data( strategy.model, self.ewc_data, self.syn_data, 0.001, self.excluded_parameters, 1, eps=self.eps, ) def device(self, strategy: "SupervisedTemplate"): if self._device == "as_strategy": return strategy.device return self._device @staticmethod @torch.no_grad() def create_syn_data( model: Module, ewc_data: EwcDataType, syn_data: SynDataType, excluded_parameters: Set[str], ): params = SynapticIntelligencePlugin.allowed_parameters( model, excluded_parameters ) for param_name, param in params: if param_name not in ewc_data[0]: # new parameter ewc_data[0][param_name] = ParamData(param_name, param.flatten().shape) ewc_data[1][param_name] = ParamData( f"imp_{param_name}", param.flatten().shape ) syn_data["old_theta"][param_name] = ParamData( f"old_theta_{param_name}", param.flatten().shape ) syn_data["new_theta"][param_name] = ParamData( f"new_theta_{param_name}", param.flatten().shape ) syn_data["grad"][param_name] = ParamData( f"grad{param_name}", param.flatten().shape ) syn_data["trajectory"][param_name] = ParamData( f"trajectory_{param_name}", param.flatten().shape ) syn_data["cum_trajectory"][param_name] = ParamData( f"cum_trajectory_{param_name}", param.flatten().shape ) elif ewc_data[0][param_name].shape != param.shape: # parameter expansion ewc_data[0][param_name].expand(param.flatten().shape) ewc_data[1][param_name].expand(param.flatten().shape) syn_data["old_theta"][param_name].expand(param.flatten().shape) syn_data["new_theta"][param_name].expand(param.flatten().shape) syn_data["grad"][param_name].expand(param.flatten().shape) syn_data["trajectory"][param_name].expand(param.flatten().shape) syn_data["cum_trajectory"][param_name].expand(param.flatten().shape) @staticmethod @torch.no_grad() def extract_weights( model: Module, target: ParamDict, excluded_parameters: Set[str] ): params = SynapticIntelligencePlugin.allowed_parameters( model, excluded_parameters ) for name, param in params: target[name].data = param.detach().cpu().flatten() @staticmethod @torch.no_grad() def extract_grad(model, target: ParamDict, excluded_parameters: Set[str]): params = SynapticIntelligencePlugin.allowed_parameters( model, excluded_parameters ) # Store the gradients into target for name, param in params: target[name].data = param.grad.detach().cpu().flatten() @staticmethod @torch.no_grad() def init_batch( model, ewc_data: EwcDataType, syn_data: SynDataType, excluded_parameters: Set[str], ): # Keep initial weights SynapticIntelligencePlugin.extract_weights( model, ewc_data[0], excluded_parameters ) for param_name, param_trajectory in syn_data["trajectory"].items(): @staticmethod @torch.no_grad() def pre_update(model, syn_data: SynDataType, excluded_parameters: Set[str]): SynapticIntelligencePlugin.extract_weights( model, syn_data["old_theta"], excluded_parameters ) @staticmethod @torch.no_grad() def post_update(model, syn_data: SynDataType, excluded_parameters: Set[str]): SynapticIntelligencePlugin.extract_weights( model, syn_data["new_theta"], excluded_parameters ) SynapticIntelligencePlugin.extract_grad( model, syn_data["grad"], excluded_parameters ) for param_name in syn_data["trajectory"]: syn_data["trajectory"][param_name].data += syn_data["grad"][ param_name ].data * ( syn_data["new_theta"][param_name].data - syn_data["old_theta"][param_name].data ) @staticmethod def compute_ewc_loss( model, ewc_data: EwcDataType, excluded_parameters: Set[str], device, lambd=0.0, ): params = SynapticIntelligencePlugin.allowed_parameters( model, excluded_parameters ) loss = None for name, param in params: weights = # Flat, not detached ewc_data0 = ewc_data[0][name] # Flat, detached ewc_data1 = ewc_data[1][name] # Flat, detached syn_loss: Tensor =, (weights - ewc_data0) ** 2) * ( lambd / 2 ) if loss is None: loss = syn_loss else: loss += syn_loss return loss @staticmethod @torch.no_grad() def update_ewc_data( net, ewc_data: EwcDataType, syn_data: SynDataType, clip_to: float, excluded_parameters: Set[str], c=0.0015, eps: float = 0.0000001, ): SynapticIntelligencePlugin.extract_weights( net, syn_data["new_theta"], excluded_parameters ) for param_name in syn_data["cum_trajectory"]: syn_data["cum_trajectory"][param_name].data += ( c * syn_data["trajectory"][param_name].data / ( np.square( syn_data["new_theta"][param_name].data - ewc_data[0][param_name].data ) + eps ) ) for param_name in syn_data["cum_trajectory"]: ewc_data[1][param_name].data = torch.empty_like( syn_data["cum_trajectory"][param_name].data ).copy_(-syn_data["cum_trajectory"][param_name].data) # change sign here because the Ewc regularization # in Caffe (theta - thetaold) is inverted w.r.t. syn equation [4] # (thetaold - theta) for param_name in ewc_data[1]: ewc_data[1][param_name].data = torch.clamp( ewc_data[1][param_name].data, max=clip_to ) ewc_data[0][param_name].data = syn_data["new_theta"][ param_name ].data.clone() @staticmethod def explode_excluded_parameters(excluded: Set[str]) -> Set[str]: """ Explodes a list of excluded parameters by adding a generic final ".*" wildcard at its end. :param excluded: The original set of excluded parameters. :return: The set of excluded parameters in which ".*" patterns have been added. """ result = set() for x in excluded: result.add(x) if not x.endswith("*"): result.add(x + ".*") return result @staticmethod def not_excluded_parameters( model: Module, excluded_parameters: Set[str] ) -> Sequence[Tuple[str, Tensor]]: # Add wildcards ".*" to all excluded parameter names result: List[Tuple[str, Tensor]] = [] excluded_parameters = SynapticIntelligencePlugin.explode_excluded_parameters( excluded_parameters ) layers_params = get_layers_and_params(model) for lp in layers_params: if isinstance(lp.layer, _NormBase): # Exclude batch norm parameters excluded_parameters.add(lp.parameter_name) for name, param in model.named_parameters(): accepted = True for exclusion_pattern in excluded_parameters: if fnmatch(name, exclusion_pattern): accepted = False break if accepted: result.append((name, param)) return result @staticmethod def allowed_parameters( model: Module, excluded_parameters: Set[str] ) -> List[Tuple[str, Tensor]]: allow_list = SynapticIntelligencePlugin.not_excluded_parameters( model, excluded_parameters ) result = [] for name, param in allow_list: if param.requires_grad: result.append((name, param)) return result