Source code for avalanche.training.plugins.gss_greedy

from typing import TYPE_CHECKING

import torch
from avalanche.benchmarks.utils import make_classification_dataset
from avalanche.benchmarks.utils.data_loader import ReplayDataLoader
from avalanche.training.plugins.strategy_plugin import SupervisedPlugin

if TYPE_CHECKING:
    from ..templates import SupervisedTemplate


[docs]class GSS_greedyPlugin(SupervisedPlugin): """GSSPlugin replay plugin. Code adapted from the repository: https://github.com/RaptorMai/online-continual-learning Handles an external memory fulled with samples selected using the Greedy approach of GSS algorithm. `before_forward` callback is used to process the current sample and estimate a score. """
[docs] def __init__(self, mem_size=200, mem_strength=5, input_size=[]): """ :param mem_size: total number of patterns to be stored in the external memory. :param mem_strength: :param input_size: """ super().__init__() self.mem_size = mem_size self.mem_strength = mem_strength self.device = torch.device("cpu") self.ext_mem_list_x = torch.FloatTensor(mem_size, *input_size).fill_(0) self.ext_mem_list_y = torch.LongTensor(mem_size).fill_(0) self.ext_mem_list_current_index = 0 self.buffer_score = torch.FloatTensor(self.mem_size).fill_(0)
def before_training(self, strategy: "SupervisedTemplate", **kwargs): self.device = strategy.device self.ext_mem_list_x = self.ext_mem_list_x.to(strategy.device) self.ext_mem_list_y = self.ext_mem_list_y.to(strategy.device) self.buffer_score = self.buffer_score.to(strategy.device) def cosine_similarity(self, x1, x2=None, eps=1e-8): x2 = x1 if x2 is None else x2 w1 = x1.norm(p=2, dim=1, keepdim=True) w2 = w1 if x2 is x1 else x2.norm(p=2, dim=1, keepdim=True) sim = torch.mm(x1, x2.t()) / (w1 * w2.t()).clamp(min=eps) return sim def get_grad_vector(self, pp, grad_dims): """ gather the gradients in one vector """ grads = torch.zeros(sum(grad_dims), device=self.device) grads.fill_(0.0) cnt = 0 for param in pp(): if param.grad is not None: beg = 0 if cnt == 0 else sum(grad_dims[:cnt]) en = sum(grad_dims[: cnt + 1]) grads[beg:en].copy_(param.grad.data.view(-1)) cnt += 1 return grads def get_batch_sim(self, strategy, grad_dims, batch_x, batch_y): """ Args: buffer: memory buffer grad_dims: gradient dimensions batch_x: current batch x batch_y: current batch y Returns: score of current batch, gradient from memory subsets """ mem_grads = self.get_rand_mem_grads(strategy, grad_dims, len(batch_x)) strategy.model.zero_grad() loss = strategy._criterion(strategy.model.forward(batch_x), batch_y) loss.backward() batch_grad = self.get_grad_vector( strategy.model.parameters, grad_dims ).unsqueeze(0) batch_sim = max(self.cosine_similarity(mem_grads, batch_grad)) return batch_sim, mem_grads def get_rand_mem_grads(self, strategy, grad_dims, gss_batch_size): """ Args: buffer: memory buffer grad_dims: gradient dimensions Returns: gradient from memory subsets """ temp_gss_batch_size = min( gss_batch_size, self.ext_mem_list_current_index ) num_mem_subs = min( self.mem_strength, self.ext_mem_list_current_index // gss_batch_size ) mem_grads = torch.zeros( num_mem_subs, sum(grad_dims), dtype=torch.float32, device=self.device, ) shuffeled_inds = torch.randperm( self.ext_mem_list_current_index, device=self.device ) for i in range(num_mem_subs): random_batch_inds = shuffeled_inds[ i * temp_gss_batch_size : i * temp_gss_batch_size + temp_gss_batch_size ] batch_x = self.ext_mem_list_x[random_batch_inds].to(strategy.device) batch_y = self.ext_mem_list_y[random_batch_inds].to(strategy.device) strategy.model.zero_grad() loss = strategy._criterion(strategy.model.forward(batch_x), batch_y) loss.backward() mem_grads[i].data.copy_( self.get_grad_vector(strategy.model.parameters, grad_dims) ) return mem_grads def get_each_batch_sample_sim( self, strategy, grad_dims, mem_grads, batch_x, batch_y ): """ Args: buffer: memory buffer grad_dims: gradient dimensions mem_grads: gradient from memory subsets batch_x: batch images batch_y: batch labels Returns: score of each sample from current batch """ cosine_sim = torch.zeros(batch_x.size(0), device=strategy.device) for i, (x, y) in enumerate(zip(batch_x, batch_y)): strategy.model.zero_grad() ptloss = strategy._criterion( strategy.model.forward(x.unsqueeze(0)), y.unsqueeze(0) ) ptloss.backward() # add the new grad to the memory grads and add it is cosine # similarity this_grad = self.get_grad_vector( strategy.model.parameters, grad_dims ).unsqueeze(0) cosine_sim[i] = max(self.cosine_similarity(mem_grads, this_grad)) return cosine_sim def before_training_exp( self, strategy, num_workers=0, shuffle=True, **kwargs ): """ Dataloader to build batches containing examples from both memories and the training dataset """ if self.ext_mem_list_current_index == 0: return temp_x_tensors = [] for elem in self.ext_mem_list_x: temp_x_tensors.append(elem.to("cpu")) temp_y_tensors = self.ext_mem_list_y.to("cpu") memory = list(zip(temp_x_tensors, temp_y_tensors)) memory = make_classification_dataset(memory, targets=temp_y_tensors) strategy.dataloader = ReplayDataLoader( strategy.adapted_dataset, memory, oversample_small_tasks=True, num_workers=num_workers, batch_size=strategy.train_mb_size, shuffle=shuffle, ) def after_forward(self, strategy, num_workers=0, shuffle=True, **kwargs): """ After every forward this function select sample to fill the memory buffer based on cosine similarity """ strategy.model.eval() # Compute the gradient dimension grad_dims = [] for param in strategy.model.parameters(): grad_dims.append(param.data.numel()) place_left = ( self.ext_mem_list_x.size(0) - self.ext_mem_list_current_index ) if place_left <= 0: # buffer full batch_sim, mem_grads = self.get_batch_sim( strategy, grad_dims, batch_x=strategy.mb_x, batch_y=strategy.mb_y, ) if batch_sim < 0: buffer_score = self.buffer_score[ : self.ext_mem_list_current_index ].cpu() buffer_sim = (buffer_score - torch.min(buffer_score)) / ( (torch.max(buffer_score) - torch.min(buffer_score)) + 0.01 ) # draw candidates for replacement from the buffer index = torch.multinomial( buffer_sim, strategy.mb_x.size(0), replacement=False ).to(strategy.device) # estimate the similarity of each sample in the received batch # to the randomly drawn samples from the buffer. batch_item_sim = self.get_each_batch_sample_sim( strategy, grad_dims, mem_grads, strategy.mb_x, strategy.mb_y ) # normalize to [0,1] scaled_batch_item_sim = ((batch_item_sim + 1) / 2).unsqueeze(1) buffer_repl_batch_sim = ( (self.buffer_score[index] + 1) / 2 ).unsqueeze(1) # draw an event to decide on replacement decision outcome = torch.multinomial( torch.cat( (scaled_batch_item_sim, buffer_repl_batch_sim), dim=1 ), 1, replacement=False, ) # replace samples with outcome =1 added_indx = torch.arange( end=batch_item_sim.size(0), device=strategy.device ) sub_index = outcome.squeeze(1).bool() self.ext_mem_list_x[index[sub_index]] = strategy.mb_x[ added_indx[sub_index] ].clone() self.ext_mem_list_y[index[sub_index]] = strategy.mb_y[ added_indx[sub_index] ].clone() self.buffer_score[index[sub_index]] = batch_item_sim[ added_indx[sub_index] ].clone() else: offset = min(place_left, strategy.mb_x.size(0)) updated_mb_x = strategy.mb_x[:offset] updated_mb_y = strategy.mb_y[:offset] # first buffer insertion if self.ext_mem_list_current_index == 0: batch_sample_memory_cos = ( torch.zeros(updated_mb_x.size(0)) + 0.1 ) else: # draw random samples from buffer mem_grads = self.get_rand_mem_grads( strategy=strategy, grad_dims=grad_dims, gss_batch_size=len(strategy.mb_x), ) # estimate a score for each added sample batch_sample_memory_cos = self.get_each_batch_sample_sim( strategy, grad_dims, mem_grads, updated_mb_x, updated_mb_y ) curr_idx = self.ext_mem_list_current_index self.ext_mem_list_x[curr_idx : curr_idx + offset].data.copy_( updated_mb_x ) self.ext_mem_list_y[curr_idx : curr_idx + offset].data.copy_( updated_mb_y ) self.buffer_score[curr_idx : curr_idx + offset].data.copy_( batch_sample_memory_cos ) self.ext_mem_list_current_index += offset strategy.model.train()