avalanche.benchmarks.utils.data_loader.TaskBalancedDataLoader
- class avalanche.benchmarks.utils.data_loader.TaskBalancedDataLoader(data: AvalancheDataset, batch_size: int = 32, oversample_small_groups: bool = False, distributed_sampling: bool = True, **kwargs)[source]
Task-balanced data loader for Avalanche’s datasets.
- __init__(data: AvalancheDataset, batch_size: int = 32, oversample_small_groups: bool = False, distributed_sampling: bool = True, **kwargs)[source]
Task-balanced data loader for Avalanche’s datasets.
The iterator returns a mini-batch balanced across each task, which makes it useful when training in multi-task scenarios whenever data is highly unbalanced.
If oversample_small_tasks == True smaller tasks are oversampled to match the largest task. Otherwise, once the data for a specific task is terminated, that task will not be present in the subsequent mini-batches.
- Parameters:
data – an instance of AvalancheDataset.
oversample_small_groups – whether smaller tasks should be oversampled to match the largest one.
distributed_sampling – If True, apply the PyTorch
DistributedSampler
. Defaults to True. Note: the distributed sampler is not applied if not running a distributed training, even when True is passed.kwargs – data loader arguments used to instantiate the loader for each task separately. See pytorch
DataLoader
.
Methods
__init__
(data[, batch_size, ...])Task-balanced data loader for Avalanche's datasets.