Source code for avalanche.benchmarks.utils.data

################################################################################
# Copyright (c) 2022 ContinualAI.                                              #
# Copyrights licensed under the MIT License.                                   #
# See the accompanying LICENSE file for terms.                                 #
#                                                                              #
# Date: 19-07-2022                                                             #
# Author(s): Antonio Carta                                                     #
# E-mail: contact@continualai.org                                              #
# Website: avalanche.continualai.org                                           #
################################################################################

"""
This module contains the implementation of the Avalanche Dataset,
Avalanche dataset class which extends PyTorch's dataset.
AvalancheDataset offers additional features like the
management of preprocessing pipelines and task/class labels.
"""
import copy
import warnings

from torch.utils.data.dataloader import default_collate

from avalanche.benchmarks.utils.dataset_definitions import IDataset
from .data_attribute import DataAttribute

from typing import List, Any, Sequence, Union, TypeVar, Callable

from .flat_data import FlatData
from .transform_groups import TransformGroups, EmptyTransformGroups
from torch.utils.data import Dataset as TorchDataset


T_co = TypeVar("T_co", covariant=True)
TAvalancheDataset = TypeVar("TAvalancheDataset", bound="AvalancheDataset")


[docs]class AvalancheDataset(FlatData): """Avalanche Dataset. Avlanche dataset are pytorch-compatible Datasets with some additional functionality such as: - management of transformation groups via :class:`AvalancheTransform` - support for sample attributes such as class targets and task labels Data Attributes --------------- Avalanche datasets manage sample-wise information such as class or task labels via :class:`DataAttribute`. Transformation Groups --------------------- Avalanche datasets manage transformation via transformation groups. Simply put, a transformation group is a named preprocessing function (as in torchvision datasets). By default, Avalanche expects two transformation groups: - 'train', which contains transformations applied to training patterns. - 'eval', that contain transformations applied to test patterns. Having both groups allows to use different transformations during training and evaluation and to seamlessly switch between them by using the :func:`train` and :func:`eval` methods. Arbitrary transformation groups can be added and used. If you define custom groups, you can use them by calling the `:func:with_transforms` method. switching to a different transformation group by calling the ``train()``, ``eval()`` or ``with_transforms` methods always returns a new dataset, levaing the original one unchanged. Ttransformation groups can be manipulated by removing, freezing, or replacing transformations. Each operation returns a new dataset, leaving the original one unchanged. """
[docs] def __init__( self, datasets: List[IDataset], *, indices: List[int] = None, data_attributes: List[DataAttribute] = None, transform_groups: TransformGroups = None, frozen_transform_groups: TransformGroups = None, collate_fn: Callable[[List], Any] = None, ): """Creates a ``AvalancheDataset`` instance. :param dataset: Original dataset. Beware that AvalancheDataset will not overwrite transformations already applied by this dataset. :param transform_groups: Avalanche transform groups. """ if isinstance(datasets, TorchDataset) or isinstance( datasets, AvalancheDataset ): warnings.warn( "AvalancheDataset constructor has been changed. " "Please check the documentation for the correct usage. You can" " use `avalanche.benchmarks.utils.make_classification_dataset" "if you need the old behavior.", DeprecationWarning, ) if issubclass(type(datasets), TorchDataset) or \ issubclass(type(datasets), AvalancheDataset): datasets = [datasets] # NOTES on implementation: # - raw datasets operations are implemented by _FlatData # - data attributes are implemented by DataAttribute # - transformations are implemented by TransformGroups # AvalancheDataset just takes care to manage all of these attributes # together and decides how the information propagates through # operations (e.g. how to pass attributes after concat/subset # operations). can_flatten = ( (transform_groups is None) and (frozen_transform_groups is None) and data_attributes is None and collate_fn is None ) super().__init__(datasets, indices, can_flatten) if data_attributes is None: self._data_attributes = {} else: self._data_attributes = {da.name: da for da in data_attributes} for da in data_attributes: ld = sum(len(d) for d in self._datasets) if len(da) != ld: raise ValueError( "Data attribute {} has length {} but the dataset " "has length {}".format(da.name, len(da), ld) ) if isinstance(transform_groups, dict): transform_groups = TransformGroups(transform_groups) if isinstance(frozen_transform_groups, dict): frozen_transform_groups = TransformGroups(frozen_transform_groups) self._transform_groups = transform_groups self._frozen_transform_groups = frozen_transform_groups self.collate_fn = collate_fn #################################### # Init transformations #################################### cgroup = None # inherit transformation group from original dataset for dd in self._datasets: if isinstance(dd, AvalancheDataset): if cgroup is None and dd._transform_groups is not None: cgroup = dd._transform_groups.current_group elif ( dd._transform_groups is not None and dd._transform_groups.current_group != cgroup ): # all datasets must have the same transformation group warnings.warn( f"Concatenated datasets have different transformation " f"groups. Using group={cgroup}." ) if self._frozen_transform_groups is None: self._frozen_transform_groups = EmptyTransformGroups() if self._transform_groups is None: self._transform_groups = EmptyTransformGroups() if cgroup is None: cgroup = "train" self._frozen_transform_groups.current_group = cgroup self._transform_groups.current_group = cgroup #################################### # Init collate_fn #################################### if len(datasets) > 0: self.collate_fn = self._init_collate_fn(datasets[0], collate_fn) else: self.collate_fn = default_collate """ The collate function to use when creating mini-batches from this dataset. """ #################################### # Init data attributes #################################### # concat attributes from child datasets if len(self._datasets) > 0 and isinstance( self._datasets[0], AvalancheDataset ): for attr in self._datasets[0]._data_attributes.values(): if attr.name in self._data_attributes: continue # don't touch overridden attributes acat = attr found_all = True for d2 in self._datasets[1:]: if hasattr(d2, attr.name): acat = acat.concat(getattr(d2, attr.name)) else: found_all = False break if found_all: self._data_attributes[attr.name] = acat if self._indices is not None: # subset operation for attributes for da in self._data_attributes.values(): # TODO: this was the old behavior. How do we know what to do if # we permute the entire dataset? # DEPRECATED! always subset attributes # we keep this behavior only for `classification_subset` # if len(da) != sum([len(d) for d in datasets]): # self._data_attributes[da.name] = da # else: # self._data_attributes[da.name] = da.subset(self._indices) # # dasub = da.subset(indices) # self._data_attributes[da.name] = dasub dasub = da.subset(self._indices) self._data_attributes[da.name] = dasub # set attributes dynamically for el in self._data_attributes.values(): assert len(el) == len( self ), f"BUG: Wrong size for attribute {el.name}" if hasattr(self, el.name): raise ValueError( f"Trying to add DataAttribute `{el.name}` to " f"AvalancheDataset but the attribute name is already used." ) setattr(self, el.name, el)
@property def transform(self): raise AttributeError( "Cannot access or modify transform directly. Use transform_groups " "methods such as `replace_current_transform_group`. " "See the documentation for more info." ) def __eq__(self, other: "make_avalanche_dataset"): if not hasattr(other, "_datasets"): return False eq_datasets = len(self._datasets) == len(other._datasets) eq_datasets = eq_datasets and all( d1 == d2 for d1, d2 in zip(self._datasets, other._datasets) ) return ( eq_datasets and self._transform_groups == other._transform_groups and self._data_attributes == other._data_attributes and self.collate_fn == other.collate_fn ) def _getitem_recursive_call(self, idx, group_name): """Private method only for internal use. We need this recursive call to avoid appending task label multiple times inside the __getitem__. """ dataset_idx, idx = self._get_idx(idx) dd = self._datasets[dataset_idx] if isinstance(dd, AvalancheDataset): element = dd._getitem_recursive_call(idx, group_name=group_name) else: element = dd[idx] if self._frozen_transform_groups is not None: element = self._frozen_transform_groups( element, group_name=group_name ) if self._transform_groups is not None: element = self._transform_groups(element, group_name=group_name) return element def __getitem__(self, idx) -> Union[T_co, Sequence[T_co]]: elem = self._getitem_recursive_call( idx, self._transform_groups.current_group ) for da in self._data_attributes.values(): if da.use_in_getitem: if isinstance(elem, dict): elem[da.name] = da[idx] elif isinstance(elem, tuple): elem = list(elem) elem.append(da[idx]) else: elem.append(da[idx]) return elem def train(self): """Returns a new dataset with the transformations of the 'train' group loaded. The current dataset will not be affected. :return: A new dataset with the training transformations loaded. """ return self.with_transforms("train") def eval(self): """ Returns a new dataset with the transformations of the 'eval' group loaded. Eval transformations usually don't contain augmentation procedures. This function may be useful when in need to test on training data (for instance, in order to run a validation pass). The current dataset will not be affected. :return: A new dataset with the eval transformations loaded. """ return self.with_transforms("eval") def with_transforms( self: TAvalancheDataset, group_name: str ) -> TAvalancheDataset: """ Returns a new dataset with the transformations of a different group loaded. The current dataset will not be affected. :param group_name: The name of the transformations group to use. :return: A new dataset with the new transformations. """ datacopy = self._shallow_clone_dataset() datacopy._frozen_transform_groups.with_transform(group_name) datacopy._transform_groups.with_transform(group_name) return datacopy def freeze_transforms(self): """Returns a new dataset with the transformation groups frozen.""" tgroups = copy.copy(self._transform_groups) frozen_tgroups = copy.copy(self._frozen_transform_groups) datacopy = self._shallow_clone_dataset() datacopy._frozen_transform_groups = frozen_tgroups + tgroups datacopy._transform_groups = EmptyTransformGroups() dds = [] for dd in datacopy._datasets: if isinstance(dd, AvalancheDataset): dds.append(dd.freeze_transforms()) else: dds.append(dd) datacopy.data_list = dds return datacopy def remove_current_transform_group(self): """Recursively remove transformation groups from dataset tree.""" dataset_copy = self._shallow_clone_dataset() cgroup = dataset_copy._transform_groups.current_group dataset_copy._transform_groups[cgroup] = None dds = [] for dd in dataset_copy._datasets: if isinstance(dd, AvalancheDataset): dds.append(dd.remove_current_transform_group()) else: dds.append(dd) dataset_copy._datasets = dds return dataset_copy def replace_current_transform_group(self, transform): """Recursively remove the current transformation group from the dataset tree and replaces it.""" dataset_copy = self.remove_current_transform_group() cgroup = dataset_copy._transform_groups.current_group dataset_copy._transform_groups[cgroup] = transform dds = [] for dd in dataset_copy._datasets: if isinstance(dd, AvalancheDataset): dds.append(dd.remove_current_transform_group()) else: dds.append(dd) dataset_copy._datasets = dds return dataset_copy def _shallow_clone_dataset(self: TAvalancheDataset) -> TAvalancheDataset: """Clone dataset. This is a shallow copy, i.e. the data attributes are not copied. """ dataset_copy = copy.copy(self) dataset_copy._transform_groups = copy.copy( dataset_copy._transform_groups ) dataset_copy._frozen_transform_groups = copy.copy( dataset_copy._frozen_transform_groups ) return dataset_copy def _init_collate_fn(self, dataset, collate_fn): if collate_fn is not None: return collate_fn if hasattr(dataset, "collate_fn"): return getattr(dataset, "collate_fn") return default_collate
[docs]def make_avalanche_dataset( dataset: IDataset, *, data_attributes: List[DataAttribute] = None, transform_groups: TransformGroups = None, frozen_transform_groups: TransformGroups = None, collate_fn: Callable[[List], Any] = None, ): """Avalanche Dataset. Creates a ``AvalancheDataset`` instance. See ``AvalancheDataset`` for more details. :param dataset: Original dataset. Beware that AvalancheDataset will not overwrite transformations already applied by this dataset. :param transform_groups: Avalanche transform groups. """ return AvalancheDataset( [dataset], data_attributes=data_attributes, transform_groups=transform_groups, frozen_transform_groups=frozen_transform_groups, collate_fn=collate_fn, )
def _print_frozen_transforms(self): """Internal debugging method. Do not use it. Prints the current frozen transformations.""" print("FROZEN TRANSFORMS:\n" + str(self._frozen_transform_groups)) for dd in self._datasets: if isinstance(dd, AvalancheDataset): print("PARENT FROZEN:\n") _print_frozen_transforms(dd) def _print_nonfrozen_transforms(self): """Internal debugging method. Do not use it. Prints the current non-frozen transformations.""" print("TRANSFORMS:\n" + str(self._transform_groups)) for dd in self._datasets: if isinstance(dd, AvalancheDataset): print("PARENT TRANSFORMS:\n") _print_nonfrozen_transforms(dd) def _print_transforms(self): """Internal debugging method. Do not use it. Prints the current transformations.""" self._print_frozen_transforms() self._print_nonfrozen_transforms() __all__ = ["AvalancheDataset", "make_avalanche_dataset"]