################################################################################
# Copyright (c) 2020 ContinualAI #
# Copyrights licensed under the MIT License. #
# See the accompanying LICENSE file for terms. #
# #
# Date: 19-02-2021 #
# Author: Tyler L. Hayes #
# E-mail: contact@continualai.org #
# Website: www.continualai.org #
################################################################################
""" Stream-51 Pytorch Dataset """
import os
import shutil
import json
import random
from pathlib import Path
from typing import Union
from torchvision.datasets.folder import default_loader
from zipfile import ZipFile
from torchvision.transforms import ToTensor
from avalanche.benchmarks.datasets import DownloadableDataset, \
default_dataset_location
from avalanche.benchmarks.datasets.stream51 import stream51_data
[docs]class Stream51(DownloadableDataset):
""" Stream-51 Pytorch Dataset """
[docs] def __init__(self, root: Union[str, Path] = None,
*,
train=True, transform=None,
target_transform=None, loader=default_loader, download=True):
"""
Creates an instance of the Stream-51 dataset.
:param root: The directory where the dataset can be found or downloaded.
Defaults to None, which means that the default location for
'stream51' will be used.
:param train: If True, the training set will be returned. If False,
the test set will be returned.
:param transform: The transformations to apply to the X values.
:param target_transform: The transformations to apply to the Y values.
:param loader: The image loader to use.
:param download: If True, the dataset will be downloaded if needed.
"""
if root is None:
root = default_dataset_location('stream51')
self.train = train # training set or test set
self.transform = transform
self.target_transform = target_transform
self.loader = loader
self.transform = transform
self.target_transform = target_transform
self.bbox_crop = True
self.ratio = 1.1
super(Stream51, self).__init__(root, download=download, verbose=True)
self._load_dataset()
def _download_dataset(self) -> None:
self._download_file(stream51_data.name[1], stream51_data.name[0],
stream51_data.name[2])
if self.verbose:
print('[Stream-51] Extracting dataset...')
if stream51_data.name[1].endswith('.zip'):
lfilename = self.root / stream51_data.name[0]
with ZipFile(str(lfilename), 'r') as zipf:
for member in zipf.namelist():
filename = os.path.basename(member)
# skip directories
if not filename:
continue
# copy file (taken from zipfile's extract)
source = zipf.open(member)
if 'json' in filename:
target = open(str(self.root / filename), "wb")
else:
dest_folder = os.path.join(
*(member.split(os.path.sep)[1:-1]))
dest_folder = self.root / dest_folder
dest_folder.mkdir(exist_ok=True, parents=True)
target = open(str(dest_folder / filename), "wb")
with source, target:
shutil.copyfileobj(source, target)
# lfilename.unlink()
def _load_metadata(self) -> bool:
if self.train:
data_list = json.load(
open(str(self.root / 'Stream-51_meta_train.json')))
else:
data_list = json.load(
open(str(self.root / 'Stream-51_meta_test.json')))
self.samples = data_list
self.targets = [s[0] for s in data_list]
self.bbox_crop = True
self.ratio = 1.1
return True
def _download_error_message(self) -> str:
return '[Stream-51] Error downloading the dataset. Consider ' \
'downloading it manually at: ' + stream51_data.name[1] + \
' and placing it in: ' + str(self.root)
@staticmethod
def _instance_ordering(data_list, seed):
# organize data by video
total_videos = 0
new_data_list = []
temp_video = []
for x in data_list:
if x[3] == 0:
new_data_list.append(temp_video)
total_videos += 1
temp_video = [x]
else:
temp_video.append(x)
new_data_list.append(temp_video)
new_data_list = new_data_list[1:]
# shuffle videos
random.seed(seed)
random.shuffle(new_data_list)
# reorganize by clip
data_list = []
for v in new_data_list:
for x in v:
data_list.append(x)
return data_list
@staticmethod
def _class_ordering(data_list, class_type, seed):
# organize data by class
new_data_list = []
for class_id in range(data_list[-1][0] + 1):
class_data_list = [x for x in data_list if x[0] == class_id]
if class_type == 'class_iid':
# shuffle all class data
random.seed(seed)
random.shuffle(class_data_list)
else:
# shuffle clips within class
class_data_list = Stream51._instance_ordering(
class_data_list, seed)
new_data_list.append(class_data_list)
# shuffle classes
random.seed(seed)
random.shuffle(new_data_list)
# reorganize by class
data_list = []
for v in new_data_list:
for x in v:
data_list.append(x)
return data_list
@staticmethod
def make_dataset(data_list, ordering='class_instance', seed=666):
"""
data_list
for train: [class_id, clip_num, video_num, frame_num, bbox, file_loc]
for test: [class_id, bbox, file_loc]
"""
if not ordering or len(data_list[0]) == 3: # cannot order the test set
return data_list
if ordering not in ['iid', 'class_iid', 'instance', 'class_instance']:
raise ValueError(
'dataset ordering must be one of: "iid", "class_iid", '
'"instance", or "class_instance"')
if ordering == 'iid':
# shuffle all data
random.seed(seed)
random.shuffle(data_list)
return data_list
elif ordering == 'instance':
return Stream51._instance_ordering(data_list, seed)
elif 'class' in ordering:
return Stream51._class_ordering(data_list, ordering, seed)
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (sample, target) where target is class_index of the target
class.
"""
fpath, target = self.samples[index][-1], self.targets[index]
sample = self.loader(str(self.root / fpath))
if self.bbox_crop:
bbox = self.samples[index][-2]
cw = bbox[0] - bbox[1]
ch = bbox[2] - bbox[3]
center = [int(bbox[1] + cw / 2), int(bbox[3] + ch / 2)]
bbox = [
min([int(center[0] + (cw * self.ratio / 2)), sample.size[0]]),
max([int(center[0] - (cw * self.ratio / 2)), 0]),
min([int(center[1] + (ch * self.ratio / 2)), sample.size[1]]),
max([int(center[1] - (ch * self.ratio / 2)), 0])]
sample = sample.crop((bbox[1],
bbox[3],
bbox[0],
bbox[2]))
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target
def __len__(self):
return len(self.samples)
def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(
tmp, self.transform.__repr__().replace(
'\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(
tmp, self.target_transform.__repr__().replace(
'\n', '\n' + ' ' * len(tmp)))
return fmt_str
if __name__ == "__main__":
# this little example script can be used to visualize the first image
# loaded from the dataset.
from torch.utils.data.dataloader import DataLoader
import matplotlib.pyplot as plt
from torchvision import transforms
import torch
train_data = Stream51(transform=ToTensor())
test_data = Stream51(transform=ToTensor(), train=False)
print("train size: ", len(train_data))
print("Test size: ", len(test_data))
dataloader = DataLoader(train_data, batch_size=1)
for batch_data in dataloader:
x, y = batch_data
plt.imshow(
transforms.ToPILImage()(torch.squeeze(x))
)
plt.show()
print(x.size())
print(len(y))
break
__all__ = [
'Stream51'
]