Source code for avalanche.benchmarks.utils.data_attribute

################################################################################
# Copyright (c) 2022 ContinualAI.                                              #
# Copyrights licensed under the MIT License.                                   #
# See the accompanying LICENSE file for terms.                                 #
#                                                                              #
# Date: 19-07-2022                                                             #
# Author(s): Antonio Carta                                                     #
# E-mail: contact@continualai.org                                              #
# Website: avalanche.continualai.org                                           #
################################################################################

"""
This module contains the implementation of the DataAttribute,
a class designed to managed task and class labels. DataAttributes allow fast
concatenation and subsampling operations and are automatically managed by
AvalancheDatasets.
"""

from typing import Dict, List, Optional, Sequence, Set, TypeVar, Union, overload
import torch

from .dataset_definitions import IDataset
from .flat_data import ConstantSequence, FlatData

T_co = TypeVar("T_co", covariant=True)


[docs]class DataAttribute(IDataset[T_co], Sequence[T_co]): """Data attributes manage sample-wise information such as task or class labels. It provides access to unique values (`self.uniques`) and their indices (`self.val_to_idx`). Both fields are initialized lazily. Data attributes can be efficiently concatenated and subsampled. """
[docs] def __init__(self, data: IDataset[T_co], name: str, use_in_getitem: bool = False): """Data Attribute. :param data: a sequence of values, one for each sample. :param name: a name that uniquely identifies the attribute. It is used by `AvalancheDataset` to dynamically add it to its attributes. :param use_in_getitem: If True, `AvalancheDataset` will add the value at the end of each sample. """ assert name is not None assert data is not None self.name: str = name self.use_in_getitem: bool = use_in_getitem self._data: FlatData[T_co] = self._normalize_sequence(data) self._uniques: Optional[Set[T_co]] = None # set() self._val_to_idx: Optional[Dict[T_co, Sequence[int]]] = None # dict() self._count: Optional[Dict[T_co, int]] = None # dict()
def __iter__(self): for i in range(len(self)): yield self[i] @overload def __getitem__(self, item: int) -> T_co: ... @overload def __getitem__(self, item: slice) -> Sequence[T_co]: ... def __getitem__(self, item: Union[int, slice]) -> Union[T_co, Sequence[T_co]]: return self.data[item] def __len__(self): return len(self.data) def __repr__(self): return str(self.data[:]) def __str__(self): return str(self.data[:]) @property def data(self): return self._data @property def uniques(self): """Set of unique values in the attribute.""" if self._uniques is None: self._uniques = set() # init. uniques with fast paths for special cases if isinstance(self.data, ConstantSequence): self.uniques.add(self.data[0]) elif isinstance(self.data, DataAttribute): self.uniques.update(self.data.uniques) else: for el in self.data: self.uniques.add(el) return self._uniques @property def count(self): """Dictionary of value -> count.""" if self._count is None: self._count = {} for val in self.uniques: self._count[val] = 0 for val in self.data: self._count[val] += 1 return self._count @property def val_to_idx(self): """Dictionary mapping unique values to indices.""" if self._val_to_idx is None: # init. val-to-idx self._val_to_idx = dict() if isinstance(self.data, ConstantSequence): self._val_to_idx = {self.data[0]: range(len(self.data))} else: for i, x in enumerate(self.data): if x not in self.val_to_idx: self._val_to_idx[x] = [] self._val_to_idx[x].append(i) # type: ignore return self._val_to_idx def subset(self, indices): """Subset operation. Return a new `DataAttribute` by keeping only the elements in `indices`. :param indices: position of the elements in the new subset :return: the new `DataAttribute` """ return DataAttribute( self.data.subset(indices), self.name, use_in_getitem=self.use_in_getitem, ) def concat(self, other: "DataAttribute"): """Concatenation operation. :param other: the other `DataAttribute` :return: the new concatenated `DataAttribute` """ assert ( self.name == other.name ), "Cannot concatenate DataAttributes with different names." return DataAttribute( self.data.concat(other.data), self.name, use_in_getitem=self.use_in_getitem, ) def _normalize_sequence(self, seq: IDataset[T_co]) -> FlatData[T_co]: if isinstance(seq, torch.Tensor): # equality doesn't work for tensors seq = seq.tolist() if not isinstance(seq, FlatData) and not isinstance(seq, ConstantSequence): return FlatData([seq]) return seq
TensorDataAttribute = DataAttribute[T_co] def has_task_labels(dataset): """Check if `dataset` has task labels.""" return hasattr(dataset, "targets_task_labels") class TaskLabels(DataAttribute): """Task labels are `DataAttribute`s that are automatically appended to the mini-batch.""" def __init__(self, task_labels): super().__init__(task_labels, "targets_task_labels", use_in_getitem=True) __all__ = ["DataAttribute", "TensorDataAttribute", "TaskLabels"]