Source code for avalanche.benchmarks.utils.data_loader

################################################################################
# Copyright (c) 2021 ContinualAI.                                              #
# Copyrights licensed under the MIT License.                                   #
# See the accompanying LICENSE file for terms.                                 #
#                                                                              #
# Date: 01-12-2020                                                             #
# Author(s): Antonio Carta                                                     #
# E-mail: contact@continualai.org                                              #
# Website: avalanche.continualai.org                                           #
################################################################################
"""
    Avalanche supports data loading using pytorch's dataloaders.
    This module provides custom dataloaders for continual learning such as
    support for balanced dataloading between different tasks or balancing
    between the current data and the replay memory.
"""
from itertools import chain
from typing import Dict, Sequence

import torch
from torch.utils.data import RandomSampler
from torch.utils.data.dataloader import DataLoader

from avalanche.benchmarks.utils import AvalancheDataset


def _default_collate_mbatches_fn(mbatches):
    """ Combines multiple mini-batches together.

    Concatenates each tensor in the mini-batches along dimension 0 (usually this
    is the batch size).

    :param mbatches: sequence of mini-batches.
    :return: a single mini-batch
    """
    batch = []
    for i in range(len(mbatches[0])):
        t = torch.cat([el[i] for el in mbatches], dim=0)
        batch.append(t)
    return batch


[docs]class TaskBalancedDataLoader: """ Task-balanced data loader for Avalanche's datasets."""
[docs] def __init__(self, data: AvalancheDataset, oversample_small_tasks: bool = False, collate_mbatches=_default_collate_mbatches_fn, **kwargs): """ Task-balanced data loader for Avalanche's datasets. The iterator returns a mini-batch balanced across each task, which makes it useful when training in multi-task scenarios whenever data is highly unbalanced. If `oversample_small_tasks == True` smaller tasks are oversampled to match the largest task. Otherwise, once the data for a specific task is terminated, that task will not be present in the subsequent mini-batches. :param data: an instance of `AvalancheDataset`. :param oversample_small_tasks: whether smaller tasks should be oversampled to match the largest one. :param collate_mbatches: function that given a sequence of mini-batches (one for each task) combines them into a single mini-batch. Used to combine the mini-batches obtained separately from each task. :param kwargs: data loader arguments used to instantiate the loader for each task separately. See pytorch :class:`DataLoader`. """ self.data = data self.dataloaders: Dict[int, DataLoader] = {} self.oversample_small_tasks = oversample_small_tasks self.collate_mbatches = collate_mbatches # split data by task. task_datasets = [] for task_label in self.data.task_set: tdata = self.data.task_set[task_label] task_datasets.append(tdata) # the iteration logic is implemented by GroupBalancedDataLoader. # we use kwargs to pass the arguments to avoid passing the same # arguments multiple times. if 'data' in kwargs: del kwargs['data'] # needed if they are passed as positional arguments kwargs['oversample_small_groups'] = oversample_small_tasks kwargs['collate_mbatches'] = collate_mbatches self._dl = GroupBalancedDataLoader(datasets=task_datasets, **kwargs)
def __iter__(self): for el in self._dl.__iter__(): yield el def __len__(self): return self._dl.__len__()
[docs]class GroupBalancedDataLoader: """ Data loader that balances data from multiple datasets."""
[docs] def __init__(self, datasets: Sequence[AvalancheDataset], oversample_small_groups: bool = False, collate_mbatches=_default_collate_mbatches_fn, **kwargs): """ Data loader that balances data from multiple datasets. Mini-batches emitted by this dataloader are created by collating together mini-batches from each group. It may be used to balance data among classes, experiences, tasks, and so on. If `oversample_small_groups == True` smaller groups are oversampled to match the largest group. Otherwise, once data from a group is completely iterated, the group will be skipped. :param datasets: an instance of `AvalancheDataset`. :param oversample_small_groups: whether smaller groups should be oversampled to match the largest one. :param collate_mbatches: function that given a sequence of mini-batches (one for each task) combines them into a single mini-batch. Used to combine the mini-batches obtained separately from each task. :param kwargs: data loader arguments used to instantiate the loader for each group separately. See pytorch :class:`DataLoader`. """ self.datasets = datasets self.dataloaders = [] self.oversample_small_groups = oversample_small_groups self.collate_mbatches = collate_mbatches for data in self.datasets: self.dataloaders.append(DataLoader(data, **kwargs)) self.max_len = max([len(d) for d in self.dataloaders])
def __iter__(self): iter_dataloaders = [] for dl in self.dataloaders: iter_dataloaders.append(iter(dl)) max_num_mbatches = max([len(d) for d in iter_dataloaders]) for it in range(max_num_mbatches): mb_curr = [] is_removed_dataloader = False # copy() is necessary because we may remove keys from the # dictionary. This would break the generator. for tid, t_loader in enumerate(iter_dataloaders): try: batch = next(t_loader) except StopIteration: # StopIteration is thrown if dataset ends. if self.oversample_small_groups: # reinitialize data loader iter_dataloaders[tid] = iter(self.dataloaders[tid]) batch = next(iter_dataloaders[tid]) else: # We iteratated over all the data from this group # and we don't need the iterator anymore. iter_dataloaders[tid] = None is_removed_dataloader = True continue mb_curr.append(batch) yield self.collate_mbatches(mb_curr) # clear empty data-loaders if is_removed_dataloader: while None in iter_dataloaders: iter_dataloaders.remove(None) def __len__(self): return self.max_len
[docs]class GroupBalancedInfiniteDataLoader: """ Data loader that balances data from multiple datasets emitting an infinite stream."""
[docs] def __init__(self, datasets: Sequence[AvalancheDataset], collate_mbatches=_default_collate_mbatches_fn, **kwargs): """ Data loader that balances data from multiple datasets emitting an infinite stream. Mini-batches emitted by this dataloader are created by collating together mini-batches from each group. It may be used to balance data among classes, experiences, tasks, and so on. :param datasets: an instance of `AvalancheDataset`. :param collate_mbatches: function that given a sequence of mini-batches (one for each task) combines them into a single mini-batch. Used to combine the mini-batches obtained separately from each task. :param kwargs: data loader arguments used to instantiate the loader for each group separately. See pytorch :class:`DataLoader`. """ self.datasets = datasets self.dataloaders = [] self.collate_mbatches = collate_mbatches for data in self.datasets: infinite_sampler = RandomSampler(data, replacement=True, num_samples=10 ** 10) dl = DataLoader( data, sampler=infinite_sampler, **kwargs) self.dataloaders.append(dl) self.max_len = 10 ** 10
def __iter__(self): iter_dataloaders = [] for dl in self.dataloaders: iter_dataloaders.append(iter(dl)) while True: mb_curr = [] for tid, t_loader in enumerate(iter_dataloaders): batch = next(t_loader) mb_curr.append(batch) yield self.collate_mbatches(mb_curr) def __len__(self): return self.max_len
[docs]class ReplayDataLoader: """ Custom data loader for rehearsal/replay strategies."""
[docs] def __init__(self, data: AvalancheDataset, memory: AvalancheDataset = None, oversample_small_tasks: bool = False, collate_mbatches=_default_collate_mbatches_fn, batch_size: int = 32, force_data_batch_size: int = None, **kwargs): """ Custom data loader for rehearsal strategies. The iterates in parallel two datasets, the current `data` and the rehearsal `memory`, which are used to create mini-batches by concatenating their data together. Mini-batches from both of them are balanced using the task label (i.e. each mini-batch contains a balanced number of examples from all the tasks in the `data` and `memory`). If `oversample_small_tasks == True` smaller tasks are oversampled to match the largest task. :param data: AvalancheDataset. :param memory: AvalancheDataset. :param oversample_small_tasks: whether smaller tasks should be oversampled to match the largest one. :param collate_mbatches: function that given a sequence of mini-batches (one for each task) combines them into a single mini-batch. Used to combine the mini-batches obtained separately from each task. :param batch_size: the size of the batch. It must be greater than or equal to the number of tasks. :param ratio_data_mem: How many of the samples should be from :param kwargs: data loader arguments used to instantiate the loader for each task separately. See pytorch :class:`DataLoader`. """ self.data = data self.memory = memory self.loader_data: Sequence[DataLoader] = {} self.loader_memory: Sequence[DataLoader] = {} self.oversample_small_tasks = oversample_small_tasks self.collate_mbatches = collate_mbatches if force_data_batch_size is not None: assert force_data_batch_size <= batch_size, \ "Forced batch size of data must be <= entire batch size" mem_batch_size = batch_size - force_data_batch_size remaining_example = 0 mem_keys = len(self.memory.task_set) assert mem_batch_size >= mem_keys, \ "Batch size must be greator or equal " \ "to the number of tasks in the memory." self.loader_data, _ = self._create_dataloaders( data, force_data_batch_size, remaining_example, **kwargs) self.loader_memory, _ = self._create_dataloaders( memory, mem_batch_size, remaining_example, **kwargs) else: num_keys = len(self.data.task_set) + len(self.memory.task_set) assert batch_size >= num_keys, \ "Batch size must be greator or equal " \ "to the number of tasks in the memory " \ "and current data." single_group_batch_size = batch_size // num_keys remaining_example = batch_size % num_keys self.loader_data, remaining_example = self._create_dataloaders( data, single_group_batch_size, remaining_example, **kwargs) self.loader_memory, remaining_example = self._create_dataloaders( memory, single_group_batch_size, remaining_example, **kwargs) self.max_len = max([len(d) for d in chain( self.loader_data.values(), self.loader_memory.values())] )
def __iter__(self): iter_data_dataloaders = {} iter_buffer_dataloaders = {} for t in self.loader_data.keys(): iter_data_dataloaders[t] = iter(self.loader_data[t]) for t in self.loader_memory.keys(): iter_buffer_dataloaders[t] = iter(self.loader_memory[t]) max_len = max([len(d) for d in chain(iter_data_dataloaders.values(), iter_buffer_dataloaders.values())]) try: for it in range(max_len): mb_curr = [] self._get_mini_batch_from_data_dict( self.data, iter_data_dataloaders, self.loader_data, self.oversample_small_tasks, mb_curr) self._get_mini_batch_from_data_dict( self.memory, iter_buffer_dataloaders, self.loader_memory, self.oversample_small_tasks, mb_curr) yield self.collate_mbatches(mb_curr) except StopIteration: return def __len__(self): return self.max_len def _get_mini_batch_from_data_dict(self, data, iter_dataloaders, loaders_dict, oversample_small_tasks, mb_curr): # list() is necessary because we may remove keys from the # dictionary. This would break the generator. for t in list(iter_dataloaders.keys()): t_loader = iter_dataloaders[t] try: tbatch = next(t_loader) except StopIteration: # StopIteration is thrown if dataset ends. # reinitialize data loader if oversample_small_tasks: # reinitialize data loader iter_dataloaders[t] = iter(loaders_dict[t]) tbatch = next(iter_dataloaders[t]) else: del iter_dataloaders[t] continue mb_curr.append(tbatch) def _create_dataloaders(self, data_dict, single_exp_batch_size, remaining_example, **kwargs): loaders_dict: Dict[int, DataLoader] = {} for task_id in data_dict.task_set: data = data_dict.task_set[task_id] current_batch_size = single_exp_batch_size if remaining_example > 0: current_batch_size += 1 remaining_example -= 1 loaders_dict[task_id] = DataLoader( data, batch_size=current_batch_size, **kwargs) return loaders_dict, remaining_example
__all__ = [ 'TaskBalancedDataLoader', 'GroupBalancedDataLoader', 'ReplayDataLoader', 'GroupBalancedInfiniteDataLoader' ]