avalanche.benchmarks.utils.data_loader.TaskBalancedDataLoader

class avalanche.benchmarks.utils.data_loader.TaskBalancedDataLoader(data: avalanche.benchmarks.utils.avalanche_dataset.AvalancheDataset, oversample_small_tasks: bool = False, collate_mbatches=<function _default_collate_mbatches_fn>, **kwargs)[source]

Task-balanced data loader for Avalanche’s datasets.

__init__(data: avalanche.benchmarks.utils.avalanche_dataset.AvalancheDataset, oversample_small_tasks: bool = False, collate_mbatches=<function _default_collate_mbatches_fn>, **kwargs)[source]

Task-balanced data loader for Avalanche’s datasets.

The iterator returns a mini-batch balanced across each task, which makes it useful when training in multi-task scenarios whenever data is highly unbalanced.

If oversample_small_tasks == True smaller tasks are oversampled to match the largest task. Otherwise, once the data for a specific task is terminated, that task will not be present in the subsequent mini-batches.

Parameters
  • data – an instance of AvalancheDataset.

  • oversample_small_tasks – whether smaller tasks should be oversampled to match the largest one.

  • collate_mbatches – function that given a sequence of mini-batches (one for each task) combines them into a single mini-batch. Used to combine the mini-batches obtained separately from each task.

  • kwargs – data loader arguments used to instantiate the loader for each task separately. See pytorch DataLoader.

Methods

__init__(data[, oversample_small_tasks, ...])

Task-balanced data loader for Avalanche's datasets.