Source code for avalanche.benchmarks.datasets.cub200.cub200

################################################################################
# Copyright (c) 2021 ContinualAI.                                              #
# Copyrights licensed under the MIT License.                                   #
# See the accompanying LICENSE file for terms.                                 #
#                                                                              #
# Date: 12-04-2021                                                             #
# Author: Lorenzo Pellegrini, Vincenzo Lomonaco                                #
# E-mail: contact@continualai.org                                              #
# Website: continualai.org                                                     #
################################################################################

"""
CUB200 Pytorch Dataset: Caltech-UCSD Birds-200-2011 (CUB-200-2011) is an
extended version of the CUB-200 dataset, with roughly double the number of
images per class and new part location annotations. For detailed information
about the dataset, please check the official website:
http://www.vision.caltech.edu/visipedia/CUB-200-2011.html.
"""

import csv
from pathlib import Path
from typing import Union

import gdown
import os
from collections import OrderedDict
from torchvision.datasets.folder import default_loader

from avalanche.benchmarks.datasets import default_dataset_location, \
    DownloadableDataset
from avalanche.benchmarks.utils import PathsDataset


[docs]class CUB200(PathsDataset, DownloadableDataset): """ Basic CUB200 PathsDataset to be used as a standard PyTorch Dataset. A classic continual learning benchmark built on top of this dataset can be found in 'benchmarks.classic', while for more custom benchmark design please use the 'benchmarks.generators'.""" images_folder = 'CUB_200_2011/images' official_url = 'http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/'\ 'CUB_200_2011.tgz' gdrive_url = "https://drive.google.com/u/0/uc?id=" \ "1hbzc_P1FuxMkcabkgn9ZKinBwW683j45" filename = 'CUB_200_2011.tgz' tgz_md5 = '97eceeb196236b17998738112f37df78'
[docs] def __init__( self, root: Union[str, Path] = None, *, train=True, transform=None, target_transform=None, loader=default_loader, download=True): """ :param root: root dir where the dataset can be found or downloaded. Defaults to None, which means that the default location for 'CUB_200_2011' will be used. :param train: train or test subset of the original dataset. Default to True. :param transform: eventual input data transformations to apply. Default to None. :param target_transform: eventual target data transformations to apply. Default to None. :param loader: method to load the data from disk. Default to torchvision default_loader. :param download: default set to True. If the data is already downloaded it will skip the download. """ if root is None: root = default_dataset_location('CUB_200_2011') self.train = train DownloadableDataset.__init__( self, root, download=download, verbose=True) self._load_dataset() PathsDataset.__init__( self, os.path.join(root, CUB200.images_folder), self._images, transform=transform, target_transform=target_transform, loader=loader)
def _download_dataset(self) -> None: try: self._download_and_extract_archive( CUB200.official_url, CUB200.filename, checksum=CUB200.tgz_md5) except Exception: if self.verbose: print('[CUB200] Direct download may no longer be possible, ' 'will try GDrive.') filepath = self.root / self.filename gdown.download(self.gdrive_url, str(filepath), quiet=False) gdown.cached_download( self.gdrive_url, str(filepath), md5=self.tgz_md5 ) self._extract_archive(filepath) def _download_error_message(self) -> str: return '[CUB200] Error downloading the dataset. Consider downloading ' \ 'it manually at: ' + CUB200.official_url + ' and placing it ' \ 'in: ' + str(self.root) def _load_metadata(self): """ Main method to load the CUB200 metadata """ cub_dir = self.root / 'CUB_200_2011' self._images = OrderedDict() with open(str(cub_dir / 'train_test_split.txt')) as csv_file: csv_reader = csv.reader(csv_file, delimiter=' ') for row in csv_reader: img_id = int(row[0]) is_train_instance = int(row[1]) == 1 if is_train_instance == self.train: self._images[img_id] = [] with open(str(cub_dir / 'images.txt')) as csv_file: csv_reader = csv.reader(csv_file, delimiter=' ') for row in csv_reader: img_id = int(row[0]) if img_id in self._images: self._images[img_id].append(row[1]) with open(str(cub_dir / 'image_class_labels.txt')) as csv_file: csv_reader = csv.reader(csv_file, delimiter=' ') for row in csv_reader: img_id = int(row[0]) if img_id in self._images: # CUB starts counting classes from 1 ... self._images[img_id].append(int(row[1]) - 1) with open(str(cub_dir / 'bounding_boxes.txt')) as csv_file: csv_reader = csv.reader(csv_file, delimiter=' ') for row in csv_reader: img_id = int(row[0]) if img_id in self._images: box_cub = [int(float(x)) for x in row[1:]] box_avl = [box_cub[1], box_cub[0], box_cub[3], box_cub[2]] # PathsDataset accepts (top, left, height, width) self._images[img_id].append(box_avl) images_tuples = [] for _, img_tuple in self._images.items(): images_tuples.append(tuple(img_tuple)) self._images = images_tuples # Integrity check for row in self._images: filepath = self.root / CUB200.images_folder / row[0] if not filepath.is_file(): if self.verbose: print('[CUB200] Error checking integrity of:', filepath) return False return True
if __name__ == "__main__": """ Simple test that will start if you run this script directly """ import matplotlib.pyplot as plt dataset = CUB200(train=False, download=True) print("test data len:", len(dataset)) img, _ = dataset[14] plt.imshow(img) plt.show() dataset = CUB200(train=True) print("train data len:", len(dataset)) img, _ = dataset[700] plt.imshow(img) plt.show() __all__ = [ 'CUB200' ]