avalanche.benchmarks.utils.classification_subset

avalanche.benchmarks.utils.classification_subset(dataset: IDatasetWithTargets | ITensorDataset | Subset | ConcatDataset | ClassificationDataset, indices: Sequence[int] | None = None, *, class_mapping: Sequence[int] | None = None, transform: Callable[[Any], Any] | None = None, target_transform: Callable[[int], int] | None = None, transform_groups: Dict[str, Tuple[XTransformDef | XComposedTransformDef | None, YTransformDef | None]] | None = None, initial_transform_group: str | None = None, task_labels: int | Sequence[int] | None = None, targets: Sequence[int] | None = None, collate_fn: Callable[[List], Any] | None = None)[source]

Creates an AvalancheSubset instance.

For simple subset operations you should use the method dataset.subset(indices). Use this constructor only if you need to redefine transformation or class/task labels.

A Dataset that behaves like a PyTorch torch.utils.data.Subset. This Dataset also supports transformations, slicing, advanced indexing, the targets field, class mapping and all the other goodies listed in AvalancheDataset.

Parameters:
  • dataset – The whole dataset.

  • indices – Indices in the whole set selected for subset. Can be None, which means that the whole dataset will be returned.

  • class_mapping – A list that, for each possible target (Y) value, contains its corresponding remapped value. Can be None. Beware that setting this parameter will force the final dataset type to be CLASSIFICATION or UNDEFINED.

  • transform – A function/transform that takes the X value of a pattern from the original dataset and returns a transformed version.

  • target_transform – A function/transform that takes in the target and transforms it.

  • transform_groups – A dictionary containing the transform groups. Transform groups are used to quickly switch between training and eval (test) transformations. This becomes useful when in need to test on the training dataset as test transformations usually don’t contain random augmentations. AvalancheDataset natively supports the ‘train’ and ‘eval’ groups by calling the train() and eval() methods. When using custom groups one can use the with_transforms(group_name) method instead. Defaults to None, which means that the current transforms will be used to handle both ‘train’ and ‘eval’ groups (just like in standard torchvision datasets).

  • initial_transform_group – The name of the initial transform group to be used. Defaults to None, which means that the current group of the input dataset will be used (if an AvalancheDataset). If the input dataset is not an AvalancheDataset, then ‘train’ will be used.

  • task_labels – The task label for each instance. Must be a sequence of ints, one for each instance in the dataset. This can either be a list of task labels for the original dataset or the list of task labels for the instances of the subset (an automatic detection will be made). In the unfortunate case in which the original dataset and the subset contain the same amount of instances, then this parameter is considered to contain the task labels of the subset. Alternatively can be a single int value, in which case that value will be used as the task label for all the instances. Defaults to None, which means that the dataset will try to obtain the task labels from the original dataset. If no task labels could be found, a default task label 0 will be applied to all instances.

  • targets – The label of each pattern. Defaults to None, which means that the targets will be retrieved from the dataset (if possible). This can either be a list of target labels for the original dataset or the list of target labels for the instances of the subset (an automatic detection will be made). In the unfortunate case in which the original dataset and the subset contain the same amount of instances, then this parameter is considered to contain the target labels of the subset.

  • collate_fn – The function to use when slicing to merge single patterns. This function is the function used in the data loading process, too. If None, the constructor will check if a collate_fn field exists in the dataset. If no such field exists, the default collate function will be used.