from typing import Callable, List, Sequence, Optional, Union
from packaging.version import parse
import warnings
import torch
from avalanche.training.templates.strategy_mixin_protocol import CriterionType
if parse(torch.__version__) < parse("2.0.0"):
warnings.warn(f"LaMAML requires torch >= 2.0.0.")
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Module, CrossEntropyLoss
from torch.optim import Optimizer
from torch import Tensor
import math
from copy import deepcopy
from avalanche.training.plugins import SupervisedPlugin, EvaluationPlugin
from avalanche.training.plugins.evaluation import default_evaluator
from avalanche.training.templates import SupervisedMetaLearningTemplate
from avalanche.training.storage_policy import ReservoirSamplingBuffer
[docs]class LaMAML(SupervisedMetaLearningTemplate):
[docs] def __init__(
self,
*,
model: Module,
optimizer: Optimizer,
criterion: CriterionType = CrossEntropyLoss(),
n_inner_updates: int = 5,
second_order: bool = True,
grad_clip_norm: float = 1.0,
learn_lr: bool = True,
lr_alpha: float = 0.25,
sync_update: bool = False,
alpha_init: float = 0.1,
max_buffer_size: int = 200,
buffer_mb_size: int = 10,
train_mb_size: int = 1,
train_epochs: int = 1,
eval_mb_size: int = 1,
device: Union[str, torch.device] = "cpu",
plugins: Optional[Sequence["SupervisedPlugin"]] = None,
evaluator: Union[
EvaluationPlugin, Callable[[], EvaluationPlugin]
] = default_evaluator,
eval_every=-1,
peval_mode="epoch",
):
"""Implementation of Look-ahead MAML (LaMAML) strategy.
:param model: PyTorch model.
:param optimizer: PyTorch optimizer.
:param criterion: loss function.
:param n_inner_updates: number of inner updates.
:param second_order: If True, it computes the second-order derivative
of the inner update trajectory for the meta-loss. Otherwise,
it computes the meta-loss with a first-order approximation.
:param grad_clip_norm: gradient clipping norm.
:param learn_lr: if True, it learns the LR for each batch of data.
:param lr_alpha: LR for learning the main update's learning rate.
:param sync_update: if True, it updates the meta-model with a fixed
learning rate. Mutually exclusive with learn_lr and
lr_alpha.
:param alpha_init: initialization value for learnable LRs.
:param max_buffer_size: maximum buffer size. The default storage
policy is reservoir-sampling.
:param buffer_mb_size: number of buffer samples in each step.
"""
super().__init__(
model,
optimizer,
criterion,
train_mb_size,
train_epochs,
eval_mb_size,
device,
plugins,
evaluator,
eval_every,
peval_mode,
)
self.n_inner_updates = n_inner_updates
self.second_order = second_order
self.grad_clip_norm = grad_clip_norm
self.learn_lr = learn_lr
self.lr_alpha = lr_alpha
self.sync_update = sync_update
self.alpha_init = alpha_init
self.alpha_params: nn.ParameterDict = nn.ParameterDict()
self.alpha_params_initialized: bool = False
self.meta_losses: List[Tensor] = []
self.buffer = Buffer(
max_buffer_size=max_buffer_size,
buffer_mb_size=buffer_mb_size,
device=device,
)
self.model.apply(init_kaiming_normal)
def _before_training_exp(self, **kwargs):
super()._before_training_exp(drop_last=True, **kwargs)
# Initialize alpha-lr parameters
if not self.alpha_params_initialized:
self.alpha_params_initialized = True
# Iterate through model parameters and add the corresponding
# alpha_lr parameter
for n, p in self.model.named_parameters():
alpha_param = nn.Parameter(
torch.ones(p.shape) * self.alpha_init, requires_grad=True
)
self.alpha_params[n.replace(".", "_")] = alpha_param
self.alpha_params.to(self.device)
# Create optimizer for the alpha_lr parameters
self.optimizer_alpha = torch.optim.SGD(
self.alpha_params.parameters(), lr=self.lr_alpha
)
# update alpha-lr parameters
for n, p in self.model.named_parameters():
n = n.replace(".", "_") # dict does not support names with '.'
if n in self.alpha_params:
if self.alpha_params[n].shape != p.shape:
old_shape = self.alpha_params[n].shape
# parameter expansion
expanded = False
assert len(p.shape) == len(
old_shape
), "Expansion cannot add new dimensions"
for i, (snew, sold) in enumerate(zip(p.shape, old_shape)):
assert snew >= sold, "Shape cannot decrease."
if snew > sold:
assert not expanded, (
"Expansion cannot occur " "in more than one dimension."
)
expanded = True
exp_idx = i
alpha_param = torch.ones(p.shape) * self.alpha_init
idx = [
slice(el) if i != exp_idx else slice(old_shape[exp_idx])
for i, el in enumerate(p.shape)
]
alpha_param[idx] = self.alpha_params[n].detach().clone()
alpha_param = nn.Parameter(alpha_param, requires_grad=True)
self.alpha_params[n] = alpha_param
else:
# Add new alpha_lr for the new parameter
alpha_param = nn.Parameter(
torch.ones(p.shape) * self.alpha_init, requires_grad=True
)
self.alpha_params[n] = alpha_param
self.alpha_params.to(self.device)
# Re-init optimizer for the new set of alpha_lr parameters
self.optimizer_alpha = torch.optim.SGD(
self.alpha_params.parameters(), lr=self.lr_alpha
)
def copy_grads(self, params_1, params_2):
for p1, p2 in zip(params_1, params_2):
if p2.grad is not None:
p1.grad = p2.grad
def inner_update_step(self, fast_params, x, y, t):
"""Update fast weights using current samples and
return the updated fast model.
"""
logits = torch.func.functional_call(self.model, fast_params, (x, t))
loss = self._criterion(logits, y)
# Compute gradient with respect to the current fast weights
grads = list(
torch.autograd.grad(
loss,
fast_params.values(),
retain_graph=self.second_order,
create_graph=self.second_order,
allow_unused=True,
)
)
# Clip grad norms
grads = [
(
torch.clamp(g, min=-self.grad_clip_norm, max=self.grad_clip_norm)
if g is not None
else g
)
for g in grads
]
# New fast parameters
new_fast_params = {
n: param - alpha * grad if grad is not None else param
for ((n, param), alpha, grad) in zip(
fast_params.items(), self.alpha_params.parameters(), grads
)
}
return new_fast_params
def _inner_updates(self, **kwargs):
# Make a copy of model parameters for fast updates
self.initial_fast_params = {
n: deepcopy(p) for (n, p) in self.model.named_parameters()
}
# Keep reference to the initial fast params
fast_params = self.initial_fast_params
# Samples from the current batch
batch_x, batch_y, batch_t = self.mb_x, self.mb_y, self.mb_task_id
# Get batches from the buffer
if self.clock.train_exp_counter > 0:
buff_x, buff_y, buff_t = self.buffer.get_buffer_batch()
mixed_x = torch.cat([batch_x, buff_x], dim=0)
mixed_y = torch.cat([batch_y, buff_y], dim=0)
mixed_t = torch.cat([batch_t, buff_t], dim=0)
else:
mixed_x, mixed_y, mixed_t = batch_x, batch_y, batch_t
# Split the current batch into smaller chuncks
bsize_data = batch_x.shape[0]
rough_sz = math.ceil(bsize_data / self.n_inner_updates)
self.meta_losses = [torch.empty(0) for _ in range(self.n_inner_updates)]
# Iterate through the chunks as inner-loops
for i in range(self.n_inner_updates):
batch_x_i = batch_x[i * rough_sz : (i + 1) * rough_sz]
batch_y_i = batch_y[i * rough_sz : (i + 1) * rough_sz]
batch_t_i = batch_t[i * rough_sz : (i + 1) * rough_sz]
# We assume that samples for inner update are from the same task
fast_params = self.inner_update_step(
fast_params, batch_x_i, batch_y_i, batch_t_i
)
# Compute meta-loss with the combination of batch and buffer samples
logits_meta = torch.func.functional_call(
self.model, fast_params, (mixed_x, mixed_t)
)
meta_loss_i = self._criterion(logits_meta, mixed_y)
self.meta_losses[i] = meta_loss_i
def _outer_update(self, **kwargs):
self.model.zero_grad()
self.alpha_params.zero_grad()
# Compute meta-gradient for the main model
meta_loss = sum(self.meta_losses) / len(self.meta_losses)
meta_loss.backward()
self.copy_grads(self.model.parameters(), self.initial_fast_params.values())
# Clip gradients
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip_norm)
if self.learn_lr:
# Update lr for the current batch
torch.nn.utils.clip_grad_norm_(
self.alpha_params.parameters(), self.grad_clip_norm
)
self.optimizer_alpha.step()
# If sync-update: update with self.optimizer
# o.w: use the learned LRs to update the model
if self.sync_update:
self.optimizer.step()
else:
for p, alpha in zip(
self.model.parameters(), self.alpha_params.parameters()
):
# Use relu on updated LRs to avoid negative values
if p.grad is not None:
p.data = p.data - p.grad * F.relu(alpha)
self.loss = meta_loss
def _after_training_exp(self, **kwargs):
self.buffer.update(self)
super()._after_training_exp(**kwargs)
class Buffer:
def __init__(
self, max_buffer_size=100, buffer_mb_size=10, device=torch.device("cpu")
):
self.storage_policy = ReservoirSamplingBuffer(max_size=max_buffer_size)
self.buffer_mb_size = buffer_mb_size
self.device = device
def update(self, strategy):
self.storage_policy.update(strategy)
def __len__(self):
return len(self.storage_policy.buffer)
def get_buffer_batch(self):
rnd_ind = torch.randperm(len(self))[: self.buffer_mb_size]
buff = self.storage_policy.buffer.subset(rnd_ind)
buff_x, buff_y, buff_t = [], [], []
for bx, by, bt in buff:
buff_x.append(bx)
buff_y.append(by)
buff_t.append(bt)
buff_x = torch.stack(buff_x, dim=0).to(self.device)
buff_y = torch.tensor(buff_y).to(self.device).long()
buff_t = torch.tensor(buff_t).to(self.device).long()
return buff_x, buff_y, buff_t
def init_kaiming_normal(m):
if isinstance(m, nn.Conv2d):
torch.nn.init.constant_(m.weight.data, 1.0)
torch.nn.init.kaiming_normal_(m.weight.data)
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
torch.nn.init.constant_(m.weight.data, 1.0)
torch.nn.init.kaiming_normal_(m.weight.data)
if m.bias is not None:
m.bias.data.zero_()