avalanche.benchmarks.utils.avalanche_dataset.AvalancheConcatDataset
- class avalanche.benchmarks.utils.avalanche_dataset.AvalancheConcatDataset(datasets: Collection[Union[IDatasetWithTargets, ITensorDataset, Subset, ConcatDataset]], *, transform: Optional[Callable[[Any], Any]] = None, target_transform: Optional[Callable[[int], int]] = None, transform_groups: Optional[Dict[str, Tuple[Optional[Union[XTransformDef, XComposedTransformDef]], Optional[YTransformDef]]]] = None, initial_transform_group: Optional[str] = None, task_labels: Optional[Union[int, Sequence[int], Sequence[Sequence[int]]]] = None, targets: Optional[Union[Sequence[TTargetType], Sequence[Sequence[TTargetType]]]] = None, dataset_type: Optional[AvalancheDatasetType] = None, collate_fn: Optional[Callable[[List], Any]] = None, targets_adapter: Optional[Callable[[Any], TTargetType]] = None)[source]
A Dataset that behaves like a PyTorch
torch.utils.data.ConcatDataset
. However, this Dataset also supports transformations, slicing, advanced indexing and the targets field and all the other goodies listed inAvalancheDataset
.This dataset guarantees that the operations involving the transformations and transformations groups are consistent across the concatenated dataset (if they are subclasses of
AvalancheDataset
).- __init__(datasets: Collection[Union[IDatasetWithTargets, ITensorDataset, Subset, ConcatDataset]], *, transform: Optional[Callable[[Any], Any]] = None, target_transform: Optional[Callable[[int], int]] = None, transform_groups: Optional[Dict[str, Tuple[Optional[Union[XTransformDef, XComposedTransformDef]], Optional[YTransformDef]]]] = None, initial_transform_group: Optional[str] = None, task_labels: Optional[Union[int, Sequence[int], Sequence[Sequence[int]]]] = None, targets: Optional[Union[Sequence[TTargetType], Sequence[Sequence[TTargetType]]]] = None, dataset_type: Optional[AvalancheDatasetType] = None, collate_fn: Optional[Callable[[List], Any]] = None, targets_adapter: Optional[Callable[[Any], TTargetType]] = None)[source]
Creates a
AvalancheConcatDataset
instance.- Parameters
datasets – A collection of datasets.
transform – A function/transform that takes the X value of a pattern from the original dataset and returns a transformed version.
target_transform – A function/transform that takes in the target and transforms it.
transform_groups – A dictionary containing the transform groups. Transform groups are used to quickly switch between training and eval (test) transformations. This becomes useful when in need to test on the training dataset as test transformations usually don’t contain random augmentations.
AvalancheDataset
natively supports the ‘train’ and ‘eval’ groups by calling thetrain()
andeval()
methods. When using custom groups one can use thewith_transforms(group_name)
method instead. Defaults to None, which means that the current transforms will be used to handle both ‘train’ and ‘eval’ groups (just like in standardtorchvision
datasets).initial_transform_group – The name of the initial transform group to be used. Defaults to None, which means that if all AvalancheDatasets in the input datasets list agree on a common group (the “current group” is the same for all datasets), then that group will be used as the initial one. If the list of input datasets does not contain an AvalancheDataset or if the AvalancheDatasets do not agree on a common group, then ‘train’ will be used.
targets – The label of each pattern. Can either be a sequence of labels or, alternatively, a sequence containing sequences of labels (one for each dataset to be concatenated). Defaults to None, which means that the targets will be retrieved from the datasets (if possible).
task_labels – The task labels for each pattern. Must be a sequence of ints, one for each pattern in the dataset. Alternatively, task labels can be expressed as a sequence containing sequences of ints (one for each dataset to be concatenated) or even a single int, in which case that value will be used as the task label for all instances. Defaults to None, which means that the dataset will try to obtain the task labels from the original datasets. If no task labels could be found for a dataset, a default task label “0” will be applied to all patterns of that dataset.
dataset_type – The type of the dataset. Defaults to None, which means that the type will be inferred from the list of input datasets. When dataset_type is None and the list of datasets contains incompatible types, an error will be raised. A list of datasets is compatible if they all have the same type. Datasets that are not instances of AvalancheDataset and instances of AvalancheDataset with type UNDEFINED are always compatible with other types. When the dataset_type is different than UNDEFINED, a proper value for collate_fn and targets_adapter will be set. If the dataset_type is different than UNDEFINED, then collate_fn and targets_adapter must not be set.
collate_fn – The function to use when slicing to merge single patterns. In the future this function may become the function used in the data loading process, too. If None, the constructor will check if a collate_fn field exists in the first dataset. If no such field exists, the default collate function will be used. Beware that the chosen collate function will be applied to all the concatenated datasets even if a different collate is defined in different datasets.
targets_adapter – A function used to convert the values of the targets field. Defaults to None. Note: the adapter will not change the value of the second element returned by __getitem__. The adapter is used to adapt the values of the targets field only.
Methods
__init__
(datasets, *[, transform, ...])Creates a
AvalancheConcatDataset
instance.add_transforms
([transform, target_transform])Returns a new dataset with the given transformations added to the existing ones.
add_transforms_group
(group_name, transform, ...)Returns a new dataset with a new transformations group.
add_transforms_to_group
(group_name[, ...])Returns a new dataset with the given transformations added to the existing ones for a certain group.
eval
()Returns a new dataset with the transformations of the 'eval' group loaded.
freeze_group_transforms
(group_name)Returns a new dataset where the transformations for a specific group are frozen.
freeze_transforms
()Returns a new dataset where the current transformations are frozen.
get_transforms
([transforms_group])Returns the transformations given a group.
replace_transforms
(transform, target_transform)Returns a new dataset with the existing transformations replaced with the given ones.
train
()Returns a new dataset with the transformations of the 'train' group loaded.
with_transforms
(group_name)Returns a new dataset with the transformations of a different group loaded.
Attributes