avalanche.benchmarks.utils.data_loader.GroupBalancedDataLoader

class avalanche.benchmarks.utils.data_loader.GroupBalancedDataLoader(datasets: typing.Sequence[avalanche.benchmarks.utils.avalanche_dataset.AvalancheDataset], oversample_small_groups: bool = False, collate_mbatches=<function _default_collate_mbatches_fn>, batch_size: int = 32, **kwargs)[source]

Data loader that balances data from multiple datasets.

__init__(datasets: typing.Sequence[avalanche.benchmarks.utils.avalanche_dataset.AvalancheDataset], oversample_small_groups: bool = False, collate_mbatches=<function _default_collate_mbatches_fn>, batch_size: int = 32, **kwargs)[source]

Data loader that balances data from multiple datasets.

Mini-batches emitted by this dataloader are created by collating together mini-batches from each group. It may be used to balance data among classes, experiences, tasks, and so on.

If oversample_small_groups == True smaller groups are oversampled to match the largest group. Otherwise, once data from a group is completely iterated, the group will be skipped.

Parameters
  • datasets – an instance of AvalancheDataset.

  • oversample_small_groups – whether smaller groups should be oversampled to match the largest one.

  • collate_mbatches – function that given a sequence of mini-batches (one for each task) combines them into a single mini-batch. Used to combine the mini-batches obtained separately from each task.

  • batch_size – the size of the batch. It must be greater than or equal to the number of groups.

  • kwargs – data loader arguments used to instantiate the loader for each group separately. See pytorch DataLoader.

Methods

__init__(datasets[, ...])

Data loader that balances data from multiple datasets.