Training module

training

Training Templates

Templates define the training/eval loop for each setting (supervised CL, online CL, RL, …). Each template supports a set of callback that can be used by a plugin to execute code inside the training/eval loops.

Templates

Templates are defined in the avalanche.training.templates module.

BaseTemplate(model[, device, plugins])

Base class for continual learning skeletons.

BaseSGDTemplate(model, optimizer[, ...])

Base SGD class for continual learning skeletons.

SupervisedTemplate(model, optimizer[, ...])

Base class for continual learning strategies.

OnlineSupervisedTemplate(model, optimizer[, ...])

Base class for continual learning strategies.

Plugins ABCs

ABCs for plugins are available in avalanche.core.

BasePlugin()

ABC for BaseTemplate plugins.

BaseSGDPlugin()

ABC for BaseSGDTemplate plugins.

SupervisedPlugin()

ABC for SupervisedTemplate plugins.

Training Strategies

Ready-to-use continual learning strategies.

Cumulative(model, optimizer, criterion, ...)

Cumulative training strategy.

JointTraining(model, optimizer, criterion, ...)

Joint training on the entire stream.

Naive(model, optimizer[, criterion, device, ...])

Naive finetuning.

AR1([criterion, momentum, l2, ...])

AR1 with Latent Replay.

StreamingLDA(slda_model, criterion, ...[, ...])

Deep Streaming Linear Discriminant Analysis.

ICaRL(feature_extractor, classifier, ...[, ...])

iCaRL Strategy.

PNNStrategy(model, optimizer[, criterion, ...])

Progressive Neural Network strategy.

CWRStar(model, optimizer, criterion, ...[, ...])

CWR* Strategy.

Replay(model, optimizer, criterion, ...[, ...])

Experience replay strategy.

GSS_greedy(model, optimizer, criterion, mem_size)

Experience replay strategy.

GDumb(model, optimizer, criterion, mem_size, ...)

GDumb strategy.

LwF(model, optimizer, criterion, alpha, ...)

Learning without Forgetting (LwF) strategy.

AGEM(model, optimizer, criterion, ...[, ...])

Average Gradient Episodic Memory (A-GEM) strategy.

GEM(model, optimizer, criterion, ...[, ...])

Gradient Episodic Memory (GEM) strategy.

EWC(model, optimizer, criterion, ewc_lambda, ...)

Elastic Weight Consolidation (EWC) strategy.

SynapticIntelligence(model, optimizer, ...)

Synaptic Intelligence strategy.

CoPE(model, optimizer, criterion, mem_size, ...)

Continual Prototype Evolution strategy.

LFL(model, optimizer, criterion, lambda_e, ...)

Less Forgetful Learning strategy.

GenerativeReplay(model, optimizer[, ...])

Generative Replay Strategy

MAS(model, optimizer, criterion, lambda_reg, ...)

Memory Aware Synapses (MAS) strategy.

BiC(model, optimizer, criterion, mem_size, ...)

Bias Correction (BiC) strategy.

MIR(model, optimizer, criterion, mem_size, ...)

Maximally Interfered Replay Strategy See ER_MIR plugin for details.

Replay Buffers and Selection Strategies

Buffers to store past samples according to different policies and selection strategies.

Buffers

ExemplarsBuffer(max_size)

ABC for rehearsal buffers to store exemplars.

ReservoirSamplingBuffer(max_size)

Buffer updated with reservoir sampling.

BalancedExemplarsBuffer(max_size[, ...])

A buffer that stores exemplars for rehearsal in separate groups.

ExperienceBalancedBuffer(max_size[, ...])

Rehearsal buffer with samples balanced over experiences.

ClassBalancedBuffer(max_size[, ...])

Stores samples for replay, equally divided over classes.

ParametricBuffer(max_size[, groupby, ...])

Stores samples for replay using a custom selection strategy and grouping.

Selection strategies

ExemplarsSelectionStrategy()

Base class to define how to select a subset of exemplars from a dataset.

RandomExemplarsSelectionStrategy()

Select the exemplars at random in the dataset

FeatureBasedExemplarsSelectionStrategy(...)

Base class to select exemplars from their features

HerdingSelectionStrategy(model, layer_name)

The herding strategy as described in iCaRL.

ClosestToCenterSelectionStrategy(model, ...)

A greedy algorithm that selects the remaining exemplar that is the closest to the center of all elements (in feature space).

Loss Functions

ICaRLLossPlugin()

Similar to the Knowledge Distillation Loss.

RegularizationMethod()

RegularizationMethod implement regularization strategies.

LearningWithoutForgetting([alpha, temperature])

Learning Without Forgetting.

Training Plugins

Plugins can be added to any CL strategy to support additional behavior.

Utilities in avalanche.training.plugins.

EarlyStoppingPlugin(patience, val_stream_name)

Early stopping and model checkpoint plugin.

EvaluationPlugin(*metrics[, loggers, ...])

Manager for logging and metrics.

LRSchedulerPlugin(scheduler[, ...])

Learning Rate Scheduler Plugin.

Strategy implemented as plugins in avalanche.training.plugins.

AGEMPlugin(patterns_per_experience, sample_size)

Average Gradient Episodic Memory Plugin.

CoPEPlugin([mem_size, n_classes, p_size, ...])

Continual Prototype Evolution plugin.

CWRStarPlugin(model[, cwr_layer_name, ...])

CWR* Strategy.

EWCPlugin(ewc_lambda[, mode, decay_factor, ...])

Elastic Weight Consolidation (EWC) plugin.

GDumbPlugin([mem_size])

GDumb plugin.

GEMPlugin(patterns_per_experience, ...)

Gradient Episodic Memory Plugin.

GSS_greedyPlugin([mem_size, mem_strength, ...])

GSSPlugin replay plugin.

LFLPlugin(lambda_e)

Less-Forgetful Learning (LFL) Plugin.

LwFPlugin([alpha, temperature])

Learning without Forgetting plugin.

ReplayPlugin([mem_size, batch_size, ...])

Experience replay plugin.

SynapticIntelligencePlugin(si_lambda[, eps, ...])

Synaptic Intelligence plugin.

MASPlugin([lambda_reg, alpha, verbose])

Memory Aware Synapses (MAS) plugin.

TrainGeneratorAfterExpPlugin()

TrainGeneratorAfterExpPlugin makes sure that after each experience of training the solver of a scholar model, we also train the generator on the data of the current experience.

RWalkPlugin([ewc_lambda, ewc_alpha, delta_t])

Riemannian Walk (RWalk) plugin.

GenerativeReplayPlugin([generator_strategy, ...])

Experience generative replay plugin.

BiCPlugin([mem_size, batch_size, ...])

Bias Correction (BiC) plugin.

MIRPlugin([mem_size, subsample, batch_size_mem])

Maximally Interfered Retrieval plugin, Implements the strategy defined in "Online Continual Learning with Maximally Interfered Retrieval" https://arxiv.org/abs/1908.04742