avalanche.benchmarks.utils.make_tensor_classification_dataset
- avalanche.benchmarks.utils.make_tensor_classification_dataset(*dataset_tensors: Sequence, transform: Optional[Callable[[Any], Any]] = None, target_transform: Optional[Callable[[int], int]] = None, transform_groups: Optional[Dict[str, Tuple[Optional[Union[XTransformDef, XComposedTransformDef]], Optional[YTransformDef]]]] = None, initial_transform_group: str = 'train', task_labels: Optional[Union[int, Sequence[int]]] = None, targets: Optional[Union[int, Sequence[int]]] = None, collate_fn: Optional[Callable[[List], Any]] = None)[source]
Creates a
AvalancheTensorDataset
instance.A Dataset that wraps existing ndarrays, Tensors, lists… to provide basic Dataset functionalities. Very similar to TensorDataset from PyTorch, this Dataset also supports transformations, slicing, advanced indexing, the targets field and all the other goodies listed in
AvalancheDataset
.- Parameters
dataset_tensors – Sequences, Tensors or ndarrays representing the content of the dataset.
transform – A function/transform that takes in a single element from the first tensor and returns a transformed version.
target_transform – A function/transform that takes a single element of the second tensor and transforms it.
transform_groups – A dictionary containing the transform groups. Transform groups are used to quickly switch between training and eval (test) transformations. This becomes useful when in need to test on the training dataset as test transformations usually don’t contain random augmentations.
AvalancheDataset
natively supports the ‘train’ and ‘eval’ groups by calling thetrain()
andeval()
methods. When using custom groups one can use thewith_transforms(group_name)
method instead. Defaults to None, which means that the current transforms will be used to handle both ‘train’ and ‘eval’ groups (just like in standardtorchvision
datasets).initial_transform_group – The name of the transform group to be used. Defaults to ‘train’.
task_labels – The task labels for each pattern. Must be a sequence of ints, one for each pattern in the dataset. Alternatively can be a single int value, in which case that value will be used as the task label for all the instances. Defaults to None, which means that a default task label 0 will be applied to all patterns.
targets – The label of each pattern. Defaults to None, which means that the targets will be retrieved from the second tensor of the dataset. Otherwise, it can be a sequence of values containing as many elements as the number of patterns.
collate_fn – The function to use when slicing to merge single patterns. In the future this function may become the function used in the data loading process, too.