from typing import Callable, Sequence, Optional, Union
import torch
import warnings
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Module, CrossEntropyLoss
from torch.optim import Optimizer
import math
from avalanche.training.templates.strategy_mixin_protocol import CriterionType
try:
import higher
except ImportError:
warnings.warn(
"higher not found, if you want to use "
"MAML please install avalanche with "
"the extra dependencies: "
"pip install avalanche-lib[extra]"
)
from avalanche.training.plugins import SupervisedPlugin, EvaluationPlugin
from avalanche.training.plugins.evaluation import default_evaluator
from avalanche.training.templates import SupervisedMetaLearningTemplate
from avalanche.models.utils import avalanche_forward
[docs]class LaMAML(SupervisedMetaLearningTemplate):
[docs] def __init__(
self,
*,
model: Module,
optimizer: Optimizer,
criterion: CriterionType = CrossEntropyLoss(),
n_inner_updates: int = 5,
second_order: bool = True,
grad_clip_norm: float = 1.0,
learn_lr: bool = True,
lr_alpha: float = 0.25,
sync_update: bool = False,
alpha_init: float = 0.1,
train_mb_size: int = 1,
train_epochs: int = 1,
eval_mb_size: int = 1,
device: Union[str, torch.device] = "cpu",
plugins: Optional[Sequence["SupervisedPlugin"]] = None,
evaluator: Union[
EvaluationPlugin, Callable[[], EvaluationPlugin]
] = default_evaluator,
eval_every=-1,
peval_mode="epoch",
):
"""Implementation of Look-ahead MAML (LaMAML) algorithm in Avalanche
using Higher library for applying fast updates.
:param model: PyTorch model.
:param optimizer: PyTorch optimizer.
:param criterion: loss function.
:param n_inner_updates: number of inner updates.
:param second_order: If True, it computes the second-order derivative
of the inner update trajectory for the meta-loss. Otherwise,
it computes the meta-loss with a first-order approximation.
:param grad_clip_norm: gradient clipping norm.
:param learn_lr: if True, it learns the LR for each batch of data.
:param lr_alpha: LR for learning the main update's learning rate.
:param sync_update: if True, it updates the meta-model with a fixed
learning rate. Mutually exclusive with learn_lr and
lr_alpha.
:param alpha_init: initialization value for learnable LRs.
"""
super().__init__(
model,
optimizer,
criterion,
train_mb_size,
train_epochs,
eval_mb_size,
device,
plugins,
evaluator,
eval_every,
peval_mode,
)
self.n_inner_updates = n_inner_updates
self.second_order = second_order
self.grad_clip_norm = grad_clip_norm
self.learn_lr = learn_lr
self.lr_alpha = lr_alpha
self.sync_update = sync_update
self.alpha_init = alpha_init
self.alpha_params: nn.ParameterDict = nn.ParameterDict()
self.alpha_params_initialized: bool = False
self.model.apply(init_kaiming_normal)
def _before_training_exp(self, **kwargs):
super()._before_training_exp(drop_last=True, **kwargs)
# Initialize alpha-lr parameters
if not self.alpha_params_initialized:
self.alpha_params_initialized = True
# Iterate through model parameters and add the corresponding
# alpha_lr parameter
for n, p in self.model.named_parameters():
alpha_param = nn.Parameter(
torch.ones(p.shape) * self.alpha_init, requires_grad=True
)
self.alpha_params[n.replace(".", "_")] = alpha_param
self.alpha_params.to(self.device)
# Create optimizer for the alpha_lr parameters
self.optimizer_alpha = torch.optim.SGD(
self.alpha_params.parameters(), lr=self.lr_alpha
)
# update alpha-lr parameters
for n, p in self.model.named_parameters():
n = n.replace(".", "_") # dict does not support names with '.'
if n in self.alpha_params:
if self.alpha_params[n].shape != p.shape:
old_shape = self.alpha_params[n].shape
# parameter expansion
expanded = False
assert len(p.shape) == len(
old_shape
), "Expansion cannot add new dimensions"
for i, (snew, sold) in enumerate(zip(p.shape, old_shape)):
assert snew >= sold, "Shape cannot decrease."
if snew > sold:
assert not expanded, (
"Expansion cannot occur " "in more than one dimension."
)
expanded = True
exp_idx = i
alpha_param = torch.ones(p.shape) * self.alpha_init
idx = [
slice(el) if i != exp_idx else slice(old_shape[exp_idx])
for i, el in enumerate(p.shape)
]
alpha_param[idx] = self.alpha_params[n].detach().clone()
alpha_param = nn.Parameter(alpha_param, requires_grad=True)
self.alpha_params[n] = alpha_param
else:
# Add new alpha_lr for the new parameter
alpha_param = nn.Parameter(
torch.ones(p.shape) * self.alpha_init, requires_grad=True
)
self.alpha_params[n] = alpha_param
self.alpha_params.to(self.device)
# Re-init optimizer for the new set of alpha_lr parameters
self.optimizer_alpha = torch.optim.SGD(
self.alpha_params.parameters(), lr=self.lr_alpha
)
def apply_grad(self, module, grads):
for i, p in enumerate(module.parameters()):
grad = grads[i]
if grad is None:
grad = torch.zeros(p.shape).float().to(self.device)
if p.grad is None:
p.grad = grad
else:
p.grad += grad
def inner_update_step(self, fast_model, x, y, t):
"""Update fast weights using current samples and
return the updated fast model.
"""
logits = avalanche_forward(fast_model, x, t)
loss = self._criterion(logits, y)
# Compute gradient with respect to the current fast weights
grads = list(
torch.autograd.grad(
loss,
fast_model.fast_params,
create_graph=self.second_order,
retain_graph=self.second_order,
allow_unused=True,
)
)
# Clip grad norms
grads = [
(
torch.clamp(g, min=-self.grad_clip_norm, max=self.grad_clip_norm)
if g is not None
else g
)
for g in grads
]
# New fast parameters
new_fast_params = [
param - alpha * grad if grad is not None else param
for (param, alpha, grad) in zip(
fast_model.fast_params, self.alpha_params.parameters(), grads
)
]
# Update fast model's weights
fast_model.update_params(new_fast_params)
def _inner_updates(self, **kwargs):
# Create a stateless copy of the model for inner-updates
self.fast_model = higher.patch.monkeypatch(
self.model,
copy_initial_weights=True,
track_higher_grads=self.second_order,
)
if self.clock.train_exp_counter > 0:
batch_x = self.mb_x[: self.train_mb_size]
batch_y = self.mb_y[: self.train_mb_size]
batch_t = self.mb_task_id[: self.train_mb_size]
else:
batch_x, batch_y, batch_t = self.mb_x, self.mb_y, self.mb_task_id
bsize_data = batch_x.shape[0]
rough_sz = math.ceil(bsize_data / self.n_inner_updates)
self.meta_losses = [0 for _ in range(self.n_inner_updates)]
for i in range(self.n_inner_updates):
batch_x_i = batch_x[i * rough_sz : (i + 1) * rough_sz]
batch_y_i = batch_y[i * rough_sz : (i + 1) * rough_sz]
batch_t_i = batch_t[i * rough_sz : (i + 1) * rough_sz]
# We assume that samples for inner update are from the same task
self.inner_update_step(self.fast_model, batch_x_i, batch_y_i, batch_t_i)
# Compute meta-loss with the combination of batch and buffer samples
logits_meta = avalanche_forward(self.fast_model, self.mb_x, self.mb_task_id)
meta_loss = self._criterion(logits_meta, self.mb_y)
self.meta_losses[i] = meta_loss
def _outer_update(self, **kwargs):
# Compute meta-gradient for the main model
meta_loss = sum(self.meta_losses) / len(self.meta_losses)
meta_grad_model = torch.autograd.grad(
meta_loss,
self.fast_model.parameters(time=0),
retain_graph=True,
allow_unused=True,
)
self.model.zero_grad()
self.apply_grad(self.model, meta_grad_model)
# Clip gradients
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip_norm)
if self.learn_lr:
# Compute meta-gradient for alpha-lr parameters
meta_grad_alpha = torch.autograd.grad(
meta_loss, self.alpha_params.parameters(), allow_unused=True
)
self.alpha_params.zero_grad()
self.apply_grad(self.alpha_params, meta_grad_alpha)
torch.nn.utils.clip_grad_norm_(
self.alpha_params.parameters(), self.grad_clip_norm
)
self.optimizer_alpha.step()
# If sync-update: update with self.optimizer
# o.w: use the learned LRs to update the model
if self.sync_update:
self.optimizer.step()
else:
for p, alpha in zip(
self.model.parameters(), self.alpha_params.parameters()
):
# Use relu on updated LRs to avoid negative values
p.data = p.data - p.grad * F.relu(alpha)
self.loss = meta_loss
def init_kaiming_normal(m):
if isinstance(m, nn.Conv2d):
torch.nn.init.constant_(m.weight.data, 1.0)
torch.nn.init.kaiming_normal_(m.weight.data)
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
torch.nn.init.constant_(m.weight.data, 1.0)
torch.nn.init.kaiming_normal_(m.weight.data)
if m.bias is not None:
m.bias.data.zero_()
__all__ = ["LaMAML"]