# -*- coding: utf-8 -*-
# File: common.py
import tensorflow as tf
from ..compat import tfv1
from ..utils.argtools import graph_memoized
from .collect_env import collect_env_info
__all__ = ['get_default_sess_config',
'get_global_step_value',
'get_global_step_var',
'get_tf_version_tuple',
'collect_env_info'
# 'get_op_tensor_name',
# 'get_tensors_by_names',
# 'get_op_or_tensor_by_name',
]
def get_default_sess_config(mem_fraction=0.99):
"""
Return a tf.ConfigProto to use as default session config.
You can modify the returned config to fit your needs.
Args:
mem_fraction(float): see the `per_process_gpu_memory_fraction` option
in TensorFlow's GPUOptions protobuf:
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/protobuf/config.proto
Returns:
tf.ConfigProto: the config to use.
"""
conf = tfv1.ConfigProto()
conf.allow_soft_placement = True
# conf.log_device_placement = True
conf.intra_op_parallelism_threads = 1
conf.inter_op_parallelism_threads = 0
# TF benchmark use cpu_count() - gpu_thread_count(), e.g. 80 - 8 * 2
# Didn't see much difference.
conf.gpu_options.per_process_gpu_memory_fraction = mem_fraction
# This hurt performance of large data pipeline:
# https://github.com/tensorflow/benchmarks/commit/1528c46499cdcff669b5d7c006b7b971884ad0e6
# conf.gpu_options.force_gpu_compatible = True
conf.gpu_options.allow_growth = True
# from tensorflow.core.protobuf import rewriter_config_pb2 as rwc
# conf.graph_options.rewrite_options.memory_optimization = \
# rwc.RewriterConfig.HEURISTICS
# May hurt performance?
# conf.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
# conf.graph_options.place_pruned_graph = True
return conf
@graph_memoized
def get_global_step_var():
"""
Returns:
tf.Tensor: the global_step variable in the current graph. Create if doesn't exist.
"""
scope = tfv1.VariableScope(reuse=False, name='') # the root vs
with tfv1.variable_scope(scope):
var = tfv1.train.get_or_create_global_step()
return var
def get_global_step_value():
"""
Returns:
int: global_step value in current graph and session
Has to be called under a default session.
"""
return tfv1.train.global_step(
tfv1.get_default_session(),
get_global_step_var())
def get_op_tensor_name(name):
"""
Will automatically determine if ``name`` is a tensor name (ends with ':x')
or a op name.
If it is an op name, the corresponding tensor name is assumed to be ``op_name + ':0'``.
Args:
name(str): name of an op or a tensor
Returns:
tuple: (op_name, tensor_name)
"""
if len(name) >= 3 and name[-2] == ':':
return name[:-2], name
else:
return name, name + ':0'
def get_tensors_by_names(names):
"""
Get a list of tensors in the default graph by a list of names.
Args:
names (list):
"""
ret = []
G = tfv1.get_default_graph()
for n in names:
opn, varn = get_op_tensor_name(n)
ret.append(G.get_tensor_by_name(varn))
return ret
def get_op_or_tensor_by_name(name):
"""
Get either tf.Operation of tf.Tensor from names.
Args:
name (list[str] or str): names of operations or tensors.
Raises:
KeyError, if the name doesn't exist
"""
G = tfv1.get_default_graph()
def f(n):
if len(n) >= 3 and n[-2] == ':':
return G.get_tensor_by_name(n)
else:
return G.get_operation_by_name(n)
if not isinstance(name, list):
return f(name)
else:
return list(map(f, name))
def gpu_available_in_session():
sess = tfv1.get_default_session()
for dev in sess.list_devices():
if dev.device_type.lower() == 'gpu':
return True
return False
def get_tf_version_tuple():
"""
Return TensorFlow version as a 2-element tuple (for comparison).
"""
return tuple(map(int, tf.__version__.split('.')[:2]))