avalanche.training.SynapticIntelligence

class avalanche.training.SynapticIntelligence(model: torch.nn.modules.module.Module, optimizer: torch.optim.optimizer.Optimizer, criterion, si_lambda: typing.Union[float, typing.Sequence[float]], eps: float = 1e-07, train_mb_size: int = 1, train_epochs: int = 1, eval_mb_size: int = 1, device='cpu', plugins: typing.Optional[typing.Sequence[avalanche.core.SupervisedPlugin]] = None, evaluator=<avalanche.training.plugins.evaluation.EvaluationPlugin object>, eval_every=-1, **base_kwargs)[source]

Synaptic Intelligence strategy.

This is the Synaptic Intelligence PyTorch implementation of the algorithm described in the paper “Continuous Learning in Single-Incremental-Task Scenarios” (https://arxiv.org/abs/1806.08568)

The original implementation has been proposed in the paper “Continual Learning Through Synaptic Intelligence” (https://arxiv.org/abs/1703.04200).

The Synaptic Intelligence regularization can also be used in a different strategy by applying the SynapticIntelligencePlugin plugin.

__init__(model: torch.nn.modules.module.Module, optimizer: torch.optim.optimizer.Optimizer, criterion, si_lambda: typing.Union[float, typing.Sequence[float]], eps: float = 1e-07, train_mb_size: int = 1, train_epochs: int = 1, eval_mb_size: int = 1, device='cpu', plugins: typing.Optional[typing.Sequence[avalanche.core.SupervisedPlugin]] = None, evaluator=<avalanche.training.plugins.evaluation.EvaluationPlugin object>, eval_every=-1, **base_kwargs)[source]

Init.

Creates an instance of the Synaptic Intelligence strategy.

Parameters
  • model – PyTorch model.

  • optimizer – PyTorch optimizer.

  • criterion – loss function.

  • si_lambda – Synaptic Intelligence lambda term. If list, one lambda for each experience. If the list has less elements than the number of experiences, last lambda will be used for the remaining experiences.

  • eps – Synaptic Intelligence damping parameter.

  • train_mb_size – mini-batch size for training.

  • train_epochs – number of training epochs.

  • eval_mb_size – mini-batch size for eval.

  • device – PyTorch device to run the model.

  • plugins – (optional) list of StrategyPlugins.

  • evaluator – (optional) instance of EvaluationPlugin for logging and metric computations.

  • eval_every – the frequency of the calls to eval inside the training loop. -1 disables the evaluation. 0 means eval is called only at the end of the learning experience. Values >0 mean that eval is called every eval_every epochs and at the end of the learning experience.

  • base_kwargs – any additional BaseTemplate constructor arguments.

Methods

__init__(model, optimizer, criterion, si_lambda)

Init.

backward()

Run the backward pass.

criterion()

Loss function.

eval(exp_list, **kwargs)

Evaluate the current model on a series of experiences and returns the last recorded value for each metric.

eval_dataset_adaptation(**kwargs)

Initialize self.adapted_dataset.

eval_epoch(**kwargs)

Evaluation loop over the current self.dataloader.

forward()

Compute the model's output given the current mini-batch.

make_eval_dataloader([num_workers, ...])

Initializes the eval data loader. :param num_workers: How many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0). :param pin_memory: If True, the data loader will copy Tensors into CUDA pinned memory before returning them. Defaults to True. :param kwargs: :return:.

make_optimizer()

Optimizer initialization.

make_train_dataloader([num_workers, ...])

Data loader initialization.

model_adaptation([model])

Adapts the model to the current data.

optimizer_step()

Execute the optimizer step (weights update).

stop_training()

Signals to stop training at the next iteration.

train(experiences[, eval_streams])

Training loop.

train_dataset_adaptation(**kwargs)

Initialize self.adapted_dataset.

training_epoch(**kwargs)

Training epoch.

Attributes

is_eval

True if the strategy is in evaluation mode.

mb_task_id

Current mini-batch task labels.

mb_x

Current mini-batch input.

mb_y

Current mini-batch target.