Source code for avalanche.benchmarks.datasets.openloris.openloris

################################################################################
# Copyright (c) 2021 ContinualAI.                                              #
# Copyrights licensed under the MIT License.                                   #
# See the accompanying LICENSE file for terms.                                 #
#                                                                              #
# Date: 11-11-2020                                                             #
# Author: Vincenzo Lomonaco                                                    #
# E-mail: contact@continualai.org                                              #
# Website: www.continualai.org                                                 #
################################################################################

""" OpenLoris Pytorch Dataset """

import pickle as pkl
from pathlib import Path
from typing import Union

from torchvision.datasets.folder import default_loader
from torchvision.transforms import ToTensor

from avalanche.benchmarks.datasets import (
    DownloadableDataset,
    default_dataset_location,
)
from avalanche.benchmarks.datasets.openloris import openloris_data


[docs]class OpenLORIS(DownloadableDataset): """OpenLORIS Pytorch Dataset"""
[docs] def __init__( self, root: Union[str, Path] = None, *, train=True, transform=None, target_transform=None, loader=default_loader, download=True, ): """ Creates an instance of the OpenLORIS dataset. :param root: The directory where the dataset can be found or downloaded. Defaults to None, which means that the default location for 'openloris' will be used. :param train: If True, the training set will be returned. If False, the test set will be returned. :param transform: The transformations to apply to the X values. :param target_transform: The transformations to apply to the Y values. :param loader: The image loader to use. :param download: If True, the dataset will be downloaded if needed. """ if root is None: root = default_dataset_location("openloris") self.train = train # training set or test set self.transform = transform self.target_transform = target_transform self.loader = loader super(OpenLORIS, self).__init__(root, download=download, verbose=True) self._load_dataset()
def _download_dataset(self) -> None: data2download = openloris_data.avl_vps_data for name in data2download: if self.verbose: print("Downloading " + name[1] + "...") file = self._download_file(name[1], name[0], name[2]) if name[1].endswith(".zip"): if self.verbose: print(f"Extracting {name[0]}...") self._extract_archive(file) if self.verbose: print("Extraction completed!") def _load_metadata(self) -> bool: if not self._check_integrity(): return False # any scenario and factor is good here since we want just to load the # train images and targets with no particular order scen = "domain" factor = [_ for _ in range(4)] ntask = 9 print("Loading paths...") with open(str(self.root / "Paths.pkl"), "rb") as f: self.train_test_paths = pkl.load(f) print("Loading labels...") with open(str(self.root / "Labels.pkl"), "rb") as f: self.all_targets = pkl.load(f) self.train_test_targets = [] for fact in factor: for i in range(ntask + 1): self.train_test_targets += self.all_targets[scen][fact][i] print("Loading LUP...") with open(str(self.root / "LUP.pkl"), "rb") as f: self.LUP = pkl.load(f) self.idx_list = [] if self.train: for fact in factor: for i in range(ntask): self.idx_list += self.LUP[scen][fact][i] else: for fact in factor: self.idx_list += self.LUP[scen][fact][-1] self.paths = [] self.targets = [] for idx in self.idx_list: self.paths.append(self.train_test_paths[idx]) self.targets.append(self.train_test_targets[idx]) return True def _download_error_message(self) -> str: base_url = openloris_data.base_gdrive_url all_urls = [ base_url + name_url[1] for name_url in openloris_data.avl_vps_data ] base_msg = ( "[OpenLoris] Direct download may no longer be supported!\n" "You should download data manually using the following links:\n" ) for url in all_urls: base_msg += url base_msg += "\n" base_msg += "and place these files in " + str(self.root) return base_msg def _check_integrity(self): """Checks if the data is already available and intact""" for name, url, md5 in openloris_data.avl_vps_data: filepath = self.root / name if not filepath.is_file(): if self.verbose: print( "[OpenLORIS] Error checking integrity of:", str(filepath), ) return False return True def __getitem__(self, index): target = self.targets[index] img = self.loader(str(self.root / self.paths[index])) if self.transform is not None: img = self.transform(img) if self.target_transform is not None: target = self.target_transform(target) return img, target def __len__(self): return len(self.targets)
if __name__ == "__main__": # this little example script can be used to visualize the first image # loaded from the dataset. from torch.utils.data.dataloader import DataLoader import matplotlib.pyplot as plt from torchvision import transforms import torch train_data = OpenLORIS(download=True, transform=ToTensor()) test_data = OpenLORIS(train=False, transform=ToTensor()) print("train size: ", len(train_data)) print("Test size: ", len(test_data)) dataloader = DataLoader(train_data, batch_size=1) for batch_data in dataloader: x, y = batch_data plt.imshow(transforms.ToPILImage()(torch.squeeze(x))) plt.show() print(x.size()) print(len(y)) break __all__ = ["OpenLORIS"]