Source code for tensorpack.tfutils.scope_utils

# -*- coding: utf-8 -*-
# File: scope_utils.py


import functools
from contextlib import contextmanager

from ..compat import tfv1 as tf
from ..utils.argtools import graph_memoized
from ..utils import logger
from .common import get_tf_version_tuple

__all__ = ['auto_reuse_variable_scope', 'cached_name_scope', 'under_name_scope']


[docs]def auto_reuse_variable_scope(func): """ A decorator which automatically reuses the current variable scope if the function has been called with the same variable scope before. Example: .. code-block:: python @auto_reuse_variable_scope def myfunc(x): return tf.layers.conv2d(x, 128, 3) myfunc(x1) # will inherit parent scope reuse myfunc(x2) # will reuse with tf.variable_scope('newscope'): myfunc(x3) # will inherit parent scope reuse myfunc(x4) # will reuse """ used_scope = set() @functools.wraps(func) def wrapper(*args, **kwargs): scope = tf.get_variable_scope() h = hash((tf.get_default_graph(), scope.name)) # print("Entering " + scope.name + " reuse: " + str(h in used_scope)) if h in used_scope: if get_tf_version_tuple() >= (1, 5): with tf.variable_scope(scope, reuse=True, auxiliary_name_scope=False): return func(*args, **kwargs) else: ns = tf.get_default_graph().get_name_scope() with tf.variable_scope(scope, reuse=True), \ tf.name_scope(ns + '/' if ns else ''): return func(*args, **kwargs) else: used_scope.add(h) return func(*args, **kwargs) return wrapper
[docs]def under_name_scope(name_scope=None): """ Args: name_scope(str): the default scope to use. If None, will use the name of the function. Returns: A decorator which makes the function run under a name scope. The name scope is obtained by the following: 1. The 'name_scope' keyword argument when the decorated function is called. 2. The 'name_scope' argument of the decorator. 3. (default) The name of the decorated function itself. If the name is taken and cannot be used, a warning will be printed in the first case. Example: .. code-block:: python @under_name_scope() def rms(x): return tf.sqrt( tf.reduce_mean(tf.square(x))) rms(tensor) # will be called under name scope 'rms' rms(tensor, name_scope='scope') # will be called under name scope 'scope' Todo: Add a reuse option. """ def _impl(func): @functools.wraps(func) def wrapper(*args, **kwargs): warn_incorrect_scope = 'name_scope' in kwargs scopename = kwargs.pop('name_scope', name_scope) if scopename is None: scopename = func.__name__ if warn_incorrect_scope: # cached_name_scope will try to reenter the existing scope with cached_name_scope(scopename, top_level=False) as scope: scope = scope.strip('/') # but it can still conflict with an existing tensor if not scope.endswith(scopename): logger.warn(""" \ Calling function {} with name_scope='{}', but actual name scope becomes '{}'. \ The name '{}' might be taken.""".format(func.__name__, scopename, scope.split('/')[-1], scopename)) return func(*args, **kwargs) else: with tf.name_scope(scopename): return func(*args, **kwargs) return wrapper return _impl
def under_variable_scope(): """ Returns: A decorator which makes the function happen under a variable scope, which is named by the function itself. Example: .. code-block:: python @under_variable_scope() def mid_level(x): with argscope(Conv2D, kernel_shape=3, nl=BNReLU): x = Conv2D('conv1', x, 512, stride=1) x = Conv2D('conv2', x, 256, stride=1) return x """ def _impl(func): @functools.wraps(func) def wrapper(*args, **kwargs): name = func.__name__ with tf.variable_scope(name): return func(*args, **kwargs) return wrapper return _impl @graph_memoized def _get_cached_ns(name): with tf.name_scope(None): with tf.name_scope(name) as scope: return scope
[docs]@contextmanager def cached_name_scope(name, top_level=True): """ Return a context which either opens and caches a new name scope, or reenter an existing one. Args: top_level(bool): if True, the name scope will always be top-level. It will not be nested under any existing name scope of the caller. """ if not top_level: current_ns = tf.get_default_graph().get_name_scope() if current_ns: name = current_ns + '/' + name ns = _get_cached_ns(name) with tf.name_scope(ns): yield ns