Source code for tensorpack.dataflow.dataset.mnist

# -*- coding: utf-8 -*-
# File: mnist.py


import os
import gzip
import numpy
from six.moves import range

from ...utils import logger
from ...utils.fs import download, get_dataset_path
from ..base import RNGDataFlow

__all__ = ['Mnist', 'FashionMnist']


def maybe_download(url, work_directory):
    """Download the data from Yann's website, unless it's already here."""
    filename = url.split('/')[-1]
    filepath = os.path.join(work_directory, filename)
    if not os.path.exists(filepath):
        logger.info("Downloading to {}...".format(filepath))
        download(url, work_directory)
    return filepath


def _read32(bytestream):
    dt = numpy.dtype(numpy.uint32).newbyteorder('>')
    return numpy.frombuffer(bytestream.read(4), dtype=dt)[0]


def extract_images(filename):
    """Extract the images into a 4D uint8 numpy array [index, y, x, depth]."""
    with gzip.open(filename) as bytestream:
        magic = _read32(bytestream)
        if magic != 2051:
            raise ValueError(
                'Invalid magic number %d in MNIST image file: %s' %
                (magic, filename))
        num_images = _read32(bytestream)
        rows = _read32(bytestream)
        cols = _read32(bytestream)
        buf = bytestream.read(rows * cols * num_images)
        data = numpy.frombuffer(buf, dtype=numpy.uint8)
        data = data.reshape(num_images, rows, cols, 1)
        data = data.astype('float32') / 255.0
        return data


def extract_labels(filename):
    """Extract the labels into a 1D uint8 numpy array [index]."""
    with gzip.open(filename) as bytestream:
        magic = _read32(bytestream)
        if magic != 2049:
            raise ValueError(
                'Invalid magic number %d in MNIST label file: %s' %
                (magic, filename))
        num_items = _read32(bytestream)
        buf = bytestream.read(num_items)
        labels = numpy.frombuffer(buf, dtype=numpy.uint8)
        return labels


[docs]class Mnist(RNGDataFlow): """ Produces [image, label] in MNIST dataset, image is 28x28 in the range [0,1], label is an int. """ _DIR_NAME = 'mnist_data' _SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'
[docs] def __init__(self, train_or_test, shuffle=True, dir=None): """ Args: train_or_test (str): either 'train' or 'test' shuffle (bool): shuffle the dataset """ if dir is None: dir = get_dataset_path(self._DIR_NAME) assert train_or_test in ['train', 'test'] self.train_or_test = train_or_test self.shuffle = shuffle def get_images_and_labels(image_file, label_file): f = maybe_download(self._SOURCE_URL + image_file, dir) images = extract_images(f) f = maybe_download(self._SOURCE_URL + label_file, dir) labels = extract_labels(f) assert images.shape[0] == labels.shape[0] return images, labels if self.train_or_test == 'train': self.images, self.labels = get_images_and_labels( 'train-images-idx3-ubyte.gz', 'train-labels-idx1-ubyte.gz') else: self.images, self.labels = get_images_and_labels( 't10k-images-idx3-ubyte.gz', 't10k-labels-idx1-ubyte.gz')
def __len__(self): return self.images.shape[0] def __iter__(self): idxs = list(range(self.__len__())) if self.shuffle: self.rng.shuffle(idxs) for k in idxs: img = self.images[k].reshape((28, 28)) label = self.labels[k] yield [img, label]
[docs]class FashionMnist(Mnist): """ Same API as :class:`Mnist`, but more fashion. """ _DIR_NAME = 'fashion_mnist_data' _SOURCE_URL = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/'
[docs] def get_label_names(self): """ Returns: [str]: the name of each class """ # copied from https://github.com/zalandoresearch/fashion-mnist return ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
if __name__ == '__main__': ds = Mnist('train') ds.reset_state() for (img, label) in ds: from IPython import embed embed() break