Source code for avalanche.models.dynamic_optimizers

################################################################################
# Copyright (c) 2021 ContinualAI.                                              #
# Copyrights licensed under the MIT License.                                   #
# See the accompanying LICENSE file for terms.                                 #
#                                                                              #
# Date: 14-04-2020                                                             #
# Author(s): Antonio Carta                                                     #
# E-mail: contact@continualai.org                                              #
# Website: avalanche.continualai.org                                           #
################################################################################
"""
    Utilities to handle optimizer's update when using dynamic architectures.
    Dynamic Modules (e.g. multi-head) can change their parameters dynamically
    during training, which usually requires to update the optimizer to learn
    the new parameters or freeze the old ones.
"""
from collections import defaultdict


[docs]def reset_optimizer(optimizer, model): """ Reset the optimizer to update the list of learnable parameters. .. warning:: This function fails if the optimizer uses multiple parameter groups. :param optimizer: :param model: :return: """ assert len(optimizer.param_groups) == 1 optimizer.state = defaultdict(dict) optimizer.param_groups[0]['params'] = list(model.parameters())
[docs]def update_optimizer(optimizer, old_params, new_params, reset_state=True): """ Update the optimizer by substituting old_params with new_params. :param old_params: List of old trainable parameters. :param new_params: List of new trainable parameters. :param reset_state: Wheter to reset the optimizer's state. Defaults to True. :return: """ for old_p, new_p in zip(old_params, new_params): found = False # iterate over group and params for each group. for group in optimizer.param_groups: for i, curr_p in enumerate(group['params']): if hash(curr_p) == hash(old_p): # update parameter reference group['params'][i] = new_p found = True break if found: break if not found: raise Exception(f"Parameter {old_params} not found in the " f"current optimizer.") if reset_state: # State contains parameter-specific information. # We reset it because the model is (probably) changed. optimizer.state = defaultdict(dict)
[docs]def add_new_params_to_optimizer(optimizer, new_params): """ Add new parameters to the trainable parameters. :param new_params: list of trainable parameters """ optimizer.add_param_group({'params': new_params})