Source code for tensorpack.dataflow.imgaug.convert

# -*- coding: utf-8 -*-
# File:

import numpy as np
import cv2

from .base import ImageAugmentor
from .meta import MapImage

__all__ = ['ColorSpace', 'Grayscale', 'ToUint8', 'ToFloat32']

[docs]class ColorSpace(ImageAugmentor): """ Convert into another color space. """
[docs] def __init__(self, mode, keepdims=True): """ Args: mode: OpenCV color space conversion code (e.g., ``cv2.COLOR_BGR2HSV``) keepdims (bool): keep the dimension of image unchanged if OpenCV changes it. """ self._init(locals())
def _augment(self, img, _): transf = cv2.cvtColor(img, self.mode) if self.keepdims: if len(transf.shape) is not len(img.shape): transf = transf[..., None] return transf
[docs]class Grayscale(ColorSpace): """ Convert image to grayscale. """
[docs] def __init__(self, keepdims=True, rgb=False): """ Args: keepdims (bool): return image of shape [H, W, 1] instead of [H, W] rgb (bool): interpret input as RGB instead of the default BGR """ mode = cv2.COLOR_RGB2GRAY if rgb else cv2.COLOR_BGR2GRAY super(Grayscale, self).__init__(mode, keepdims)
[docs]class ToUint8(MapImage): """ Convert image to uint8. Useful to reduce communication overhead. """ def __init__(self): super(ToUint8, self).__init__(lambda x: np.clip(x, 0, 255).astype(np.uint8), lambda x: x)
[docs]class ToFloat32(MapImage): """ Convert image to float32, may increase quality of the augmentor. """ def __init__(self): super(ToFloat32, self).__init__(lambda x: x.astype(np.float32), lambda x: x)