avalanche.benchmarks.utils.data_loader.GroupBalancedDataLoader
- class avalanche.benchmarks.utils.data_loader.GroupBalancedDataLoader(datasets: Sequence[AvalancheDataset], oversample_small_groups: bool = False, batch_size: int = 32, distributed_sampling: bool = True, **kwargs)[source]
Data loader that balances data from multiple datasets.
- __init__(datasets: Sequence[AvalancheDataset], oversample_small_groups: bool = False, batch_size: int = 32, distributed_sampling: bool = True, **kwargs)[source]
Data loader that balances data from multiple datasets.
Mini-batches emitted by this dataloader are created by collating together mini-batches from each group. It may be used to balance data among classes, experiences, tasks, and so on.
If oversample_small_groups == True smaller groups are oversampled to match the largest group. Otherwise, once data from a group is completely iterated, the group will be skipped.
- Parameters:
datasets – an instance of AvalancheDataset.
oversample_small_groups – whether smaller groups should be oversampled to match the largest one.
batch_size – the size of the batch. It must be greater than or equal to the number of groups.
distributed_sampling – If True, apply the PyTorch
DistributedSampler
. Defaults to True. Note: the distributed sampler is not applied if not running a distributed training, even when True is passed.kwargs – data loader arguments used to instantiate the loader for each group separately. See pytorch
DataLoader
.
Methods
__init__
(datasets[, ...])Data loader that balances data from multiple datasets.