Source code for avalanche.benchmarks.scenarios.validation_scenario

from typing import (
    Callable,
    Union,
    Tuple,
    Optional,
)

import random
from avalanche.benchmarks.utils.data import AvalancheDataset
from .generic_scenario import EagerCLStream, CLScenario, make_stream
from .dataset_scenario import (
    LazyTrainValSplitter,
    DatasetExperience,
    split_validation_random,
)
from .supervised import with_classes_timeline


[docs]def benchmark_with_validation_stream( benchmark: CLScenario, validation_size: Union[int, float] = 0.5, shuffle: bool = False, seed: Optional[int] = None, split_strategy: Optional[ Callable[[AvalancheDataset], Tuple[AvalancheDataset, AvalancheDataset]] ] = None, ) -> CLScenario: """Helper to obtain a benchmark with a validation stream. This generator accepts an existing benchmark instance and returns a version of it in which the train stream has been split into training and validation streams. Each train/validation experience will be by splitting the original training experiences. Patterns selected for the validation experience will be removed from the training experiences. The default splitting strategy is a random split as implemented by `split_validation_random`. If you want to use class balancing you can use `split_validation_class_balanced`, or use a custom `split_strategy`, as shown in the following example:: validation_size = 0.2 foo = lambda exp: split_dataset_class_balanced(validation_size, exp) bm = benchmark_with_validation_stream(bm, custom_split_strategy=foo) :param benchmark: The benchmark to split. :param validation_size: The size of the validation experience, as an int or a float between 0 and 1. Ignored if `custom_split_strategy` is used. :param shuffle: If True, patterns will be allocated to the validation stream randomly. This will use the default PyTorch random number generator at its current state. Defaults to False. Ignored if `custom_split_strategy` is used. If False, the first instances will be allocated to the training dataset by leaving the last ones to the validation dataset. :param split_strategy: A function that implements a custom splitting strategy. The function must accept an AvalancheDataset and return a tuple containing the new train and validation dataset. By default, the splitting strategy will split the data according to `validation_size` and `shuffle`). A good starting to understand the mechanism is to look at the implementation of the standard splitting function :func:`random_validation_split_strategy`. :return: A benchmark instance in which the validation stream has been added. """ if split_strategy is None: if seed is None: seed = random.randint(0, 1000000) # functools.partial is a more compact option # However, MyPy does not understand what a partial is -_- def random_validation_split_strategy_wrapper(data): return split_validation_random(validation_size, shuffle, seed, data) split_strategy = random_validation_split_strategy_wrapper else: split_strategy = split_strategy stream = benchmark.streams["train"] if isinstance(stream, EagerCLStream): # eager split train_exps, valid_exps = [], [] exp: DatasetExperience for exp in stream: train_data, valid_data = split_strategy(exp.dataset) train_exps.append(DatasetExperience(dataset=train_data)) valid_exps.append(DatasetExperience(dataset=valid_data)) else: # Lazy splitting (based on a generator) split_generator = LazyTrainValSplitter(split_strategy, stream) train_exps = (DatasetExperience(dataset=a) for a, _ in split_generator) valid_exps = (DatasetExperience(dataset=b) for _, b in split_generator) train_stream = make_stream(name="train", exps=train_exps) valid_stream = make_stream(name="valid", exps=valid_exps) other_streams = benchmark.streams # don't drop classes-timeline for compatibility with old API e0 = next(iter(train_stream)) if hasattr(e0, "dataset") and hasattr(e0.dataset, "targets"): train_stream = with_classes_timeline(train_stream) valid_stream = with_classes_timeline(valid_stream) del other_streams["train"] return CLScenario( streams=[train_stream, valid_stream] + list(other_streams.values()) )
__all__ = ["benchmark_with_validation_stream"]