Source code for tensorpack.dataflow.remote

# -*- coding: utf-8 -*-
# File: remote.py


import multiprocessing as mp
import time
from collections import deque
import tqdm
from six.moves import range

from ..utils import logger
from ..utils.concurrency import DIE
from ..utils.serialize import dumps, loads
from ..utils.utils import get_tqdm_kwargs
from .base import DataFlow, DataFlowReentrantGuard

try:
    import zmq
except ImportError:
    logger.warn("Error in 'import zmq'. remote feature won't be available")
    __all__ = []
else:
    __all__ = ['send_dataflow_zmq', 'RemoteDataZMQ']


[docs]def send_dataflow_zmq(df, addr, hwm=50, format=None, bind=False): """ Run DataFlow and send data to a ZMQ socket addr. It will serialize and send each datapoint to this address with a PUSH socket. This function never returns. Args: df (DataFlow): Will infinitely loop over the DataFlow. addr: a ZMQ socket endpoint. hwm (int): ZMQ high-water mark (buffer size) format (str): The serialization format. Default format uses :mod:`tensorpack.utils.serialize`. This format works with :class:`dataflow.RemoteDataZMQ`. An alternate format is 'zmq_ops', used by https://github.com/tensorpack/zmq_ops and :class:`input_source.ZMQInput`. bind (bool): whether to bind or connect to the endpoint address. """ assert format in [None, 'zmq_op', 'zmq_ops'] if format is None: dump_fn = dumps else: from zmq_ops import dump_arrays dump_fn = dump_arrays ctx = zmq.Context() socket = ctx.socket(zmq.PUSH) socket.set_hwm(hwm) if bind: socket.bind(addr) else: socket.connect(addr) try: df.reset_state() logger.info("Serving data to {} with {} format ...".format( addr, 'default' if format is None else 'zmq_ops')) INTERVAL = 200 q = deque(maxlen=INTERVAL) try: total = len(df) except NotImplementedError: total = 0 tqdm_args = get_tqdm_kwargs(leave=True, smoothing=0.8) tqdm_args['bar_format'] = tqdm_args['bar_format'] + "{postfix}" while True: with tqdm.trange(total, **tqdm_args) as pbar: for dp in df: start = time.time() socket.send(dump_fn(dp), copy=False) q.append(time.time() - start) pbar.update(1) if pbar.n % INTERVAL == 0: avg = "{:.3f}".format(sum(q) / len(q)) pbar.set_postfix({'AvgSendLat': avg}) finally: logger.info("Exiting send_dataflow_zmq ...") socket.setsockopt(zmq.LINGER, 0) socket.close() if not ctx.closed: ctx.destroy(0)
[docs]class RemoteDataZMQ(DataFlow): """ Produce data from ZMQ PULL socket(s). It is the receiver-side counterpart of :func:`send_dataflow_zmq`, which uses :mod:`tensorpack.utils.serialize` for serialization. See http://tensorpack.readthedocs.io/tutorial/efficient-dataflow.html#distributed-dataflow Attributes: cnt1, cnt2 (int): number of data points received from addr1 and addr2 """
[docs] def __init__(self, addr1, addr2=None, hwm=50, bind=True): """ Args: addr1,addr2 (str): addr of the zmq endpoint to connect to. Use both if you need two protocols (e.g. both IPC and TCP). I don't think you'll ever need 3. hwm (int): ZMQ high-water mark (buffer size) bind (bool): whether to connect or bind the endpoint """ assert addr1 self._addr1 = addr1 self._addr2 = addr2 self._hwm = int(hwm) self._guard = DataFlowReentrantGuard() self._bind = bind
def reset_state(self): self.cnt1 = 0 self.cnt2 = 0
[docs] def bind_or_connect(self, socket, addr): if self._bind: socket.bind(addr) else: socket.connect(addr)
def __iter__(self): with self._guard: try: ctx = zmq.Context() if self._addr2 is None: socket = ctx.socket(zmq.PULL) socket.set_hwm(self._hwm) self.bind_or_connect(socket, self._addr1) while True: dp = loads(socket.recv(copy=False)) yield dp self.cnt1 += 1 else: socket1 = ctx.socket(zmq.PULL) socket1.set_hwm(self._hwm) self.bind_or_connect(socket1, self._addr1) socket2 = ctx.socket(zmq.PULL) socket2.set_hwm(self._hwm) self.bind_or_connect(socket2, self._addr2) poller = zmq.Poller() poller.register(socket1, zmq.POLLIN) poller.register(socket2, zmq.POLLIN) while True: evts = poller.poll() for sock, evt in evts: dp = loads(sock.recv(copy=False)) yield dp if sock == socket1: self.cnt1 += 1 else: self.cnt2 += 1 finally: ctx.destroy(linger=0)
# for internal use only def dump_dataflow_to_process_queue(df, size, nr_consumer): """ Convert a DataFlow to a :class:`multiprocessing.Queue`. The DataFlow will only be reset in the spawned process. Args: df (DataFlow): the DataFlow to dump. size (int): size of the queue nr_consumer (int): number of consumer of the queue. The producer will add this many of ``DIE`` sentinel to the end of the queue. Returns: tuple(queue, process): The process will take data from ``df`` and fill the queue, once you start it. Each element in the queue is (idx, dp). idx can be the ``DIE`` sentinel when ``df`` is exhausted. """ q = mp.Queue(size) class EnqueProc(mp.Process): def __init__(self, df, q, nr_consumer): super(EnqueProc, self).__init__() self.df = df self.q = q def run(self): self.df.reset_state() try: for idx, dp in enumerate(self.df): self.q.put((idx, dp)) finally: for _ in range(nr_consumer): self.q.put((DIE, None)) proc = EnqueProc(df, q, nr_consumer) return q, proc if __name__ == '__main__': from argparse import ArgumentParser from .raw import FakeData from .common import TestDataSpeed """ Test the multi-producer single-consumer model """ parser = ArgumentParser() parser.add_argument('-t', '--task', choices=['send', 'recv'], required=True) parser.add_argument('-a', '--addr1', required=True) parser.add_argument('-b', '--addr2', default=None) args = parser.parse_args() # tcp addr like "tcp://127.0.0.1:8877" # ipc addr like "ipc://@ipc-test" if args.task == 'send': # use random=True to make it slow and cpu-consuming ds = FakeData([(128, 244, 244, 3)], 1000, random=True) send_dataflow_zmq(ds, args.addr1) else: ds = RemoteDataZMQ(args.addr1, args.addr2) logger.info("Each DP is 73.5MB") TestDataSpeed(ds).start_test()