Source code for tensorpack.utils.concurrency

# -*- coding: utf-8 -*-
# File: concurrency.py

# Some code taken from zxytim

import sys
import atexit
import bisect
import multiprocessing as mp
import platform
import signal
import threading
import weakref
from contextlib import contextmanager
import six
from six.moves import queue
import subprocess

from . import logger
from .argtools import log_once


__all__ = ['StoppableThread', 'LoopThread', 'ShareSessionThread',
           'ensure_proc_terminate',
           'start_proc_mask_signal']


[docs]class StoppableThread(threading.Thread): """ A thread that has a 'stop' event. """
[docs] def __init__(self, evt=None): """ Args: evt(threading.Event): if None, will create one. """ super(StoppableThread, self).__init__() if evt is None: evt = threading.Event() self._stop_evt = evt
[docs] def stop(self): """ Stop the thread""" self._stop_evt.set()
[docs] def stopped(self): """ Returns: bool: whether the thread is stopped or not """ return self._stop_evt.isSet()
[docs] def queue_put_stoppable(self, q, obj): """ Put obj to queue, but will give up when the thread is stopped""" while not self.stopped(): try: q.put(obj, timeout=5) break except queue.Full: pass
[docs] def queue_get_stoppable(self, q): """ Take obj from queue, but will give up when the thread is stopped""" while not self.stopped(): try: return q.get(timeout=5) except queue.Empty: pass
[docs]class LoopThread(StoppableThread): """ A pausable thread that simply runs a loop"""
[docs] def __init__(self, func, pausable=True): """ Args: func: the function to run """ super(LoopThread, self).__init__() self._func = func self._pausable = pausable if pausable: self._lock = threading.Lock() self.daemon = True
[docs] def run(self): while not self.stopped(): if self._pausable: self._lock.acquire() self._lock.release() self._func()
[docs] def pause(self): """ Pause the loop """ assert self._pausable self._lock.acquire()
[docs] def resume(self): """ Resume the loop """ assert self._pausable self._lock.release()
[docs]class ShareSessionThread(threading.Thread): """ A wrapper around thread so that the thread uses the default session at "start()" time. """
[docs] def __init__(self, th=None): """ Args: th (threading.Thread or None): """ super(ShareSessionThread, self).__init__() if th is not None: assert isinstance(th, threading.Thread), th self._th = th self.name = th.name self.daemon = th.daemon
[docs] @contextmanager def default_sess(self): if self._sess: with self._sess.as_default(): yield self._sess else: logger.warn("ShareSessionThread {} wasn't under a default session!".format(self.name)) yield None
[docs] def start(self): from ..compat import tfv1 self._sess = tfv1.get_default_session() super(ShareSessionThread, self).start()
[docs] def run(self): if not self._th: raise NotImplementedError() with self._sess.as_default(): self._th.run()
class DIE(object): """ A placeholder class indicating end of queue """ pass
[docs]def ensure_proc_terminate(proc): """ Make sure processes terminate when main process exit. Args: proc (multiprocessing.Process or list) """ if isinstance(proc, list): for p in proc: ensure_proc_terminate(p) return def stop_proc_by_weak_ref(ref): proc = ref() if proc is None: return if not proc.is_alive(): return proc.terminate() proc.join() assert isinstance(proc, mp.Process) atexit.register(stop_proc_by_weak_ref, weakref.ref(proc))
def enable_death_signal(_warn=True): """ Set the "death signal" of the current process, so that the current process will be cleaned with guarantee in case the parent dies accidentally. """ if platform.system() != 'Linux': return try: import prctl # pip install python-prctl except ImportError: if _warn: log_once('"import prctl" failed! Install python-prctl so that processes can be cleaned with guarantee.', 'warn') return else: assert hasattr(prctl, 'set_pdeathsig'), \ "prctl.set_pdeathsig does not exist! Note that you need to install 'python-prctl' instead of 'prctl'." # is SIGHUP a good choice? prctl.set_pdeathsig(signal.SIGHUP) def is_main_thread(): if six.PY2: return isinstance(threading.current_thread(), threading._MainThread) else: # a nicer solution with py3 return threading.current_thread() == threading.main_thread() @contextmanager def mask_sigint(): """ Returns: If called in main thread, returns a context where ``SIGINT`` is ignored, and yield True. Otherwise yield False. """ if is_main_thread(): sigint_handler = signal.signal(signal.SIGINT, signal.SIG_IGN) yield True signal.signal(signal.SIGINT, sigint_handler) else: yield False
[docs]def start_proc_mask_signal(proc): """ Start process(es) with SIGINT ignored. Args: proc: (mp.Process or list) Note: The signal mask is only applied when called from main thread. """ if not isinstance(proc, list): proc = [proc] with mask_sigint(): for p in proc: if isinstance(p, mp.Process): if sys.version_info < (3, 4) or mp.get_start_method() == 'fork': log_once(""" Starting a process with 'fork' method is efficient but not safe and may cause deadlock or crash. Use 'forkserver' or 'spawn' method instead if you run into such issues. See https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods on how to set them. """.replace("\n", ""), 'warn') # noqa p.start()
def subproc_call(cmd, timeout=None): """ Execute a command with timeout, and return STDOUT and STDERR Args: cmd(str): the command to execute. timeout(float): timeout in seconds. Returns: output(bytes), retcode(int). If timeout, retcode is -1. """ try: output = subprocess.check_output( cmd, stderr=subprocess.STDOUT, shell=True, timeout=timeout) return output, 0 except subprocess.TimeoutExpired as e: logger.warn("Command '{}' timeout!".format(cmd)) if e.output: logger.warn(e.output.decode('utf-8')) return e.output, -1 else: return "", -1 except subprocess.CalledProcessError as e: logger.warn("Command '{}' failed, return code={}".format(cmd, e.returncode)) logger.warn(e.output.decode('utf-8')) return e.output, e.returncode except Exception: logger.warn("Command '{}' failed to run.".format(cmd)) return "", -2 class OrderedContainer(object): """ Like a queue, but will always wait to receive item with rank (x+1) and produce (x+1) before producing (x+2). Warning: It is not thread-safe. """ def __init__(self, start=0): """ Args: start(int): the starting rank. """ self.ranks = [] self.data = [] self.wait_for = start def put(self, rank, val): """ Args: rank(int): rank of th element. All elements must have different ranks. val: an object """ idx = bisect.bisect(self.ranks, rank) self.ranks.insert(idx, rank) self.data.insert(idx, val) def has_next(self): if len(self.ranks) == 0: return False return self.ranks[0] == self.wait_for def get(self): assert self.has_next() ret = self.data[0] rank = self.ranks[0] del self.ranks[0] del self.data[0] self.wait_for += 1 return rank, ret class OrderedResultGatherProc(mp.Process): """ Gather indexed data from a data queue, and produce results with the original index-based order. """ def __init__(self, data_queue, nr_producer, start=0): """ Args: data_queue(mp.Queue): a queue which contains datapoints. nr_producer(int): number of producer processes. This process will terminate after receiving this many of :class:`DIE` sentinel. start(int): the rank of the first object """ super(OrderedResultGatherProc, self).__init__() self.data_queue = data_queue self.ordered_container = OrderedContainer(start=start) self.result_queue = mp.Queue() self.nr_producer = nr_producer def run(self): nr_end = 0 try: while True: task_id, data = self.data_queue.get() if task_id == DIE: self.result_queue.put((task_id, data)) nr_end += 1 if nr_end == self.nr_producer: return else: self.ordered_container.put(task_id, data) while self.ordered_container.has_next(): self.result_queue.put(self.ordered_container.get()) except Exception as e: import traceback traceback.print_exc() raise e def get(self): return self.result_queue.get()