Source code for avalanche.benchmarks.utils.data

################################################################################
# Copyright (c) 2022 ContinualAI.                                              #
# Copyrights licensed under the MIT License.                                   #
# See the accompanying LICENSE file for terms.                                 #
#                                                                              #
# Date: 19-07-2022                                                             #
# Author(s): Antonio Carta                                                     #
# E-mail: contact@continualai.org                                              #
# Website: avalanche.continualai.org                                           #
################################################################################

"""
This module contains the implementation of the Avalanche Dataset,
Avalanche dataset class which extends PyTorch's dataset.
AvalancheDataset offers additional features like the
management of preprocessing pipelines and task/class labels.
"""
import copy
import warnings
import numpy as np

from torch.utils.data.dataloader import default_collate

from avalanche.benchmarks.utils.dataset_definitions import IDataset
from .data_attribute import DataAttribute

from typing import (
    Dict,
    List,
    Any,
    Optional,
    Sequence,
    TypeVar,
    Callable,
    Union,
    overload,
)

from .flat_data import FlatData
from .transform_groups import TransformGroups, EmptyTransformGroups
from torch.utils.data import Dataset as TorchDataset
from collections import OrderedDict


T_co = TypeVar("T_co", covariant=True)
TAvalancheDataset = TypeVar("TAvalancheDataset", bound="AvalancheDataset")
TDataWTransform = TypeVar("TDataWTransform", bound="_FlatDataWithTransform")


[docs]class AvalancheDataset(IDataset[T_co]): """Avalanche Dataset. Avlanche dataset are pytorch-compatible Datasets with some additional functionality such as: - management of transformation groups via :class:`AvalancheTransform` - support for sample attributes such as class targets and task labels Data Attributes --------------- Avalanche datasets manage sample-wise information such as class or task labels via :class:`DataAttribute`. Transformation Groups --------------------- Avalanche datasets manage transformation via transformation groups. Simply put, a transformation group is a named preprocessing function (as in torchvision datasets). By default, Avalanche expects two transformation groups: - 'train', which contains transformations applied to training patterns. - 'eval', that contain transformations applied to test patterns. Having both groups allows to use different transformations during training and evaluation and to seamlessly switch between them by using the :func:`train` and :func:`eval` methods. Arbitrary transformation groups can be added and used. If you define custom groups, you can use them by calling the `:func:with_transforms` method. switching to a different transformation group by calling the ``train()``, ``eval()`` or ``with_transforms` methods always returns a new dataset, levaing the original one unchanged. Ttransformation groups can be manipulated by removing, freezing, or replacing transformations. Each operation returns a new dataset, leaving the original one unchanged. """
[docs] def __init__( self, datasets: Sequence[IDataset[T_co]], *, indices: Optional[List[int]] = None, data_attributes: Optional[List[DataAttribute]] = None, transform_groups: Optional[TransformGroups] = None, frozen_transform_groups: Optional[TransformGroups] = None, collate_fn: Optional[Callable[[List], Any]] = None, ): """Creates a ``AvalancheDataset`` instance. :param dataset: Original dataset. Beware that AvalancheDataset will not overwrite transformations already applied by this dataset. :param transform_groups: Avalanche transform groups. """ if issubclass(type(datasets), TorchDataset) or issubclass( type(datasets), AvalancheDataset ): datasets = [datasets] # type: ignore # NOTES on implementation: # - raw datasets operations are implemented by _FlatData # - data attributes are implemented by DataAttribute # - transformations are implemented by TransformGroups # AvalancheDataset just takes care to manage all of these attributes # together and decides how the information propagates through # operations (e.g. how to pass attributes after concat/subset # operations). flat_datas = [] for d in datasets: if len(d) > 0: if isinstance(d, AvalancheDataset): flat_datas.append(d._flat_data) elif not isinstance(d, _FlatDataWithTransform): flat_datas.append(_FlatDataWithTransform([d])) else: flat_datas.append(d) if ( transform_groups is None and frozen_transform_groups is None and indices is not None and len(flat_datas) == 1 ): # TODO: remove. shouldn't be needed but helps with flattening assert len(flat_datas) == 1 self._flat_data = flat_datas[0].subset(indices) elif ( transform_groups is None and frozen_transform_groups is None and indices is None and len(flat_datas) >= 1 ): # TODO: remove. shouldn't be needed but helps with flattening if len(flat_datas) == 0: self._flat_data = _FlatDataWithTransform([]) self._flat_data = flat_datas[0] if not isinstance(self._flat_data, _FlatDataWithTransform): self._flat_data = _FlatDataWithTransform([self._flat_data]) for d in flat_datas[1:]: if not isinstance(d, _FlatDataWithTransform): d = _FlatDataWithTransform([d]) self._flat_data = self._flat_data.concat(d) else: self._flat_data: _FlatDataWithTransform[T_co] = _FlatDataWithTransform( flat_datas, indices=indices, transform_groups=transform_groups, frozen_transform_groups=frozen_transform_groups, ) self.collate_fn = collate_fn #################################### # Init collate_fn #################################### if len(datasets) > 0: self.collate_fn = self._init_collate_fn(datasets[0], collate_fn) else: self.collate_fn = default_collate """ The collate function to use when creating mini-batches from this dataset. """ #################################### # Init data attributes #################################### # concat attributes from child datasets new_data_attributes: Dict[str, DataAttribute] = dict() if data_attributes is not None: new_data_attributes = {da.name: da for da in data_attributes} ld = sum(len(d) for d in datasets) for da in data_attributes: if len(da) != ld: raise ValueError( "Data attribute {} has length {} but the dataset " "has length {}".format(da.name, len(da), ld) ) self._data_attributes: Dict[str, DataAttribute] = OrderedDict() first_dataset = datasets[0] if len(datasets) > 0 else None if isinstance(first_dataset, AvalancheDataset): for attr in first_dataset._data_attributes.values(): if attr.name in new_data_attributes: # Keep overridden attributes in their previous position self._data_attributes[attr.name] = new_data_attributes.pop( attr.name ) continue acat = attr found_all = True for d2 in datasets[1:]: if hasattr(d2, attr.name): acat = acat.concat(getattr(d2, attr.name)) elif len(d2) > 0: # if empty we allow missing attributes found_all = False break if found_all: self._data_attributes[attr.name] = acat # Insert new data attributes after inherited ones for da in new_data_attributes.values(): self._data_attributes[da.name] = da if indices is not None: # subset operation for attributes for da in self._data_attributes.values(): # TODO: this was the old behavior. How do we know what to do if # we permute the entire dataset? # DEPRECATED! always subset attributes # we keep this behavior only for `classification_subset` # if len(da) != sum([len(d) for d in datasets]): # self._data_attributes[da.name] = da # else: # self._data_attributes[da.name] = da.subset(self._indices) # # dasub = da.subset(indices) # self._data_attributes[da.name] = dasub dasub = da.subset(indices) self._data_attributes[da.name] = dasub # set attributes dynamically for el in self._data_attributes.values(): assert len(el) == len(self), f"BUG: Wrong size for attribute {el.name}" is_property = False if hasattr(self, el.name): is_property = True # Do not raise an error if a property. # Any check related to the property will be done # in the property setter method. if not isinstance(getattr(type(self), el.name, None), property): raise ValueError( f"Trying to add DataAttribute `{el.name}` to " f"AvalancheDataset but the attribute name is " f"already used." ) if not is_property: setattr(self, el.name, el)
def __len__(self) -> int: return len(self._flat_data) def __add__(self: TAvalancheDataset, other: TAvalancheDataset) -> TAvalancheDataset: return self.concat(other) def __radd__( self: TAvalancheDataset, other: TAvalancheDataset ) -> TAvalancheDataset: return other.concat(self) @property def _datasets(self): """Only for backward compatibility of old unit tests. Do not use.""" return self._flat_data._datasets def concat(self: TAvalancheDataset, other: TAvalancheDataset) -> TAvalancheDataset: """Concatenate this dataset with other. :param other: Other dataset to concatenate. :return: A new dataset. """ return self.__class__([self, other]) def subset(self: TAvalancheDataset, indices: Sequence[int]) -> TAvalancheDataset: """Subset this dataset. :param indices: The indices to keep. :return: A new dataset. """ return self.__class__([self], indices=indices) @property def transform(self): raise AttributeError( "Cannot access or modify transform directly. Use transform_groups " "methods such as `replace_current_transform_group`. " "See the documentation for more info." ) def update_data_attribute( self: TAvalancheDataset, name: str, new_value ) -> TAvalancheDataset: """ Return a new dataset with the added or replaced data attribute. If a object of type :class:`DataAttribute` is passed, then the data attribute is setted as is. Otherwise, if a raw value is passed, a new DataAttribute is created. If a DataAttribute with the same already exists, the use_in_getitem flag is inherited, otherwise it is set to False. :param name: The name of the data attribute to add/replace. :param new_value: Either a :class:`DataAttribute` or a sequence containing as many elements as the datasets. :returns: A copy of this dataset with the given data attribute set. """ assert len(new_value) == len( self ), f"Size mismatch when updating data attribute {name}" datacopy = self._shallow_clone_dataset() datacopy._data_attributes = copy.copy(datacopy._data_attributes) if isinstance(new_value, DataAttribute): assert name == new_value.name datacopy._data_attributes[name] = new_value else: use_in_getitem = False prev_attr = datacopy._data_attributes.get(name, None) if prev_attr is not None: use_in_getitem = prev_attr.use_in_getitem datacopy._data_attributes[name] = DataAttribute( new_value, name=name, use_in_getitem=use_in_getitem ) if not hasattr(datacopy, name): # Creates the field if it does not exist setattr(datacopy, name, datacopy._data_attributes[name]) return datacopy def __eq__(self, other: object): for required_attr in ["_flat_data", "_data_attributes", "collate_fn"]: if not hasattr(other, required_attr): return False return ( other._flat_data == self._flat_data and self._data_attributes == other._data_attributes # type: ignore and self.collate_fn == other.collate_fn # type: ignore ) @overload def __getitem__(self, exp_id: int) -> T_co: ... @overload def __getitem__(self: TAvalancheDataset, exp_id: slice) -> TAvalancheDataset: ... def __getitem__( self: TAvalancheDataset, idx: Union[int, slice] ) -> Union[T_co, TAvalancheDataset]: elem = self._flat_data[idx] for da in self._data_attributes.values(): if da.use_in_getitem: if isinstance(elem, dict): elem[da.name] = da[idx] elif isinstance(elem, tuple): elem = list(elem) # type: ignore elem.append(da[idx]) # type: ignore else: elem.append(da[idx]) # type: ignore return elem def train(self): """Returns a new dataset with the transformations of the 'train' group loaded. The current dataset will not be affected. :return: A new dataset with the training transformations loaded. """ return self.with_transforms("train") def eval(self): """ Returns a new dataset with the transformations of the 'eval' group loaded. Eval transformations usually don't contain augmentation procedures. This function may be useful when in need to test on training data (for instance, in order to run a validation pass). The current dataset will not be affected. :return: A new dataset with the eval transformations loaded. """ return self.with_transforms("eval") def with_transforms(self: TAvalancheDataset, group_name: str) -> TAvalancheDataset: """ Returns a new dataset with the transformations of a different group loaded. The current dataset will not be affected. :param group_name: The name of the transformations group to use. :return: A new dataset with the new transformations. """ datacopy = self._shallow_clone_dataset() datacopy._flat_data = datacopy._flat_data.with_transforms(group_name) return datacopy def freeze_transforms(self: TAvalancheDataset) -> TAvalancheDataset: """Returns a new dataset with the transformation groups frozen.""" datacopy = self._shallow_clone_dataset() datacopy._flat_data = datacopy._flat_data.freeze_transforms() return datacopy def remove_current_transform_group(self): """Recursively remove transformation groups from dataset tree.""" datacopy = self._shallow_clone_dataset() fdata = datacopy._flat_data datacopy._flat_data = fdata.remove_current_transform_group() return datacopy def replace_current_transform_group(self, transform): """Recursively remove the current transformation group from the dataset tree and replaces it.""" datacopy = self._shallow_clone_dataset() fdata = datacopy._flat_data datacopy._flat_data = fdata.replace_current_transform_group(transform) return datacopy def _shallow_clone_dataset(self: TAvalancheDataset) -> TAvalancheDataset: """Clone dataset. This is a shallow copy, i.e. the data attributes are not copied. """ dataset_copy = copy.copy(self) dataset_copy._flat_data = self._flat_data._shallow_clone_dataset() return dataset_copy def _init_collate_fn(self, dataset, collate_fn): if collate_fn is not None: return collate_fn if hasattr(dataset, "collate_fn"): return getattr(dataset, "collate_fn") return default_collate def __repr__(self): return repr(self._flat_data) def _tree_depth(self): """Return the depth of the tree of datasets. Use only to debug performance issues. """ return self._flat_data._tree_depth()
class _FlatDataWithTransform(FlatData[T_co]): """Private class used to wrap a dataset with a transformation group. Do not use outside of this file. """ def __init__( self, datasets: Sequence[IDataset[T_co]], *, indices: Optional[List[int]] = None, transform_groups: Optional[TransformGroups] = None, frozen_transform_groups: Optional[TransformGroups] = None, discard_elements_not_in_indices: bool = False, ): can_flatten = (transform_groups is None) and (frozen_transform_groups is None) super().__init__( datasets, indices=indices, can_flatten=can_flatten, discard_elements_not_in_indices=discard_elements_not_in_indices, ) if isinstance(transform_groups, dict): transform_groups = TransformGroups(transform_groups) if isinstance(frozen_transform_groups, dict): frozen_transform_groups = TransformGroups(frozen_transform_groups) if transform_groups is None: transform_groups = EmptyTransformGroups() if frozen_transform_groups is None: frozen_transform_groups = EmptyTransformGroups() self._transform_groups: TransformGroups = transform_groups self._frozen_transform_groups: TransformGroups = frozen_transform_groups #################################### # Init transformations #################################### cgroup = None # inherit transformation group from original dataset for dd in datasets: if isinstance(dd, _FlatDataWithTransform): if cgroup is None and dd._transform_groups is not None: cgroup = dd._transform_groups.current_group elif ( dd._transform_groups is not None and dd._transform_groups.current_group != cgroup ): # all datasets must have the same transformation group warnings.warn( f"Concatenated datasets have different transformation " f"groups. Using group={cgroup}." ) if cgroup is None: cgroup = "train" self._frozen_transform_groups.current_group = cgroup self._transform_groups.current_group = cgroup def __eq__(self, other): for required_attr in [ "_datasets", "_transform_groups", "_frozen_transform_groups", ]: if not hasattr(other, required_attr): return False eq_datasets = len(self._datasets) == len(other._datasets) # type: ignore eq_datasets = eq_datasets and all( d1 == d2 for d1, d2 in zip(self._datasets, other._datasets) # type: ignore ) ftg = other._frozen_transform_groups # type: ignore return ( eq_datasets and self._transform_groups == other._transform_groups # type: ignore and self._frozen_transform_groups == ftg # type: ignore ) def _getitem_recursive_call(self, idx, group_name) -> T_co: """Private method only for internal use. We need this recursive call to avoid appending task label multiple times inside the __getitem__. """ dataset_idx, idx = self._get_idx(idx) dd = self._datasets[dataset_idx] if isinstance(dd, _FlatDataWithTransform): element = dd._getitem_recursive_call(idx, group_name=group_name) else: element = dd[idx] if self._frozen_transform_groups is not None: element = self._frozen_transform_groups(element, group_name=group_name) if self._transform_groups is not None: element = self._transform_groups(element, group_name=group_name) return element def __getitem__( self: TDataWTransform, idx: Union[int, slice] ) -> Union[T_co, TDataWTransform]: if isinstance(idx, (int, np.integer)): elem = self._getitem_recursive_call( idx, self._transform_groups.current_group ) return elem # type: ignore else: return super().__getitem__(idx) def with_transforms(self: TDataWTransform, group_name: str) -> TDataWTransform: """ Returns a new dataset with the transformations of a different group loaded. The current dataset will not be affected. :param group_name: The name of the transformations group to use. :return: A new dataset with the new transformations. """ datacopy = self._shallow_clone_dataset() datacopy._frozen_transform_groups.with_transform(group_name) datacopy._transform_groups.with_transform(group_name) return datacopy def freeze_transforms(self: TDataWTransform) -> TDataWTransform: """Returns a new dataset with the transformation groups frozen.""" tgroups = copy.copy(self._transform_groups) frozen_tgroups = copy.copy(self._frozen_transform_groups) datacopy = self._shallow_clone_dataset() datacopy._frozen_transform_groups = frozen_tgroups + tgroups datacopy._transform_groups = EmptyTransformGroups() dds: List[IDataset] = [] for dd in datacopy._datasets: if isinstance(dd, _FlatDataWithTransform): dds.append(dd.freeze_transforms()) else: dds.append(dd) datacopy._datasets = dds return datacopy def remove_current_transform_group(self): """Recursively remove transformation groups from dataset tree.""" dataset_copy = self._shallow_clone_dataset() cgroup = dataset_copy._transform_groups.current_group dataset_copy._transform_groups[cgroup] = None dds = [] for dd in dataset_copy._datasets: if isinstance(dd, _FlatDataWithTransform): dds.append(dd.remove_current_transform_group()) else: dds.append(dd) dataset_copy._datasets = dds return dataset_copy def replace_current_transform_group(self, transform): """Recursively remove the current transformation group from the dataset tree and replaces it.""" dataset_copy = self.remove_current_transform_group() cgroup = dataset_copy._transform_groups.current_group dataset_copy._transform_groups[cgroup] = transform dds = [] for dd in dataset_copy._datasets: if isinstance(dd, _FlatDataWithTransform): dds.append(dd.remove_current_transform_group()) else: dds.append(dd) dataset_copy._datasets = dds return dataset_copy def _shallow_clone_dataset(self: TDataWTransform) -> TDataWTransform: """Clone dataset. This is a shallow copy, i.e. the data attributes are not copied. """ dataset_copy = copy.copy(self) dataset_copy._transform_groups = copy.copy(dataset_copy._transform_groups) dataset_copy._frozen_transform_groups = copy.copy( dataset_copy._frozen_transform_groups ) return dataset_copy
[docs]def make_avalanche_dataset( dataset: IDataset[T_co], *, data_attributes: Optional[List[DataAttribute]] = None, transform_groups: Optional[TransformGroups] = None, frozen_transform_groups: Optional[TransformGroups] = None, collate_fn: Optional[Callable[[List], Any]] = None, ) -> AvalancheDataset[T_co]: """Avalanche Dataset. Creates a ``AvalancheDataset`` instance. See ``AvalancheDataset`` for more details. :param dataset: Original dataset. Beware that AvalancheDataset will not overwrite transformations already applied by this dataset. :param transform_groups: Avalanche transform groups. """ return AvalancheDataset( [dataset], data_attributes=data_attributes, transform_groups=transform_groups, frozen_transform_groups=frozen_transform_groups, collate_fn=collate_fn, )
def _print_frozen_transforms(self): """Internal debugging method. Do not use it. Prints the current frozen transformations.""" print("FROZEN TRANSFORMS:\n" + str(self._frozen_transform_groups)) for dd in self._datasets: if isinstance(dd, AvalancheDataset): print("PARENT FROZEN:\n") _print_frozen_transforms(dd) def _print_nonfrozen_transforms(self): """Internal debugging method. Do not use it. Prints the current non-frozen transformations.""" print("TRANSFORMS:\n" + str(self._transform_groups)) for dd in self._datasets: if isinstance(dd, AvalancheDataset): print("PARENT TRANSFORMS:\n") _print_nonfrozen_transforms(dd) def _print_transforms(self): """Internal debugging method. Do not use it. Prints the current transformations.""" self._print_frozen_transforms() self._print_nonfrozen_transforms() __all__ = ["AvalancheDataset", "make_avalanche_dataset"]