Source code for tensorpack.dataflow.dataset.places

#-*- coding: utf-8 -*-

import os
import numpy as np

from ...utils import logger
from ..base import RNGDataFlow


[docs]class Places365Standard(RNGDataFlow): """ The Places365-Standard Dataset, in low resolution format only. Produces BGR images of shape (256, 256, 3) in range [0, 255]. """
[docs] def __init__(self, dir, name, shuffle=None): """ Args: dir: path to the Places365-Standard dataset in its "easy directory structure". See http://places2.csail.mit.edu/download.html name: one of "train" or "val" shuffle (bool): shuffle the dataset. Defaults to True if name=='train'. """ assert name in ['train', 'val'], name dir = os.path.expanduser(dir) assert os.path.isdir(dir), dir self.name = name if shuffle is None: shuffle = name == 'train' self.shuffle = shuffle label_file = os.path.join(dir, name + ".txt") all_files = [] labels = set() with open(label_file) as f: for line in f: filepath = os.path.join(dir, line.strip()) line = line.strip().split("/") label = line[1] all_files.append((filepath, label)) labels.add(label) self._labels = sorted(labels) # class ids are sorted alphabetically: # https://github.com/CSAILVision/places365/blob/master/categories_places365.txt labelmap = {label: id for id, label in enumerate(self._labels)} self._files = [(path, labelmap[x]) for path, x in all_files] logger.info("Found {} images in {}.".format(len(self._files), label_file))
[docs] def get_label_names(self): """ Returns: [str]: name of each class. """ return self._labels
def __len__(self): return len(self._files) def __iter__(self): idxs = np.arange(len(self._files)) if self.shuffle: self.rng.shuffle(idxs) for k in idxs: fname, label = self._files[k] im = cv2.imread(fname, cv2.IMREAD_COLOR) assert im is not None, fname yield [im, label]
try: import cv2 except ImportError: from ...utils.develop import create_dummy_class Places365Standard = create_dummy_class('Places365Standard', 'cv2') # noqa if __name__ == '__main__': from tensorpack.dataflow import PrintData ds = Places365Standard("~/data/places365_standard/", 'train') ds = PrintData(ds, num=100) ds.reset_state() for k in ds: pass