Source code for avalanche.benchmarks.datasets.torchaudio_wrapper

# Copyright (c) 2022 ContinualAI.                                              #
# Copyrights licensed under the MIT License.                                   #
# See the accompanying LICENSE file for terms.                                 #
#                                                                              #
# Author(s): Andrea Cossu                                                      #
# E-mail:                                              #
# Website:                                                 #

""" This module conveniently wraps TorchAudio Datasets for using a clean and
comprehensive Avalanche API."""

    import torchaudio
except ImportError:
    raise ModuleNotFoundError(
        "TorchAudio package is required to load its dataset. "
        "You can install it as extra dependency with "
        "`pip install avalanche-lib[extra]`"
from torchaudio.datasets import SPEECHCOMMANDS
from avalanche.benchmarks.utils import make_classification_dataset
from avalanche.benchmarks.datasets import default_dataset_location
import torch

def speech_commands_collate(batch):
    tensors, targets, t_labels = [], [], []
    for waveform, label, rate, sid, uid, t_label in batch:
        tensors += [waveform]
        targets += [torch.tensor(label)]
        t_labels += [torch.tensor(t_label)]
    tensors = [item.t() for item in tensors]
    tensors = torch.nn.utils.rnn.pad_sequence(
        tensors, batch_first=True, padding_value=0.0
    if len(tensors.size()) == 2:  # no MFCC, add feature dimension
        tensors = tensors.unsqueeze(-1)
    targets = torch.stack(targets)
    t_labels = torch.stack(t_labels)
    return tensors, targets, t_labels

class SpeechCommandsData(SPEECHCOMMANDS):
    def __init__(self, root, url, download, subset, mfcc_preprocessing):
        super().__init__(root=root, download=download, subset=subset, url=url)
        self.labels_names = [
        self.mfcc_preprocessing = mfcc_preprocessing

    def __getitem__(self, item):
        wave, rate, label, speaker_id, ut_number = super().__getitem__(item)
        label = self.labels_names.index(label)
        wave = wave.squeeze(0)  # (T,)
        if self.mfcc_preprocessing is not None:
            assert rate == self.mfcc_preprocessing.sample_rate
            # (T, MFCC)
            wave = self.mfcc_preprocessing(wave).permute(1, 0)
        return wave, label, rate, speaker_id, ut_number

[docs]def SpeechCommands( root=default_dataset_location(""), url="speech_commands_v0.02", download=True, subset=None, mfcc_preprocessing=None, ): """ root: dataset root location url: version name of the dataset download: automatically download the dataset, if not present subset: one of 'training', 'validation', 'testing' mfcc_preprocessing: an optional torchaudio.transforms.MFCC instance to preprocess each audio. Warning: this may slow down the execution since preprocessing is applied on-the-fly each time a sample is retrieved from the dataset. """ dataset = SpeechCommandsData( root=root, download=download, subset=subset, url=url, mfcc_preprocessing=mfcc_preprocessing, ) labels = [datapoint[1] for datapoint in dataset] return make_classification_dataset( dataset, collate_fn=speech_commands_collate, targets=labels )
__all__ = ["SpeechCommands"]