avalanche.training.plugins.EWCPlugin

class avalanche.training.plugins.EWCPlugin(ewc_lambda, mode='separate', decay_factor=None, keep_importance_data=False)[source]

Elastic Weight Consolidation (EWC) plugin. EWC computes importance of each weight at the end of training on current experience. During training on each minibatch, the loss is augmented with a penalty which keeps the value of the current weights close to the value they had on previous experiences in proportion to their importance on that experience. Importances are computed with an additional pass on the training set. This plugin does not use task identities.

__init__(ewc_lambda, mode='separate', decay_factor=None, keep_importance_data=False)[source]
Parameters
  • ewc_lambda – hyperparameter to weigh the penalty inside the total loss. The larger the lambda, the larger the regularization.

  • modeseparate to keep a separate penalty for each previous experience. online to keep a single penalty summed with a decay factor over all previous tasks.

  • decay_factor – used only if mode is online. It specifies the decay term of the importance matrix.

  • keep_importance_data – if True, keep in memory both parameter values and importances for all previous task, for all modes. If False, keep only last parameter values and importances. If mode is separate, the value of keep_importance_data is set to be True.

Methods

__init__(ewc_lambda[, mode, decay_factor, ...])

param ewc_lambda

hyperparameter to weigh the penalty inside the total

after_backward(strategy, *args, **kwargs)

Called after criterion.backward() by the BaseTemplate.

after_eval(strategy, *args, **kwargs)

Called after eval by the BaseTemplate.

after_eval_dataset_adaptation(strategy, ...)

Called after eval_dataset_adaptation by the BaseTemplate.

after_eval_exp(strategy, *args, **kwargs)

Called after eval_exp by the BaseTemplate.

after_eval_forward(strategy, *args, **kwargs)

Called after model.forward() by the BaseTemplate.

after_eval_iteration(strategy, *args, **kwargs)

Called after the end of an iteration by the BaseTemplate.

after_forward(strategy, *args, **kwargs)

Called after model.forward() by the BaseTemplate.

after_train_dataset_adaptation(strategy, ...)

Called after train_dataset_adapatation by the BaseTemplate.

after_training(strategy, *args, **kwargs)

Called after train by the BaseTemplate.

after_training_epoch(strategy, *args, **kwargs)

Called after train_epoch by the BaseTemplate.

after_training_exp(strategy, **kwargs)

Compute importances of parameters after each experience.

after_training_iteration(strategy, *args, ...)

Called after the end of a training iteration by the BaseTemplate.

after_update(strategy, *args, **kwargs)

Called after optimizer.update() by the BaseTemplate.

before_backward(strategy, **kwargs)

Compute EWC penalty and add it to the loss.

before_eval(strategy, *args, **kwargs)

Called before eval by the BaseTemplate.

before_eval_dataset_adaptation(strategy, ...)

Called before eval_dataset_adaptation by the BaseTemplate.

before_eval_exp(strategy, *args, **kwargs)

Called before eval_exp by the BaseTemplate.

before_eval_forward(strategy, *args, **kwargs)

Called before model.forward() by the BaseTemplate.

before_eval_iteration(strategy, *args, **kwargs)

Called before the start of a training iteration by the BaseTemplate.

before_forward(strategy, *args, **kwargs)

Called before model.forward() by the BaseTemplate.

before_train_dataset_adaptation(strategy, ...)

Called before train_dataset_adapatation by the BaseTemplate.

before_training(strategy, *args, **kwargs)

Called before train by the BaseTemplate.

before_training_epoch(strategy, *args, **kwargs)

Called before train_epoch by the BaseTemplate.

before_training_exp(strategy, *args, **kwargs)

Called before train_exp by the BaseTemplate.

before_training_iteration(strategy, *args, ...)

Called before the start of a training iteration by the BaseTemplate.

before_update(strategy, *args, **kwargs)

Called before optimizer.update() by the BaseTemplate.

compute_importances(model, criterion, ...)

Compute EWC importance matrix for each parameter

update_importances(importances, t)

Update importance for each parameter based on the currently computed importances.