################################################################################
# Copyright (c) 2021 ContinualAI. #
# Copyrights licensed under the MIT License. #
# See the accompanying LICENSE file for terms. #
# #
# Date: 11-11-2020 #
# Author: Vincenzo Lomonaco #
# E-mail: contact@continualai.org #
# Website: www.continualai.org #
################################################################################
""" OpenLoris Pytorch Dataset """
import pickle as pkl
from pathlib import Path
from typing import Union
from torchvision.datasets.folder import default_loader
from torchvision.transforms import ToTensor
from avalanche.benchmarks.datasets import DownloadableDataset, \
default_dataset_location
from avalanche.benchmarks.datasets.openloris import openloris_data
[docs]class OpenLORIS(DownloadableDataset):
""" OpenLORIS Pytorch Dataset """
[docs] def __init__(self, root: Union[str, Path] = None,
*,
train=True, transform=None, target_transform=None,
loader=default_loader, download=True):
"""
Creates an instance of the OpenLORIS dataset.
:param root: The directory where the dataset can be found or downloaded.
Defaults to None, which means that the default location for
'openloris' will be used.
:param train: If True, the training set will be returned. If False,
the test set will be returned.
:param transform: The transformations to apply to the X values.
:param target_transform: The transformations to apply to the Y values.
:param loader: The image loader to use.
:param download: If True, the dataset will be downloaded if needed.
"""
if root is None:
root = default_dataset_location('openloris')
self.train = train # training set or test set
self.transform = transform
self.target_transform = target_transform
self.loader = loader
super(OpenLORIS, self).__init__(root, download=download, verbose=True)
self._load_dataset()
def _download_dataset(self) -> None:
data2download = openloris_data.avl_vps_data
for name in data2download:
if self.verbose:
print("Downloading " + name[1] + "...")
file = self._download_file(name[1], name[0], name[2])
if name[1].endswith('.zip'):
if self.verbose:
print(f'Extracting {name[0]}...')
self._extract_archive(file)
if self.verbose:
print('Extraction completed!')
def _load_metadata(self) -> bool:
if not self._check_integrity():
return False
# any scenario and factor is good here since we want just to load the
# train images and targets with no particular order
scen = 'domain'
factor = 0
ntask = 9
print("Loading paths...")
with open(str(self.root / 'Paths.pkl'), 'rb') as f:
self.train_test_paths = pkl.load(f)
print("Loading labels...")
with open(str(self.root / 'Labels.pkl'), 'rb') as f:
self.all_targets = pkl.load(f)
self.train_test_targets = []
for i in range(ntask + 1):
self.train_test_targets += self.all_targets[scen][factor][i]
print("Loading LUP...")
with open(str(self.root / 'LUP.pkl'), 'rb') as f:
self.LUP = pkl.load(f)
self.idx_list = []
if self.train:
for i in range(ntask + 1):
self.idx_list += self.LUP[scen][factor][i]
else:
self.idx_list = self.LUP[scen][factor][-1]
self.paths = []
self.targets = []
for idx in self.idx_list:
self.paths.append(self.train_test_paths[idx])
self.targets.append(self.train_test_targets[idx])
return True
def _download_error_message(self) -> str:
base_url = openloris_data.base_gdrive_url
all_urls = [
base_url + name_url[1] for name_url in openloris_data.avl_vps_data
]
base_msg = \
'[OpenLoris] Direct download may no longer be supported!\n' \
'You should download data manually using the following links:\n'
for url in all_urls:
base_msg += url
base_msg += '\n'
base_msg += 'and place these files in ' + str(self.root)
return base_msg
def _check_integrity(self):
""" Checks if the data is already available and intact """
for name, url, md5 in openloris_data.avl_vps_data:
filepath = self.root / name
if not filepath.is_file():
if self.verbose:
print('[OpenLORIS] Error checking integrity of:',
str(filepath))
return False
return True
def __getitem__(self, index):
target = self.targets[index]
img = self.loader(str(self.root / self.paths[index]))
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.targets)
if __name__ == "__main__":
# this little example script can be used to visualize the first image
# loaded from the dataset.
from torch.utils.data.dataloader import DataLoader
import matplotlib.pyplot as plt
from torchvision import transforms
import torch
train_data = OpenLORIS(download=True, transform=ToTensor())
test_data = OpenLORIS(train=False, transform=ToTensor())
print("train size: ", len(train_data))
print("Test size: ", len(test_data))
dataloader = DataLoader(train_data, batch_size=1)
for batch_data in dataloader:
x, y = batch_data
plt.imshow(
transforms.ToPILImage()(torch.squeeze(x))
)
plt.show()
print(x.size())
print(len(y))
break
__all__ = [
'OpenLORIS'
]