Source code for avalanche.training.plugins.gss_greedy

from typing import TYPE_CHECKING

import torch
from avalanche.benchmarks.utils import AvalancheDataset
from avalanche.benchmarks.utils.data_loader import ReplayDataLoader
from avalanche.training.plugins.strategy_plugin import StrategyPlugin

if TYPE_CHECKING:
    from .. import BaseStrategy


[docs]class GSS_greedyPlugin(StrategyPlugin): """ GSSPlugin replay plugin. Code adapted from the repository: https://github.com/RaptorMai/online-continual-learning Handles an external memory fulled with samples selected using the Greedy approach of GSS algorithm. `before_forward` callback is used to process the current sample and estimate a score. """
[docs] def __init__(self, mem_size=200, mem_strength=5, input_size=[]): """ :param mem_size: total number of patterns to be stored in the external memory. :param mem_strength: :param input_size: """ super().__init__() self.mem_size = mem_size self.mem_strength = mem_strength self.device = 'cpu' self.ext_mem_list_x = torch.FloatTensor(mem_size, *input_size).fill_(0) self.ext_mem_list_y = torch.LongTensor(mem_size).fill_(0) self.ext_mem_list_current_index = 0 self.buffer_score = torch.FloatTensor(self.mem_size).fill_(0)
def before_training(self, strategy: 'BaseStrategy', **kwargs): self.device = strategy.device self.ext_mem_list_x = self.ext_mem_list_x.to(strategy.device) self.ext_mem_list_y = self.ext_mem_list_y.to(strategy.device) self.buffer_score = self.buffer_score.to(strategy.device) def cosine_similarity(self, x1, x2=None, eps=1e-8): x2 = x1 if x2 is None else x2 w1 = x1.norm(p=2, dim=1, keepdim=True) w2 = w1 if x2 is x1 else x2.norm(p=2, dim=1, keepdim=True) sim = torch.mm(x1, x2.t())/(w1 * w2.t()).clamp(min=eps) return sim def get_grad_vector(self, pp, grad_dims): """ gather the gradients in one vector """ grads = torch.zeros(sum(grad_dims), device=self.device) grads.fill_(0.0) cnt = 0 for param in pp(): if param.grad is not None: beg = 0 if cnt == 0 else sum(grad_dims[:cnt]) en = sum(grad_dims[:cnt + 1]) grads[beg: en].copy_(param.grad.data.view(-1)) cnt += 1 return grads def get_batch_sim(self, strategy, grad_dims, batch_x, batch_y): """ Args: buffer: memory buffer grad_dims: gradient dimensions batch_x: current batch x batch_y: current batch y Returns: score of current batch, gradient from memory subsets """ mem_grads = self.get_rand_mem_grads(strategy, grad_dims, len(batch_x)) strategy.model.zero_grad() loss = strategy._criterion(strategy.model.forward(batch_x), batch_y) loss.backward() batch_grad = self.get_grad_vector( strategy.model.parameters, grad_dims).unsqueeze(0) batch_sim = max(self.cosine_similarity(mem_grads, batch_grad)) return batch_sim, mem_grads def get_rand_mem_grads(self, strategy, grad_dims, gss_batch_size): """ Args: buffer: memory buffer grad_dims: gradient dimensions Returns: gradient from memory subsets """ temp_gss_batch_size = min( gss_batch_size, self.ext_mem_list_current_index) num_mem_subs = min(self.mem_strength, self.ext_mem_list_current_index // gss_batch_size) mem_grads = torch.zeros(num_mem_subs, sum( grad_dims), dtype=torch.float32, device=self.device) shuffeled_inds = torch.randperm(self.ext_mem_list_current_index, device=self.device) for i in range(num_mem_subs): random_batch_inds = shuffeled_inds[i * temp_gss_batch_size:i * temp_gss_batch_size + temp_gss_batch_size] batch_x = self.ext_mem_list_x[random_batch_inds].to(strategy.device) batch_y = self.ext_mem_list_y[random_batch_inds].to(strategy.device) strategy.model.zero_grad() loss = strategy._criterion(strategy.model.forward(batch_x), batch_y) loss.backward() mem_grads[i].data.copy_(self.get_grad_vector( strategy.model.parameters, grad_dims)) return mem_grads def get_each_batch_sample_sim( self, strategy, grad_dims, mem_grads, batch_x, batch_y): """ Args: buffer: memory buffer grad_dims: gradient dimensions mem_grads: gradient from memory subsets batch_x: batch images batch_y: batch labels Returns: score of each sample from current batch """ cosine_sim = torch.zeros(batch_x.size(0), device=strategy.device) for i, (x, y) in enumerate(zip(batch_x, batch_y)): strategy.model.zero_grad() ptloss = strategy._criterion( strategy.model.forward(x.unsqueeze(0)), y.unsqueeze(0)) ptloss.backward() # add the new grad to the memory grads and add it is cosine # similarity this_grad = self.get_grad_vector( strategy.model.parameters, grad_dims).unsqueeze(0) cosine_sim[i] = max(self.cosine_similarity(mem_grads, this_grad)) return cosine_sim def before_training_exp(self, strategy, num_workers=0, shuffle=True, **kwargs): """ Dataloader to build batches containing examples from both memories and the training dataset """ if self.ext_mem_list_current_index == 0: return temp_x_tensors = [] for elem in self.ext_mem_list_x: temp_x_tensors.append(elem.to('cpu')) temp_y_tensors = self.ext_mem_list_y.to('cpu') memory = list(zip(temp_x_tensors, temp_y_tensors)) memory = AvalancheDataset(memory, targets=temp_y_tensors) strategy.dataloader = ReplayDataLoader( strategy.adapted_dataset, memory, oversample_small_tasks=True, num_workers=num_workers, batch_size=strategy.train_mb_size, shuffle=shuffle) def after_forward(self, strategy, num_workers=0, shuffle=True, **kwargs): """ After every forward this function select sample to fill the memory buffer based on cosine similarity """ strategy.model.eval() # Compute the gradient dimension grad_dims = [] for param in strategy.model.parameters(): grad_dims.append(param.data.numel()) place_left = self.ext_mem_list_x.size( 0) - self.ext_mem_list_current_index if(place_left <= 0): # buffer full batch_sim, mem_grads = self.get_batch_sim( strategy, grad_dims, batch_x=strategy.mb_x, batch_y=strategy.mb_y) if batch_sim < 0: buffer_score = self.buffer_score[ :self.ext_mem_list_current_index].cpu() buffer_sim = ((buffer_score - torch.min(buffer_score)) / ((torch.max(buffer_score) - torch.min(buffer_score)) + 0.01)) # draw candidates for replacement from the buffer index = torch.multinomial( buffer_sim, strategy.mb_x.size(0), replacement=False)\ .to(strategy.device) # estimate the similarity of each sample in the received batch # to the randomly drawn samples from the buffer. batch_item_sim = self.get_each_batch_sample_sim( strategy, grad_dims, mem_grads, strategy.mb_x, strategy.mb_y) # normalize to [0,1] scaled_batch_item_sim = ((batch_item_sim + 1) / 2).unsqueeze(1) buffer_repl_batch_sim = ( (self.buffer_score[index] + 1) / 2).unsqueeze(1) # draw an event to decide on replacement decision outcome = torch.multinomial(torch.cat((scaled_batch_item_sim, buffer_repl_batch_sim), dim=1), 1, replacement=False) # replace samples with outcome =1 added_indx = torch.arange(end=batch_item_sim.size(0), device=strategy.device) sub_index = outcome.squeeze(1).bool() self.ext_mem_list_x[index[sub_index]] = strategy.mb_x[ added_indx[sub_index]].clone() self.ext_mem_list_y[index[sub_index] ] = strategy.mb_y[added_indx[ sub_index]].clone() self.buffer_score[index[sub_index]] = batch_item_sim[ added_indx[sub_index]].clone() else: offset = min(place_left, strategy.mb_x.size(0)) updated_mb_x = strategy.mb_x[:offset] updated_mb_y = strategy.mb_y[:offset] # first buffer insertion if self.ext_mem_list_current_index == 0: batch_sample_memory_cos = torch.zeros( updated_mb_x.size(0)) + 0.1 else: # draw random samples from buffer mem_grads = self.get_rand_mem_grads( strategy=strategy, grad_dims=grad_dims, gss_batch_size=len(strategy.mb_x)) # estimate a score for each added sample batch_sample_memory_cos = self.get_each_batch_sample_sim( strategy, grad_dims, mem_grads, updated_mb_x, updated_mb_y) self.ext_mem_list_x[self.ext_mem_list_current_index: self.ext_mem_list_current_index + offset].data.copy_(updated_mb_x) self.ext_mem_list_y[self.ext_mem_list_current_index: self.ext_mem_list_current_index + offset].data.copy_(updated_mb_y) self.buffer_score[self.ext_mem_list_current_index: self.ext_mem_list_current_index + offset].data.copy_(batch_sample_memory_cos) self.ext_mem_list_current_index += offset strategy.model.train()