# -*- coding: utf-8 -*-
# File: parallel.py
import atexit
import pickle
import errno
import traceback
import itertools
import multiprocessing as mp
import os
import sys
import uuid
import weakref
from contextlib import contextmanager
import zmq
from six.moves import queue, range
from ..utils import logger
from ..utils.concurrency import (
StoppableThread, enable_death_signal, ensure_proc_terminate, start_proc_mask_signal)
from ..utils.serialize import dumps_once as dumps, loads_once as loads
from .base import DataFlow, DataFlowReentrantGuard, DataFlowTerminated, ProxyDataFlow
__all__ = ['PrefetchData', 'MultiProcessPrefetchData',
'MultiProcessRunner', 'MultiProcessRunnerZMQ', 'MultiThreadRunner',
'PrefetchDataZMQ', 'MultiThreadPrefetchData']
# from https://github.com/pytorch/pytorch/blob/master/torch/utils/data/_utils/__init__.py
class _ExceptionWrapper:
MAGIC = b"EXC_MAGIC"
"""Wraps an exception plus traceback to communicate across threads"""
def __init__(self, exc_info):
# It is important that we don't store exc_info, see
# NOTE [ Python Traceback Reference Cycle Problem ]
self.exc_type = exc_info[0]
self.exc_msg = "".join(traceback.format_exception(*exc_info))
def pack(self):
return self.MAGIC + pickle.dumps(self)
@staticmethod
def unpack(dp):
if isinstance(dp, bytes) and dp.startswith(_ExceptionWrapper.MAGIC):
return pickle.loads(dp[len(_ExceptionWrapper.MAGIC):])
def _repeat_iter(get_itr):
while True:
yield from get_itr()
def _bind_guard(sock, name):
try:
sock.bind(name)
except zmq.ZMQError:
logger.error(
"ZMQError in socket.bind('{}'). Perhaps you're \
using pipes on a non-local file system. See documentation of MultiProcessRunnerZMQ \
for more information.".format(name))
raise
def _get_pipe_name(name):
if sys.platform.startswith('linux'):
# linux supports abstract sockets: http://api.zeromq.org/4-1:zmq-ipc
pipename = "ipc://@{}-pipe-{}".format(name, str(uuid.uuid1())[:8])
pipedir = os.environ.get('TENSORPACK_PIPEDIR', None)
if pipedir is not None:
logger.warn("TENSORPACK_PIPEDIR is not used on Linux any more! Abstract sockets will be used.")
else:
pipedir = os.environ.get('TENSORPACK_PIPEDIR', None)
if pipedir is not None:
logger.info("ZMQ uses TENSORPACK_PIPEDIR={}".format(pipedir))
else:
pipedir = '.'
assert os.path.isdir(pipedir), pipedir
filename = '{}/{}-pipe-{}'.format(pipedir.rstrip('/'), name, str(uuid.uuid1())[:6])
assert not os.path.exists(filename), "Pipe {} exists! You may be unlucky.".format(filename)
pipename = "ipc://{}".format(filename)
return pipename
def del_weakref(x):
o = x()
if o is not None:
o.__del__()
@contextmanager
def _zmq_catch_error(name):
try:
yield
except zmq.ContextTerminated:
logger.info("[{}] Context terminated.".format(name))
raise DataFlowTerminated()
except zmq.ZMQError as e:
if e.errno == errno.ENOTSOCK: # socket closed
logger.info("[{}] Socket closed.".format(name))
raise DataFlowTerminated()
else:
raise
except Exception:
raise
class _MultiProcessZMQDataFlow(DataFlow):
def __init__(self):
assert os.name != 'nt', "ZMQ IPC doesn't support windows!"
self._reset_done = False
self._procs = []
def reset_state(self):
"""
All forked dataflows should only be reset **once and only once** in spawned processes.
Subclasses should call this method with super.
"""
assert not self._reset_done, "reset_state() was called twice! This violates the API of DataFlow!"
self._reset_done = True
# __del__ not guaranteed to get called at exit
atexit.register(del_weakref, weakref.ref(self))
def _start_processes(self):
start_proc_mask_signal(self._procs)
def __del__(self):
try:
if not self._reset_done:
return
if not self.context.closed:
self.socket.close(0)
self.context.destroy(0)
for x in self._procs:
x.terminate()
x.join(5)
print("{} successfully cleaned-up.".format(type(self).__name__))
except Exception:
pass
[docs]class MultiProcessRunner(ProxyDataFlow):
"""
Running a DataFlow in >=1 processes using Python multiprocessing utilities.
It will fork the process that calls :meth:`__init__`, collect datapoints from `ds` in each
process by a Python :class:`multiprocessing.Queue`.
Note:
1. (Data integrity) An iterator cannot run faster automatically -- what's happening is
that the process will be forked ``num_proc`` times.
There will be ``num_proc`` dataflow running in parallel and **independently**.
As a result, we have the following guarantee on the dataflow correctness:
a. When ``num_proc=1``, this dataflow produces the same data as the
given dataflow in the same order.
b. When ``num_proc>1``, if each sample from the given dataflow is i.i.d.,
then this dataflow produces the **same distribution** of data as the given dataflow.
This implies that there will be duplication, reordering, etc.
You probably only want to use it for training.
For example, if your original dataflow contains no randomness and produces the same first datapoint,
then after parallel prefetching, the datapoint will be produced ``num_proc`` times
at the beginning.
Even when your original dataflow is fully shuffled, you still need to be aware of the
`Birthday Paradox <https://en.wikipedia.org/wiki/Birthday_problem>`_
and know that you'll likely see duplicates.
To utilize parallelism with more strict data integrity, you can use
the parallel versions of :class:`MapData`: :class:`MultiThreadMapData`, :class:`MultiProcessMapData`.
2. This has more serialization overhead than :class:`MultiProcessRunnerZMQ` when data is large.
3. You can nest like this: ``MultiProcessRunnerZMQ(MultiProcessRunner(df, num_proc=a), num_proc=b)``.
A total of ``a`` instances of ``df`` worker processes will be created.
4. Fork happens in `__init__`. `reset_state()` is a no-op.
DataFlow in the worker processes will be reset at the time of fork.
5. This DataFlow does support windows. However, Windows requires more strict picklability on processes,
which means that some code that's forkable on Linux may not be forkable on Windows. If that happens you'll
need to re-organize some part of code that's not forkable.
"""
class _Worker(mp.Process):
def __init__(self, ds, queue, idx):
super(MultiProcessRunner._Worker, self).__init__()
self.ds = ds
self.queue = queue
self.idx = idx
def run(self):
enable_death_signal(_warn=self.idx == 0)
# reset all ds so each process will produce different data
self.ds.reset_state()
while True:
for dp in self.ds:
self.queue.put(dp)
[docs] def __init__(self, ds, num_prefetch, num_proc):
"""
Args:
ds (DataFlow): input DataFlow.
num_prefetch (int): size of the queue to hold prefetched datapoints.
Required.
num_proc (int): number of processes to use. Required.
"""
# https://docs.python.org/3.6/library/multiprocessing.html?highlight=process#the-spawn-and-forkserver-start-methods
if os.name == 'nt':
logger.warn("MultiProcessRunner does support Windows. \
However, Windows requires more strict picklability on processes, which may \
lead of failure on some of the code.")
super(MultiProcessRunner, self).__init__(ds)
try:
self._size = len(ds)
except NotImplementedError:
self._size = -1
assert num_proc > 0, num_proc
assert num_prefetch > 0, num_prefetch
self.num_proc = num_proc
self.num_prefetch = num_prefetch
if num_proc > 1:
logger.info("[MultiProcessRunner] Will fork a dataflow more than one times. "
"This assumes the datapoints are i.i.d.")
self.queue = mp.Queue(self.num_prefetch)
self.procs = [MultiProcessRunner._Worker(self.ds, self.queue, idx)
for idx in range(self.num_proc)]
ensure_proc_terminate(self.procs)
self._reset_done = False
def __iter__(self):
for k in itertools.count():
if self._size > 0 and k >= self._size:
break
dp = self.queue.get()
yield dp
def reset_state(self):
assert not self._reset_done, "reset_state() was called twice! This violates the API of DataFlow!"
self._reset_done = True
start_proc_mask_signal(self.procs)
[docs]class MultiProcessRunnerZMQ(_MultiProcessZMQDataFlow):
"""
Run a DataFlow in >=1 processes, with ZeroMQ for communication.
It will fork the calling process of :meth:`reset_state()`,
and collect datapoints from the given dataflow in each process by ZeroMQ IPC pipe.
This is typically faster than :class:`MultiProcessRunner`.
Note:
1. (Data integrity) An iterator cannot run faster automatically -- what's happening is
that the process will be forked ``num_proc`` times.
There will be ``num_proc`` dataflow running in parallel and **independently**.
As a result, we have the following guarantee on the dataflow correctness:
a. When ``num_proc=1``, this dataflow produces the same data as the
given dataflow in the same order.
b. When ``num_proc>1``, if each sample from the given dataflow is i.i.d.,
then this dataflow produces the **same distribution** of data as the given dataflow.
This implies that there will be duplication, reordering, etc.
You probably only want to use it for training.
For example, if your original dataflow contains no randomness and produces the same first datapoint,
then after parallel prefetching, the datapoint will be produced ``num_proc`` times
at the beginning.
Even when your original dataflow is fully shuffled, you still need to be aware of the
`Birthday Paradox <https://en.wikipedia.org/wiki/Birthday_problem>`_
and know that you'll likely see duplicates.
To utilize parallelism with more strict data integrity, you can use
the parallel versions of :class:`MapData`: :class:`MultiThreadMapData`, :class:`MultiProcessMapData`.
2. `reset_state()` of the given dataflow will be called **once and only once** in the worker processes.
3. The fork of processes happened in this dataflow's `reset_state()` method.
Please note that forking a TensorFlow GPU session may be unsafe.
If you're managing this dataflow on your own,
it's better to fork before creating the session.
4. (Fork-safety) After the fork has happened, this dataflow becomes not fork-safe.
i.e., if you fork an already reset instance of this dataflow,
it won't be usable in the forked process. Therefore, do not nest two `MultiProcessRunnerZMQ`.
5. (Thread-safety) ZMQ is not thread safe. Therefore, do not call :meth:`get_data` of the same dataflow in
more than 1 threads.
6. This dataflow does not support windows. Use `MultiProcessRunner` which works on windows.
7. (For Mac only) A UNIX named pipe will be created in the current directory.
However, certain non-local filesystem such as NFS/GlusterFS/AFS doesn't always support pipes.
You can change the directory by ``export TENSORPACK_PIPEDIR=/other/dir``.
In particular, you can use somewhere under '/tmp' which is usually local.
Note that some non-local FS may appear to support pipes and code
may appear to run but crash with bizarre error.
Also note that ZMQ limits the maximum length of pipe path.
If you hit the limit, you can set the directory to a softlink
which points to a local directory.
"""
class _Worker(mp.Process):
def __init__(self, ds, conn_name, hwm, idx):
super(MultiProcessRunnerZMQ._Worker, self).__init__()
self.ds = ds
self.conn_name = conn_name
self.hwm = hwm
self.idx = idx
def run(self):
enable_death_signal(_warn=self.idx == 0)
self.ds.reset_state()
itr = _repeat_iter(lambda: self.ds)
context = zmq.Context()
socket = context.socket(zmq.PUSH)
socket.set_hwm(self.hwm)
socket.connect(self.conn_name)
try:
while True:
try:
dp = next(itr)
socket.send(dumps(dp), copy=False)
except Exception:
dp = _ExceptionWrapper(sys.exc_info()).pack()
socket.send(dumps(dp), copy=False)
raise
# sigint could still propagate here, e.g. when nested
except KeyboardInterrupt:
pass
finally:
socket.close(0)
context.destroy(0)
[docs] def __init__(self, ds, num_proc=1, hwm=50):
"""
Args:
ds (DataFlow): input DataFlow.
num_proc (int): number of processes to use.
hwm (int): the zmq "high-water mark" (queue size) for both sender and receiver.
"""
super(MultiProcessRunnerZMQ, self).__init__()
self.ds = ds
self.num_proc = num_proc
self._hwm = hwm
if num_proc > 1:
logger.info("[MultiProcessRunnerZMQ] Will fork a dataflow more than one times. "
"This assumes the datapoints are i.i.d.")
try:
self._size = ds.__len__()
except NotImplementedError:
self._size = -1
def _recv(self):
ret = loads(self.socket.recv(copy=False))
exc = _ExceptionWrapper.unpack(ret)
if exc is not None:
logger.error("Exception '{}' in worker:".format(str(exc.exc_type)))
raise exc.exc_type(exc.exc_msg)
return ret
def __len__(self):
return self.ds.__len__()
def __iter__(self):
with self._guard, _zmq_catch_error('MultiProcessRunnerZMQ'):
for k in itertools.count():
if self._size > 0 and k >= self._size:
break
yield self._recv()
def reset_state(self):
super(MultiProcessRunnerZMQ, self).reset_state()
self._guard = DataFlowReentrantGuard()
self.context = zmq.Context()
self.socket = self.context.socket(zmq.PULL)
self.socket.set_hwm(self._hwm)
pipename = _get_pipe_name('dataflow')
_bind_guard(self.socket, pipename)
self._procs = [MultiProcessRunnerZMQ._Worker(self.ds, pipename, self._hwm, idx)
for idx in range(self.num_proc)]
self._start_processes()
[docs]class MultiThreadRunner(DataFlow):
"""
Create multiple dataflow instances and run them each in one thread.
Collect outputs from them with a queue.
Note:
1. (Data integrity) An iterator cannot run faster automatically -- what's happening is
that each thread will create a dataflow iterator.
There will be ``num_thread`` dataflow running in parallel and **independently**.
As a result, we have the following guarantee on the dataflow correctness:
a. When ``num_thread=1``, this dataflow produces the same data as the
given dataflow in the same order.
b. When ``num_thread>1``, if each sample from the given dataflow is i.i.d.,
then this dataflow produces the **same distribution** of data as the given dataflow.
This implies that there will be duplication, reordering, etc.
You probably only want to use it for training.
For example, if your original dataflow contains no randomness and produces the same first datapoint,
then after parallel prefetching, the datapoint will be produced ``num_thread`` times
at the beginning.
Even when your original dataflow is fully shuffled, you still need to be aware of the
`Birthday Paradox <https://en.wikipedia.org/wiki/Birthday_problem>`_
and know that you'll likely see duplicates.
To utilize parallelism with more strict data integrity, you can use
the parallel versions of :class:`MapData`: :class:`MultiThreadMapData`, :class:`MultiProcessMapData`.
"""
class _Worker(StoppableThread):
def __init__(self, get_df, queue):
super(MultiThreadRunner._Worker, self).__init__()
self.df = get_df()
assert isinstance(self.df, DataFlow), self.df
self.queue = queue
self.daemon = True
def run(self):
self.df.reset_state()
try:
while True:
for dp in self.df:
if self.stopped():
return
self.queue_put_stoppable(self.queue, dp)
except Exception:
if self.stopped():
pass # skip duplicated error messages
else:
raise
finally:
self.stop()
[docs] def __init__(self, get_df, num_prefetch, num_thread):
"""
Args:
get_df ( -> DataFlow): a callable which returns a DataFlow.
Each thread will call this function to get the DataFlow to use.
Therefore do not return the same DataFlow object for each call,
unless your dataflow is stateless.
num_prefetch (int): size of the queue
num_thread (int): number of threads
"""
assert num_thread > 0, num_thread
assert num_prefetch > 0, num_prefetch
self.num_thread = num_thread
self.queue = queue.Queue(maxsize=num_prefetch)
self.threads = [
MultiThreadRunner._Worker(get_df, self.queue)
for _ in range(num_thread)]
try:
self._size = self.__len__()
except NotImplementedError:
self._size = -1
def reset_state(self):
for th in self.threads:
th.df.reset_state()
th.start()
def __len__(self):
return self.threads[0].df.__len__()
def __iter__(self):
for k in itertools.count():
if self._size > 0 and k >= self._size:
break
yield self.queue.get()
def __del__(self):
for p in self.threads:
if p.is_alive():
p.stop()
p.join()
class PlasmaPutData(ProxyDataFlow):
"""
Put each data point to plasma shared memory object store, and yield the object id instead.
Experimental.
"""
def __init__(self, ds, socket="/tmp/plasma"):
self._socket = socket
super(PlasmaPutData, self).__init__(ds)
def reset_state(self):
super(PlasmaPutData, self).reset_state()
self.client = plasma.connect(self._socket, "", 0)
def __iter__(self):
for dp in self.ds:
oid = self.client.put(dp)
yield [oid.binary()]
class PlasmaGetData(ProxyDataFlow):
"""
Take plasma object id from a DataFlow, and retrieve it from plasma shared
memory object store.
Experimental.
"""
def __init__(self, ds, socket="/tmp/plasma"):
self._socket = socket
super(PlasmaGetData, self).__init__(ds)
def reset_state(self):
super(PlasmaGetData, self).reset_state()
self.client = plasma.connect(self._socket, "", 0)
def __iter__(self):
for dp in self.ds:
oid = plasma.ObjectID(dp[0])
dp = self.client.get(oid)
yield dp
plasma = None
# These plasma code is only experimental
# try:
# import pyarrow.plasma as plasma
# except ImportError:
# from ..utils.develop import create_dummy_class
# PlasmaPutData = create_dummy_class('PlasmaPutData', 'pyarrow') # noqa
# PlasmaGetData = create_dummy_class('PlasmaGetData', 'pyarrow') # noqa
# The old inappropriate names:
PrefetchData = MultiProcessRunner
MultiProcessPrefetchData = MultiProcessRunner
PrefetchDataZMQ = MultiProcessRunnerZMQ
MultiThreadPrefetchData = MultiThreadRunner
if __name__ == '__main__':
import time
from .raw import DataFromGenerator
from .common import FixedSizeData
x = DataFromGenerator(itertools.count())
x = FixedSizeData(x, 100)
x = MultiProcessRunnerZMQ(x, 2)
x.reset_state()
for idx, dp in enumerate(x):
print(dp)
time.sleep(0.1)