################################################################################
# Copyright (c) 2021 ContinualAI. #
# Copyrights licensed under the MIT License. #
# See the accompanying LICENSE file for terms. #
# #
# Date: 01-12-2020 #
# Author(s): Antonio Carta #
# E-mail: contact@continualai.org #
# Website: avalanche.continualai.org #
################################################################################
"""
Avalanche supports data loading using pytorch's dataloaders.
This module provides custom dataloaders for continual learning such as
support for balanced dataloading between different tasks or balancing
between the current data and the replay memory.
"""
from itertools import chain
from typing import Dict, Optional, Sequence, Union
import torch
from torch.utils.data import RandomSampler, DistributedSampler
from torch.utils.data.dataloader import DataLoader
from avalanche.benchmarks.utils import make_classification_dataset
from avalanche.benchmarks.utils.collate_functions import (
classification_collate_mbatches_fn,
)
from avalanche.benchmarks.utils.collate_functions import (
detection_collate_fn as _detection_collate_fn,
)
from avalanche.benchmarks.utils.collate_functions import (
detection_collate_mbatches_fn as _detection_collate_mbatches_fn,
)
from avalanche.benchmarks.utils.data import AvalancheDataset
_default_collate_mbatches_fn = classification_collate_mbatches_fn
detection_collate_fn = _detection_collate_fn
detection_collate_mbatches_fn = _detection_collate_mbatches_fn
def collate_from_data_or_kwargs(data, kwargs):
if "collate_fn" in kwargs:
return
elif hasattr(data, "collate_fn"):
kwargs["collate_fn"] = data.collate_fn
[docs]class TaskBalancedDataLoader:
"""Task-balanced data loader for Avalanche's datasets."""
[docs] def __init__(
self,
data: AvalancheDataset,
oversample_small_tasks: bool = False,
**kwargs
):
"""Task-balanced data loader for Avalanche's datasets.
The iterator returns a mini-batch balanced across each task, which
makes it useful when training in multi-task scenarios whenever data is
highly unbalanced.
If `oversample_small_tasks == True` smaller tasks are
oversampled to match the largest task. Otherwise, once the data for a
specific task is terminated, that task will not be present in the
subsequent mini-batches.
:param data: an instance of `AvalancheDataset`.
:param oversample_small_tasks: whether smaller tasks should be
oversampled to match the largest one.
:param kwargs: data loader arguments used to instantiate the loader for
each task separately. See pytorch :class:`DataLoader`.
"""
if "collate_mbatches" in kwargs:
raise ValueError(
"collate_mbatches is not needed anymore and it has been "
"deprecated. Data loaders will use the collate function"
"`data.collate_fn`."
)
self.data = data
self.dataloaders: Dict[int, DataLoader] = dict()
self.oversample_small_tasks = oversample_small_tasks
# split data by task.
task_datasets = []
for task_label in self.data.targets_task_labels.uniques:
tidxs = self.data.targets_task_labels.val_to_idx[task_label]
tdata = self.data.subset(tidxs)
task_datasets.append(tdata)
# the iteration logic is implemented by GroupBalancedDataLoader.
# we use kwargs to pass the arguments to avoid passing the same
# arguments multiple times.
if "data" in kwargs:
del kwargs["data"]
# needed if they are passed as positional arguments
kwargs["oversample_small_groups"] = oversample_small_tasks
self._dl = GroupBalancedDataLoader(datasets=task_datasets, **kwargs)
def __iter__(self):
for el in self._dl.__iter__():
yield el
def __len__(self):
return self._dl.__len__()
[docs]class GroupBalancedDataLoader:
"""Data loader that balances data from multiple datasets."""
[docs] def __init__(
self,
datasets: Sequence[make_classification_dataset],
oversample_small_groups: bool = False,
batch_size: int = 32,
distributed_sampling: bool = True,
**kwargs
):
"""Data loader that balances data from multiple datasets.
Mini-batches emitted by this dataloader are created by collating
together mini-batches from each group. It may be used to balance data
among classes, experiences, tasks, and so on.
If `oversample_small_groups == True` smaller groups are oversampled to
match the largest group. Otherwise, once data from a group is
completely iterated, the group will be skipped.
:param datasets: an instance of `AvalancheDataset`.
:param oversample_small_groups: whether smaller groups should be
oversampled to match the largest one.
:param batch_size: the size of the batch. It must be greater than or
equal to the number of groups.
:param kwargs: data loader arguments used to instantiate the loader for
each group separately. See pytorch :class:`DataLoader`.
"""
if "collate_mbatches" in kwargs:
raise ValueError(
"collate_mbatches is not needed anymore and it has been "
"deprecated. Data loaders will use the collate function"
"`data.collate_fn`."
)
self.datasets = datasets
self.batch_sizes = []
self.oversample_small_groups = oversample_small_groups
self.distributed_sampling = distributed_sampling
self.loader_kwargs = kwargs
if "collate_fn" in kwargs:
self.collate_fn = kwargs["collate_fn"]
else:
self.collate_fn = self.datasets[0].collate_fn
# collate is done after we have all batches
# so we set an empty collate for the internal dataloaders
self.loader_kwargs["collate_fn"] = lambda x: x
# check if batch_size is larger than or equal to the number of datasets
assert batch_size >= len(datasets)
# divide the batch between all datasets in the group
ds_batch_size = batch_size // len(datasets)
remaining = batch_size % len(datasets)
for _ in self.datasets:
bs = ds_batch_size
if remaining > 0:
bs += 1
remaining -= 1
self.batch_sizes.append(bs)
loaders_for_len_estimation = [
_make_data_loader(
dataset,
distributed_sampling,
kwargs,
mb_size,
force_no_workers=True,
)[0]
for dataset, mb_size in zip(self.datasets, self.batch_sizes)
]
self.max_len = max([len(d) for d in loaders_for_len_estimation])
def __iter__(self):
dataloaders = []
samplers = []
for dataset, mb_size in zip(self.datasets, self.batch_sizes):
data_l, data_l_sampler = _make_data_loader(
dataset,
self.distributed_sampling,
self.loader_kwargs,
mb_size,
)
dataloaders.append(data_l)
samplers.append(data_l_sampler)
iter_dataloaders = []
for dl in dataloaders:
iter_dataloaders.append(iter(dl))
max_num_mbatches = max([len(d) for d in dataloaders])
for it in range(max_num_mbatches):
mb_curr = []
removed_dataloaders_idxs = []
# copy() is necessary because we may remove keys from the
# dictionary. This would break the generator.
for tid, (t_loader, t_loader_sampler) in enumerate(
zip(iter_dataloaders, samplers)
):
try:
batch = next(t_loader)
except StopIteration:
# StopIteration is thrown if dataset ends.
if self.oversample_small_groups:
# reinitialize data loader
if isinstance(t_loader_sampler, DistributedSampler):
# Manage shuffling in DistributedSampler
t_loader_sampler.set_epoch(
t_loader_sampler.epoch + 1
)
iter_dataloaders[tid] = iter(dataloaders[tid])
batch = next(iter_dataloaders[tid])
else:
# We iteratated over all the data from this group
# and we don't need the iterator anymore.
iter_dataloaders[tid] = None
samplers[tid] = None
removed_dataloaders_idxs.append(tid)
continue
mb_curr.extend(batch)
yield self.collate_fn(mb_curr)
# clear empty data-loaders
for tid in reversed(removed_dataloaders_idxs):
del iter_dataloaders[tid]
del samplers[tid]
def __len__(self):
return self.max_len
[docs]class GroupBalancedInfiniteDataLoader:
"""Data loader that balances data from multiple datasets emitting an
infinite stream."""
[docs] def __init__(
self,
datasets: Sequence[make_classification_dataset],
collate_mbatches=_default_collate_mbatches_fn,
distributed_sampling: bool = True,
**kwargs
):
"""Data loader that balances data from multiple datasets emitting an
infinite stream.
Mini-batches emitted by this dataloader are created by collating
together mini-batches from each group. It may be used to balance data
among classes, experiences, tasks, and so on.
:param datasets: an instance of `AvalancheDataset`.
:param collate_mbatches: function that given a sequence of mini-batches
(one for each task) combines them into a single mini-batch. Used to
combine the mini-batches obtained separately from each task.
:param kwargs: data loader arguments used to instantiate the loader for
each group separately. See pytorch :class:`DataLoader`.
"""
self.datasets = datasets
self.dataloaders = []
self.collate_mbatches = collate_mbatches
for data in self.datasets:
if _DistributedHelper.is_distributed and distributed_sampling:
seed = torch.randint(
0,
2 ** 32 - 1 - _DistributedHelper.world_size,
(1,),
dtype=torch.int64,
)
seed += _DistributedHelper.rank
generator = torch.Generator()
generator.manual_seed(int(seed))
else:
generator = None # Default
infinite_sampler = RandomSampler(
data,
replacement=True,
num_samples=10 ** 10,
generator=generator,
)
collate_from_data_or_kwargs(data, kwargs)
dl = DataLoader(data, sampler=infinite_sampler, **kwargs)
self.dataloaders.append(dl)
self.max_len = 10 ** 10
def __iter__(self):
iter_dataloaders = []
for dl in self.dataloaders:
iter_dataloaders.append(iter(dl))
while True:
mb_curr = []
for tid, t_loader in enumerate(iter_dataloaders):
batch = next(t_loader)
mb_curr.append(batch)
yield self.collate_mbatches(mb_curr)
def __len__(self):
return self.max_len
[docs]class ReplayDataLoader:
"""Custom data loader for rehearsal/replay strategies."""
[docs] def __init__(
self,
data: AvalancheDataset,
memory: Optional[AvalancheDataset] = None,
oversample_small_tasks: bool = False,
batch_size: int = 32,
batch_size_mem: int = 32,
task_balanced_dataloader: bool = False,
distributed_sampling: bool = True,
**kwargs
):
"""Custom data loader for rehearsal strategies.
This dataloader iterates in parallel two datasets, the current `data`
and the rehearsal `memory`, which are used to create mini-batches by
concatenating their data together. Mini-batches from both of them are
balanced using the task label (i.e. each mini-batch contains a balanced
number of examples from all the tasks in the `data` and `memory`).
If `oversample_small_tasks == True` smaller tasks are oversampled to
match the largest task.
:param data: AvalancheDataset.
:param memory: AvalancheDataset.
:param oversample_small_tasks: whether smaller tasks should be
oversampled to match the largest one.
:param batch_size: the size of the data batch. It must be greater
than or equal to the number of tasks.
:param batch_size_mem: the size of the memory batch. If
`task_balanced_dataloader` is set to True, it must be greater than
or equal to the number of tasks.
:param task_balanced_dataloader: if true, buffer data loaders will be
task-balanced, otherwise it creates a single data loader for the
buffer samples.
:param kwargs: data loader arguments used to instantiate the loader for
each task separately. See pytorch :class:`DataLoader`.
"""
if "collate_mbatches" in kwargs:
raise ValueError(
"collate_mbatches is not needed anymore and it has been "
"deprecated. Data loaders will use the collate function"
"`data.collate_fn`."
)
self.data = data
self.memory = memory
self.oversample_small_tasks = oversample_small_tasks
self.task_balanced_dataloader = task_balanced_dataloader
self.data_batch_sizes: Union[int, Dict[int, int]] = dict()
self.memory_batch_sizes: Union[int, Dict[int, int]] = dict()
self.distributed_sampling = distributed_sampling
self.loader_kwargs = kwargs
if "collate_fn" in kwargs:
self.collate_fn = kwargs["collate_fn"]
else:
self.collate_fn = self.data.collate_fn
# collate is done after we have all batches
# so we set an empty collate for the internal dataloaders
self.loader_kwargs["collate_fn"] = lambda x: x
if task_balanced_dataloader:
num_keys = len(self.memory.targets_task_labels.uniques)
assert batch_size_mem >= num_keys, (
"Batch size must be greator or equal "
"to the number of tasks in the memory "
"and current data."
)
self.data_batch_sizes, _ = self._get_batch_sizes(
data, batch_size, 0, False
)
# Create dataloader for memory items
if task_balanced_dataloader:
num_keys = len(self.memory.targets_task_labels.uniques)
single_group_batch_size = batch_size_mem // num_keys
remaining_example = batch_size_mem % num_keys
else:
single_group_batch_size = batch_size_mem
remaining_example = 0
self.memory_batch_sizes, _ = self._get_batch_sizes(
memory,
single_group_batch_size,
remaining_example,
task_balanced_dataloader,
)
loaders_for_len_estimation = []
if isinstance(self.data_batch_sizes, int):
loaders_for_len_estimation.append(
_make_data_loader(
data,
distributed_sampling,
kwargs,
self.data_batch_sizes,
force_no_workers=True,
)[0]
)
else:
# Task balanced
for task_id in data.task_set:
dataset = data.task_set[task_id]
mb_sz = self.data_batch_sizes[task_id]
loaders_for_len_estimation.append(
_make_data_loader(
dataset,
distributed_sampling,
kwargs,
mb_sz,
force_no_workers=True,
)[0]
)
if isinstance(self.memory_batch_sizes, int):
loaders_for_len_estimation.append(
_make_data_loader(
memory,
distributed_sampling,
kwargs,
self.memory_batch_sizes,
force_no_workers=True,
)[0]
)
else:
for task_id in memory.task_set:
dataset = memory.task_set[task_id]
mb_sz = self.memory_batch_sizes[task_id]
loaders_for_len_estimation.append(
_make_data_loader(
dataset,
distributed_sampling,
kwargs,
mb_sz,
force_no_workers=True,
)[0]
)
self.max_len = max([len(d) for d in loaders_for_len_estimation])
def __iter__(self):
loader_data, sampler_data = self._create_loaders_and_samplers(
self.data, self.data_batch_sizes
)
loader_memory, sampler_memory = self._create_loaders_and_samplers(
self.memory, self.memory_batch_sizes
)
iter_data_dataloaders = {}
iter_buffer_dataloaders = {}
for t in loader_data.keys():
iter_data_dataloaders[t] = iter(loader_data[t])
for t in loader_memory.keys():
iter_buffer_dataloaders[t] = iter(loader_memory[t])
max_len = max(
[
len(d)
for d in chain(
loader_data.values(),
loader_memory.values(),
)
]
)
try:
for it in range(max_len):
mb_curr = []
ReplayDataLoader._get_mini_batch_from_data_dict(
iter_data_dataloaders,
sampler_data,
loader_data,
self.oversample_small_tasks,
mb_curr,
)
ReplayDataLoader._get_mini_batch_from_data_dict(
iter_buffer_dataloaders,
sampler_memory,
loader_memory,
self.oversample_small_tasks,
mb_curr,
)
yield self.collate_fn(mb_curr)
except StopIteration:
return
def __len__(self):
return self.max_len
@staticmethod
def _get_mini_batch_from_data_dict(
iter_dataloaders,
iter_samplers,
loaders_dict,
oversample_small_tasks,
mb_curr,
):
# list() is necessary because we may remove keys from the
# dictionary. This would break the generator.
for t in list(iter_dataloaders.keys()):
t_loader = iter_dataloaders[t]
t_sampler = iter_samplers[t]
try:
tbatch = next(t_loader)
except StopIteration:
# StopIteration is thrown if dataset ends.
# reinitialize data loader
if oversample_small_tasks:
# reinitialize data loader
if isinstance(t_sampler, DistributedSampler):
# Manage shuffling in DistributedSampler
t_sampler.set_epoch(t_sampler.epoch + 1)
iter_dataloaders[t] = iter(loaders_dict[t])
tbatch = next(iter_dataloaders[t])
else:
del iter_dataloaders[t]
del iter_samplers[t]
continue
mb_curr.extend(tbatch)
def _create_loaders_and_samplers(self, data, batch_sizes):
loaders = dict()
samplers = dict()
if isinstance(batch_sizes, int):
loader, sampler = _make_data_loader(
data,
self.distributed_sampling,
self.loader_kwargs,
batch_sizes,
)
loaders[0] = loader
samplers[0] = sampler
else:
for task_id in data.task_set:
dataset = data.task_set[task_id]
mb_sz = batch_sizes[task_id]
loader, sampler = _make_data_loader(
dataset,
self.distributed_sampling,
self.loader_kwargs,
mb_sz,
)
loaders[task_id] = loader
samplers[task_id] = sampler
return loaders, samplers
@staticmethod
def _get_batch_sizes(
data_dict,
single_exp_batch_size,
remaining_example,
task_balanced_dataloader,
):
batch_sizes = dict()
if task_balanced_dataloader:
for task_id in data_dict.task_set:
current_batch_size = single_exp_batch_size
if remaining_example > 0:
current_batch_size += 1
remaining_example -= 1
batch_sizes[task_id] = current_batch_size
else:
# Current data is loaded without task balancing
batch_sizes = single_exp_batch_size
return batch_sizes, remaining_example
def _make_data_loader(
dataset,
distributed_sampling,
data_loader_args,
batch_size,
force_no_workers=False,
):
data_loader_args = data_loader_args.copy()
collate_from_data_or_kwargs(dataset, data_loader_args)
if force_no_workers:
data_loader_args['num_workers'] = 0
if 'persistent_workers' in data_loader_args:
data_loader_args['persistent_workers'] = False
if _DistributedHelper.is_distributed and distributed_sampling:
sampler = DistributedSampler(
dataset,
shuffle=data_loader_args.pop("shuffle", False),
drop_last=data_loader_args.pop("drop_last", False),
)
data_loader = DataLoader(
dataset, sampler=sampler, batch_size=batch_size, **data_loader_args
)
else:
sampler = None
data_loader = DataLoader(
dataset, batch_size=batch_size, **data_loader_args
)
return data_loader, sampler
class __DistributedHelperPlaceholder:
is_distributed = False
world_size = 1
rank = 0
_DistributedHelper = __DistributedHelperPlaceholder()
__all__ = [
"detection_collate_fn",
"detection_collate_mbatches_fn",
"collate_from_data_or_kwargs",
"TaskBalancedDataLoader",
"GroupBalancedDataLoader",
"ReplayDataLoader",
"GroupBalancedInfiniteDataLoader",
]