Source code for tensorpack.dataflow.image

# -*- coding: utf-8 -*-
# File: image.py


import copy as copy_mod
import numpy as np
from contextlib import contextmanager

from ..utils import logger
from ..utils.argtools import shape2d
from .base import RNGDataFlow
from .common import MapData, MapDataComponent

__all__ = ['ImageFromFile', 'AugmentImageComponent', 'AugmentImageCoordinates', 'AugmentImageComponents']


def check_dtype(img):
    assert isinstance(img, np.ndarray), "[Augmentor] Needs an numpy array, but got a {}!".format(type(img))
    assert not isinstance(img.dtype, np.integer) or (img.dtype == np.uint8), \
        "[Augmentor] Got image of type {}, use uint8 or floating points instead!".format(img.dtype)


def validate_coords(coords):
    assert coords.ndim == 2, coords.ndim
    assert coords.shape[1] == 2, coords.shape
    assert np.issubdtype(coords.dtype, np.float), coords.dtype


class ExceptionHandler:
    def __init__(self, catch_exceptions=False):
        self._nr_error = 0
        self.catch_exceptions = catch_exceptions

    @contextmanager
    def catch(self):
        try:
            yield
        except Exception:
            self._nr_error += 1
            if not self.catch_exceptions:
                raise
            else:
                if self._nr_error % 100 == 0 or self._nr_error < 10:
                    logger.exception("Got {} augmentation errors.".format(self._nr_error))


[docs]class ImageFromFile(RNGDataFlow): """ Produce images read from a list of files as (h, w, c) arrays. """
[docs] def __init__(self, files, channel=3, resize=None, shuffle=False): """ Args: files (list): list of file paths. channel (int): 1 or 3. Will convert grayscale to RGB images if channel==3. Will produce (h, w, 1) array if channel==1. resize (tuple): int or (h, w) tuple. If given, resize the image. """ assert len(files), "No image files given to ImageFromFile!" self.files = files self.channel = int(channel) assert self.channel in [1, 3], self.channel self.imread_mode = cv2.IMREAD_GRAYSCALE if self.channel == 1 else cv2.IMREAD_COLOR if resize is not None: resize = shape2d(resize) self.resize = resize self.shuffle = shuffle
def __len__(self): return len(self.files) def __iter__(self): if self.shuffle: self.rng.shuffle(self.files) for f in self.files: im = cv2.imread(f, self.imread_mode) assert im is not None, f if self.channel == 3: im = im[:, :, ::-1] if self.resize is not None: im = cv2.resize(im, tuple(self.resize[::-1])) if self.channel == 1: im = im[:, :, np.newaxis] yield [im]
[docs]class AugmentImageComponent(MapDataComponent): """ Apply image augmentors on 1 image component. """
[docs] def __init__(self, ds, augmentors, index=0, copy=True, catch_exceptions=False): """ Args: ds (DataFlow): input DataFlow. augmentors (AugmentorList): a list of :class:`imgaug.ImageAugmentor` to be applied in order. index (int or str): the index or key of the image component to be augmented in the datapoint. copy (bool): Some augmentors modify the input images. When copy is True, a copy will be made before any augmentors are applied, to keep the original images not modified. Turn it off to save time when you know it's OK. catch_exceptions (bool): when set to True, will catch all exceptions and only warn you when there are too many (>100). Can be used to ignore occasion errors in data. """ if isinstance(augmentors, AugmentorList): self.augs = augmentors else: self.augs = AugmentorList(augmentors) self._copy = copy self._exception_handler = ExceptionHandler(catch_exceptions) super(AugmentImageComponent, self).__init__(ds, self._aug_mapper, index)
def reset_state(self): self.ds.reset_state() self.augs.reset_state() def _aug_mapper(self, x): check_dtype(x) with self._exception_handler.catch(): if self._copy: x = copy_mod.deepcopy(x) return self.augs.augment(x)
[docs]class AugmentImageCoordinates(MapData): """ Apply image augmentors on an image and a list of coordinates. Coordinates must be a Nx2 floating point array, each row is (x, y). """
[docs] def __init__(self, ds, augmentors, img_index=0, coords_index=1, copy=True, catch_exceptions=False): """ Args: ds (DataFlow): input DataFlow. augmentors (AugmentorList): a list of :class:`imgaug.ImageAugmentor` to be applied in order. img_index (int or str): the index/key of the image component to be augmented. coords_index (int or str): the index/key of the coordinate component to be augmented. copy, catch_exceptions: same as in :class:`AugmentImageComponent` """ if isinstance(augmentors, AugmentorList): self.augs = augmentors else: self.augs = AugmentorList(augmentors) self._img_index = img_index self._coords_index = coords_index self._copy = copy self._exception_handler = ExceptionHandler(catch_exceptions) super(AugmentImageCoordinates, self).__init__(ds, self._aug_mapper)
def reset_state(self): self.ds.reset_state() self.augs.reset_state() def _aug_mapper(self, dp): with self._exception_handler.catch(): img, coords = dp[self._img_index], dp[self._coords_index] check_dtype(img) validate_coords(coords) if self._copy: img, coords = copy_mod.deepcopy((img, coords)) tfms = self.augs.get_transform(img) dp[self._img_index] = tfms.apply_image(img) dp[self._coords_index] = tfms.apply_coords(coords) return dp
[docs]class AugmentImageComponents(MapData): """ Apply image augmentors on several components, with shared augmentation parameters. Example: .. code-block:: python ds = MyDataFlow() # produce [image(HWC), segmask(HW), keypoint(Nx2)] ds = AugmentImageComponents( ds, augs, index=(0,1), coords_index=(2,)) """
[docs] def __init__(self, ds, augmentors, index=(0, 1), coords_index=(), copy=True, catch_exceptions=False): """ Args: ds (DataFlow): input DataFlow. augmentors (AugmentorList): a list of :class:`imgaug.ImageAugmentor` instance to be applied in order. index: tuple of indices of the image components. coords_index: tuple of indices of the coordinates components. copy, catch_exceptions: same as in :class:`AugmentImageComponent` """ if isinstance(augmentors, AugmentorList): self.augs = augmentors else: self.augs = AugmentorList(augmentors) self.ds = ds self._exception_handler = ExceptionHandler(catch_exceptions) self._copy = copy self._index = index self._coords_index = coords_index super(AugmentImageComponents, self).__init__(ds, self._aug_mapper)
def reset_state(self): self.ds.reset_state() self.augs.reset_state() def _aug_mapper(self, dp): dp = copy_mod.copy(dp) # always do a shallow copy, make sure the list is intact copy_func = copy_mod.deepcopy if self._copy else lambda x: x # noqa with self._exception_handler.catch(): major_image = self._index[0] # image to be used to get params. TODO better design? im = copy_func(dp[major_image]) check_dtype(im) tfms = self.augs.get_transform(im) dp[major_image] = tfms.apply_image(im) for idx in self._index[1:]: check_dtype(dp[idx]) dp[idx] = tfms.apply_image(copy_func(dp[idx])) for idx in self._coords_index: coords = copy_func(dp[idx]) validate_coords(coords) dp[idx] = tfms.apply_coords(coords) return dp
try: import cv2 from .imgaug import AugmentorList except ImportError: from ..utils.develop import create_dummy_class ImageFromFile = create_dummy_class('ImageFromFile', 'cv2') # noqa AugmentImageComponent = create_dummy_class('AugmentImageComponent', 'cv2') # noqa AugmentImageCoordinates = create_dummy_class('AugmentImageCoordinates', 'cv2') # noqa AugmentImageComponents = create_dummy_class('AugmentImageComponents', 'cv2') # noqa