avalanche.training.plugins.CoPEPlugin

class avalanche.training.plugins.CoPEPlugin(mem_size=200, n_classes=10, p_size=100, alpha=0.99, T=0.1, max_it_cnt=1)[source]

Continual Prototype Evolution plugin.

Each class has a prototype for nearest-neighbor classification. The prototypes are updated continually with an exponentially moving average, using class-balanced replay to keep the prototypes up-to-date. The embedding space is optimized using the PseudoPrototypicalProxy-loss, exploiting both prototypes and batch information.

This plugin doesn’t use task identities in training or eval (data incremental) and is designed for online learning (1 epoch per task).

__init__(mem_size=200, n_classes=10, p_size=100, alpha=0.99, T=0.1, max_it_cnt=1)[source]
Parameters
  • mem_size – max number of input samples in the replay memory.

  • n_classes – total number of classes that will be encountered. This is used to output predictions for all classes, with zero probability for unseen classes.

  • p_size – The prototype size, which equals the feature size of the last layer.

  • alpha – The momentum for the exponentially moving average of the prototypes.

  • T – The softmax temperature, used as a concentration parameter.

  • max_it_cnt – How many processing iterations per batch (experience)

Methods

__init__([mem_size, n_classes, p_size, ...])

param mem_size

max number of input samples in the replay memory.

after_backward(strategy, *args, **kwargs)

Called after criterion.backward() by the BaseTemplate.

after_eval(strategy, *args, **kwargs)

Called after eval by the BaseTemplate.

after_eval_dataset_adaptation(strategy, ...)

Called after eval_dataset_adaptation by the BaseTemplate.

after_eval_exp(strategy, *args, **kwargs)

Called after eval_exp by the BaseTemplate.

after_eval_forward(strategy, *args, **kwargs)

Called after model.forward() by the BaseTemplate.

after_eval_iteration(strategy, **kwargs)

Convert output scores to probabilities for other metrics like accuracy and forgetting.

after_forward(strategy, **kwargs)

After the forward we can use the representations to update our running avg of the prototypes.

after_train_dataset_adaptation(strategy, ...)

Called after train_dataset_adapatation by the BaseTemplate.

after_training(strategy, *args, **kwargs)

Called after train by the BaseTemplate.

after_training_epoch(strategy, *args, **kwargs)

Called after train_epoch by the BaseTemplate.

after_training_exp(strategy, **kwargs)

After the current experience (batch), update prototypes and store observed samples for replay.

after_training_iteration(strategy, **kwargs)

Implements early stopping, determining how many subsequent times a batch can be used for updates.

after_update(strategy, *args, **kwargs)

Called after optimizer.update() by the BaseTemplate.

before_backward(strategy, *args, **kwargs)

Called before criterion.backward() by the BaseTemplate.

before_eval(strategy, *args, **kwargs)

Called before eval by the BaseTemplate.

before_eval_dataset_adaptation(strategy, ...)

Called before eval_dataset_adaptation by the BaseTemplate.

before_eval_exp(strategy, *args, **kwargs)

Called before eval_exp by the BaseTemplate.

before_eval_forward(strategy, *args, **kwargs)

Called before model.forward() by the BaseTemplate.

before_eval_iteration(strategy, *args, **kwargs)

Called before the start of a training iteration by the BaseTemplate.

before_forward(strategy, *args, **kwargs)

Called before model.forward() by the BaseTemplate.

before_train_dataset_adaptation(strategy, ...)

Called before train_dataset_adapatation by the BaseTemplate.

before_training(strategy, **kwargs)

Enforce using the PPP-loss and add a NN-classifier.

before_training_epoch(strategy, *args, **kwargs)

Called before train_epoch by the BaseTemplate.

before_training_exp(strategy[, num_workers, ...])

Random retrieval from a class-balanced memory.

before_training_iteration(strategy, *args, ...)

Called before the start of a training iteration by the BaseTemplate.

before_update(strategy, *args, **kwargs)

Called before optimizer.update() by the BaseTemplate.