Source code for avalanche.benchmarks.datasets.stream51.stream51

################################################################################
# Copyright (c) 2020 ContinualAI                                               #
# Copyrights licensed under the MIT License.                                   #
# See the accompanying LICENSE file for terms.                                 #
#                                                                              #
# Date: 19-02-2021                                                             #
# Author: Tyler L. Hayes                                                       #
# E-mail: contact@continualai.org                                              #
# Website: www.continualai.org                                                 #
################################################################################

""" Stream-51 Pytorch Dataset """

import os

import shutil
import json
import random
import dill
from pathlib import Path
from typing import Any, List, Optional, Sequence, Tuple, TypeVar, Union

from torchvision.datasets.folder import default_loader
from zipfile import ZipFile

from torchvision.transforms import ToTensor

from avalanche.benchmarks.datasets import (
    DownloadableDataset,
    default_dataset_location,
)
from avalanche.benchmarks.datasets.stream51 import stream51_data
from avalanche.checkpointing import constructor_based_serialization


TSequence = TypeVar("TSequence", bound=Sequence)


[docs]class Stream51(DownloadableDataset): """Stream-51 Pytorch Dataset"""
[docs] def __init__( self, root: Optional[Union[str, Path]] = None, *, train=True, transform=None, target_transform=None, loader=default_loader, download=True ): """ Creates an instance of the Stream-51 dataset. :param root: The directory where the dataset can be found or downloaded. Defaults to None, which means that the default location for 'stream51' will be used. :param train: If True, the training set will be returned. If False, the test set will be returned. :param transform: The transformations to apply to the X values. :param target_transform: The transformations to apply to the Y values. :param loader: The image loader to use. :param download: If True, the dataset will be downloaded if needed. """ if root is None: root = default_dataset_location("stream51") self.train = train # training set or test set self.transform = transform self.target_transform = target_transform self.loader = loader self.transform = transform self.target_transform = target_transform self.bbox_crop = True self.ratio = 1.1 self.samples: Sequence[Tuple[int, Any, str]] = [] super(Stream51, self).__init__(root, download=download, verbose=True) self._load_dataset()
def _download_dataset(self) -> None: self._download_file( stream51_data.name[1], stream51_data.name[0], stream51_data.name[2] ) if self.verbose: print("[Stream-51] Extracting dataset...") if stream51_data.name[1].endswith(".zip"): lfilename = self.root / stream51_data.name[0] with ZipFile(str(lfilename), "r") as zipf: for member in zipf.namelist(): filename = os.path.basename(member) # skip directories if not filename: continue # copy file (taken from zipfile's extract) source = zipf.open(member) if "json" in filename: target = open(str(self.root / filename), "wb") else: dest_folder = os.path.join(*(member.split(os.path.sep)[1:-1])) dest_folder_path = self.root / dest_folder dest_folder_path.mkdir(exist_ok=True, parents=True) target = open(str(dest_folder_path / filename), "wb") with source, target: shutil.copyfileobj(source, target) # lfilename.unlink() def _load_metadata(self) -> bool: if self.train: data_list = json.load(open(str(self.root / "Stream-51_meta_train.json"))) else: data_list = json.load(open(str(self.root / "Stream-51_meta_test.json"))) self.samples = data_list self.targets = [s[0] for s in data_list] self.bbox_crop = True self.ratio = 1.1 return True def _download_error_message(self) -> str: return ( "[Stream-51] Error downloading the dataset. Consider " "downloading it manually at: " + stream51_data.name[1] + " and placing it in: " + str(self.root) ) @staticmethod def _instance_ordering(data_list: Sequence[TSequence], seed) -> List[TSequence]: # organize data by video total_videos = 0 new_data_list = [] temp_video: List[TSequence] = [] for x in data_list: if x[3] == 0: new_data_list.append(temp_video) total_videos += 1 temp_video = [x] else: temp_video.append(x) new_data_list.append(temp_video) new_data_list = new_data_list[1:] # shuffle videos random.seed(seed) random.shuffle(new_data_list) # reorganize by clip data_list_result = [] for v in new_data_list: for x in v: data_list_result.append(x) return data_list_result @staticmethod def _class_ordering(data_list, class_type, seed): # organize data by class new_data_list = [] for class_id in range(data_list[-1][0] + 1): class_data_list = [x for x in data_list if x[0] == class_id] if class_type == "class_iid": # shuffle all class data random.seed(seed) random.shuffle(class_data_list) else: # shuffle clips within class class_data_list = Stream51._instance_ordering(class_data_list, seed) new_data_list.append(class_data_list) # shuffle classes random.seed(seed) random.shuffle(new_data_list) # reorganize by class data_list = [] for v in new_data_list: for x in v: data_list.append(x) return data_list @staticmethod def make_dataset(data_list, ordering="class_instance", seed=666): """ data_list for train: [class_id, clip_num, video_num, frame_num, bbox, file_loc] for test: [class_id, bbox, file_loc] """ if not ordering or len(data_list[0]) == 3: # cannot order the test set return data_list if ordering not in ["iid", "class_iid", "instance", "class_instance"]: raise ValueError( 'dataset ordering must be one of: "iid", "class_iid", ' '"instance", or "class_instance"' ) if ordering == "iid": # shuffle all data random.seed(seed) random.shuffle(data_list) return data_list elif ordering == "instance": return Stream51._instance_ordering(data_list, seed) elif "class" in ordering: return Stream51._class_ordering(data_list, ordering, seed) def __getitem__(self, index): """ Args: index (int): Index Returns: tuple: (sample, target) where target is class_index of the target class. """ fpath, target = self.samples[index][-1], self.targets[index] sample = self.loader(str(self.root / fpath)) if self.bbox_crop: bbox = self.samples[index][-2] cw = bbox[0] - bbox[1] ch = bbox[2] - bbox[3] center = [int(bbox[1] + cw / 2), int(bbox[3] + ch / 2)] bbox = [ min([int(center[0] + (cw * self.ratio / 2)), sample.size[0]]), max([int(center[0] - (cw * self.ratio / 2)), 0]), min([int(center[1] + (ch * self.ratio / 2)), sample.size[1]]), max([int(center[1] - (ch * self.ratio / 2)), 0]), ] sample = sample.crop((bbox[1], bbox[3], bbox[0], bbox[2])) if self.transform is not None: sample = self.transform(sample) if self.target_transform is not None: target = self.target_transform(target) return sample, target def __len__(self): return len(self.samples) def __repr__(self): fmt_str = "Dataset " + self.__class__.__name__ + "\n" fmt_str += " Number of datapoints: {}\n".format(self.__len__()) fmt_str += " Root Location: {}\n".format(self.root) tmp = " Transforms (if any): " fmt_str += "{0}{1}\n".format( tmp, self.transform.__repr__().replace("\n", "\n" + " " * len(tmp)) ) tmp = " Target Transforms (if any): " fmt_str += "{0}{1}".format( tmp, self.target_transform.__repr__().replace("\n", "\n" + " " * len(tmp)), ) return fmt_str
@dill.register(Stream51) def checkpoint_Stream51(pickler, obj: Stream51): constructor_based_serialization( pickler, obj, Stream51, deduplicate=True, kwargs=dict( root=obj.root, train=obj.train, transform=obj.transform, target_transform=obj.target_transform, loader=obj.loader, ), ) if __name__ == "__main__": # this little example script can be used to visualize the first image # loaded from the dataset. from torch.utils.data.dataloader import DataLoader import matplotlib.pyplot as plt from torchvision import transforms import torch train_data = Stream51(transform=ToTensor()) test_data = Stream51(transform=ToTensor(), train=False) print("train size: ", len(train_data)) print("Test size: ", len(test_data)) dataloader = DataLoader(train_data, batch_size=1) for batch_data in dataloader: x, y = batch_data plt.imshow(transforms.ToPILImage()(torch.squeeze(x))) plt.show() print(x.size()) print(len(y)) break __all__ = ["Stream51"]