Source code for avalanche.benchmarks.scenarios.classification_scenario

from typing import (
    Callable,
    TypeVar,
    Union,
    Sequence,
    Optional,
    Iterable,
    List,
    Set,
    Tuple,
    Mapping,
    overload,
)

import warnings

from avalanche.benchmarks.scenarios.generic_scenario import (
    AbstractClassTimelineExperience,
)

from avalanche.benchmarks.scenarios.dataset_scenario import (
    DatasetScenario,
    ClassesTimelineCLScenario,
    FactoryBasedStream,
    TStreamsUserDict
)

from avalanche.benchmarks.utils import (
    AvalancheDataset
)
from avalanche.benchmarks.utils.classification_dataset import (
    ClassificationDataset,
)
from avalanche.benchmarks.utils.dataset_utils import \
    manage_advanced_indexing


# --- Dataset ---
# From utils:
TClassificationDataset = TypeVar(
    'TClassificationDataset',
    bound='ClassificationDataset')

# --- Scenario ---
# From dataset_scenario:
TDatasetScenario = TypeVar(
    'TDatasetScenario',
    bound='DatasetScenario'
)
TClassificationScenario = TypeVar(
    'TClassificationScenario',
    bound='ClassificationScenario')

# --- Stream ---
# Defined here:
TClassificationStream = TypeVar(
    'TClassificationStream',
    bound='ClassificationStream'
)

# --- Experience ---
TClassificationExperience = TypeVar(
    'TClassificationExperience',
    bound='ClassificationExperience')


def _default_classification_stream_factory(
        stream_name: str,
        benchmark: 'ClassificationScenario'):
    return ClassificationStream(
        name=stream_name,
        benchmark=benchmark
    )


def _default_classification_experience_factory(
        stream: 'ClassificationStream',
        experience_idx: int):
    return ClassificationExperience(
        origin_stream=stream,
        current_experience=experience_idx
    )


class ClassificationScenario(
    ClassesTimelineCLScenario[
        TClassificationStream,
        TClassificationExperience,
        TClassificationDataset]):
    """
    Base implementation of a Continual Learning classification benchmark.

    For more info, please refer to the base class :class:`DatasetScenario`.
    """
    
    def __init__(
        self: TClassificationScenario,
        *,
        stream_definitions: TStreamsUserDict,
        stream_factory: Callable[
            [str, TClassificationScenario],
            TClassificationStream
            ] = _default_classification_stream_factory,
        experience_factory: Callable[
            [TClassificationStream, int], 
            TClassificationExperience
            ] = _default_classification_experience_factory,
        complete_test_set_only: bool = False
    ):
        """
        Creates an instance a Continual Learning object classification
        benchmark.

        :param stream_definitions: The definition of the streams. For a more
            precise description, please refer to :class:`DatasetScenario`
        :param n_classes: The number of classes in the scenario. Defaults to
            None.
        :param stream_factory: A callable that, given the name of the
            stream and the benchmark instance, returns a stream instance.
            Defaults to the constructor of :class:`ClassificationStream`.
        :param experience_factory: A callable that, given the
            stream instance and the experience ID, returns an experience
            instance.
            Defaults to the constructor of :class:`ClassificationExperience`.
        :param complete_test_set_only: If True, the test stream will contain
            a single experience containing the complete test set. This also
            means that the definition for the test stream must contain the
            definition for a single experience.
        """

        super().__init__(
            stream_definitions=stream_definitions,
            stream_factory=stream_factory,
            experience_factory=experience_factory,
            complete_test_set_only=complete_test_set_only)

    @property
    def classes_in_experience(self):
        return _LazyStreamClassesInClassificationExps(self)


GenericCLScenario = ClassificationScenario


[docs]class ClassificationStream( FactoryBasedStream[ TClassificationExperience ] ):
[docs] def __init__( self, name: str, benchmark: ClassificationScenario, *, slice_ids: Optional[List[int]] = None, set_stream_info: bool = True ): self.benchmark: ClassificationScenario = benchmark super().__init__( name=name, benchmark=benchmark, slice_ids=slice_ids, set_stream_info=set_stream_info)
[docs]class ClassificationExperience( AbstractClassTimelineExperience[ TClassificationDataset ] ): """ Definition of a learning experience based on a :class:`GenericCLScenario` instance. This experience implementation uses the generic experience-patterns assignment defined in the :class:`GenericCLScenario` instance. Instances of this class are usually obtained from a benchmark stream. """
[docs] def __init__( self: TClassificationExperience, origin_stream: ClassificationStream[ TClassificationExperience ], current_experience: int ): """ Creates an instance of a generic experience given the stream from this experience was taken and the current experience ID. :param origin_stream: The stream from which this experience was obtained. :param current_experience: The current experience ID, as an integer. """ self._benchmark: ClassificationScenario = origin_stream.benchmark dataset: TClassificationDataset = ( origin_stream.benchmark.stream_definitions[ origin_stream.name ].exps_data[current_experience] ) ( classes_in_this_exp, previous_classes, classes_seen_so_far, future_classes, ) = origin_stream.benchmark.get_classes_timeline( current_experience, stream=origin_stream.name ) super().__init__( origin_stream, dataset, current_experience, classes_in_this_exp, previous_classes, classes_seen_so_far, future_classes, )
@property # type: ignore[override] def benchmark(self) -> ClassificationScenario: bench = self._benchmark ClassificationExperience._check_unset_attribute( 'benchmark', bench ) return bench @benchmark.setter def benchmark(self, bench: ClassificationScenario): self._benchmark = bench def _get_stream_def(self): return self.benchmark.stream_definitions[self.origin_stream.name] @property def task_labels(self) -> List[int]: with self.no_attribute_masking(): # Needed for "current_experience" stream_def = self._get_stream_def() return list(stream_def.exps_task_labels[self.current_experience])
GenericClassificationExperience = ClassificationExperience class _LazyStreamClassesInClassificationExps( Mapping[str, Sequence[Set[int]]]): def __init__(self, benchmark: GenericCLScenario): self._benchmark = benchmark self._default_lcie = _LazyClassesInClassificationExps( benchmark, stream="train") def __len__(self): return len(self._benchmark.stream_definitions) def __getitem__(self, stream_name_or_exp_id): if isinstance(stream_name_or_exp_id, str): return _LazyClassesInClassificationExps( self._benchmark, stream=stream_name_or_exp_id ) warnings.warn( "Using classes_in_experience[exp_id] is deprecated. " "Consider using classes_in_experience[stream_name][exp_id]" "instead.", stacklevel=2, ) return self._default_lcie[stream_name_or_exp_id] def __iter__(self): yield from self._benchmark.stream_definitions.keys() LazyClassesInExpsRet = Union[Tuple[Optional[Set[int]], ...], Optional[Set[int]]] class _LazyClassesInClassificationExps(Sequence[Optional[Set[int]]]): def __init__(self, benchmark: GenericCLScenario, stream: str = "train"): self._benchmark = benchmark self._stream = stream def __len__(self) -> int: return len(self._benchmark.streams[self._stream]) @overload def __getitem__(self, exp_id: int) -> Optional[Set[int]]: ... @overload def __getitem__(self, exp_id: slice) -> Tuple[Optional[Set[int]], ...]: ... def __getitem__(self, exp_id: Union[int, slice]) -> LazyClassesInExpsRet: indexing_collate = _LazyClassesInClassificationExps._slice_collate result = manage_advanced_indexing( exp_id, self._get_single_exp_classes, len(self), indexing_collate ) return result def __str__(self): return ( "[" + ", ".join([str(self[idx]) for idx in range(len(self))]) + "]" ) def _get_single_exp_classes(self, exp_id) -> Optional[Set[int]]: b = self._benchmark.stream_definitions[self._stream] if not b.is_lazy and exp_id not in b.exps_data.targets_field_sequence: raise IndexError targets = b.exps_data.targets_field_sequence[exp_id] if targets is None: return None return set(targets) @staticmethod def _slice_collate(classes_in_exps: Iterable[Optional[Iterable[int]]]) -> \ Optional[Tuple[Set[int], ...]]: result: List[Set[int]] = [] for x in classes_in_exps: if x is None: return None result.append(set(x)) return tuple(result) __all__ = [ "ClassificationScenario", "GenericCLScenario", "ClassificationStream", "ClassificationExperience", "GenericClassificationExperience", ]