tensorpack.tfutils package

tensorpack.tfutils.argscope module

tensorpack.tfutils.argscope.argscope(layers, **kwargs)[source]
Parameters

layers (list or layer) – layer or list of layers to apply the arguments.

Returns

a context where all appearance of these layer will by default have the arguments specified by kwargs.

Example

with argscope(Conv2D, kernel_shape=3, nl=tf.nn.relu, out_channel=32):
    x = Conv2D('conv0', x)
    x = Conv2D('conv1', x)
    x = Conv2D('conv2', x, out_channel=64)  # override argscope
tensorpack.tfutils.argscope.get_arg_scope()[source]
Returns

dict – the current argscope.

An argscope is a dict of dict: dict[layername] = {arg: val}

tensorpack.tfutils.argscope.enable_argscope_for_module(module, log_shape=True)[source]

Overwrite all functions of a given module to support argscope. Note that this function monkey-patches the module and therefore could have unexpected consequences. It has been only tested to work well with tf.layers module.

Example

import tensorflow as tf
enable_argscope_for_module(tf.layers)
Parameters

log_shape (bool) – print input/output shapes of each function.

tensorpack.tfutils.argscope.enable_argscope_for_function(func, log_shape=True)[source]

Decorator for function to support argscope

Example

from mylib import myfunc
myfunc = enable_argscope_for_function(myfunc)
Parameters
  • func – A function mapping one or multiple tensors to one or multiple tensors.

  • log_shape (bool) – Specify whether the first input resp. output tensor shape should be printed once.

Remarks:

If the function func returns multiple input or output tensors, only the first input/output tensor shape is displayed during logging.

Returns

The decorated function.

tensorpack.tfutils.collection module

tensorpack.tfutils.collection.backup_collection(keys=None)[source]
Parameters

keys (list) – list of collection keys to backup. Defaults to all keys in the graph.

Returns

dict – the backup

tensorpack.tfutils.collection.restore_collection(backup)[source]

Restore from a collection backup.

Parameters

backup (dict) –

tensorpack.tfutils.collection.freeze_collection(keys)[source]
Parameters

keys (list) – list of collection keys to freeze.

Returns

a context where the collections are in the end restored to its initial state.

tensorpack.tfutils.gradproc module

class tensorpack.tfutils.gradproc.GradientProcessor[source]

Bases: object

Base class for all gradient processors. Gradient processors can be applied to optimizers by optimizer.apply_grad_processors().

Subclass should override the _process() method.

process(grads)[source]

Process the symbolic gradients.

Parameters

grads (list) – list of (grad, var).

Returns

list – processed gradients, with the same type as input.

class tensorpack.tfutils.gradproc.FilterNoneGrad(verbose=True)[source]

Bases: tensorpack.tfutils.gradproc.GradientProcessor

Skip the update and print a warning (instead of crashing), when the gradient of certain variable is None.

__init__(verbose=True)[source]
Parameters

verbose (bool) – whether to print warning about None gradients.

class tensorpack.tfutils.gradproc.GlobalNormClip(global_norm)[source]

Bases: tensorpack.tfutils.gradproc.GradientProcessor

Clip by global norm. The global norm is the sum of norm for all gradients.

See tf.clip_by_global_norm() for more information.

__init__(global_norm)[source]
Parameters

global_norm (float) – the threshold to clip with.

class tensorpack.tfutils.gradproc.MapGradient(func, regex='.*')[source]

Bases: tensorpack.tfutils.gradproc.GradientProcessor

Apply a function on all gradient if the name matches regex. Keep the other gradients unchanged.

It can be used for gradient clipping, etc.

__init__(func, regex='.*')[source]
Parameters
  • func – a user-supplied function which takes one or two arguments. The argument(s) can be either a grad tensor, or grad and var. The function should return the new gradient to be used. If it return None, the gradient is discarded (hence no update to the variable will happen).

  • regex (str) – used to match variables. Defaults to match all variables.

class tensorpack.tfutils.gradproc.SummaryGradient(regex='.*', collections=None)[source]

Bases: tensorpack.tfutils.gradproc.MapGradient

For each gradient tensor, summary its histogram and add it to moving summaries.

__init__(regex='.*', collections=None)[source]
Parameters
class tensorpack.tfutils.gradproc.PrintGradient(regex='.*')[source]

Bases: tensorpack.tfutils.gradproc.MapGradient

Print the gradients every step with symbolic_functions.print_stat().

__init__(regex='.*')[source]
Parameters

regex (str) – same as in MapGradient.

class tensorpack.tfutils.gradproc.CheckGradient[source]

Bases: tensorpack.tfutils.gradproc.MapGradient

Run tf.check_numerics() for each gradient.

class tensorpack.tfutils.gradproc.ScaleGradient(multipliers, verbose=True)[source]

Bases: tensorpack.tfutils.gradproc.MapGradient

Scale certain gradient by a multiplier.

__init__(multipliers, verbose=True)[source]
Parameters
  • multipliers (tuple or list) – tuple of (regex, float), or list of such tuples.

  • verbose (bool) – whether to print logs or not

Example

Use double learning rate for all the bias (as in caffe), and freeze layer0:

from tensorpack.tfutils import optimizer, gradproc
opt = optimizer.apply_grad_processors(
    opt, [gradproc.ScaleGradient(
        [('.*/b', 2.), ('layer0/.*', 0.)]
    )])

tensorpack.tfutils.tower module

tensorpack.tfutils.tower.get_current_tower_context()[source]

When called inside a TowerContext, returns the TowerContext.

Returns

a BaseTowerContext instance or None, if not called under a TowerContext.

class tensorpack.tfutils.tower.BaseTowerContext(ns_name, vs_name='')[source]

Bases: object

A context where the current model is built in. You need to use TowerContext() to create a BaseTowerContext.

__init__(ns_name, vs_name='')[source]

This is not supposed to be used by users. You need to use TowerContext() to create a BaseTowerContext.

Parameters
  • ns_name (str) – The name scope of the tower.

  • vs_name (str) – Open a new variable scope with this name.

get_collection_in_tower(key)[source]

From a collection, get items that are __added__ to the collection in this tower.

Note that it works by tracking the collection at the beginning and end of the tower function. Therefore it does not guarantee that the items are __created__ in this tower.

abstract property has_own_variables

Whether this tower is supposed to have its own trainable variables.

Type

bool

abstract property is_main_training_tower

Whether this tower is the main (i.e., the first) training tower.

Type

bool

property is_training

whether the context is training or not

Type

bool

property name

The name scope of the tower.

Type

str

property ns_name

The name scope of the tower.

Type

str

property vs_name

The variable scope of the tower.

Type

str

tensorpack.tfutils.tower.TowerContext(tower_name, is_training, vs_name='')[source]

The context for a tower function, containing metadata about the current tower. Tensorpack trainers use TowerContext to manage tower function. Many tensorpack layers have to be called under a TowerContext.

Example:

with TowerContext('', is_training=True):
    # call a tensorpack layer or a tower function
class tensorpack.tfutils.tower.TowerFunc(tower_fn, input_signature)[source]

Bases: object

A tower function (see tutorial on tower function) It keeps track of the name scope, variable scope and input/output tensors each time the function is called.

TowerTrainer needs this so that it knows how to build a predictor.

Conceptually, this class is roughly equivalent to tf.function with input signature, introduced in TF 2.0.

__init__(tower_fn, input_signature)[source]
Parameters
  • tower_func – a function which builds one tower in the graph. It takes several input tensors and could return anything.

  • input_signature ([TensorSpec]) – list of tf.TensorSpec. They are used to figure out the names for the input tensors.

property input_signature
property towers

a TowerTensorHandles object, that can access the tower handles by either indices or names.

Type

TowerTensorHandles

class tensorpack.tfutils.tower.TowerTensorHandle(ctx, inputs, outputs, input_signature=None)[source]

Bases: object

When a function is called multiple times under each tower, it becomes hard to keep track of the scope and access those tensors in each tower. This class provides easy access to the tensors as well as the inputs/outputs created in each tower.

__getitem__(name)[source]

The same as get_tensor().

get_collection(key=None, name=None)[source]

See BaseTowerContext.get_collection_in_tower().

Parameters
  • key (str) – the key of the collection

  • name – deprecated

get_tensor(name)[source]

Get a tensor in this tower. The name argument can be:

  1. The name of a tensor/variable without any tower prefix.

  2. A name in the input signature, if it is used when building the tower.

In the second case, this method will return the tensor that’s used as the corresponding input to the tower. Note that this tensor may have a different name (e.g. may be an output of a queue).

get_tensors(names)[source]

Like get_tensor(), but takes a list and returns a list.

get_variable(name)[source]

Get a variable used in this tower. The name should not contain the variable scope prefix of the tower.

When the tower has the same variable scope and name scope, this is equivalent to get_tensor().

get_variables(names)[source]

Like get_variable(), but takes a list and returns a list.

property inputs

The list of input tensors used to build the tower.

Type

list[Tensor]

property is_training
property ns_name
property outputs

The outputs returned by the tower function.

Type

list[Tensor]

property vs_name
class tensorpack.tfutils.tower.TowerTensorHandles(handles)[source]

Bases: object

Wrap a list of TowerTensorHandle, to support access to them by index or names.

__getitem__(name_or_index)[source]
Parameters

name_or_index (str or int) –

Returns

a TowerTensorHandle.

inference()[source]
Returns

A TowerTensorHandles, containing only the inference towers.

training()[source]
Returns

A TowerTensorHandles, containing only the training towers.

tensorpack.tfutils.scope_utils module

tensorpack.tfutils.scope_utils.auto_reuse_variable_scope(func)[source]

A decorator which automatically reuses the current variable scope if the function has been called with the same variable scope before.

Example:

@auto_reuse_variable_scope
def myfunc(x):
    return tf.layers.conv2d(x, 128, 3)

myfunc(x1)  # will inherit parent scope reuse
myfunc(x2)  # will reuse
with tf.variable_scope('newscope'):
    myfunc(x3)  # will inherit parent scope reuse
    myfunc(x4)  # will reuse
tensorpack.tfutils.scope_utils.cached_name_scope(name, top_level=True)[source]

Return a context which either opens and caches a new name scope, or reenter an existing one.

Parameters

top_level (bool) – if True, the name scope will always be top-level. It will not be nested under any existing name scope of the caller.

tensorpack.tfutils.scope_utils.under_name_scope(name_scope=None)[source]
Parameters

name_scope (str) – the default scope to use. If None, will use the name of the function.

Returns

A decorator which makes the function run under a name scope. The name scope is obtained by the following: 1. The ‘name_scope’ keyword argument when the decorated function is called. 2. The ‘name_scope’ argument of the decorator. 3. (default) The name of the decorated function itself.

If the name is taken and cannot be used, a warning will be printed in the first case.

Example:

@under_name_scope()
def rms(x):
    return tf.sqrt(
        tf.reduce_mean(tf.square(x)))

rms(tensor)  # will be called under name scope 'rms'
rms(tensor, name_scope='scope')  # will be called under name scope 'scope'

Todo

Add a reuse option.

tensorpack.tfutils.optimizer module

tensorpack.tfutils.optimizer.apply_grad_processors(opt, gradprocs)[source]

Wrapper around optimizers to apply gradient processors.

Parameters
  • opt (tf.train.Optimizer) –

  • gradprocs (list[GradientProcessor]) – gradient processors to add to the optimizer.

Returns

a tf.train.Optimizer instance which runs the gradient processors before updating the variables.

class tensorpack.tfutils.optimizer.ProxyOptimizer(opt, name='ProxyOptimizer')[source]

Bases: tensorflow.python.training.optimizer.Optimizer

A transparent proxy which delegates all methods of tf.train.Optimizer

class tensorpack.tfutils.optimizer.PostProcessOptimizer(opt, func, colocate=True)[source]

Bases: tensorpack.tfutils.optimizer.ProxyOptimizer

An optimizer which applies some “post-processing operation” per variable (e.g. clipping, quantization) after the gradient update.

__init__(opt, func, colocate=True)[source]
Parameters
  • opt (tf.train.Optimizer) –

  • func (tf.Variable -> tf.Operation or None) – the operation needed to perform for this variable after the gradient update.

  • colocate (boolean) – colocate the function with the variable. No effect since TF 1.13.

class tensorpack.tfutils.optimizer.VariableAssignmentOptimizer(opt, func)[source]

Bases: tensorpack.tfutils.optimizer.PostProcessOptimizer

An optimizer which assigns each variable a new value (e.g. clipping, quantization) after the gradient update.

__init__(opt, func)[source]
Parameters
  • opt (tf.train.Optimizer) –

  • func (tf.Variable -> tf.Tensor or None) – the new value to be assigned to this variable after the gradient update.

class tensorpack.tfutils.optimizer.AccumGradOptimizer(opt, niter)[source]

Bases: tensorpack.tfutils.optimizer.ProxyOptimizer

An optimizer which accumulates gradients across \(k\) minimize() executions, and apply them together in every \(k\) th minimize() execution. This is roughly the same as using a \(k\) times larger batch size plus a \(k\) times larger learning rate, but uses much less memory.

This optimizer can be used in any TensorFlow code (with or without tensorpack).

Example:

from tensorpack.tfutils.optimizer import AccumGradOptimizer
myopt = tf.train.GradientDescentOptimizer(0.01)
myopt = AccumGradOptimizer(myopt, niter=5)
train_op = myopt.minimize(loss)
__init__(opt, niter)[source]
Parameters
  • opt (tf.train.Optimizer) – the underlying sub-optimizer.

  • niter (int) – number of iterations to accumulate gradients.

tensorpack.tfutils.sesscreate module

class tensorpack.tfutils.sesscreate.NewSessionCreator(target='', config=None)[source]

Bases: tensorflow.python.training.monitored_session.SessionCreator

__init__(target='', config=None)[source]
Parameters
  • config (target,) – same as Session.__init__().

  • config – a tf.ConfigProto instance, defaults to tfutils.get_default_sess_config()

create_session()[source]
class tensorpack.tfutils.sesscreate.ReuseSessionCreator(sess)[source]

Bases: tensorflow.python.training.monitored_session.SessionCreator

Returns an existing session.

__init__(sess)[source]
Parameters

sess (tf.Session) – the session to reuse

create_session()[source]
class tensorpack.tfutils.sesscreate.SessionCreatorAdapter(session_creator, func)[source]

Bases: tensorflow.python.training.monitored_session.SessionCreator

Apply a function on the output of a SessionCreator. Can be used to create a debug session.

Note: Since TF 1.6, debug session may not work properly with Monitored session. This is a tensorflow bug. To use tfdbg, use the TFLocalCLIDebugHook callback instead.

__init__(session_creator, func)[source]
Parameters
  • session_creator (tf.train.SessionCreator) – a session creator

  • func (tf.Session -> tf.Session) – takes a session created by

  • and return a new session to be returned by self.create_session (session_creator,) –

create_session()[source]

tensorpack.tfutils.sessinit module

class tensorpack.tfutils.sessinit.SessionInit[source]

Bases: object

Base class for utilities to load variables to a (existing) session.

init(sess)[source]

Initialize a session

Parameters

sess (tf.Session) – the session

class tensorpack.tfutils.sessinit.ChainInit(sess_inits)[source]

Bases: tensorpack.tfutils.sessinit.SessionInit

Initialize a session by a list of SessionInit instance, executed one by one. This can be useful for, e.g., loading several models from different files to form a composition of models.

__init__(sess_inits)[source]
Parameters

sess_inits (list[SessionInit]) – list of SessionInit instances.

class tensorpack.tfutils.sessinit.SaverRestore(model_path, prefix=None, ignore=)[source]

Bases: tensorpack.tfutils.sessinit.SessionInit

Restore a tensorflow checkpoint saved by tf.train.Saver or ModelSaver.

__init__(model_path, prefix=None, ignore=)[source]
Parameters
  • model_path (str) – a model name (model-xxxx) or a checkpoint file.

  • prefix (str) – during restore, add a prefix/ for every variable in this checkpoint.

  • ignore (tuple[str]) – tensor names that should be ignored during loading, e.g. learning-rate

class tensorpack.tfutils.sessinit.SaverRestoreRelaxed(model_path, prefix=None, ignore=)[source]

Bases: tensorpack.tfutils.sessinit.SaverRestore

Same as SaverRestore, but has more relaxed constraints.

It allows upcasting certain variables, or reshape certain variables when there is a mismatch that can be fixed.

When variable shape and value shape do not match, it will print a warning but will not crash.

Another advantage is that it doesn’t add any new ops to the graph.

class tensorpack.tfutils.sessinit.DictRestore(variable_dict, ignore_mismatch=False)[source]

Bases: tensorpack.tfutils.sessinit.SessionInit

Restore variables from a dictionary.

__init__(variable_dict, ignore_mismatch=False)[source]
Parameters
  • variable_dict (dict) – a dict of {name: value}

  • ignore_mismatch (bool) – ignore failures when the value and the variable does not match in their shapes. If False, it will throw exception on such errors. If True, it will only print a warning.

class tensorpack.tfutils.sessinit.JustCurrentSession[source]

Bases: tensorpack.tfutils.sessinit.SessionInit

This is a no-op placeholder

tensorpack.tfutils.sessinit.SmartInit(obj, *, ignore_mismatch=False)[source]

Create a SessionInit to be loaded to a session, automatically from any supported objects, with some smart heuristics. The object can be:

  • A TF checkpoint

  • A dict of numpy arrays

  • A npz file, to be interpreted as a dict

  • An empty string or None, in which case the sessinit will be a no-op

  • A list of supported objects, to be initialized one by one

Parameters
  • obj – a supported object

  • ignore_mismatch (bool) – ignore failures when the value and the variable does not match in their shapes. If False, it will throw exception on such errors. If True, it will only print a warning.

Returns

SessionInit

tensorpack.tfutils.summary module

tensorpack.tfutils.summary.add_tensor_summary(x, types, name=None, collections=None, main_tower_only=True)[source]

Summarize a tensor by different methods.

Parameters
  • x (tf.Tensor) – a tensor to summarize

  • types (list[str]) – summary types, can be scalar/histogram/sparsity/mean/rms

  • name (str) – summary name. Defaults to be the op name.

  • collections (list[str]) – collections of the summary ops.

  • main_tower_only (bool) – Only run under main training tower. If set to True, calling this function under other TowerContext has no effect.

Example:

with tf.name_scope('mysummaries'):  # to not mess up tensorboard
    add_tensor_summary(
        tensor, ['histogram', 'rms', 'sparsity'], name='mytensor')
tensorpack.tfutils.summary.add_param_summary(*summary_lists, **kwargs)[source]

Add summary ops for all trainable variables matching the regex, under a reused ‘param-summary’ name scope. This function is a no-op if not calling from main training tower.

Parameters
  • summary_lists (list) – each is (regex, [list of summary type]). Summary type is defined in add_tensor_summary().

  • collections (list[str]) – collections of the summary ops.

Example:

add_param_summary(
    ('.*/W', ['histogram', 'rms']),
    ('.*/gamma', ['scalar']),
)
tensorpack.tfutils.summary.add_activation_summary(x, types=None, name=None, collections=None)[source]

Call add_tensor_summary() under a reused ‘activation-summary’ name scope. This function is a no-op if not calling from main training tower.

Parameters
  • x (tf.Tensor) – the tensor to summary.

  • types (list[str]) – summary types, defaults to ['sparsity', 'rms', 'histogram'].

  • name (str) – if is None, use x.name.

  • collections (list[str]) – collections of the summary ops.

tensorpack.tfutils.summary.add_moving_summary(*args, **kwargs)[source]

Summarize the moving average for scalar tensors. This function is a no-op if not calling from main training tower. See tutorial at https://tensorpack.readthedocs.io/tutorial/summary.html

Parameters
  • args – scalar tensors to summarize

  • decay (float) – the decay rate. Defaults to 0.95.

  • collection (str or None) – the name of the collection to add EMA-maintaining ops. The default will work together with the default MovingAverageSummary callback.

  • summary_collections ([str]) – the names of collections to add the summary op. Default is TF’s default (tf.GraphKeys.SUMMARIES).

Returns

[tf.Tensor] – list of tensors returned by assign_moving_average, which can be used to maintain the EMA.

tensorpack.tfutils.varmanip module

tensorpack.tfutils.varmanip.dump_session_params(path)[source]

Dump value of all TRAINABLE + MODEL variables to a dict, and save as npz format (loadable by sessinit.SmartInit()).

Parameters

path (str) – the file name to save the parameters. Must ends with npz.

tensorpack.tfutils.varmanip.load_checkpoint_vars(path)[source]

Load all variables from a checkpoint to a dict.

Parameters

path (str) – path to a checkpoint.

Returns

dict – a name:value dict

tensorpack.tfutils.varmanip.save_checkpoint_vars(dic, path)[source]

Save variables in dic to path.

Parameters
  • dic – {name: value}. values have to be numpy arrays

  • path – save as npz if the name ends with ‘.npz’, otherwise save as a checkpoint.

tensorpack.tfutils.varmanip.get_all_checkpoints(dir: str, prefix: str = 'model')[source]

Get a sorted list of all checkpoints found in directory.

Parameters
  • dir (str) – checkpoint directory

  • prefix (str) – common prefix among all checkpoints (without the final “-“)

Returns

list[(str, int)] – list of (name, step) sorted by step. Name is a checkpoint handle that can be passed to tf.train.NewCheckpointReader or load_checkpoint_vars().

tensorpack.tfutils.varreplace module

tensorpack.tfutils.varreplace.custom_getter_scope(custom_getter)[source]
Parameters

custom_getter – the same as in tf.get_variable()

Returns

The current variable scope with a custom_getter.

tensorpack.tfutils.varreplace.freeze_variables(stop_gradient=True, skip_collection=False)[source]

Return a context to freeze variables, by wrapping tf.get_variable with a custom getter. It works by either applying tf.stop_gradient on the variables, or keeping them out of the TRAINABLE_VARIABLES collection, or both. Both options have their own pros and cons.

Example

from tensorpack.tfutils import varreplace
with varreplace.freeze_variable(stop_gradient=False, skip_collection=True):
    x = FullyConnected('fc', x, 1000)   # fc/* will not be trained
Parameters
  • stop_gradient (bool) –

    if True, variables returned from get_variable will be wrapped with tf.stop_gradient.

    Note that the created variables may still have gradient when accessed by other approaches (e.g. by name, or by collection). For example, they may still have a gradient in weight decay. Also note that this makes tf.get_variable returns a Tensor instead of a Variable, which may break existing contract. Therefore, it’s recommended to use the skip_collection option instead.

  • skip_collection (bool) – if True, do not add the variable to TRAINABLE_VARIABLES collection, but to MODEL_VARIABLES collection. As a result they will not be trained by default.

Note:

stop_gradient only stops variables returned by get_variable within the context to contribute no gradient in this context. Therefore it may not completely freeze the variables. For example:

  1. If a variable is created, or reused outside of the context, it can still contribute to the gradient of other tensors.

  2. If a freezed variable is accessed by other approaches (e.g., by names, by collections), it can still contribute to the gradient of other tensors. For example, weight decay cannot be stopped by a stop_gradient context.

skip_collection has to be used the first time the variable is created. Once skip_collection is used, the variable is not a trainable variable anymore, and will be completely freezed from gradient update in tensorpack’s single-cost trainer.

Choose the option carefully depend on what you need.

tensorpack.tfutils.varreplace.remap_variables(fn)[source]

Use fn to map the output of any variable getter.

Parameters

fn (tf.Variable -> tf.Tensor) –

Returns

The current variable scope with a custom_getter that maps all the variables by fn.

Example

from tensorpack.tfutils import varreplace
with varreplace.remap_variables(lambda var: quantize(var)):
    x = FullyConnected('fc', x, 1000)   # fc/{W,b} will be quantized

tensorpack.tfutils.export module

A collection of functions to ease the process of exporting a model for production.

class tensorpack.tfutils.export.ModelExporter(config)[source]

Bases: object

Export models for inference.

__init__(config)[source]

Initialise the export process.

Parameters

config (PredictConfig) – the config to use. The graph will be built with the tower function defined by this PredictConfig. Then the input / output names will be used to export models for inference.

export_compact(filename, optimize=True, toco_compatible=False)[source]

Create a self-contained inference-only graph and write final graph (in pb format) to disk.

Parameters
  • filename (str) – path to the output graph

  • optimize (bool) – whether to use TensorFlow’s optimize_for_inference to prune and optimize the graph. This does not work on all types of graphs.

  • toco_compatible (bool) – See TensorFlow’s optimize_for_inference for details. Only available after TF 1.8.

export_serving(filename, tags=None, signature_name='prediction_pipeline')[source]

Converts a checkpoint and graph to a servable for TensorFlow Serving. Use TF’s SavedModelBuilder to export a trained model without tensorpack dependency.

Parameters
  • filename (str) – path for export directory

  • tags (tuple) – tuple of user specified tags. Defaults to just “SERVING”.

  • signature_name (str) – name of signature for prediction

Note

This produces

variables/       # output from the vanilla Saver
    variables.data-?????-of-?????
    variables.index
saved_model.pb   # a `SavedModel` protobuf

Currently, we only support a single signature, which is the general PredictSignatureDef: https://github.com/tensorflow/serving/blob/master/tensorflow_serving/g3doc/signature_defs.md

tensorpack.tfutils.dependency module

tensorpack.tfutils.dependency.dependency_of_targets(targets, op)[source]

Check that op is in the subgraph induced by the dependencies of targets. The result is memoized.

This is useful if some SessionRunHooks should be run only together with certain ops.

Parameters
  • targets – a tuple of ops or tensors. The targets to find dependencies of.

  • op (tf.Operation or tf.Tensor) –

Returns

bool – True if any one of targets depend on op.

tensorpack.tfutils.dependency.dependency_of_fetches(fetches, op)[source]

Check that op is in the subgraph induced by the dependencies of fetches. fetches may have more general structure.

Parameters
  • fetches – An argument to sess.run. Nested structure will affect performance.

  • op (tf.Operation or tf.Tensor) –

Returns

bool – True if any of fetches depend on op.

Other functions in tensorpack.tfutils module

tfutils.get_default_sess_config()

Return a tf.ConfigProto to use as default session config. You can modify the returned config to fit your needs.

Parameters

mem_fraction (float) – see the per_process_gpu_memory_fraction option in TensorFlow’s GPUOptions protobuf: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/protobuf/config.proto

Returns

tf.ConfigProto – the config to use.

tfutils.get_global_step_var()
Returns

tf.Tensor – the global_step variable in the current graph. Create if doesn’t exist.

tfutils.get_global_step_value()
Returns

int – global_step value in current graph and session

Has to be called under a default session.

tfutils.get_tf_version_tuple()

Return TensorFlow version as a 2-element tuple (for comparison).