################################################################################
# Copyright (c) 2021 ContinualAI. #
# Copyrights licensed under the MIT License. #
# See the accompanying LICENSE file for terms. #
# #
# Date: 14-04-2020 #
# Author(s): Antonio Carta #
# E-mail: contact@continualai.org #
# Website: avalanche.continualai.org #
################################################################################
"""
Utilities to handle optimizer's update when using dynamic architectures.
Dynamic Modules (e.g. multi-head) can change their parameters dynamically
during training, which usually requires to update the optimizer to learn
the new parameters or freeze the old ones.
"""
from collections import defaultdict
[docs]def reset_optimizer(optimizer, model):
"""Reset the optimizer to update the list of learnable parameters.
.. warning::
This function fails if the optimizer uses multiple parameter groups.
:param optimizer:
:param model:
:return:
"""
assert len(optimizer.param_groups) == 1
optimizer.state = defaultdict(dict)
optimizer.param_groups[0]["params"] = list(model.parameters())
[docs]def update_optimizer(optimizer, old_params, new_params, reset_state=True):
"""Update the optimizer by substituting old_params with new_params.
:param old_params: List of old trainable parameters.
:param new_params: List of new trainable parameters.
:param reset_state: Wheter to reset the optimizer's state.
Defaults to True.
:return:
"""
for old_p, new_p in zip(old_params, new_params):
found = False
# iterate over group and params for each group.
for group in optimizer.param_groups:
for i, curr_p in enumerate(group["params"]):
if hash(curr_p) == hash(old_p):
# update parameter reference
group["params"][i] = new_p
found = True
break
if found:
break
if not found:
raise Exception(
f"Parameter {old_params} not found in the "
f"current optimizer."
)
if reset_state:
# State contains parameter-specific information.
# We reset it because the model is (probably) changed.
optimizer.state = defaultdict(dict)
[docs]def add_new_params_to_optimizer(optimizer, new_params):
"""Add new parameters to the trainable parameters.
:param new_params: list of trainable parameters
"""
optimizer.add_param_group({"params": new_params})