Source code for avalanche.benchmarks.utils.data_attribute

################################################################################
# Copyright (c) 2022 ContinualAI.                                              #
# Copyrights licensed under the MIT License.                                   #
# See the accompanying LICENSE file for terms.                                 #
#                                                                              #
# Date: 19-07-2022                                                             #
# Author(s): Antonio Carta                                                     #
# E-mail: contact@continualai.org                                              #
# Website: avalanche.continualai.org                                           #
################################################################################

"""
This module contains the implementation of the DataAttribute,
a class designed to managed task and class labels. DataAttributes allow fast
concatenation and subsampling operations and are automatically managed by
AvalancheDatasets.
"""

import torch

from .dataset_definitions import IDataset
from .flat_data import ConstantSequence, FlatData


[docs]class DataAttribute: """Data attributes manage sample-wise information such as task or class labels. It provides access to unique values (`self.uniques`) and their indices (`self.val_to_idx`). Both fields are initialized lazily. Data attributes can be efficiently concatenated and subsampled. """
[docs] def __init__(self, data: IDataset, name: str = None, use_in_getitem=False): """Data Attribute. :param data: a sequence of values, one for each sample. :param name: a name that uniquely identifies the attribute. It is used by `AvalancheDataset` to dynamically add it to its attributes. :param use_in_getitem: If True, `AvalancheDataset` will add the value at the end of each sample. """ self.name = name self.use_in_getitem = use_in_getitem self._data = self._normalize_sequence(data) self._uniques = None # set() self._val_to_idx = None # dict() self._count = None # dict()
def __getitem__(self, item): return self.data[item] def __len__(self): return len(self.data) def __repr__(self): return str(self.data[:]) def __str__(self): return str(self.data[:]) @property def data(self): return self._data @property def uniques(self): """Set of unique values in the attribute.""" if self._uniques is None: self._uniques = set() # init. uniques with fast paths for special cases if isinstance(self.data, ConstantSequence): self.uniques.add(self.data[0]) elif isinstance(self.data, DataAttribute): self.uniques.update(self.data.uniques) else: for el in self.data: self.uniques.add(el) return self._uniques @property def count(self): """Dictionary of value -> count.""" if self._count is None: self._count = {} for val in self.uniques: self._count[val] = 0 for val in self.data: self._count[val] += 1 return self._count @property def val_to_idx(self): """Dictionary mapping unique values to indices.""" if self._val_to_idx is None: # init. val-to-idx self._val_to_idx = dict() if isinstance(self.data, ConstantSequence): self._val_to_idx = {self.data[0]: range(len(self.data))} else: for i, x in enumerate(self.data): if x not in self.val_to_idx: self._val_to_idx[x] = [] self._val_to_idx[x].append(i) return self._val_to_idx def subset(self, indices): """Subset operation. Return a new `DataAttribute` by keeping only the elements in `indices`. :param indices: position of the elements in the new subset :return: the new `DataAttribute` """ return DataAttribute( self.data.subset(indices), self.name, use_in_getitem=self.use_in_getitem, ) def concat(self, other: "DataAttribute"): """Concatenation operation. :param other: the other `DataAttribute` :return: the new concatenated `DataAttribute` """ assert self.name == other.name, ( "Cannot concatenate DataAttributes" + "with different names." ) return DataAttribute( self.data.concat(other.data), self.name, use_in_getitem=self.use_in_getitem, ) @staticmethod def _normalize_sequence(seq): if isinstance(seq, torch.Tensor): # equality doesn't work for tensors seq = seq.tolist() if not isinstance(seq, FlatData): return FlatData([seq]) return seq
class TaskLabels(DataAttribute): """Task labels are `DataAttribute`s that are automatically appended to the mini-batch.""" def __init__(self, task_labels): super().__init__(task_labels, "task_labels", use_in_getitem=True) __all__ = ["DataAttribute", "TaskLabels"]