# -*- coding: utf-8 -*-
# File:

from __future__ import division
import itertools
import numpy as np
import pprint
from collections import defaultdict, deque
from copy import copy
import six
import tqdm
from termcolor import colored

from ..utils import logger
from ..utils.utils import get_rng, get_tqdm, get_tqdm_kwargs
from .base import DataFlow, DataFlowReentrantGuard, ProxyDataFlow, RNGDataFlow

    from import Mapping
except ImportError:
    from collections import Mapping

__all__ = ['TestDataSpeed', 'PrintData', 'BatchData', 'BatchDataByShape', 'FixedSizeData', 'MapData',
           'MapDataComponent', 'RepeatedData', 'RepeatedDataPoint', 'RandomChooseData',
           'RandomMixData', 'JoinData', 'ConcatData', 'SelectComponent',
           'LocallyShuffleData', 'CacheData']

[docs]class TestDataSpeed(ProxyDataFlow): """ Test the speed of a DataFlow """
[docs] def __init__(self, ds, size=5000, warmup=0): """ Args: ds (DataFlow): the DataFlow to test. size (int): number of datapoints to fetch. warmup (int): warmup iterations """ super(TestDataSpeed, self).__init__(ds) self.test_size = int(size) self.warmup = int(warmup) self._reset_called = False
def reset_state(self): self._reset_called = True super(TestDataSpeed, self).reset_state()
[docs] def __iter__(self): """ Will run testing at the beginning, then produce data normally. """ self.start() yield from self.ds
[docs] def start(self): """ Start testing with a progress bar. """ if not self._reset_called: self.ds.reset_state() itr = self.ds.__iter__() if self.warmup: for _ in tqdm.trange(self.warmup, **get_tqdm_kwargs()): next(itr) # add smoothing for speed benchmark with get_tqdm(total=self.test_size, leave=True, smoothing=0.2) as pbar: for idx, dp in enumerate(itr): pbar.update() if idx == self.test_size - 1: break
[docs]class BatchData(ProxyDataFlow): """ Stack datapoints into batches. It produces datapoints of the same number of components as ``ds``, but each component has one new extra dimension of size ``batch_size``. The batch can be either a list of original components, or (by default) a numpy array of original components. """
[docs] def __init__(self, ds, batch_size, remainder=False, use_list=False): """ Args: ds (DataFlow): A dataflow that produces either list or dict. When ``use_list=False``, the components of ``ds`` must be either scalars or :class:`np.ndarray`, and have to be consistent in shapes. batch_size(int): batch size remainder (bool): When the remaining datapoints in ``ds`` is not enough to form a batch, whether or not to also produce the remaining data as a smaller batch. If set to False, all produced datapoints are guaranteed to have the same batch size. If set to True, `len(ds)` must be accurate. use_list (bool): if True, each component will contain a list of datapoints instead of an numpy array of an extra dimension. """ super(BatchData, self).__init__(ds) if not remainder: try: assert batch_size <= len(ds) except NotImplementedError: pass self.batch_size = int(batch_size) assert self.batch_size > 0 self.remainder = remainder self.use_list = use_list
def __len__(self): ds_size = len(self.ds) div = ds_size // self.batch_size rem = ds_size % self.batch_size if rem == 0: return div return div + int(self.remainder)
[docs] def __iter__(self): """ Yields: Batched data by stacking each component on an extra 0th dimension. """ holder = [] for data in self.ds: holder.append(data) if len(holder) == self.batch_size: yield BatchData.aggregate_batch(holder, self.use_list) del holder[:] if self.remainder and len(holder) > 0: yield BatchData.aggregate_batch(holder, self.use_list)
@staticmethod def _batch_numpy(data_list): data = data_list[0] if isinstance(data, six.integer_types): dtype = 'int32' elif type(data) == bool: dtype = 'bool' elif type(data) == float: dtype = 'float32' elif isinstance(data, (six.binary_type, six.text_type)): dtype = 'str' else: try: dtype = data.dtype except AttributeError: raise TypeError("Unsupported type to batch: {}".format(type(data))) try: return np.asarray(data_list, dtype=dtype) except Exception as e: # noqa logger.exception("Cannot batch data. Perhaps they are of inconsistent shape?") if isinstance(data, np.ndarray): s = pprint.pformat([x.shape for x in data_list]) logger.error("Shape of all arrays to be batched: " + s) try: # open an ipython shell if possible import IPython as IP; IP.embed() # noqa except ImportError: pass
[docs] @staticmethod def aggregate_batch(data_holder, use_list=False): """ Aggregate a list of datapoints to one batched datapoint. Args: data_holder (list[dp]): each dp is either a list or a dict. use_list (bool): whether to batch data into a list or a numpy array. Returns: dp: either a list or a dict, depend on the inputs. Each item is a batched version of the corresponding inputs. """ first_dp = data_holder[0] if isinstance(first_dp, (list, tuple)): result = [] for k in range(len(first_dp)): data_list = [x[k] for x in data_holder] if use_list: result.append(data_list) else: result.append(BatchData._batch_numpy(data_list)) elif isinstance(first_dp, dict): result = {} for key in first_dp.keys(): data_list = [x[key] for x in data_holder] if use_list: result[key] = data_list else: result[key] = BatchData._batch_numpy(data_list) else: raise ValueError("Data point has to be list/tuple/dict. Got {}".format(type(first_dp))) return result
[docs]class BatchDataByShape(BatchData): """ Group datapoints of the same shape together to batches. It doesn't require input DataFlow to be homogeneous anymore: it can have datapoints of different shape, and batches will be formed from those who have the same shape. Note: It is implemented by a dict{shape -> datapoints}. Therefore, datapoints of uncommon shapes may never be enough to form a batch and never get generated. """
[docs] def __init__(self, ds, batch_size, idx): """ Args: ds (DataFlow): input DataFlow. ``dp[idx]`` has to be an :class:`np.ndarray`. batch_size (int): batch size idx (int): ``dp[idx].shape`` will be used to group datapoints. Other components are assumed to be batch-able. """ super(BatchDataByShape, self).__init__(ds, batch_size, remainder=False) self.idx = idx
def reset_state(self): super(BatchDataByShape, self).reset_state() self.holder = defaultdict(list) self._guard = DataFlowReentrantGuard() def __iter__(self): with self._guard: for dp in self.ds: shp = dp[self.idx].shape holder = self.holder[shp] holder.append(dp) if len(holder) == self.batch_size: yield BatchData.aggregate_batch(holder) del holder[:]
[docs]class FixedSizeData(ProxyDataFlow): """ Generate data from another DataFlow, but with a fixed total count. """
[docs] def __init__(self, ds, size, keep_state=True): """ Args: ds (DataFlow): input dataflow size (int): size keep_state (bool): keep the iterator state of ``ds`` between calls to :meth:`__iter__()`, so that the next call will continue the previous iteration over ``ds``, instead of reinitializing an iterator. Example: .. code-block:: none ds produces: 1, 2, 3, 4, 5; 1, 2, 3, 4, 5; ... FixedSizeData(ds, 3, True): 1, 2, 3; 4, 5, 1; 2, 3, 4; ... FixedSizeData(ds, 3, False): 1, 2, 3; 1, 2, 3; ... FixedSizeData(ds, 6, False): 1, 2, 3, 4, 5, 1; 1, 2, 3, 4, 5, 1;... """ super(FixedSizeData, self).__init__(ds) self._size = int(size) self.itr = None self._keep = keep_state
def __len__(self): return self._size def reset_state(self): super(FixedSizeData, self).reset_state() self.itr = self.ds.__iter__() self._guard = DataFlowReentrantGuard() def __iter__(self): with self._guard: if self.itr is None: self.itr = self.ds.__iter__() cnt = 0 while True: try: dp = next(self.itr) except StopIteration: self.itr = self.ds.__iter__() dp = next(self.itr) cnt += 1 yield dp if cnt == self._size: if not self._keep: self.itr = None return
[docs]class MapData(ProxyDataFlow): """ Apply a mapper/filter on the datapoints of a DataFlow. Note: 1. Please make sure func doesn't modify its arguments in place, unless you're certain it's safe. 2. If you discard some datapoints, ``len(MapData(ds))`` will be incorrect. Example: .. code-block:: none ds = Mnist('train') # each datapoint is [img, label] ds = MapData(ds, lambda dp: [dp[0] * 255, dp[1]]) """
[docs] def __init__(self, ds, func): """ Args: ds (DataFlow): input DataFlow func (datapoint -> datapoint | None): takes a datapoint and returns a new datapoint. Return None to discard/skip this datapoint. """ super(MapData, self).__init__(ds) self.func = func
def __iter__(self): for dp in self.ds: ret = self.func(copy(dp)) # shallow copy the list if ret is not None: yield ret
[docs]class MapDataComponent(MapData): """ Apply a mapper/filter on a datapoint component. Note: 1. This dataflow itself doesn't modify the datapoints. But please make sure func doesn't modify its arguments in place, unless you're certain it's safe. 2. If you discard some datapoints, ``len(MapDataComponent(ds, ..))`` will be incorrect. Example: .. code-block:: none ds = Mnist('train') # each datapoint is [img, label] ds = MapDataComponent(ds, lambda img: img * 255, 0) # map the 0th component """
[docs] def __init__(self, ds, func, index=0): """ Args: ds (DataFlow): input DataFlow which produces either list or dict. func (TYPE -> TYPE|None): takes ``dp[index]``, returns a new value for ``dp[index]``. Return None to discard/skip this datapoint. index (int or str): index or key of the component. """ self._index = index self._func = func super(MapDataComponent, self).__init__(ds, self._mapper)
def _mapper(self, dp): r = self._func(dp[self._index]) if r is None: return None dp = copy(dp) # shallow copy to avoid modifying the datapoint if isinstance(dp, tuple): dp = list(dp) # to be able to modify it in the next line dp[self._index] = r return dp
[docs]class RepeatedData(ProxyDataFlow): """ Take data points from another DataFlow and produce them until it's exhausted for certain amount of times. i.e.: dp1, dp2, .... dpn, dp1, dp2, ....dpn """
[docs] def __init__(self, ds, num): """ Args: ds (DataFlow): input DataFlow num (int): number of times to repeat ds. Set to -1 to repeat ``ds`` infinite times. """ self.num = num super(RepeatedData, self).__init__(ds)
[docs] def __len__(self): """ Raises: :class:`ValueError` when num == -1. """ if self.num == -1: raise NotImplementedError("__len__() is unavailable for infinite dataflow") return len(self.ds) * self.num
def __iter__(self): if self.num == -1: while True: yield from self.ds else: for _ in range(self.num): yield from self.ds
[docs]class RepeatedDataPoint(ProxyDataFlow): """ Take data points from another DataFlow and produce them a certain number of times. i.e.: dp1, dp1, ..., dp1, dp2, ..., dp2, ... """
[docs] def __init__(self, ds, num): """ Args: ds (DataFlow): input DataFlow num (int): number of times to repeat each datapoint. """ self.num = int(num) assert self.num >= 1, self.num super(RepeatedDataPoint, self).__init__(ds)
def __len__(self): return len(self.ds) * self.num def __iter__(self): for dp in self.ds: for _ in range(self.num): yield dp
[docs]class RandomChooseData(RNGDataFlow): """ Randomly choose from several DataFlow. Stop producing when any of them is exhausted. """
[docs] def __init__(self, df_lists): """ Args: df_lists (list): a list of DataFlow, or a list of (DataFlow, probability) tuples. Probabilities must sum to 1 if used. """ super(RandomChooseData, self).__init__() if isinstance(df_lists[0], (tuple, list)): assert sum(v[1] for v in df_lists) == 1.0 self.df_lists = df_lists else: prob = 1.0 / len(df_lists) self.df_lists = [(k, prob) for k in df_lists]
def reset_state(self): super(RandomChooseData, self).reset_state() for d in self.df_lists: if isinstance(d, tuple): d[0].reset_state() else: d.reset_state() def __iter__(self): itrs = [v[0].__iter__() for v in self.df_lists] probs = np.array([v[1] for v in self.df_lists]) try: while True: itr = self.rng.choice(itrs, p=probs) yield next(itr) except StopIteration: return
[docs]class RandomMixData(RNGDataFlow): """ Perfectly mix datapoints from several DataFlow using their :meth:`__len__()`. Will stop when all DataFlow exhausted. """
[docs] def __init__(self, df_lists): """ Args: df_lists (list): a list of DataFlow. All DataFlow must implement ``__len__()``. """ super(RandomMixData, self).__init__() self.df_lists = df_lists self.sizes = [len(k) for k in self.df_lists]
def reset_state(self): super(RandomMixData, self).reset_state() for d in self.df_lists: d.reset_state() def __len__(self): return sum(self.sizes) def __iter__(self): sums = np.cumsum(self.sizes) idxs = np.arange(self.__len__()) self.rng.shuffle(idxs) idxs = np.array([np.searchsorted(sums, x, 'right') for x in idxs]) itrs = [k.__iter__() for k in self.df_lists] assert idxs.max() == len(itrs) - 1, "{}!={}".format(idxs.max(), len(itrs) - 1) for k in idxs: yield next(itrs[k])
# TODO run till exception
[docs]class ConcatData(DataFlow): """ Concatenate several DataFlow. Produce datapoints from each DataFlow and start the next when one DataFlow is exhausted. """
[docs] def __init__(self, df_lists): """ Args: df_lists (list): a list of DataFlow. """ self.df_lists = df_lists
def reset_state(self): for d in self.df_lists: d.reset_state() def __len__(self): return sum(len(x) for x in self.df_lists) def __iter__(self): for d in self.df_lists: yield from d
[docs]class JoinData(DataFlow): """ Join the components from each DataFlow. See below for its behavior. Note that you can't join a DataFlow that produces lists with one that produces dicts. Example: .. code-block:: none df1 produces: [c1, c2] df2 produces: [c3, c4] joined: [c1, c2, c3, c4] df1 produces: {"a":c1, "b":c2} df2 produces: {"c":c3} joined: {"a":c1, "b":c2, "c":c3} """
[docs] def __init__(self, df_lists): """ Args: df_lists (list): a list of DataFlow. When these dataflows have different sizes, JoinData will stop when any of them is exhausted. The list could contain the same DataFlow instance more than once, but note that in that case `__iter__` will then also be called many times. """ self.df_lists = df_lists try: self._size = len(self.df_lists[0]) for d in self.df_lists: assert len(d) == self._size, \ "All DataFlow must have the same size! {} != {}".format(len(d), self._size) except Exception:"[JoinData] Size check failed for the list of dataflow to be joined!")
def reset_state(self): for d in set(self.df_lists): d.reset_state()
[docs] def __len__(self): """ Return the minimum size among all. """ return min(len(k) for k in self.df_lists)
def __iter__(self): itrs = [k.__iter__() for k in self.df_lists] try: while True: all_dps = [next(itr) for itr in itrs] if isinstance(all_dps[0], (list, tuple)): dp = list(itertools.chain(*all_dps)) else: dp = {} for x in all_dps: dp.update(x) yield dp except StopIteration: # some of them are exhausted pass
[docs]def SelectComponent(ds, idxs): """ Select / reorder components from datapoints. Args: ds (DataFlow): input DataFlow. idxs (list[int] or list[str]): a list of component indices/keys. Example: .. code-block:: none original df produces: [c1, c2, c3] idxs: [2,1] this df: [c3, c2] """ return MapData(ds, lambda dp: [dp[i] for i in idxs])
[docs]class LocallyShuffleData(ProxyDataFlow, RNGDataFlow): """ Buffer the datapoints from a given dataflow, and shuffle them before producing them. This can be used as an alternative when a complete random shuffle is too expensive or impossible for the data source. This dataflow has the following behavior: 1. It takes datapoints from the given dataflow `ds` to an internal buffer of fixed size. Each datapoint is duplicated for `num_reuse` times. 2. Once the buffer is full, this dataflow starts to yield data from the beginning of the buffer, and new datapoints will be added to the end of the buffer. This is like a FIFO queue. 3. The internal buffer is shuffled after every `shuffle_interval` datapoints that come from `ds`. To maintain shuffling states, this dataflow is not reentrant. Datapoints from one pass of `ds` will get mixed with datapoints from a different pass. As a result, the iterator of this dataflow will run indefinitely because it does not make sense to stop the iteration anywhere. """
[docs] def __init__(self, ds, buffer_size, num_reuse=1, shuffle_interval=None): """ Args: ds (DataFlow): input DataFlow. buffer_size (int): size of the buffer. num_reuse (int): duplicate each datapoints several times into the buffer to improve speed, but duplication may hurt your model. shuffle_interval (int): shuffle the buffer after this many datapoints were produced from the given dataflow. Frequent shuffle on large buffer may affect speed, but infrequent shuffle may not provide enough randomness. Defaults to buffer_size / 3 """ ProxyDataFlow.__init__(self, ds) self.q = deque(maxlen=buffer_size) if shuffle_interval is None: shuffle_interval = int(buffer_size // 3) self.shuffle_interval = shuffle_interval self.num_reuse = num_reuse self._inf_ds = RepeatedData(ds, -1)
def reset_state(self): self._guard = DataFlowReentrantGuard() ProxyDataFlow.reset_state(self) RNGDataFlow.reset_state(self) self._iter_cnt = 0 self._inf_iter = iter(self._inf_ds) def __len__(self): return len(self.ds) * self.num_reuse def __iter__(self): with self._guard: for dp in self._inf_iter: self._iter_cnt = (self._iter_cnt + 1) % self.shuffle_interval # fill queue if self._iter_cnt == 0: self.rng.shuffle(self.q) for _ in range(self.num_reuse): if self.q.maxlen == len(self.q): yield self.q.popleft() self.q.append(dp)
[docs]class CacheData(ProxyDataFlow): """ Completely cache the first pass of a DataFlow in memory, and produce from the cache thereafter. NOTE: The user should not stop the iterator before it has reached the end. Otherwise the cache may be incomplete. """
[docs] def __init__(self, ds, shuffle=False): """ Args: ds (DataFlow): input DataFlow. shuffle (bool): whether to shuffle the cache before yielding from it. """ self.shuffle = shuffle super(CacheData, self).__init__(ds)
def reset_state(self): super(CacheData, self).reset_state() self._guard = DataFlowReentrantGuard() if self.shuffle: self.rng = get_rng(self) self.buffer = [] def __iter__(self): with self._guard: if len(self.buffer): if self.shuffle: self.rng.shuffle(self.buffer) yield from self.buffer else: for dp in self.ds: yield dp self.buffer.append(dp)
[docs]class PrintData(ProxyDataFlow): """ Behave like an identity proxy, but print shape and range of the first few datapoints. Good for debugging. Example: Place it somewhere in your dataflow like .. code-block:: python def create_my_dataflow(): ds = SomeDataSource('path/to/lmdb') ds = SomeInscrutableMappings(ds) ds = PrintData(ds, num=2, max_list=2) return ds ds = create_my_dataflow() # other code that uses ds When datapoints are taken from the dataflow, it will print outputs like: .. code-block:: none [0110 09:22:21] DataFlow Info: datapoint 0<2 with 4 components consists of 0: float with value 0.0816501893251 1: ndarray:int32 of shape (64,) in range [0, 10] 2: ndarray:float32 of shape (64, 64) in range [-1.2248, 1.2177] 3: list of len 50 0: ndarray:int32 of shape (64, 64) in range [-128, 80] 1: ndarray:float32 of shape (64, 64) in range [0.8400, 0.6845] ... datapoint 1<2 with 4 components consists of 0: float with value 5.88252075399 1: ndarray:int32 of shape (64,) in range [0, 10] 2: ndarray:float32 of shape (64, 64) with range [-0.9011, 0.8491] 3: list of len 50 0: ndarray:int32 of shape (64, 64) in range [-70, 50] 1: ndarray:float32 of shape (64, 64) in range [0.7400, 0.3545] ... """
[docs] def __init__(self, ds, num=1, name=None, max_depth=3, max_list=3): """ Args: ds (DataFlow): input DataFlow. num (int): number of dataflow points to print. name (str, optional): name to identify this DataFlow. max_depth (int, optional): stop output when too deep recursion in sub elements max_list (int, optional): stop output when too many sub elements """ super(PrintData, self).__init__(ds) self.num = num = name self.cnt = 0 self.max_depth = max_depth self.max_list = max_list
def _analyze_input_data(self, entry, k, depth=1, max_depth=3, max_list=3): """ Gather useful debug information from a datapoint. Args: entry: the datapoint component, either a list or a dict k (int): index of this component in current datapoint depth (int, optional): recursion depth max_depth, max_list: same as in :meth:`__init__`. Returns: string: debug message """ class _elementInfo(object): def __init__(self, el, pos, depth=0, max_list=3): self.shape = "" self.type = type(el).__name__ self.dtype = "" self.range = "" self.sub_elements = [] self.ident = " " * (depth * 2) self.pos = pos numpy_scalar_types = list(itertools.chain(*np.sctypes.values())) if isinstance(el, (int, float, bool)): self.range = " with value {}".format(el) elif type(el) is np.ndarray: self.shape = " of shape {}".format(el.shape) self.dtype = ":{}".format(str(el.dtype)) self.range = " in range [{}, {}]".format(el.min(), el.max()) elif type(el) in numpy_scalar_types: self.range = " with value {}".format(el) elif isinstance(el, (list, tuple)): self.shape = " of len {}".format(len(el)) if depth < max_depth: for k, subel in enumerate(el): if k < max_list: self.sub_elements.append(_elementInfo(subel, k, depth + 1, max_list)) else: self.sub_elements.append(" " * ((depth + 1) * 2) + '...') break else: if len(el) > 0: self.sub_elements.append(" " * ((depth + 1) * 2) + ' ...') def __str__(self): strings = [] vals = (self.ident, self.pos, self.type, self.dtype, self.shape, self.range) strings.append("{}{}: {}{}{}{}".format(*vals)) for k, el in enumerate(self.sub_elements): strings.append(str(el)) return "\n".join(strings) return str(_elementInfo(entry, k, depth, max_list)) def _get_msg(self, dp): msg = [colored(u"datapoint %i/%i with %i components consists of" % (self.cnt, self.num, len(dp)), "cyan")] is_dict = isinstance(dp, Mapping) for k, entry in enumerate(dp): if is_dict: key, value = entry, dp[entry] else: key, value = k, entry msg.append(self._analyze_input_data(value, key, max_depth=self.max_depth, max_list=self.max_list)) return u'\n'.join(msg) def __iter__(self): for dp in self.ds: # it is important to place this here! otherwise it mixes the output of multiple PrintData if self.cnt == 0: label = ' (%s)' % if is not None else """Contents of DataFlow%s:" % label, 'cyan')) if self.cnt < self.num: print(self._get_msg(dp)) self.cnt += 1 yield dp def reset_state(self): super(PrintData, self).reset_state() self.cnt = 0