Source code for avalanche.benchmarks.utils.avalanche_dataset

################################################################################
# Copyright (c) 2021 ContinualAI.                                              #
# Copyrights licensed under the MIT License.                                   #
# See the accompanying LICENSE file for terms.                                 #
#                                                                              #
# Date: 12-05-2020                                                             #
# Author(s): Lorenzo Pellegrini                                                #
# E-mail: contact@continualai.org                                              #
# Website: avalanche.continualai.org                                           #
################################################################################

"""
This module contains the implementation of the Avalanche Dataset,
which is the standard Avalanche implementation of a PyTorch dataset. Despite
being a child class of the PyTorch Dataset, the AvalancheDataset (and its
derivatives) is much more powerful as it offers many more features
out-of-the-box.
"""
import copy
import warnings
from collections import OrderedDict, defaultdict, deque
from enum import Enum, auto

import torch
from torch.utils.data.dataloader import default_collate
from torch.utils.data.dataset import Dataset, Subset, ConcatDataset

from .adaptive_transform import Compose, MultiParamTransform
from .dataset_utils import (
    manage_advanced_indexing,
    SequenceDataset,
    ClassificationSubset,
    LazyConcatIntTargets,
    find_list_from_index,
    ConstantSequence,
    LazyClassMapping,
    optimize_sequence,
    SubSequence,
    LazyConcatTargets,
    TupleTLabel,
)
from .dataset_definitions import (
    ITensorDataset,
    ClassificationDataset,
    IDatasetWithTargets,
    ISupportedClassificationDataset,
)

from typing import (
    List,
    Any,
    Sequence,
    Union,
    Optional,
    TypeVar,
    SupportsInt,
    Callable,
    Dict,
    Tuple,
    Collection,
)

from typing_extensions import Protocol

T_co = TypeVar("T_co", covariant=True)
TTargetType = TypeVar("TTargetType")

TAvalancheDataset = TypeVar("TAvalancheDataset", bound="AvalancheDataset")


# Info: https://mypy.readthedocs.io/en/stable/protocols.html#callback-protocols
class XComposedTransformDef(Protocol):
    def __call__(self, *input_values: Any) -> Any:
        pass


class XTransformDef(Protocol):
    def __call__(self, input_value: Any) -> Any:
        pass


class YTransformDef(Protocol):
    def __call__(self, input_value: Any) -> Any:
        pass


XTransform = Optional[Union[XTransformDef, XComposedTransformDef]]
YTransform = Optional[YTransformDef]
TransformGroupDef = Union[None, XTransform, Tuple[XTransform, YTransform]]


SupportedDataset = Union[
    IDatasetWithTargets, ITensorDataset, Subset, ConcatDataset
]


class AvalancheDatasetType(Enum):
    UNDEFINED = auto()
    CLASSIFICATION = auto()
    REGRESSION = auto()
    SEGMENTATION = auto()


[docs]class AvalancheDataset(IDatasetWithTargets[T_co, TTargetType], Dataset[T_co]): """ The Dataset used as the base implementation for Avalanche. Instances of this dataset are usually returned from benchmarks, but it can also be used in a completely standalone manner. This dataset can be used to apply transformations before returning patterns/targets, it supports slicing and advanced indexing and it also contains useful fields as `targets`, which contains the pattern labels, and `targets_task_labels`, which contains the pattern task labels. The `task_set` field can be used to obtain a the subset of patterns labeled with a given task label. This dataset can also be used to apply several advanced operations involving transformations. For instance, it allows the user to add and replace transformations, freeze them so that they can't be changed, etc. This dataset also allows the user to keep distinct transformations groups. Simply put, a transformation group is a pair of transform+target_transform (exactly as in torchvision datasets). This dataset natively supports keeping two transformation groups: the first, 'train', contains transformations applied to training patterns. Those transformations usually involve some kind of data augmentation. The second one is 'eval', that will contain transformations applied to test patterns. Having both groups can be useful when, for instance, in need to test on the training data (as this process usually involves removing data augmentation operations). Switching between transformations can be easily achieved by using the :func:`train` and :func:`eval` methods. Moreover, arbitrary transformation groups can be added and used. For more info see the constructor and the :func:`with_transforms` method. This dataset will try to inherit the task labels from the input dataset. If none are available and none are given via the `task_labels` parameter, each pattern will be assigned a default task label "0". See the constructor for more details. """
[docs] def __init__( self, dataset: SupportedDataset, *, transform: XTransform = None, target_transform: YTransform = None, transform_groups: Dict[str, TransformGroupDef] = None, initial_transform_group: str = None, task_labels: Union[int, Sequence[int]] = None, targets: Sequence[TTargetType] = None, dataset_type: AvalancheDatasetType = None, collate_fn: Callable[[List], Any] = None, targets_adapter: Callable[[Any], TTargetType] = None, ): """ Creates a ``AvalancheDataset`` instance. :param dataset: The dataset to decorate. Beware that AvalancheDataset will not overwrite transformations already applied by this dataset. :param transform: A function/transform that takes the X value of a pattern from the original dataset and returns a transformed version. :param target_transform: A function/transform that takes in the target and transforms it. :param transform_groups: A dictionary containing the transform groups. Transform groups are used to quickly switch between training and eval (test) transformations. This becomes useful when in need to test on the training dataset as test transformations usually don't contain random augmentations. ``AvalancheDataset`` natively supports the 'train' and 'eval' groups by calling the ``train()`` and ``eval()`` methods. When using custom groups one can use the ``with_transforms(group_name)`` method instead. Defaults to None, which means that the current transforms will be used to handle both 'train' and 'eval' groups (just like in standard ``torchvision`` datasets). :param initial_transform_group: The name of the initial transform group to be used. Defaults to None, which means that the current group of the input dataset will be used (if an AvalancheDataset). If the input dataset is not an AvalancheDataset, then 'train' will be used. :param task_labels: The task label of each instance. Must be a sequence of ints, one for each instance in the dataset. Alternatively can be a single int value, in which case that value will be used as the task label for all the instances. Defaults to None, which means that the dataset will try to obtain the task labels from the original dataset. If no task labels could be found, a default task label "0" will be applied to all instances. :param targets: The label of each pattern. Defaults to None, which means that the targets will be retrieved from the dataset (if possible). :param dataset_type: The type of the dataset. Defaults to None, which means that the type will be inferred from the input dataset. When the `dataset_type` is different than UNDEFINED, a proper value for `collate_fn` and `targets_adapter` will be set. If the `dataset_type` is different than UNDEFINED, then `collate_fn` and `targets_adapter` must not be set. :param collate_fn: The function to use when slicing to merge single patterns. In the future this function may become the function used in the data loading process, too. If None and the `dataset_type` is UNDEFINED, the constructor will check if a `collate_fn` field exists in the dataset. If no such field exists, the default collate function will be used. :param targets_adapter: A function used to convert the values of the targets field. Defaults to None. Note: the adapter will not change the value of the second element returned by `__getitem__`. The adapter is used to adapt the values of the targets field only. """ if transform_groups is not None and ( transform is not None or target_transform is not None ): raise ValueError( "transform_groups can't be used with transform" "and target_transform values" ) detected_type = False if dataset_type is None: detected_type = True if isinstance(dataset, AvalancheDataset): dataset_type = dataset.dataset_type else: dataset_type = AvalancheDatasetType.UNDEFINED if dataset_type != AvalancheDatasetType.UNDEFINED and ( collate_fn is not None or targets_adapter is not None ): if detected_type: raise ValueError( "dataset_type {} was inferred from the input dataset. " "This dataset type can't be used " "with custom collate_fn or targets_adapter " "parameters. Only the UNDEFINED type supports " "custom collate_fn or targets_adapter " "parameters".format(dataset_type) ) else: raise ValueError( "dataset_type {} can't be used with custom collate_fn " "or targets_adapter. Only the UNDEFINED type supports " "custom collate_fn or targets_adapter " "parameters.".format(dataset_type) ) if transform_groups is not None: AvalancheDataset._check_groups_dict_format(transform_groups) if not isinstance(dataset_type, AvalancheDatasetType): raise ValueError( "dataset_type must be a value of type " "AvalancheDatasetType" ) self._dataset: SupportedDataset = dataset """ The original dataset. """ self.dataset_type = dataset_type """ The type of this dataset (UNDEFINED, CLASSIFICATION, ...). """ self.targets: Sequence[TTargetType] = self._initialize_targets_sequence( dataset, targets, dataset_type, targets_adapter ) """ A sequence of values describing the label of each pattern contained in the dataset. """ self.targets_task_labels: Sequence[ int ] = self._initialize_task_labels_sequence(dataset, task_labels) """ A sequence of ints describing the task label of each pattern contained in the dataset. """ self.tasks_pattern_indices: Dict[ int, Sequence[int] ] = self._initialize_tasks_dict(dataset, self.targets_task_labels) """ A dictionary mapping task labels to the indices of the patterns with that task label. If you need to obtain the subset of patterns labeled with a certain task label, consider using the `task_set` field. """ self.collate_fn = self._initialize_collate_fn( dataset, dataset_type, collate_fn ) """ The collate function to use when creating mini-batches from this dataset. """ # Compress targets and task labels to save some memory self._optimize_targets() self._optimize_task_labels() self._optimize_task_dict() self.task_set = self._make_task_set_dict() """ A dictionary that can be used to obtain the subset of patterns given a specific task label. """ if initial_transform_group is None: # Detect from the input dataset. If not an AvalancheDataset then # use 'train' as the initial transform group if isinstance(dataset, AvalancheDataset): initial_transform_group = dataset.current_transform_group else: initial_transform_group = "train" self.current_transform_group = initial_transform_group """ The name of the transform group currently in use. """ self.transform_groups: Dict[ str, Tuple[XTransform, YTransform] ] = self._initialize_groups_dict( transform_groups, dataset, transform, target_transform ) """ A dictionary containing the transform groups. Transform groups are used to quickly switch between training and test (eval) transformations. This becomes useful when in need to test on the training dataset as test transformations usually don't contain random augmentations. AvalancheDataset natively supports switching between the 'train' and 'eval' groups by calling the ``train()`` and ``eval()`` methods. When using custom groups one can use the ``with_transforms(group_name)`` method instead. May be null, which means that the current transforms will be used to handle both 'train' and 'eval' groups. """ if self.current_transform_group not in self.transform_groups: raise ValueError( "Invalid transformations group " + str(self.current_transform_group) ) t_group = self.transform_groups[self.current_transform_group] self.transform: XTransform = t_group[0] """ A function/transform that takes in an PIL image and returns a transformed version. """ self.target_transform: YTransform = t_group[1] """ A function/transform that takes in the target and transforms it. """ self._frozen_transforms: Dict[ str, Tuple[XTransform, YTransform] ] = dict() """ A dictionary containing frozen transformations. """ for group_name in self.transform_groups.keys(): self._frozen_transforms[group_name] = (None, None) self._set_original_dataset_transform_group(self.current_transform_group) self._flatten_dataset()
def __add__(self, other: Dataset) -> "AvalancheDataset": return AvalancheConcatDataset([self, other]) def __radd__(self, other: Dataset) -> "AvalancheDataset": return AvalancheConcatDataset([other, self]) def __getitem__(self, idx) -> Union[T_co, Sequence[T_co]]: return TupleTLabel( manage_advanced_indexing( idx, self._get_single_item, len(self), self.collate_fn ) ) def __len__(self): return len(self._dataset) def train(self): """ Returns a new dataset with the transformations of the 'train' group loaded. The current dataset will not be affected. :return: A new dataset with the training transformations loaded. """ return self.with_transforms("train") def eval(self): """ Returns a new dataset with the transformations of the 'eval' group loaded. Eval transformations usually don't contain augmentation procedures. This function may be useful when in need to test on training data (for instance, in order to run a validation pass). The current dataset will not be affected. :return: A new dataset with the eval transformations loaded. """ return self.with_transforms("eval") def freeze_transforms(self: TAvalancheDataset) -> TAvalancheDataset: """ Returns a new dataset where the current transformations are frozen. Frozen transformations will be permanently glued to the original dataset so that they can't be changed anymore. This is usually done when using transformations to create derived datasets: in this way freezing the transformations will ensure that the user won't be able to inadvertently change them by directly setting the transformations field or by using the other transformations utility methods like ``replace_transforms``. Please note that transformations of all groups will be frozen. If you want to freeze a specific group, please use ``freeze_group_transforms``. The current dataset will not be affected. :return: A new dataset with the current transformations frozen. """ dataset_copy = self._fork_dataset() for group_name in dataset_copy.transform_groups.keys(): AvalancheDataset._freeze_dataset_group(dataset_copy, group_name) return dataset_copy def freeze_group_transforms( self: TAvalancheDataset, group_name: str ) -> TAvalancheDataset: """ Returns a new dataset where the transformations for a specific group are frozen. Frozen transformations will be permanently glued to the original dataset so that they can't be changed anymore. This is usually done when using transformations to create derived datasets: in this way freezing the transformations will ensure that the user won't be able to inadvertently change them by directly setting the transformations field or by using the other transformations utility methods like ``replace_transforms``. To freeze transformations of all groups please use ``freeze_transforms``. The current dataset will not be affected. :return: A new dataset with the transformations frozen for the given group. """ dataset_copy = self._fork_dataset() AvalancheDataset._freeze_dataset_group(dataset_copy, group_name) return dataset_copy def get_transforms( self: TAvalancheDataset, transforms_group: str = None ) -> Tuple[Any, Any]: """ Returns the transformations given a group. Beware that this will not return the frozen transformations, nor the ones included in the wrapped dataset. Only transformations directly attached to this dataset will be returned. :param transforms_group: The transformations group. Defaults to None, which means that the current group is returned. :return: The transformation group, as a tuple (transform, target_transform). """ if transforms_group is None: transforms_group = self.current_transform_group if transforms_group == self.current_transform_group: return self.transform, self.target_transform return self.transform_groups[transforms_group] def add_transforms( self: TAvalancheDataset, transform: Callable[[Any], Any] = None, target_transform: Callable[[int], int] = None, ) -> TAvalancheDataset: """ Returns a new dataset with the given transformations added to the existing ones. The transformations will be added to the current transformations group. Other transformation groups will not be affected. The given transformations will be added "at the end" of previous transformations of the current transformations group. This means that existing transformations will be applied to the patterns first. The current dataset will not be affected. :param transform: A function/transform that takes the X value of a pattern from the original dataset and returns a transformed version. :param target_transform: A function/transform that takes in the target and transforms it. :return: A new dataset with the added transformations. """ dataset_copy = self._fork_dataset() if transform is not None: if dataset_copy.transform is not None: dataset_copy.transform = Compose( [dataset_copy.transform, transform] ) else: dataset_copy.transform = transform if target_transform is not None: if dataset_copy.target_transform is not None: dataset_copy.target_transform = Compose( [dataset_copy.target_transform, target_transform] ) else: dataset_copy.target_transform = target_transform return dataset_copy def add_transforms_to_group( self: TAvalancheDataset, group_name: str, transform: Callable[[Any], Any] = None, target_transform: Callable[[int], int] = None, ) -> TAvalancheDataset: """ Returns a new dataset with the given transformations added to the existing ones for a certain group. The transformations will be added to the given transformations group. Other transformation groups will not be affected. The group must already exist. The given transformations will be added "at the end" of previous transformations of that group. This means that existing transformations will be applied to the patterns first. The current dataset will not be affected. :param group_name: The name of the group. :param transform: A function/transform that takes the X value of a pattern from the original dataset and returns a transformed version. :param target_transform: A function/transform that takes in the target and transforms it. :return: A new dataset with the added transformations. """ if self.current_transform_group == group_name: return self.add_transforms(transform, target_transform) if group_name not in self.transform_groups: raise ValueError("Invalid group name " + str(group_name)) dataset_copy = self._fork_dataset() t_group: List[XTransform, YTransform] = list( dataset_copy.transform_groups[group_name] ) if transform is not None: if t_group[0] is not None: t_group[0] = Compose([t_group[0], transform]) else: t_group[0] = transform if target_transform is not None: if t_group[1] is not None: t_group[1] = Compose([t_group[1], target_transform]) else: t_group[1] = transform # tuple(t_group) works too, but it triggers a type warning tuple_t_group: Tuple[XTransform, YTransform] = tuple( (t_group[0], t_group[1]) ) dataset_copy.transform_groups[group_name] = tuple_t_group return dataset_copy def replace_transforms( self: TAvalancheDataset, transform: XTransform, target_transform: YTransform, group: str = None, ) -> TAvalancheDataset: """ Returns a new dataset with the existing transformations replaced with the given ones. The given transformations will replace the ones of the current transformations group. Other transformation groups will not be affected. If the original dataset is an instance of :class:`AvalancheDataset`, then transformations of the original set will be considered as well (the original dataset will be left untouched). The current dataset will not be affected. Note that this function will not override frozen transformations. This will also not affect transformations found in datasets that are not instances of :class:`AvalancheDataset`. :param transform: A function/transform that takes the X value of a pattern from the original dataset and returns a transformed version. :param target_transform: A function/transform that takes in the target and transforms it. :param group: The transforms group to replace. Defaults to None, which means that the current group will be replaced. :return: A new dataset with the new transformations. """ if group is None: group = self.current_transform_group dataset_copy = self._fork_dataset().with_transforms(group) dataset_copy._replace_original_dataset_group(None, None) dataset_copy.transform = transform dataset_copy.target_transform = target_transform return dataset_copy.with_transforms(self.current_transform_group) def with_transforms( self: TAvalancheDataset, group_name: str ) -> TAvalancheDataset: """ Returns a new dataset with the transformations of a different group loaded. The current dataset will not be affected. :param group_name: The name of the transformations group to use. :return: A new dataset with the new transformations. """ dataset_copy = self._fork_dataset() if group_name not in dataset_copy.transform_groups: raise ValueError("Invalid group name " + str(group_name)) # Store current group (loaded in transform and target_transform fields) dataset_copy.transform_groups[dataset_copy.current_transform_group] = ( dataset_copy.transform, dataset_copy.target_transform, ) # Load new group in transform and target_transform fields switch_group = dataset_copy.transform_groups[group_name] dataset_copy.transform = switch_group[0] dataset_copy.target_transform = switch_group[1] dataset_copy.current_transform_group = group_name # Finally, align the underlying dataset dataset_copy._set_original_dataset_transform_group(group_name) return dataset_copy def add_transforms_group( self: TAvalancheDataset, group_name: str, transform: XTransform, target_transform: YTransform, ) -> TAvalancheDataset: """ Returns a new dataset with a new transformations group. The current dataset will not be affected. This method raises an exception if a group with the same name already exists. :param group_name: The name of the new transformations group. :param transform: A function/transform that takes the X value of a pattern from the original dataset and returns a transformed version. :param target_transform: A function/transform that takes in the target and transforms it. :return: A new dataset with the new transformations. """ dataset_copy = self._fork_dataset() if group_name in dataset_copy.transform_groups: raise ValueError("A group with the same name already exists") dataset_copy.transform_groups[group_name] = ( transform, target_transform, ) AvalancheDataset._check_groups_dict_format( dataset_copy.transform_groups ) dataset_copy._frozen_transforms[group_name] = (None, None) # Finally, align the underlying dataset dataset_copy._add_original_dataset_group(group_name) return dataset_copy def _fork_dataset(self: TAvalancheDataset) -> TAvalancheDataset: dataset_copy = copy.copy(self) dataset_copy._frozen_transforms = dict(dataset_copy._frozen_transforms) dataset_copy.transform_groups = dict(dataset_copy.transform_groups) dataset_copy.task_set = dataset_copy._make_task_set_dict() return dataset_copy @staticmethod def _freeze_dataset_group(dataset_copy: TAvalancheDataset, group_name: str): # Freeze the current transformations. Frozen transformations are saved # in a separate dict. # This may rise an error if no group with the given name exists! frozen_group = dataset_copy._frozen_transforms[group_name] final_frozen_transform = frozen_group[0] final_frozen_target_transform = frozen_group[1] is_current_group = dataset_copy.current_transform_group == group_name # If the required group is not the current one, just freeze the ones # found in dataset_copy.transform_groups). to_be_glued = dataset_copy.transform_groups[group_name] # Now that transformations are stored in to_be_glued, # we can safely reset them in the transform_groups dictionary. dataset_copy.transform_groups[group_name] = (None, None) if is_current_group: # If the required group is the current one, use the transformations # already found in transform and target_transform fields (because # the ones stored in dataset_copy.transform_groups may be not # up-to-date). to_be_glued = ( dataset_copy.transform, dataset_copy.target_transform, ) # And of course, once frozen, set transformations to None dataset_copy.transform = None dataset_copy.target_transform = None if to_be_glued[0] is not None: if frozen_group[0] is None: final_frozen_transform = to_be_glued[0] else: final_frozen_transform = Compose( [frozen_group[0], to_be_glued[0]] ) if to_be_glued[1] is not None: if frozen_group[1] is None: final_frozen_target_transform = to_be_glued[1] else: final_frozen_target_transform = Compose( [frozen_group[1], to_be_glued[1]] ) # Set the frozen transformations dataset_copy._frozen_transforms[group_name] = ( final_frozen_transform, final_frozen_target_transform, ) # Finally, apply the freeze procedure to the original dataset dataset_copy._freeze_original_dataset(group_name) def _get_single_item(self, idx: int): return self._process_pattern(self._dataset[idx], idx) def _process_pattern(self, element: Tuple, idx: int): has_task_label = isinstance(element, TupleTLabel) if has_task_label: element = element[:-1] element = self._apply_transforms(element) return TupleTLabel((*element, self.targets_task_labels[idx])) def _apply_transforms(self, element: Sequence[Any]): element = list(element) frozen_group = self._frozen_transforms[self.current_transform_group] # Target transform if frozen_group[1] is not None: element[1] = frozen_group[1](element[1]) if self.target_transform is not None: element[1] = self.target_transform(element[1]) if frozen_group[0] is not None: element = MultiParamTransform(frozen_group[0])(*element) if self.transform is not None: element = MultiParamTransform(self.transform)(*element) return element @staticmethod def _check_groups_dict_format(groups_dict): # The original groups_dict must be convertible to native Python dict groups_dict = dict(groups_dict) # Check if the format of the groups is correct for map_key in groups_dict: if not isinstance(map_key, str): raise ValueError( "Every group must be identified by a string." 'Wrong key was: "' + str(map_key) + '"' ) map_value = groups_dict[map_key] if not isinstance(map_value, tuple): raise ValueError( 'Transformations for group "' + str(map_key) + '" must be contained in a tuple' ) if not len(map_value) == 2: raise ValueError( 'Transformations for group "' + str(map_key) + '" must be ' "a tuple containing 2 elements: a transformation for the X " "values and a transformation for the Y values" ) if "test" in groups_dict: warnings.warn( 'A transformation group named "test" has been found. Beware ' "that by default AvalancheDataset supports test transformations" ' through the "eval" group. Consider using that one!' ) def _initialize_groups_dict( self, transform_groups: Optional[Dict[str, TransformGroupDef]], dataset: Any, transform: XTransform, target_transform: YTransform, ) -> Dict[str, Tuple[XTransform, YTransform]]: """ A simple helper method that tries to fill the 'train' and 'eval' groups as those two groups must always exist. If no transform_groups are passed to the class constructor, then the transform and target_transform parameters are used for both groups. If train transformations are set and eval transformations are not, then train transformations will be used for the eval group. :param dataset: The original dataset. Will be used to detect existing groups. :param transform: The transformation passed as a parameter to the class constructor. :param target_transform: The target transformation passed as a parameter to the class constructor. """ if transform_groups is None: transform_groups = { "train": (transform, target_transform), "eval": (transform, target_transform), } else: transform_groups = dict(transform_groups) for group_name, group_transforms in dict(transform_groups).items(): if group_transforms is None: transform_groups[group_name] = (None, None) elif isinstance(group_transforms, Callable): # Single transformation: (safely) assume it is the X transform transform_groups[group_name] = (group_transforms, None) elif ( isinstance(group_transforms, Sequence) and len(group_transforms) == 2 ): # X and Y transforms transform_groups[group_name] = ( group_transforms[0], group_transforms[1], ) else: raise ValueError( f"Unsupported transformations for group {group_name}. " f"The transformation group may be None, a single Callable, " f"or a tuple of 2 elements containing the X and Y " f"transforms" ) if "train" in transform_groups: if "eval" not in transform_groups: transform_groups["eval"] = transform_groups["train"] if "train" not in transform_groups: transform_groups["train"] = (None, None) if "eval" not in transform_groups: transform_groups["eval"] = (None, None) self._add_groups_from_original_dataset(dataset, transform_groups) return transform_groups def _initialize_targets_sequence( self, dataset, targets, dataset_type, targets_adapter ) -> Sequence[TTargetType]: if targets is not None: # User defined targets always take precedence # Note: no adapter is applied! if len(targets) != len(dataset): raise ValueError( "Invalid amount of target labels. It must be equal to the " "number of patterns in the dataset. Got {}, expected " "{}!".format(len(targets), len(dataset)) ) return targets if targets_adapter is None: if dataset_type == AvalancheDatasetType.CLASSIFICATION: targets_adapter = int else: targets_adapter = None return _make_target_from_supported_dataset(dataset, targets_adapter) def _initialize_task_labels_sequence( self, dataset, task_labels: Optional[Sequence[int]] ) -> Sequence[int]: if task_labels is not None: # task_labels has priority over the dataset fields if isinstance(task_labels, int): task_labels = ConstantSequence(task_labels, len(dataset)) elif len(task_labels) != len(dataset): raise ValueError( "Invalid amount of task labels. It must be equal to the " "number of patterns in the dataset. Got {}, expected " "{}!".format(len(task_labels), len(dataset)) ) return SubSequence(task_labels, converter=int) return _make_task_labels_from_supported_dataset(dataset) def _initialize_collate_fn(self, dataset, dataset_type, collate_fn): if collate_fn is not None: return collate_fn if dataset_type == AvalancheDatasetType.UNDEFINED: if hasattr(dataset, "collate_fn"): return getattr(dataset, "collate_fn") return default_collate def _initialize_tasks_dict( self, dataset, task_labels: Sequence[int] ) -> Dict[int, Sequence[int]]: if isinstance(task_labels, ConstantSequence) and len(task_labels) > 0: # Shortcut :) return {task_labels[0]: range(len(task_labels))} result = dict() for i, x in enumerate(task_labels): if x not in result: result[x] = [] result[x].append(i) if len(result) == 1: result[next(iter(result.keys()))] = range(len(task_labels)) return result def _set_original_dataset_transform_group(self, group_name: str) -> None: if isinstance(self._dataset, AvalancheDataset): if self._dataset.current_transform_group == group_name: # Prevents a huge slowdown in some corner cases # (apart from being actually more performant) return self._dataset = self._dataset.with_transforms(group_name) def _freeze_original_dataset(self, group_name: str) -> None: if isinstance(self._dataset, AvalancheDataset): self._dataset = self._dataset.freeze_group_transforms(group_name) def _replace_original_dataset_group( self, transform: XTransform, target_transform: YTransform ) -> None: if isinstance(self._dataset, AvalancheDataset): self._dataset = self._dataset.replace_transforms( transform, target_transform ) def _add_original_dataset_group(self, group_name: str) -> None: if isinstance(self._dataset, AvalancheDataset): self._dataset = self._dataset.add_transforms_group( group_name, None, None ) def _add_groups_from_original_dataset( self, dataset, transform_groups ) -> None: if isinstance(dataset, AvalancheDataset): for original_dataset_group in dataset.transform_groups.keys(): if original_dataset_group not in transform_groups: transform_groups[original_dataset_group] = (None, None) def _has_own_transformations(self): # Used to check if the current dataset has its own transformations # This method returns False if transformations are applied # by the wrapped dataset only. if self.transform is not None: return True if self.target_transform is not None: return True for transform_group in self.transform_groups.values(): for transform in transform_group: if transform is not None: return True for transform_group in self._frozen_transforms.values(): for transform in transform_group: if transform is not None: return True return False @staticmethod def _borrow_transformations( dataset, transform_groups, frozen_transform_groups ): if not isinstance(dataset, AvalancheDataset): return for original_dataset_group in dataset.transform_groups.keys(): if original_dataset_group not in transform_groups: transform_groups[original_dataset_group] = (None, None) for original_dataset_group in dataset._frozen_transforms.keys(): if original_dataset_group not in frozen_transform_groups: frozen_transform_groups[original_dataset_group] = (None, None) # Standard transforms for original_dataset_group in dataset.transform_groups.keys(): other_dataset_transforms = dataset.transform_groups[ original_dataset_group ] if dataset.current_transform_group == original_dataset_group: other_dataset_transforms = ( dataset.transform, dataset.target_transform, ) transform_groups[ original_dataset_group ] = AvalancheDataset._prepend_transforms( transform_groups[original_dataset_group], other_dataset_transforms, ) # Frozen transforms for original_dataset_group in dataset._frozen_transforms.keys(): other_dataset_transforms = dataset._frozen_transforms[ original_dataset_group ] frozen_transform_groups[ original_dataset_group ] = AvalancheDataset._prepend_transforms( frozen_transform_groups[original_dataset_group], other_dataset_transforms, ) @staticmethod def _prepend_transforms(transforms, to_be_prepended): if len(transforms) != 2: raise ValueError( "Transformation groups must contain exactly 2 transformations" ) if len(transforms) != len(to_be_prepended): raise ValueError( "Transformation group size mismatch: {} vs {}".format( len(transforms), len(to_be_prepended) ) ) result = [] for i in range(len(transforms)): if to_be_prepended[i] is None: # Nothing to prepend result.append(transforms[i]) elif transforms[i] is None: result.append(to_be_prepended[i]) else: result.append(Compose([to_be_prepended[i], transforms[i]])) return tuple(result) # Transform to tuple def _optimize_targets(self): self.targets = optimize_sequence(self.targets) def _optimize_task_labels(self): self.targets_task_labels = optimize_sequence(self.targets_task_labels) def _optimize_task_dict(self): for task_label in self.tasks_pattern_indices: self.tasks_pattern_indices[task_label] = optimize_sequence( self.tasks_pattern_indices[task_label] ) def _flatten_dataset(self): pass def _make_task_set_dict(self) -> Dict[int, "AvalancheDataset"]: task_dict = _TaskSubsetDict() for task_id in sorted(self.tasks_pattern_indices.keys()): task_indices = self.tasks_pattern_indices[task_id] task_dict[task_id] = (self, task_indices) return task_dict
class _TaskSubsetDict(OrderedDict): def __getitem__(self, item): avl_dataset, indices = super().__getitem__(item) return _TaskSubsetDict._make_subset(avl_dataset, indices) @staticmethod def _make_subset(avl_dataset, indices: Sequence[int]): return AvalancheSubset(avl_dataset, indices=indices)
[docs]class AvalancheSubset(AvalancheDataset[T_co, TTargetType]): """ A Dataset that behaves like a PyTorch :class:`torch.utils.data.Subset`. This Dataset also supports transformations, slicing, advanced indexing, the targets field, class mapping and all the other goodies listed in :class:`AvalancheDataset`. """
[docs] def __init__( self, dataset: SupportedDataset, indices: Sequence[int] = None, *, class_mapping: Sequence[int] = None, transform: Callable[[Any], Any] = None, target_transform: Callable[[int], int] = None, transform_groups: Dict[str, Tuple[XTransform, YTransform]] = None, initial_transform_group: str = None, task_labels: Union[int, Sequence[int]] = None, targets: Sequence[TTargetType] = None, dataset_type: AvalancheDatasetType = None, collate_fn: Callable[[List], Any] = None, targets_adapter: Callable[[Any], TTargetType] = None, ): """ Creates an ``AvalancheSubset`` instance. :param dataset: The whole dataset. :param indices: Indices in the whole set selected for subset. Can be None, which means that the whole dataset will be returned. :param class_mapping: A list that, for each possible target (Y) value, contains its corresponding remapped value. Can be None. Beware that setting this parameter will force the final dataset type to be CLASSIFICATION or UNDEFINED. :param transform: A function/transform that takes the X value of a pattern from the original dataset and returns a transformed version. :param target_transform: A function/transform that takes in the target and transforms it. :param transform_groups: A dictionary containing the transform groups. Transform groups are used to quickly switch between training and eval (test) transformations. This becomes useful when in need to test on the training dataset as test transformations usually don't contain random augmentations. ``AvalancheDataset`` natively supports the 'train' and 'eval' groups by calling the ``train()`` and ``eval()`` methods. When using custom groups one can use the ``with_transforms(group_name)`` method instead. Defaults to None, which means that the current transforms will be used to handle both 'train' and 'eval' groups (just like in standard ``torchvision`` datasets). :param initial_transform_group: The name of the initial transform group to be used. Defaults to None, which means that the current group of the input dataset will be used (if an AvalancheDataset). If the input dataset is not an AvalancheDataset, then 'train' will be used. :param task_labels: The task label for each instance. Must be a sequence of ints, one for each instance in the dataset. This can either be a list of task labels for the original dataset or the list of task labels for the instances of the subset (an automatic detection will be made). In the unfortunate case in which the original dataset and the subset contain the same amount of instances, then this parameter is considered to contain the task labels of the subset. Alternatively can be a single int value, in which case that value will be used as the task label for all the instances. Defaults to None, which means that the dataset will try to obtain the task labels from the original dataset. If no task labels could be found, a default task label "0" will be applied to all instances. :param targets: The label of each pattern. Defaults to None, which means that the targets will be retrieved from the dataset (if possible). This can either be a list of target labels for the original dataset or the list of target labels for the instances of the subset (an automatic detection will be made). In the unfortunate case in which the original dataset and the subset contain the same amount of instances, then this parameter is considered to contain the target labels of the subset. :param dataset_type: The type of the dataset. Defaults to None, which means that the type will be inferred from the input dataset. When the `dataset_type` is different than UNDEFINED, a proper value for `collate_fn` and `targets_adapter` will be set. If the `dataset_type` is different than UNDEFINED, then `collate_fn` and `targets_adapter` must not be set. The only exception to this rule regards `class_mapping`. If `class_mapping` is set, the final dataset_type (as set by this parameter or detected from the subset) must be CLASSIFICATION or UNDEFINED. :param collate_fn: The function to use when slicing to merge single patterns. In the future this function may become the function used in the data loading process, too. If None and the `dataset_type` is UNDEFINED, the constructor will check if a `collate_fn` field exists in the dataset. If no such field exists, the default collate function will be used. :param targets_adapter: A function used to convert the values of the targets field. Defaults to None. Note: the adapter will not change the value of the second element returned by `__getitem__`. The adapter is used to adapt the values of the targets field only. """ detected_type = False if dataset_type is None: detected_type = True if isinstance(dataset, AvalancheDataset): dataset_type = dataset.dataset_type if dataset_type == AvalancheDatasetType.UNDEFINED: if collate_fn is None: collate_fn = dataset.collate_fn else: dataset_type = AvalancheDatasetType.UNDEFINED if dataset_type != AvalancheDatasetType.UNDEFINED and ( collate_fn is not None or targets_adapter is not None ): if detected_type: raise ValueError( "dataset_type {} was inferred from the input dataset. " "This dataset type can't be used " "with custom collate_fn or targets_adapter " "parameters. Only the UNDEFINED type supports " "custom collate_fn or targets_adapter " "parameters".format(dataset_type) ) else: raise ValueError( "dataset_type {} can't be used with custom collate_fn " "or targets_adapter. Only the UNDEFINED type supports " "custom collate_fn or targets_adapter " "parameters.".format(dataset_type) ) if class_mapping is not None: if dataset_type not in [ AvalancheDatasetType.CLASSIFICATION, AvalancheDatasetType.UNDEFINED, ]: raise ValueError( "class_mapping is defined but the dataset type" " is neither CLASSIFICATION or UNDEFINED." ) if class_mapping is not None: subset = ClassificationSubset( dataset, indices=indices, class_mapping=class_mapping ) elif indices is not None: subset = Subset(dataset, indices=indices) else: subset = dataset # Exactly like a plain AvalancheDataset self._original_dataset = dataset # self._indices and self._class_mapping currently not used apart from # initialization procedures self._class_mapping = class_mapping self._indices = indices if initial_transform_group is None: if isinstance(dataset, AvalancheDataset): initial_transform_group = dataset.current_transform_group else: initial_transform_group = "train" super().__init__( subset, transform=transform, target_transform=target_transform, transform_groups=transform_groups, initial_transform_group=initial_transform_group, task_labels=task_labels, targets=targets, dataset_type=dataset_type, collate_fn=collate_fn, targets_adapter=targets_adapter, )
def _initialize_targets_sequence( self, dataset, targets, dataset_type, targets_adapter ) -> Sequence[TTargetType]: if targets is not None: # For the reasoning behind this, have a look at # _initialize_task_labels_sequence (it's basically the same). if len(targets) == len(self._original_dataset) and not len( targets ) == len(dataset): return SubSequence(targets, indices=self._indices) elif len(targets) == len(dataset): return targets else: raise ValueError( "Invalid amount of targets. It must be equal to the " "number of patterns in the subset. " "Got {}, expected {}!".format(len(targets), len(dataset)) ) return super()._initialize_targets_sequence( dataset, None, dataset_type, targets_adapter ) def _initialize_task_labels_sequence( self, dataset, task_labels: Optional[Sequence[int]] ) -> Sequence[int]: if task_labels is not None: # The task_labels parameter is kind of ambiguous... # it may either be the list of task labels of the required subset # or it may be the list of task labels of the original dataset. # Simple solution: check the length of task_labels! # However, if task_labels, the original dataset and this subset have # the same size, then task_labels is considered to contain the task # labels for the subset! if isinstance(task_labels, int): # Simplest case: constant task label return ConstantSequence(task_labels, len(dataset)) elif len(task_labels) == len(self._original_dataset) and not len( task_labels ) == len(dataset): # task_labels refers to the original dataset ... return SubSequence( task_labels, indices=self._indices, converter=int ) elif len(task_labels) == len(dataset): # One label for each instance return SubSequence(task_labels, converter=int) else: raise ValueError( "Invalid amount of task labels. It must be equal to the " "number of patterns in the subset. " "Got {}, expected {}!".format( len(task_labels), len(dataset) ) ) return super()._initialize_task_labels_sequence(dataset, None) def _set_original_dataset_transform_group(self, group_name: str) -> None: if isinstance(self._original_dataset, AvalancheDataset): if self._original_dataset.current_transform_group == group_name: # Prevents a huge slowdown in some corner cases # (apart from being actually more performant) return self._original_dataset = self._original_dataset.with_transforms( group_name ) self._replace_original_dataset_reference() def _freeze_original_dataset(self, group_name: str) -> None: if isinstance(self._original_dataset, AvalancheDataset): self._original_dataset = ( self._original_dataset.freeze_group_transforms(group_name) ) self._replace_original_dataset_reference() def _replace_original_dataset_group( self, transform: XTransform, target_transform: YTransform ) -> None: if isinstance(self._original_dataset, AvalancheDataset): self._original_dataset = self._original_dataset.replace_transforms( transform, target_transform ) self._replace_original_dataset_reference() def _add_original_dataset_group(self, group_name: str) -> None: if isinstance(self._original_dataset, AvalancheDataset): self._original_dataset = ( self._original_dataset.add_transforms_group( group_name, None, None ) ) self._replace_original_dataset_reference() def _add_groups_from_original_dataset( self, dataset, transform_groups ) -> None: if isinstance(self._original_dataset, AvalancheDataset): for ( original_dataset_group ) in self._original_dataset.transform_groups.keys(): if original_dataset_group not in transform_groups: transform_groups[original_dataset_group] = (None, None) def _flatten_dataset(self): # Flattens this subset by borrowing indices and class mappings from # the original dataset (if it's an AvalancheSubset or PyTorch Subset) if isinstance(self._original_dataset, AvalancheSubset): # In order to flatten the subset, we have to integrate the # transformations (also frozen ones!) AvalancheDataset._borrow_transformations( self._original_dataset, self.transform_groups, self._frozen_transforms, ) # We need to reload transformations after borrowing from the subset # This assumes that _flatten_dataset is called by __init__! self.transform, self.target_transform = self.transform_groups[ self.current_transform_group ] forward_dataset = self._original_dataset._original_dataset forward_indices = self._original_dataset._indices forward_class_mapping = self._original_dataset._class_mapping if self._class_mapping is not None: if forward_class_mapping is not None: new_class_mapping = [] for mapped_class in forward_class_mapping: # -1 is sometimes used to mark unused classes # shouldn't be a problem (if it is, is not our fault) if mapped_class == -1: forward_mapped = -1 else: # forward_mapped may be -1 forward_mapped = self._class_mapping[mapped_class] new_class_mapping.append(forward_mapped) else: new_class_mapping = self._class_mapping else: new_class_mapping = forward_class_mapping # May be None if self._indices is not None: if forward_indices is not None: new_indices = [forward_indices[x] for x in self._indices] else: new_indices = self._indices else: new_indices = forward_indices # May be None self._original_dataset = forward_dataset self._indices = new_indices self._class_mapping = new_class_mapping if new_class_mapping is None: # Subset self._dataset = Subset(forward_dataset, indices=new_indices) else: # ClassificationSubset self._dataset = ClassificationSubset( forward_dataset, indices=new_indices, class_mapping=new_class_mapping, ) # -------- # Flattening PyTorch Subset has been temporarily # disabled as the semantic of transformation groups collide # with the flattening process: PyTorch Subset doesn't have # transform groups and flattening it will expose the underlying # dataset, which may contain 'AvalancheDataset's. # -------- # elif isinstance(self._original_dataset, Subset): # # Very simple: just borrow indices (no transformations or # # class mappings to adapt here!) # forward_dataset = self._original_dataset.dataset # forward_indices = self._original_dataset.indices # # if self._indices is not None: # new_indices = [forward_indices[x] for x in self._indices] # else: # new_indices = forward_indices # # self._original_dataset = forward_dataset # self._indices = new_indices # # if self._class_mapping is not None: # self._dataset = ClassificationSubset( # forward_dataset, indices=new_indices, # class_mapping=self._class_mapping) # # elif self._indices is not None: # self._dataset = Subset(forward_dataset, indices=new_indices) def _replace_original_dataset_reference(self): if isinstance(self._dataset, ClassificationSubset): self._dataset = ClassificationSubset( self._original_dataset, indices=self._indices, class_mapping=self._class_mapping, ) elif isinstance(self._dataset, Subset): self._dataset = Subset( self._original_dataset, indices=self._indices ) else: self._dataset = self._original_dataset
[docs]class AvalancheTensorDataset(AvalancheDataset[T_co, TTargetType]): """ A Dataset that wraps existing ndarrays, Tensors, lists... to provide basic Dataset functionalities. Very similar to TensorDataset from PyTorch, this Dataset also supports transformations, slicing, advanced indexing, the targets field and all the other goodies listed in :class:`AvalancheDataset`. """
[docs] def __init__( self, *dataset_tensors: Sequence, transform: Callable[[Any], Any] = None, target_transform: Callable[[int], int] = None, transform_groups: Dict[str, Tuple[XTransform, YTransform]] = None, initial_transform_group: str = "train", task_labels: Union[int, Sequence[int]] = None, targets: Union[Sequence[TTargetType], int] = None, dataset_type: AvalancheDatasetType = AvalancheDatasetType.UNDEFINED, collate_fn: Callable[[List], Any] = None, targets_adapter: Callable[[Any], TTargetType] = None, ): """ Creates a ``AvalancheTensorDataset`` instance. :param dataset_tensors: Sequences, Tensors or ndarrays representing the content of the dataset. :param transform: A function/transform that takes in a single element from the first tensor and returns a transformed version. :param target_transform: A function/transform that takes a single element of the second tensor and transforms it. :param transform_groups: A dictionary containing the transform groups. Transform groups are used to quickly switch between training and eval (test) transformations. This becomes useful when in need to test on the training dataset as test transformations usually don't contain random augmentations. ``AvalancheDataset`` natively supports the 'train' and 'eval' groups by calling the ``train()`` and ``eval()`` methods. When using custom groups one can use the ``with_transforms(group_name)`` method instead. Defaults to None, which means that the current transforms will be used to handle both 'train' and 'eval' groups (just like in standard ``torchvision`` datasets). :param initial_transform_group: The name of the transform group to be used. Defaults to 'train'. :param task_labels: The task labels for each pattern. Must be a sequence of ints, one for each pattern in the dataset. Alternatively can be a single int value, in which case that value will be used as the task label for all the instances. Defaults to None, which means that a default task label "0" will be applied to all patterns. :param targets: The label of each pattern. Defaults to None, which means that the targets will be retrieved from the dataset. Otherwise, can be 1) a sequence of values containing as many elements as the number of patterns, or 2) the index of the sequence to use as the targets field. When using the default value of None, the targets field will be populated using the second tensor. If dataset is made of only one tensor, then that tensor will be used for the targets field, too. :param dataset_type: The type of the dataset. Defaults to UNDEFINED. Setting this parameter will automatically set a proper value for `collate_fn` and `targets_adapter`. If this parameter is set to a value different from UNDEFINED then `collate_fn` and `targets_adapter` must not be set. :param collate_fn: The function to use when slicing to merge single patterns. In the future this function may become the function used in the data loading process, too. :param targets_adapter: A function used to convert the values of the targets field. Defaults to None. Note: the adapter will not change the value of the second element returned by `__getitem__`. The adapter is used to adapt the values of the targets field only. """ if dataset_type != AvalancheDatasetType.UNDEFINED and ( collate_fn is not None or targets_adapter is not None ): raise ValueError( "dataset_type {} can't be used with custom collate_fn " "or targets_adapter. Only the UNDEFINED type supports " "custom collate_fn or targets_adapter " "parameters.".format(dataset_type) ) if len(dataset_tensors) < 1: raise ValueError("At least one sequence must be passed") if targets is None: targets = min(1, len(dataset_tensors)) if isinstance(targets, int): base_dataset = SequenceDataset(*dataset_tensors, targets=targets) targets = None else: base_dataset = SequenceDataset( *dataset_tensors, targets=min(1, len(dataset_tensors)) ) super().__init__( base_dataset, transform=transform, target_transform=target_transform, transform_groups=transform_groups, initial_transform_group=initial_transform_group, task_labels=task_labels, targets=targets, dataset_type=dataset_type, collate_fn=collate_fn, targets_adapter=targets_adapter, )
[docs]class AvalancheConcatDataset(AvalancheDataset[T_co, TTargetType]): """ A Dataset that behaves like a PyTorch :class:`torch.utils.data.ConcatDataset`. However, this Dataset also supports transformations, slicing, advanced indexing and the targets field and all the other goodies listed in :class:`AvalancheDataset`. This dataset guarantees that the operations involving the transformations and transformations groups are consistent across the concatenated dataset (if they are subclasses of :class:`AvalancheDataset`). """
[docs] def __init__( self, datasets: Collection[SupportedDataset], *, transform: Callable[[Any], Any] = None, target_transform: Callable[[int], int] = None, transform_groups: Dict[str, Tuple[XTransform, YTransform]] = None, initial_transform_group: str = None, task_labels: Union[int, Sequence[int], Sequence[Sequence[int]]] = None, targets: Union[ Sequence[TTargetType], Sequence[Sequence[TTargetType]] ] = None, dataset_type: AvalancheDatasetType = None, collate_fn: Callable[[List], Any] = None, targets_adapter: Callable[[Any], TTargetType] = None, ): """ Creates a ``AvalancheConcatDataset`` instance. :param datasets: A collection of datasets. :param transform: A function/transform that takes the X value of a pattern from the original dataset and returns a transformed version. :param target_transform: A function/transform that takes in the target and transforms it. :param transform_groups: A dictionary containing the transform groups. Transform groups are used to quickly switch between training and eval (test) transformations. This becomes useful when in need to test on the training dataset as test transformations usually don't contain random augmentations. ``AvalancheDataset`` natively supports the 'train' and 'eval' groups by calling the ``train()`` and ``eval()`` methods. When using custom groups one can use the ``with_transforms(group_name)`` method instead. Defaults to None, which means that the current transforms will be used to handle both 'train' and 'eval' groups (just like in standard ``torchvision`` datasets). :param initial_transform_group: The name of the initial transform group to be used. Defaults to None, which means that if all AvalancheDatasets in the input datasets list agree on a common group (the "current group" is the same for all datasets), then that group will be used as the initial one. If the list of input datasets does not contain an AvalancheDataset or if the AvalancheDatasets do not agree on a common group, then 'train' will be used. :param targets: The label of each pattern. Can either be a sequence of labels or, alternatively, a sequence containing sequences of labels (one for each dataset to be concatenated). Defaults to None, which means that the targets will be retrieved from the datasets (if possible). :param task_labels: The task labels for each pattern. Must be a sequence of ints, one for each pattern in the dataset. Alternatively, task labels can be expressed as a sequence containing sequences of ints (one for each dataset to be concatenated) or even a single int, in which case that value will be used as the task label for all instances. Defaults to None, which means that the dataset will try to obtain the task labels from the original datasets. If no task labels could be found for a dataset, a default task label "0" will be applied to all patterns of that dataset. :param dataset_type: The type of the dataset. Defaults to None, which means that the type will be inferred from the list of input datasets. When `dataset_type` is None and the list of datasets contains incompatible types, an error will be raised. A list of datasets is compatible if they all have the same type. Datasets that are not instances of `AvalancheDataset` and instances of `AvalancheDataset` with type `UNDEFINED` are always compatible with other types. When the `dataset_type` is different than UNDEFINED, a proper value for `collate_fn` and `targets_adapter` will be set. If the `dataset_type` is different than UNDEFINED, then `collate_fn` and `targets_adapter` must not be set. :param collate_fn: The function to use when slicing to merge single patterns. In the future this function may become the function used in the data loading process, too. If None, the constructor will check if a `collate_fn` field exists in the first dataset. If no such field exists, the default collate function will be used. Beware that the chosen collate function will be applied to all the concatenated datasets even if a different collate is defined in different datasets. :param targets_adapter: A function used to convert the values of the targets field. Defaults to None. Note: the adapter will not change the value of the second element returned by `__getitem__`. The adapter is used to adapt the values of the targets field only. """ dataset_list = list(datasets) ( dataset_type, collate_fn, targets_adapter, ) = self._get_dataset_type_collate_and_adapter( dataset_list, dataset_type, collate_fn, targets_adapter ) self._dataset_list = dataset_list self._datasets_lengths = [len(dataset) for dataset in dataset_list] self._datasets_cumulative_lengths = ConcatDataset.cumsum(dataset_list) self._overall_length = sum(self._datasets_lengths) if initial_transform_group is None: uniform_group = None for d_set in self._dataset_list: if isinstance(d_set, AvalancheDataset): if uniform_group is None: uniform_group = d_set.current_transform_group else: if uniform_group != d_set.current_transform_group: uniform_group = None break if uniform_group is None: initial_transform_group = "train" else: initial_transform_group = uniform_group if task_labels is not None: task_labels = self._concat_task_labels(task_labels) if targets is not None: targets = self._concat_targets(targets) self._adapt_concat_datasets() super().__init__( ClassificationDataset(), # not used transform=transform, target_transform=target_transform, transform_groups=transform_groups, initial_transform_group=initial_transform_group, task_labels=task_labels, targets=targets, dataset_type=dataset_type, collate_fn=collate_fn, targets_adapter=targets_adapter, )
def _get_dataset_type_collate_and_adapter( self, datasets, dataset_type, collate_fn, targets_adapter ): if dataset_type is not None: return dataset_type, collate_fn, targets_adapter identified_types = set() first_collate_fn = None for dataset in datasets: if isinstance(dataset, AvalancheDataset): if dataset.dataset_type != AvalancheDatasetType.UNDEFINED: identified_types.add(dataset.dataset_type) if first_collate_fn is None: first_collate_fn = getattr(dataset, "collate_fn", None) if len(identified_types) > 1: raise ValueError( "Error trying to infer a common dataset type while " "concatenating different datasets. " "Incompatible types: {}".format(list(identified_types)) ) elif len(identified_types) == 0: dataset_type = AvalancheDatasetType.UNDEFINED else: # len(identified_types) == 1 dataset_type = next(iter(identified_types)) if dataset_type != AvalancheDatasetType.UNDEFINED and ( collate_fn is not None or targets_adapter is not None ): raise ValueError( "dataset_type {} was inferred from the list of " "concatenated dataset. This dataset type can't be used " "with custom collate_fn or targets_adapter " "parameters. Only the UNDEFINED type supports " "custom collate_fn or targets_adapter " "parameters.".format(dataset_type) ) if ( collate_fn is None and dataset_type == AvalancheDatasetType.UNDEFINED ): collate_fn = first_collate_fn return dataset_type, collate_fn, targets_adapter def __len__(self) -> int: return self._overall_length def _get_single_item(self, idx: int): dataset_idx, internal_idx = find_list_from_index( idx, self._datasets_lengths, self._overall_length, cumulative_sizes=self._datasets_cumulative_lengths, ) single_element = self._dataset_list[dataset_idx][internal_idx] return self._process_pattern(single_element, idx) def _fork_dataset(self: TAvalancheDataset) -> TAvalancheDataset: dataset_copy = super()._fork_dataset() dataset_copy._dataset_list = list(dataset_copy._dataset_list) # Note: there is no need to duplicate _datasets_lengths return dataset_copy def _initialize_targets_sequence( self, dataset, targets, dataset_type, targets_adapter ) -> Sequence[TTargetType]: if targets is not None: if len(targets) != self._overall_length: raise ValueError( "Invalid amount of target labels. It must be equal to the " "number of patterns in the dataset. Got {}, expected " "{}!".format(len(targets), self._overall_length) ) return targets targets_list = [] # Could be easily done with a single line of code # This however, allows the user to better check which was the # problematic dataset by using a debugger. for dataset_idx, single_dataset in enumerate(self._dataset_list): targets_list.append( super()._initialize_targets_sequence( single_dataset, None, dataset_type, targets_adapter ) ) return LazyConcatTargets(targets_list) def _initialize_task_labels_sequence( self, dataset, task_labels: Optional[Sequence[int]] ) -> Sequence[int]: if task_labels is not None: # task_labels has priority over the dataset fields if isinstance(task_labels, int): return ConstantSequence(task_labels, self._overall_length) elif len(task_labels) != self._overall_length: raise ValueError( "Invalid amount of task labels. It must be equal to the " "number of patterns in the dataset. Got {}, expected " "{}!".format(len(task_labels), self._overall_length) ) return SubSequence(task_labels, converter=int) concat_t_labels = [] for dataset_idx, single_dataset in enumerate(self._dataset_list): concat_t_labels.append( super()._initialize_task_labels_sequence(single_dataset, None) ) return LazyConcatTargets(concat_t_labels) def _initialize_collate_fn(self, dataset, dataset_type, collate_fn): if collate_fn is not None: return collate_fn if len(self._dataset_list) > 0 and hasattr( self._dataset_list[0], "collate_fn" ): return getattr(self._dataset_list[0], "collate_fn") return default_collate def _set_original_dataset_transform_group(self, group_name: str) -> None: for dataset_idx, dataset in enumerate(self._dataset_list): if isinstance(dataset, AvalancheDataset): if dataset.current_transform_group == group_name: # Prevents a huge slowdown in some corner cases # (apart from being actually more performant) continue self._dataset_list[dataset_idx] = dataset.with_transforms( group_name ) def _freeze_original_dataset(self, group_name: str) -> None: for dataset_idx, dataset in enumerate(self._dataset_list): if isinstance(dataset, AvalancheDataset): self._dataset_list[ dataset_idx ] = dataset.freeze_group_transforms(group_name) def _replace_original_dataset_group( self, transform: XTransform, target_transform: YTransform ) -> None: for dataset_idx, dataset in enumerate(self._dataset_list): if isinstance(dataset, AvalancheDataset): self._dataset_list[dataset_idx] = dataset.replace_transforms( transform, target_transform ) def _add_original_dataset_group(self, group_name: str) -> None: for dataset_idx, dataset in enumerate(self._dataset_list): if isinstance(dataset, AvalancheDataset): self._dataset_list[dataset_idx] = dataset.add_transforms_group( group_name, None, None ) def _add_groups_from_original_dataset( self, dataset, transform_groups ) -> None: for dataset_idx, dataset in enumerate(self._dataset_list): if isinstance(dataset, AvalancheDataset): for original_dataset_group in dataset.transform_groups.keys(): if original_dataset_group not in transform_groups: transform_groups[original_dataset_group] = (None, None) def _adapt_concat_datasets(self): all_groups = set() for dataset in self._dataset_list: if isinstance(dataset, AvalancheDataset): all_groups.update(dataset.transform_groups.keys()) for dataset_idx, dataset in enumerate(self._dataset_list): if isinstance(dataset, AvalancheDataset): for group_name in all_groups: if group_name not in dataset.transform_groups: self._dataset_list[ dataset_idx ] = dataset.add_transforms_group(group_name, None, None) @staticmethod def _concat_task_labels( task_labels: Union[int, Sequence[int], Sequence[Sequence[int]]] ): if isinstance(task_labels, int): # A single value has been passed -> use it for all instances # The value is returned as is because it's already managed when in # this form (in _initialize_task_labels_sequence). return task_labels elif isinstance(task_labels[0], int): # Flat list of task labels -> just return it. # The constructor will check if it has the correct size. return task_labels else: # One list for each dataset, concat them. return LazyConcatIntTargets(task_labels) @staticmethod def _concat_targets( targets: Union[Sequence[TTargetType], Sequence[Sequence[TTargetType]]] ): if isinstance(targets[0], Sequence): return LazyConcatTargets(targets) else: return targets def _flatten_dataset(self): # Flattens this subset by borrowing the list of concatenated datasets # from the original datasets (if they're 'AvalancheConcatSubset's or # PyTorch 'ConcatDataset's) flattened_list = [] for dataset in self._dataset_list: if isinstance(dataset, AvalancheConcatDataset): if dataset._has_own_transformations(): # Can't flatten as the dataset has custom transformations flattened_list.append(dataset) else: flattened_list.extend(dataset._dataset_list) # PyTorch ConcatDataset doesn't have custom transformations # -------- # Flattening PyTorch ConcatDatasets has been temporarily # disabled as the semantic of transformation groups collide # with the flattening process: PyTorch ConcatDataset doesn't have # transform groups and flattening it will expose the underlying # concatenated datasets list, which may contain 'AvalancheDataset's. # -------- # elif isinstance(dataset, ConcatDataset): # flattened_list.extend(dataset.datasets) elif isinstance(dataset, AvalancheSubset): flattened_list += self._flatten_subset_concat_branch(dataset) else: flattened_list.append(dataset) self._dataset_list = flattened_list self._datasets_lengths = [len(dataset) for dataset in flattened_list] self._datasets_cumulative_lengths = ConcatDataset.cumsum(flattened_list) self._overall_length = sum(self._datasets_lengths) def _flatten_subset_concat_branch( self, dataset: AvalancheSubset ) -> List[Dataset]: """ Optimizes the dataset hierarchy in the corner case: self -> [Subset, Subset, ] -> ConcatDataset -> [Dataset] :param dataset: The dataset. This function returns [dataset] if the dataset is not a subset containing a concat dataset (or if other corner cases are encountered). :return: The flattened list of datasets to be concatenated. """ if not isinstance(dataset._original_dataset, AvalancheConcatDataset): return [dataset] concat_dataset: AvalancheConcatDataset = dataset._original_dataset if concat_dataset._has_own_transformations(): # The dataset has custom transforms -> do nothing return [dataset] result: List[AvalancheSubset] = [] last_c_dataset = None last_c_idxs = [] last_c_targets = [] last_c_tasks = [] for subset_idx, idx in enumerate(dataset._indices): dataset_idx, internal_idx = find_list_from_index( idx, concat_dataset._datasets_lengths, concat_dataset._overall_length, cumulative_sizes=concat_dataset._datasets_cumulative_lengths, ) if last_c_dataset is None: last_c_dataset = dataset_idx elif last_c_dataset != dataset_idx: # Consolidate current subset result.append( AvalancheConcatDataset._make_similar_subset( dataset, concat_dataset._dataset_list[last_c_dataset], last_c_idxs, last_c_targets, last_c_tasks, ) ) # Switch to next dataset last_c_dataset = dataset_idx last_c_idxs = [] last_c_targets = [] last_c_tasks = [] last_c_idxs.append(internal_idx) last_c_targets.append(dataset.targets[subset_idx]) last_c_tasks.append(dataset.targets_task_labels[subset_idx]) if last_c_dataset is not None: result.append( AvalancheConcatDataset._make_similar_subset( dataset, concat_dataset._dataset_list[last_c_dataset], last_c_idxs, last_c_targets, last_c_tasks, ) ) return result @staticmethod def _make_similar_subset(subset, ref_dataset, indices, targets, tasks): t_groups = dict() f_groups = dict() AvalancheDataset._borrow_transformations(subset, t_groups, f_groups) collate_fn = None if subset.dataset_type == AvalancheDatasetType.UNDEFINED: collate_fn = subset.collate_fn result = AvalancheSubset( ref_dataset, indices=indices, class_mapping=subset._class_mapping, transform_groups=t_groups, initial_transform_group=subset.current_transform_group, task_labels=tasks, targets=targets, dataset_type=subset.dataset_type, collate_fn=collate_fn, ) result._frozen_transforms = f_groups return result
def concat_datasets_sequentially( train_dataset_list: Sequence[ISupportedClassificationDataset], test_dataset_list: Sequence[ISupportedClassificationDataset], ) -> Tuple[AvalancheConcatDataset, AvalancheConcatDataset, List[list]]: """ Concatenates a list of datasets. This is completely different from :class:`ConcatDataset`, in which datasets are merged together without other processing. Instead, this function re-maps the datasets class IDs. For instance: let the dataset[0] contain patterns of 3 different classes, let the dataset[1] contain patterns of 2 different classes, then class IDs will be mapped as follows: dataset[0] class "0" -> new class ID is "0" dataset[0] class "1" -> new class ID is "1" dataset[0] class "2" -> new class ID is "2" dataset[1] class "0" -> new class ID is "3" dataset[1] class "1" -> new class ID is "4" ... -> ... dataset[-1] class "C-1" -> new class ID is "overall_n_classes-1" In contrast, using PyTorch ConcatDataset: dataset[0] class "0" -> ID is "0" dataset[0] class "1" -> ID is "1" dataset[0] class "2" -> ID is "2" dataset[1] class "0" -> ID is "0" dataset[1] class "1" -> ID is "1" Note: ``train_dataset_list`` and ``test_dataset_list`` must have the same number of datasets. :param train_dataset_list: A list of training datasets :param test_dataset_list: A list of test datasets :returns: A concatenated dataset. """ remapped_train_datasets = [] remapped_test_datasets = [] next_remapped_idx = 0 # Obtain the number of classes of each dataset classes_per_dataset = [ _count_unique( train_dataset_list[dataset_idx].targets, test_dataset_list[dataset_idx].targets, ) for dataset_idx in range(len(train_dataset_list)) ] new_class_ids_per_dataset = [] for dataset_idx in range(len(train_dataset_list)): # Get the train and test sets of the dataset train_set = train_dataset_list[dataset_idx] test_set = test_dataset_list[dataset_idx] # Get the classes in the dataset dataset_classes = set(map(int, train_set.targets)) # The class IDs for this dataset will be in range # [n_classes_in_previous_datasets, # n_classes_in_previous_datasets + classes_in_this_dataset) new_classes = list( range( next_remapped_idx, next_remapped_idx + classes_per_dataset[dataset_idx], ) ) new_class_ids_per_dataset.append(new_classes) # AvalancheSubset is used to apply the class IDs transformation. # Remember, the class_mapping parameter must be a list in which: # new_class_id = class_mapping[original_class_id] # Hence, a list of size equal to the maximum class index is created # Only elements corresponding to the present classes are remapped class_mapping = [-1] * (max(dataset_classes) + 1) j = 0 for i in dataset_classes: class_mapping[i] = new_classes[j] j += 1 # Create remapped datasets and append them to the final list remapped_train_datasets.append( AvalancheSubset(train_set, class_mapping=class_mapping) ) remapped_test_datasets.append( AvalancheSubset(test_set, class_mapping=class_mapping) ) next_remapped_idx += classes_per_dataset[dataset_idx] return ( AvalancheConcatDataset(remapped_train_datasets), AvalancheConcatDataset(remapped_test_datasets), new_class_ids_per_dataset, ) def as_avalanche_dataset( dataset: ISupportedClassificationDataset[T_co], dataset_type: AvalancheDatasetType = None, ) -> AvalancheDataset[T_co, TTargetType]: if isinstance(dataset, AvalancheDataset) and dataset_type is None: # There is no need to show the warning return dataset if dataset_type is None: warnings.warn( '"as_avalanche_dataset" was called without setting ' '"dataset_type": this behaviour is deprecated. Consider ' "setting this value or calling the specific functions " '"as_classification_dataset", "as_regression_dataset", ' '"as_segmentation_dataset" or "as_undefined_dataset"' ) dataset_type = AvalancheDatasetType.UNDEFINED if ( isinstance(dataset, AvalancheDataset) and dataset.dataset_type == dataset_type ): return dataset return AvalancheDataset(dataset, dataset_type=dataset_type) def as_classification_dataset( dataset: ISupportedClassificationDataset[T_co], ) -> AvalancheDataset[T_co, int]: return as_avalanche_dataset( dataset, dataset_type=AvalancheDatasetType.CLASSIFICATION ) def as_regression_dataset( dataset: ISupportedClassificationDataset[T_co], ) -> AvalancheDataset[T_co, Any]: return as_avalanche_dataset( dataset, dataset_type=AvalancheDatasetType.REGRESSION ) def as_segmentation_dataset( dataset: ISupportedClassificationDataset[T_co], ) -> AvalancheDataset[T_co, Any]: return as_avalanche_dataset( dataset, dataset_type=AvalancheDatasetType.SEGMENTATION ) def as_undefined_dataset( dataset: ISupportedClassificationDataset[T_co], ) -> AvalancheDataset[T_co, Any]: return as_avalanche_dataset( dataset, dataset_type=AvalancheDatasetType.UNDEFINED ) def train_eval_avalanche_datasets( train_dataset: ISupportedClassificationDataset, test_dataset: ISupportedClassificationDataset, train_transformation, eval_transformation, dataset_type=None, ): train = AvalancheDataset( train_dataset, transform_groups=dict( train=(train_transformation, None), eval=(eval_transformation, None) ), initial_transform_group="train", dataset_type=dataset_type, ) test = AvalancheDataset( test_dataset, transform_groups=dict( train=(train_transformation, None), eval=(eval_transformation, None) ), initial_transform_group="eval", dataset_type=dataset_type, ) return train, test def _traverse_supported_dataset( dataset, values_selector: Callable[[Dataset, List[int]], List], indices=None ) -> List: initial_error = None try: result = values_selector(dataset, indices) if result is not None: return result except BaseException as e: initial_error = e if isinstance(dataset, Subset): if indices is None: indices = range(len(dataset)) indices = [dataset.indices[x] for x in indices] return list( _traverse_supported_dataset( dataset.dataset, values_selector, indices ) ) if isinstance(dataset, ConcatDataset): result = [] if indices is None: for c_dataset in dataset.datasets: result += list( _traverse_supported_dataset( c_dataset, values_selector, indices ) ) return result datasets_to_indexes = defaultdict(list) indexes_to_dataset = [] datasets_len = [] recursion_result = [] all_size = 0 for c_dataset in dataset.datasets: len_dataset = len(c_dataset) datasets_len.append(len_dataset) all_size += len_dataset for subset_idx in indices: dataset_idx, pattern_idx = find_list_from_index( subset_idx, datasets_len, all_size ) datasets_to_indexes[dataset_idx].append(pattern_idx) indexes_to_dataset.append(dataset_idx) for dataset_idx, c_dataset in enumerate(dataset.datasets): recursion_result.append( deque( _traverse_supported_dataset( c_dataset, values_selector, datasets_to_indexes[dataset_idx], ) ) ) result = [] for idx in range(len(indices)): dataset_idx = indexes_to_dataset[idx] result.append(recursion_result[dataset_idx].popleft()) return result if initial_error is not None: raise initial_error raise ValueError("Error: can't find the needed data in the given dataset") def _count_unique(*sequences: Sequence[SupportsInt]): uniques = set() for seq in sequences: for x in seq: uniques.add(int(x)) return len(uniques) def _select_targets(dataset, indices): if hasattr(dataset, "targets"): # Standard supported dataset found_targets = dataset.targets elif hasattr(dataset, "tensors"): # Support for PyTorch TensorDataset if len(dataset.tensors) < 2: raise ValueError( "Tensor dataset has not enough tensors: " "at least 2 are required." ) found_targets = dataset.tensors[1] else: raise ValueError( "Unsupported dataset: must have a valid targets field " "or has to be a Tensor Dataset with at least 2 " "Tensors" ) if indices is not None: found_targets = SubSequence(found_targets, indices=indices) return found_targets def _select_task_labels(dataset, indices): found_task_labels = None if hasattr(dataset, "targets_task_labels"): found_task_labels = dataset.targets_task_labels if found_task_labels is None: if isinstance(dataset, (Subset, ConcatDataset)): return None # Continue traversing if found_task_labels is None: if indices is None: return ConstantSequence(0, len(dataset)) return ConstantSequence(0, len(indices)) if indices is not None: found_task_labels = SubSequence(found_task_labels, indices=indices) return found_task_labels def _make_target_from_supported_dataset( dataset: SupportedDataset, converter: Callable[[Any], TTargetType] = None ) -> Sequence[TTargetType]: if isinstance(dataset, AvalancheDataset): if converter is None: return dataset.targets elif ( isinstance(dataset.targets, (SubSequence, LazyConcatTargets)) and dataset.targets.converter == converter ): return dataset.targets elif isinstance(dataset.targets, LazyClassMapping) and converter == int: # LazyClassMapping already outputs int targets return dataset.targets targets = _traverse_supported_dataset(dataset, _select_targets) return SubSequence(targets, converter=converter) def _make_task_labels_from_supported_dataset( dataset: SupportedDataset, ) -> Sequence[int]: if isinstance(dataset, AvalancheDataset): return dataset.targets_task_labels task_labels = _traverse_supported_dataset(dataset, _select_task_labels) return SubSequence(task_labels, converter=int) __all__ = [ "SupportedDataset", "AvalancheDatasetType", "AvalancheDataset", "AvalancheSubset", "AvalancheTensorDataset", "AvalancheConcatDataset", "concat_datasets_sequentially", "as_avalanche_dataset", "as_classification_dataset", "as_regression_dataset", "as_segmentation_dataset", "as_undefined_dataset", "train_eval_avalanche_datasets", ]