# -*- coding: utf-8 -*-
# File: image.py
import copy as copy_mod
import numpy as np
from contextlib import contextmanager
from ..utils import logger
from ..utils.argtools import shape2d
from .base import RNGDataFlow
from .common import MapData, MapDataComponent
__all__ = ['ImageFromFile', 'AugmentImageComponent', 'AugmentImageCoordinates', 'AugmentImageComponents']
def check_dtype(img):
assert isinstance(img, np.ndarray), "[Augmentor] Needs an numpy array, but got a {}!".format(type(img))
assert not isinstance(img.dtype, np.integer) or (img.dtype == np.uint8), \
"[Augmentor] Got image of type {}, use uint8 or floating points instead!".format(img.dtype)
def validate_coords(coords):
assert coords.ndim == 2, coords.ndim
assert coords.shape[1] == 2, coords.shape
assert np.issubdtype(coords.dtype, np.float), coords.dtype
class ExceptionHandler:
def __init__(self, catch_exceptions=False):
self._nr_error = 0
self.catch_exceptions = catch_exceptions
@contextmanager
def catch(self):
try:
yield
except Exception:
self._nr_error += 1
if not self.catch_exceptions:
raise
else:
if self._nr_error % 100 == 0 or self._nr_error < 10:
logger.exception("Got {} augmentation errors.".format(self._nr_error))
[docs]class ImageFromFile(RNGDataFlow):
""" Produce images read from a list of files as (h, w, c) arrays. """
[docs] def __init__(self, files, channel=3, resize=None, shuffle=False):
"""
Args:
files (list): list of file paths.
channel (int): 1 or 3. Will convert grayscale to RGB images if channel==3.
Will produce (h, w, 1) array if channel==1.
resize (tuple): int or (h, w) tuple. If given, resize the image.
"""
assert len(files), "No image files given to ImageFromFile!"
self.files = files
self.channel = int(channel)
assert self.channel in [1, 3], self.channel
self.imread_mode = cv2.IMREAD_GRAYSCALE if self.channel == 1 else cv2.IMREAD_COLOR
if resize is not None:
resize = shape2d(resize)
self.resize = resize
self.shuffle = shuffle
def __len__(self):
return len(self.files)
def __iter__(self):
if self.shuffle:
self.rng.shuffle(self.files)
for f in self.files:
im = cv2.imread(f, self.imread_mode)
assert im is not None, f
if self.channel == 3:
im = im[:, :, ::-1]
if self.resize is not None:
im = cv2.resize(im, tuple(self.resize[::-1]))
if self.channel == 1:
im = im[:, :, np.newaxis]
yield [im]
[docs]class AugmentImageComponent(MapDataComponent):
"""
Apply image augmentors on 1 image component.
"""
[docs] def __init__(self, ds, augmentors, index=0, copy=True, catch_exceptions=False):
"""
Args:
ds (DataFlow): input DataFlow.
augmentors (AugmentorList): a list of :class:`imgaug.ImageAugmentor` to be applied in order.
index (int or str): the index or key of the image component to be augmented in the datapoint.
copy (bool): Some augmentors modify the input images. When copy is
True, a copy will be made before any augmentors are applied,
to keep the original images not modified.
Turn it off to save time when you know it's OK.
catch_exceptions (bool): when set to True, will catch
all exceptions and only warn you when there are too many (>100).
Can be used to ignore occasion errors in data.
"""
if isinstance(augmentors, AugmentorList):
self.augs = augmentors
else:
self.augs = AugmentorList(augmentors)
self._copy = copy
self._exception_handler = ExceptionHandler(catch_exceptions)
super(AugmentImageComponent, self).__init__(ds, self._aug_mapper, index)
def reset_state(self):
self.ds.reset_state()
self.augs.reset_state()
def _aug_mapper(self, x):
check_dtype(x)
with self._exception_handler.catch():
if self._copy:
x = copy_mod.deepcopy(x)
return self.augs.augment(x)
[docs]class AugmentImageCoordinates(MapData):
"""
Apply image augmentors on an image and a list of coordinates.
Coordinates must be a Nx2 floating point array, each row is (x, y).
"""
[docs] def __init__(self, ds, augmentors, img_index=0, coords_index=1, copy=True, catch_exceptions=False):
"""
Args:
ds (DataFlow): input DataFlow.
augmentors (AugmentorList): a list of :class:`imgaug.ImageAugmentor` to be applied in order.
img_index (int or str): the index/key of the image component to be augmented.
coords_index (int or str): the index/key of the coordinate component to be augmented.
copy, catch_exceptions: same as in :class:`AugmentImageComponent`
"""
if isinstance(augmentors, AugmentorList):
self.augs = augmentors
else:
self.augs = AugmentorList(augmentors)
self._img_index = img_index
self._coords_index = coords_index
self._copy = copy
self._exception_handler = ExceptionHandler(catch_exceptions)
super(AugmentImageCoordinates, self).__init__(ds, self._aug_mapper)
def reset_state(self):
self.ds.reset_state()
self.augs.reset_state()
def _aug_mapper(self, dp):
with self._exception_handler.catch():
img, coords = dp[self._img_index], dp[self._coords_index]
check_dtype(img)
validate_coords(coords)
if self._copy:
img, coords = copy_mod.deepcopy((img, coords))
tfms = self.augs.get_transform(img)
dp[self._img_index] = tfms.apply_image(img)
dp[self._coords_index] = tfms.apply_coords(coords)
return dp
[docs]class AugmentImageComponents(MapData):
"""
Apply image augmentors on several components, with shared augmentation parameters.
Example:
.. code-block:: python
ds = MyDataFlow() # produce [image(HWC), segmask(HW), keypoint(Nx2)]
ds = AugmentImageComponents(
ds, augs,
index=(0,1), coords_index=(2,))
"""
[docs] def __init__(self, ds, augmentors, index=(0, 1), coords_index=(), copy=True, catch_exceptions=False):
"""
Args:
ds (DataFlow): input DataFlow.
augmentors (AugmentorList): a list of :class:`imgaug.ImageAugmentor` instance to be applied in order.
index: tuple of indices of the image components.
coords_index: tuple of indices of the coordinates components.
copy, catch_exceptions: same as in :class:`AugmentImageComponent`
"""
if isinstance(augmentors, AugmentorList):
self.augs = augmentors
else:
self.augs = AugmentorList(augmentors)
self.ds = ds
self._exception_handler = ExceptionHandler(catch_exceptions)
self._copy = copy
self._index = index
self._coords_index = coords_index
super(AugmentImageComponents, self).__init__(ds, self._aug_mapper)
def reset_state(self):
self.ds.reset_state()
self.augs.reset_state()
def _aug_mapper(self, dp):
dp = copy_mod.copy(dp) # always do a shallow copy, make sure the list is intact
copy_func = copy_mod.deepcopy if self._copy else lambda x: x # noqa
with self._exception_handler.catch():
major_image = self._index[0] # image to be used to get params. TODO better design?
im = copy_func(dp[major_image])
check_dtype(im)
tfms = self.augs.get_transform(im)
dp[major_image] = tfms.apply_image(im)
for idx in self._index[1:]:
check_dtype(dp[idx])
dp[idx] = tfms.apply_image(copy_func(dp[idx]))
for idx in self._coords_index:
coords = copy_func(dp[idx])
validate_coords(coords)
dp[idx] = tfms.apply_coords(coords)
return dp
try:
import cv2
from .imgaug import AugmentorList
except ImportError:
from ..utils.develop import create_dummy_class
ImageFromFile = create_dummy_class('ImageFromFile', 'cv2') # noqa
AugmentImageComponent = create_dummy_class('AugmentImageComponent', 'cv2') # noqa
AugmentImageCoordinates = create_dummy_class('AugmentImageCoordinates', 'cv2') # noqa
AugmentImageComponents = create_dummy_class('AugmentImageComponents', 'cv2') # noqa