Source code for avalanche.benchmarks.datasets.openloris.openloris

# Copyright (c) 2021 ContinualAI.                                              #
# Copyrights licensed under the MIT License.                                   #
# See the accompanying LICENSE file for terms.                                 #
#                                                                              #
# Date: 11-11-2020                                                             #
# Author: Vincenzo Lomonaco                                                    #
# E-mail:                                              #
# Website:                                                 #

""" OpenLoris Pytorch Dataset """

import pickle as pkl
from pathlib import Path
from typing import Union

from torchvision.datasets.folder import default_loader
from torchvision.transforms import ToTensor

from avalanche.benchmarks.datasets import DownloadableDataset, \
from avalanche.benchmarks.datasets.openloris import openloris_data

[docs]class OpenLORIS(DownloadableDataset): """ OpenLORIS Pytorch Dataset """
[docs] def __init__(self, root: Union[str, Path] = None, *, train=True, transform=None, target_transform=None, loader=default_loader, download=True): """ Creates an instance of the OpenLORIS dataset. :param root: The directory where the dataset can be found or downloaded. Defaults to None, which means that the default location for 'openloris' will be used. :param train: If True, the training set will be returned. If False, the test set will be returned. :param transform: The transformations to apply to the X values. :param target_transform: The transformations to apply to the Y values. :param loader: The image loader to use. :param download: If True, the dataset will be downloaded if needed. """ if root is None: root = default_dataset_location('openloris') self.train = train # training set or test set self.transform = transform self.target_transform = target_transform self.loader = loader super(OpenLORIS, self).__init__(root, download=download, verbose=True) self._load_dataset()
def _download_dataset(self) -> None: data2download = openloris_data.avl_vps_data for name in data2download: if self.verbose: print("Downloading " + name[1] + "...") file = self._download_file(name[1], name[0], name[2]) if name[1].endswith('.zip'): if self.verbose: print(f'Extracting {name[0]}...') self._extract_archive(file) if self.verbose: print('Extraction completed!') def _load_metadata(self) -> bool: if not self._check_integrity(): return False # any scenario and factor is good here since we want just to load the # train images and targets with no particular order scen = 'domain' factor = 0 ntask = 9 print("Loading paths...") with open(str(self.root / 'Paths.pkl'), 'rb') as f: self.train_test_paths = pkl.load(f) print("Loading labels...") with open(str(self.root / 'Labels.pkl'), 'rb') as f: self.all_targets = pkl.load(f) self.train_test_targets = [] for i in range(ntask + 1): self.train_test_targets += self.all_targets[scen][factor][i] print("Loading LUP...") with open(str(self.root / 'LUP.pkl'), 'rb') as f: self.LUP = pkl.load(f) self.idx_list = [] if self.train: for i in range(ntask + 1): self.idx_list += self.LUP[scen][factor][i] else: self.idx_list = self.LUP[scen][factor][-1] self.paths = [] self.targets = [] for idx in self.idx_list: self.paths.append(self.train_test_paths[idx]) self.targets.append(self.train_test_targets[idx]) return True def _download_error_message(self) -> str: base_url = openloris_data.base_gdrive_url all_urls = [ base_url + name_url[1] for name_url in openloris_data.avl_vps_data ] base_msg = \ '[OpenLoris] Direct download may no longer be supported!\n' \ 'You should download data manually using the following links:\n' for url in all_urls: base_msg += url base_msg += '\n' base_msg += 'and place these files in ' + str(self.root) return base_msg def _check_integrity(self): """ Checks if the data is already available and intact """ for name, url, md5 in openloris_data.avl_vps_data: filepath = self.root / name if not filepath.is_file(): if self.verbose: print('[OpenLORIS] Error checking integrity of:', str(filepath)) return False return True def __getitem__(self, index): target = self.targets[index] img = self.loader(str(self.root / self.paths[index])) if self.transform is not None: img = self.transform(img) if self.target_transform is not None: target = self.target_transform(target) return img, target def __len__(self): return len(self.targets)
if __name__ == "__main__": # this little example script can be used to visualize the first image # loaded from the dataset. from import DataLoader import matplotlib.pyplot as plt from torchvision import transforms import torch train_data = OpenLORIS(download=True, transform=ToTensor()) test_data = OpenLORIS(train=False, transform=ToTensor()) print("train size: ", len(train_data)) print("Test size: ", len(test_data)) dataloader = DataLoader(train_data, batch_size=1) for batch_data in dataloader: x, y = batch_data plt.imshow( transforms.ToPILImage()(torch.squeeze(x)) ) print(x.size()) print(len(y)) break __all__ = [ 'OpenLORIS' ]