avalanche.benchmarks.utils.avalanche_dataset.AvalancheDataset

class avalanche.benchmarks.utils.avalanche_dataset.AvalancheDataset(dataset: Union[IDatasetWithTargets, ITensorDataset, Subset, ConcatDataset], *, transform: Optional[Union[XTransformDef, XComposedTransformDef]] = None, target_transform: Optional[YTransformDef] = None, transform_groups: Optional[Dict[str, Union[None, XTransformDef, XComposedTransformDef, Tuple[Optional[Union[XTransformDef, XComposedTransformDef]], Optional[YTransformDef]]]]] = None, initial_transform_group: Optional[str] = None, task_labels: Optional[Union[int, Sequence[int]]] = None, targets: Optional[Sequence[TTargetType]] = None, dataset_type: Optional[AvalancheDatasetType] = None, collate_fn: Optional[Callable[[List], Any]] = None, targets_adapter: Optional[Callable[[Any], TTargetType]] = None)[source]

The Dataset used as the base implementation for Avalanche.

Instances of this dataset are usually returned from benchmarks, but it can also be used in a completely standalone manner. This dataset can be used to apply transformations before returning patterns/targets, it supports slicing and advanced indexing and it also contains useful fields as targets, which contains the pattern labels, and targets_task_labels, which contains the pattern task labels. The task_set field can be used to obtain a the subset of patterns labeled with a given task label.

This dataset can also be used to apply several advanced operations involving transformations. For instance, it allows the user to add and replace transformations, freeze them so that they can’t be changed, etc.

This dataset also allows the user to keep distinct transformations groups. Simply put, a transformation group is a pair of transform+target_transform (exactly as in torchvision datasets). This dataset natively supports keeping two transformation groups: the first, ‘train’, contains transformations applied to training patterns. Those transformations usually involve some kind of data augmentation. The second one is ‘eval’, that will contain transformations applied to test patterns. Having both groups can be useful when, for instance, in need to test on the training data (as this process usually involves removing data augmentation operations). Switching between transformations can be easily achieved by using the train() and eval() methods.

Moreover, arbitrary transformation groups can be added and used. For more info see the constructor and the with_transforms() method.

This dataset will try to inherit the task labels from the input dataset. If none are available and none are given via the task_labels parameter, each pattern will be assigned a default task label “0”. See the constructor for more details.

__init__(dataset: Union[IDatasetWithTargets, ITensorDataset, Subset, ConcatDataset], *, transform: Optional[Union[XTransformDef, XComposedTransformDef]] = None, target_transform: Optional[YTransformDef] = None, transform_groups: Optional[Dict[str, Union[None, XTransformDef, XComposedTransformDef, Tuple[Optional[Union[XTransformDef, XComposedTransformDef]], Optional[YTransformDef]]]]] = None, initial_transform_group: Optional[str] = None, task_labels: Optional[Union[int, Sequence[int]]] = None, targets: Optional[Sequence[TTargetType]] = None, dataset_type: Optional[AvalancheDatasetType] = None, collate_fn: Optional[Callable[[List], Any]] = None, targets_adapter: Optional[Callable[[Any], TTargetType]] = None)[source]

Creates a AvalancheDataset instance.

Parameters
  • dataset – The dataset to decorate. Beware that AvalancheDataset will not overwrite transformations already applied by this dataset.

  • transform – A function/transform that takes the X value of a pattern from the original dataset and returns a transformed version.

  • target_transform – A function/transform that takes in the target and transforms it.

  • transform_groups – A dictionary containing the transform groups. Transform groups are used to quickly switch between training and eval (test) transformations. This becomes useful when in need to test on the training dataset as test transformations usually don’t contain random augmentations. AvalancheDataset natively supports the ‘train’ and ‘eval’ groups by calling the train() and eval() methods. When using custom groups one can use the with_transforms(group_name) method instead. Defaults to None, which means that the current transforms will be used to handle both ‘train’ and ‘eval’ groups (just like in standard torchvision datasets).

  • initial_transform_group – The name of the initial transform group to be used. Defaults to None, which means that the current group of the input dataset will be used (if an AvalancheDataset). If the input dataset is not an AvalancheDataset, then ‘train’ will be used.

  • task_labels – The task label of each instance. Must be a sequence of ints, one for each instance in the dataset. Alternatively can be a single int value, in which case that value will be used as the task label for all the instances. Defaults to None, which means that the dataset will try to obtain the task labels from the original dataset. If no task labels could be found, a default task label “0” will be applied to all instances.

  • targets – The label of each pattern. Defaults to None, which means that the targets will be retrieved from the dataset (if possible).

  • dataset_type – The type of the dataset. Defaults to None, which means that the type will be inferred from the input dataset. When the dataset_type is different than UNDEFINED, a proper value for collate_fn and targets_adapter will be set. If the dataset_type is different than UNDEFINED, then collate_fn and targets_adapter must not be set.

  • collate_fn – The function to use when slicing to merge single patterns. In the future this function may become the function used in the data loading process, too. If None and the dataset_type is UNDEFINED, the constructor will check if a collate_fn field exists in the dataset. If no such field exists, the default collate function will be used.

  • targets_adapter – A function used to convert the values of the targets field. Defaults to None. Note: the adapter will not change the value of the second element returned by __getitem__. The adapter is used to adapt the values of the targets field only.

Methods

__init__(dataset, *[, transform, ...])

Creates a AvalancheDataset instance.

add_transforms([transform, target_transform])

Returns a new dataset with the given transformations added to the existing ones.

add_transforms_group(group_name, transform, ...)

Returns a new dataset with a new transformations group.

add_transforms_to_group(group_name[, ...])

Returns a new dataset with the given transformations added to the existing ones for a certain group.

eval()

Returns a new dataset with the transformations of the 'eval' group loaded.

freeze_group_transforms(group_name)

Returns a new dataset where the transformations for a specific group are frozen.

freeze_transforms()

Returns a new dataset where the current transformations are frozen.

get_transforms([transforms_group])

Returns the transformations given a group.

replace_transforms(transform, target_transform)

Returns a new dataset with the existing transformations replaced with the given ones.

train()

Returns a new dataset with the transformations of the 'train' group loaded.

with_transforms(group_name)

Returns a new dataset with the transformations of a different group loaded.

Attributes

targets

A sequence of values describing the label of each pattern contained in the dataset.

dataset_type

The type of this dataset (UNDEFINED, CLASSIFICATION, ...).

targets_task_labels

A sequence of ints describing the task label of each pattern contained in the dataset.

tasks_pattern_indices

A dictionary mapping task labels to the indices of the patterns with that task label.

collate_fn

The collate function to use when creating mini-batches from this dataset.

task_set

A dictionary that can be used to obtain the subset of patterns given a specific task label.

current_transform_group

The name of the transform group currently in use.

transform_groups

A dictionary containing the transform groups.

transform

A function/transform that takes in an PIL image and returns a transformed version.

target_transform

A function/transform that takes in the target and transforms it.