tensorpack.train package

Relevant tutorials: Trainers, Training Interface

exception tensorpack.train.StopTraining[source]

Bases: Exception

An exception thrown to stop training.

class tensorpack.train.Trainer[source]

Bases: object

Base class for a trainer.


The number of the currently ongoing epoch.

An epoch is defined to cover the moment before calling before_epoch until after calling trigger_epoch. i.e., in the trigger_epoch of epoch 3, self.epoch_num is 3. If you need use self.epoch_num in your callback, you’ll need to know this.


The tensorflow global_step, i.e. how many times hooked_sess.run has been called.


  1. global_step is incremented after each hooked_sess.run returns from TF runtime.

  2. If you make zero or more than one calls to hooked_sess.run in one run_step(), local_step and global_step may increment at different speed.

hooked_sess = None

The tf.train.MonitoredSession object the trainer is using. It contains all the before_run/after_run hooks the callbacks have registered. It is used for running the training iterations. Available after initialize().

Note that using hooked_sess.run will evaluate all the hooks, just like running a training iteration. It may do the following:

  1. Take a datapoint from the InputSource

  2. Increase the global_step

  3. Evaluate some summaries

Typically you do not want to use hooked_sess.run in callbacks, because it is for the “training iteration”. If you just want to evaluate some tensors, use sess.run if the tensors does not depend on the inputs, or more generally, use before_run/after_run to evaluate the tensors along with the training iterations.

initialize(session_creator, session_init)[source]

Create the session and set self.sess. Call self.initiailize_hooks() Finalize the graph.

It must be called after callbacks are setup.


Create SessionRunHooks for all callbacks, and hook it onto self.sess to create self.hooked_sess.

A new trainer may override this method to create multiple groups of hooks, which can be useful when the training is not done by a single train_op.

is_chief = True

Whether this process is the chief worker in distributed training. Certain callbacks will only be run by chief worker.


The number of steps that have finished in the current epoch.

main_loop(steps_per_epoch, starting_epoch, max_epoch)[source]

Run the main training loop.

Parameters:starting_epoch, max_epoch (steps_per_epoch,) –

Register callbacks to the trainer. It can only be called before Trainer.train().

Parameters:cb (Callback or [Callback]) – a callback or a list of callbacks
Returns:succeed or not

Defines what to do in one iteration. The default is: self.hooked_sess.run(self.train_op).

The behavior of each iteration can be changed by either setting trainer.train_op, or overriding this method.

sess = None

The tf.Session object the trainer is using. Available after initialize().

Using trainer.sess.run to evaluate tensors that depend on the inputs can lead to unexpected effect:

For example, if you use trainer.sess.run to evaluate a tensor that depends on the inputs coming from a StagingArea, this will take a datapoint from the StagingArea, making the StagingArea empty, and as a result make the training hang.

setup_callbacks(callbacks, monitors)[source]

Setup callbacks and monitors. Must be called after the main graph is built.

train(callbacks, monitors, session_creator, session_init, steps_per_epoch, starting_epoch=1, max_epoch=9999999)[source]

Implemented by three lines:

self.setup_callbacks(callbacks, monitors)
self.initialize(session_creator, session_init)
self.main_loop(steps_per_epoch, starting_epoch, max_epoch)

You can call those methods by yourself to have better control on details if needed.

train_with_defaults(_sentinel=None, callbacks=None, monitors=None, session_creator=None, session_init=None, steps_per_epoch=None, starting_epoch=1, max_epoch=9999999, extra_callbacks=None)[source]

Same as train(), except:

  1. Add extra_callbacks to callbacks. The default value for extra_callbacks is DEFAULT_CALLBACKS().

  2. Default value for monitors is DEFAULT_MONITORS().

  3. Provide default values for every option except steps_per_epoch.

class tensorpack.train.TrainConfig(dataflow=None, data=None, model=None, callbacks=None, extra_callbacks=None, monitors=None, session_creator=None, session_config=None, session_init=None, starting_epoch=1, steps_per_epoch=None, max_epoch=99999, **kwargs)[source]

Bases: object

A collection of options to be used for single-cost trainers.

Note that you do not have to use TrainConfig. You can use the API of Trainer directly, to have more fine-grained control of the training.

__init__(dataflow=None, data=None, model=None, callbacks=None, extra_callbacks=None, monitors=None, session_creator=None, session_config=None, session_init=None, starting_epoch=1, steps_per_epoch=None, max_epoch=99999, **kwargs)[source]
  • dataflow (DataFlow) –

  • data (InputSource) –

  • model (ModelDesc) –

  • callbacks (list[Callback]) – a list of Callback to use during training.

  • extra_callbacks (list[Callback]) –

    This argument is only used to provide the defaults in addition to callbacks. The list of callbacks that will be used in the end is simply callbacks + extra_callbacks.

    It is usually left as None, and the default value for this argument is DEFAULT_CALLBACKS(). You can override it when you don’t like any of the default callbacks. For example, if you’d like to let the progress bar print tensors, you can use


  • monitors (list[MonitorBase]) – Defaults to DEFAULT_MONITORS().

  • session_creator (tf.train.SessionCreator) – Defaults to sesscreate.NewSessionCreator() with the config returned by tfutils.get_default_sess_config().

  • session_config (tf.ConfigProto) – when session_creator is None, use this to create the session.

  • session_init (SessionInit) – how to initialize variables of a session. Defaults to do nothing.

  • starting_epoch (int) – The index of the first epoch.

  • steps_per_epoch (int) – the number of steps (defined by Trainer.run_step()) to run in each epoch. Defaults to the input data size.

  • max_epoch (int) – maximum number of epoch to run training.

class tensorpack.train.AutoResumeTrainConfig(always_resume=True, **kwargs)[source]

Bases: tensorpack.train.config.TrainConfig

Same as TrainConfig, but does the following to automatically resume from training:

  1. If a checkpoint was found in logger.get_logger_dir(), set session_init option to load it.

  2. If a JSON history was found in logger.get_logger_dir(), try to load the epoch number from it and set the starting_epoch option to continue training.

You can choose to let the above two option to either overwrite or not overwrite user-provided arguments, as explained below.

Note that the functionality requires the logging directory to obtain necessary information from a previous run. In some cases (e.g. when using Horovod), the directory is not available or different for different workers and this class may not function properly.

__init__(always_resume=True, **kwargs)[source]
  • always_resume (bool) – If False, user-provided arguments session_init and starting_epoch will take priority. Otherwise, resume will take priority.

  • kwargs – same as in TrainConfig.


The main goal of this class is to let a training job to resume without changing any line of code or command line arguments. So it’s useful to let resume take priority over user-provided arguments sometimes:

If your training starts from a pre-trained model, you would want it to use user-provided model loader at the beginning, but a “resume” model loader when the job was interrupted and restarted.


Return the default callbacks, which will be used in TrainConfig and Trainer.train_with_defaults(). They are:

  1. MovingAverageSummary()

  2. ProgressBar()

  3. MergeAllSummaries()

  4. RunUpdateOps()


Return the default monitors, which will be used in TrainConfig and Trainer.train_with_defaults(). They are:

  1. TFEventWriter()

  2. JSONWriter()

  3. ScalarPrinter()

tensorpack.train.launch_train_with_config(config, trainer)[source]

Train with a TrainConfig and a Trainer, to present the simple and old training interface. It basically does the following 3 things (and you can easily do them by yourself if you need more control):

  1. Setup the input with automatic prefetching heuristics, from config.data or config.dataflow.

  2. Call trainer.setup_graph with the input as well as config.model.

  3. Call trainer.train with rest of the attributes of config.



    config, SyncMultiGPUTrainerParameterServer(8, ps_device='gpu'))
class tensorpack.train.SingleCostTrainer[source]

Bases: tensorpack.train.tower.TowerTrainer

Base class for single-cost trainer.

Single-cost trainer has a setup_graph() method which takes (inputs_desc, input, get_cost_fn, get_opt_fn), and build the training graph from them.

To use a SingleCostTrainer object, call trainer.setup_graph(…); trainer.train(…).


See tf.gradients.


See tf.gradients. It sometimes can heavily affect performance when backward op does not support the device of forward op.


See tf.gradients.


Use xla.compile() to compile the tower function. Note that XLA has very strong requirements on the tower function, e.g.:

  1. limited op support

  2. inferrable shape

  3. no summary support

and many tower functions cannot be compiled by XLA. Don’t use it if you don’t understand it.

setup_graph(inputs_desc, input, get_cost_fn, get_opt_fn)[source]

Responsible for building the main training graph for single-cost training.

  • inputs_desc ([InputDesc]) –

  • input (InputSource) –

  • get_cost_fn ([tf.Tensor] -> tf.Tensor) – callable, takes some input tensors and return a cost tensor.

  • get_opt_fn (-> tf.train.Optimizer) – callable which returns an optimizer. Will only be called once.


get_cost_fn will be part of the tower function. It must follows the rules of tower function..

class tensorpack.train.TowerTrainer[source]

Bases: tensorpack.train.base.Trainer

Base trainers for models that can be built by calling a tower function under a TowerContext.

This is required by some features that replicates the model automatically, e.g. creating a predictor.

To use features of TowerTrainer, set tower_func and use it to build the graph. Note that tower_func can only be set once per instance.

get_predictor(input_names, output_names, device=0)[source]

This method will build the trainer’s tower function under TowerContext(is_training=False), and returns a callable predictor with input placeholders & output tensors in this tower.

  • input_names (list) – list of input names, matching the inputs declared for the trainer.

  • output_names (list) – list of tensor names without the tower prefix.

  • device (int) – build the predictor on device ‘/gpu:{device}’ or use -1 for ‘/cpu:0’.


an OnlinePredictor.


# in the graph:
interesting_tensor = tf.identity(x, name='fun')
# in _setup_graph callback method:
self._predictor = self.trainer.get_predictor(['input1', 'input2'], ['fun'])
# After session is initialized (see Tutorials - Write a Callback), can use it by:
outputs = self._predictor(input1, input2)

The CycleGAN example and DQN example have more concrete use of this method.


Returns – list[InputDesc]: metainfo about the inputs to the tower.


A TowerFuncWrapper instance. See [tutorial on tower function](http://tensorpack.readthedocs.io/tutorial/trainer.html#tower-trainer) for more information.

Returns – a TowerTensorHandles object, to
access the tower handles by either indices or names.

It is accessbile only after the graph is set up. With towers(), you can then access many attributes of each tower:


# Access the conv1/output tensor in the first training tower
class tensorpack.train.NoOpTrainer[source]

Bases: tensorpack.train.trainers.SimpleTrainer

A special trainer that builds the graph (if given a tower function) and does nothing in each step. It is used to only run the callbacks.

Note that steps_per_epoch and max_epochs are still valid options.


Defines what to do in one iteration. The default is: self.hooked_sess.run(self.train_op).

The behavior of each iteration can be changed by either setting trainer.train_op, or overriding this method.

class tensorpack.train.SimpleTrainer[source]

Bases: tensorpack.train.tower.SingleCostTrainer

Single-GPU single-cost single-tower trainer.


Return a default multi-GPU trainer, if you don’t care about the details. It may not be the most efficient one for your task.

Parameters:gpus (list[int]) – list of GPU ids.
class tensorpack.train.SyncMultiGPUTrainerReplicated(gpus, average=True, mode=None, use_nccl=None)[source]

Bases: tensorpack.train.tower.SingleCostTrainer

Data-parallel training in “replicated” mode, where each GPU contains a replicate of the whole model. It will build one tower on each GPU under its own variable scope. Each gradient update is averaged or summed across or GPUs through NCCL.

It is an equivalent of --variable_update=replicated in tensorflow/benchmarks.


Whether to broadcast the variables every epoch. Theoretically this is a no-op (because the variables are supposed to be in-sync). But this cheap operation may help prevent certain numerical issues in practice.

__init__(gpus, average=True, mode=None, use_nccl=None)[source]
  • gpus (int or [int]) – list of GPU ids.

  • average (bool) – whether to average or sum gradients.

  • mode (str or None) – Gradient aggregation mode. Supported values: [‘nccl’, ‘hierarchical’, ‘cpu’]. Default to pick automatically by heuristics. These modes may have slight (within 5%) differences in speed. “hierarchical” mode was designed for DGX-like 8GPU machines.

  • use_nccl – deprecated option

devices = None

List of GPU ids.

class tensorpack.train.SyncMultiGPUTrainerParameterServer(gpus, ps_device=None)[source]

Bases: tensorpack.train.tower.SingleCostTrainer

Data-parallel training in ‘ParameterServer’ mode. It builds one tower on each GPU with shared variable scope. It synchronizes the gradients computed from each tower, averages them and applies to the shared variables.

It is an equivalent of --variable_update=parameter_server in tensorflow/benchmarks.

__init__(gpus, ps_device=None)[source]
  • gpus ([int]) – list of GPU ids.

  • ps_device – either ‘gpu’ or ‘cpu’, where variables are stored. The default value is subject to change.

devices = None

List of GPU ids.

class tensorpack.train.AsyncMultiGPUTrainer(gpus, scale_gradient=True)[source]

Bases: tensorpack.train.tower.SingleCostTrainer

Data-parallel training with async update. It builds one tower on each GPU with shared variable scope. Every tower computes the gradients and independently applies them to the variables, without synchronizing and averaging across towers.

__init__(gpus, scale_gradient=True)[source]
  • gpus ([int]) – list of GPU ids.

  • scale_gradient (bool) – if True, will scale each gradient by 1.0/nr_gpu.

devices = None

List of GPU ids.

class tensorpack.train.DistributedTrainerParameterServer(gpus, server, caching_device='cpu')[source]

Bases: tensorpack.train.trainers.DistributedTrainerBase

Distributed parameter server training. A single copy of parameters are scattered around PS. Gradients across GPUs are averaged within the worker, and applied to PS. Each worker also caches the variables for reading.

It is an equivalent of --variable_update=parameter_server in tensorflow/benchmarks. However this implementation hasn’t been well tested. It probably still has issues in model saving, etc. Also, TensorFlow team is not actively maintaining distributed training features. Check HorovodTrainer and ResNet-Horovod for better distributed training support.


  1. Gradients are not averaged across workers, but applied to PS variables directly (either with or without locking depending on the optimizer).

__init__(gpus, server, caching_device='cpu')[source]
  • gpus ([int]) – list of GPU ids.

  • server (tf.train.Server) – the server with ps and workers.

  • caching_device (str) – either ‘cpu’ or ‘gpu’. The device to cache variables copied from PS

class tensorpack.train.DistributedTrainerReplicated(gpus, server)[source]

Bases: tensorpack.train.trainers.DistributedTrainerBase

Distributed replicated training. Each worker process builds the same model on one or more GPUs. Gradients across GPUs are averaged within the worker, and get synchronously applied to the global copy of variables located on PS. Then each worker copy the latest variables from PS back to local.

It is an equivalent of --variable_update=distributed_replicated in tensorflow/benchmarks. Note that the performance of this trainer is still not satisfactory, and TensorFlow team is not actively maintaining distributed training features. Check HorovodTrainer and ResNet-Horovod for better distributed training support.


  1. Gradients are not averaged across workers, but applied to PS variables directly (either with or without locking depending on the optimizer).

  2. Some details about collections: all variables created inside tower will become local variables, and a clone will be made in global variables for all trainable/model variables.


# Create the server object like this:
hosts = ['host1.com', 'host2.com']
cluster_spec = tf.train.ClusterSpec({
    'ps': [h + ':2222' for h in hosts],
    'worker': [h + ':2223' for h in hosts]
server = tf.train.Server(
    cluster_spec, job_name=args.job, task_index=args.task,
# initialize trainer with this server object
# Start training like this:
(host1)$ ./train.py --job worker --task 0
(host1)$ CUDA_VISIBLE_DEVICES= ./train.py --job ps --task 0
(host2)$ ./train.py --job worker --task 1
(host2)$ CUDA_VISIBLE_DEVICES= ./train.py --job ps --task 1
__init__(gpus, server)[source]
  • gpus (list[int]) – list of GPU ids.

  • server (tf.train.Server) – the server with ps and workers.

class tensorpack.train.HorovodTrainer(average=True, compression=None)[source]

Bases: tensorpack.train.tower.SingleCostTrainer

Horovod trainer, support both multi-GPU and distributed training.

To use for multi-GPU training:

# First, change trainer to HorovodTrainer(), then
CUDA_VISIBLE_DEVICES=0,1,2,3 NCCL_DEBUG=INFO mpirun -np 4 --output-filename mylog python train.py

To use for distributed training:

# First, change trainer to HorovodTrainer(), then
mpirun -np 8 -H server1:4,server2:4  \
    -bind-to none -map-by slot \
    --output-filename mylog -x NCCL_DEBUG=INFO -x LD_LIBRARY_PATH \
    python train.py
# Add other environment variables you need by -x, e.g. PYTHONPATH, PATH.
# If using all GPUs, you can always skip the `CUDA_VISIBLE_DEVICES` option.
# There are other MPI options that can potentially improve performance especially on special hardwares.


  1. To reach the maximum speed in your system, there are many options to tune for Horovod installation and in the MPI command line. See Horovod docs for details.

  2. Due to a TF bug, you must not initialize CUDA context before the trainer starts training. Therefore TF functions like is_gpu_available() or list_local_devices() must be avoided.

  1. MPI does not like fork(). If your dataflow contains multiprocessing, it may cause problems.

  2. MPI sometimes fails to kill all processes in the end. Be sure to check it afterwards.

  3. Keep in mind that there is one process running the script per GPU, therefore:

    • Make sure your InputSource has reasonable randomness.

    • If your data processing is heavy, doing it in a single dedicated process might be a better choice than doing them repeatedly in each process.

    • You need to make sure log directories in each process won’t conflict. You can set it only for the chief process, or set a different one for each process.

    • Callbacks have an option to be run only in the chief process, or in all processes. See Callback.set_chief_only(). Most callbacks have a reasonable default already, but certain callbacks may not behave properly by default. Report an issue if you find any.

    • You can use Horovod API such as hvd.rank() to know which process you are and choose different code path. Chief process has rank 0.

  4. Due to these caveats, see ResNet-Horovod for a full example which has handled these common issues. This example can train ImageNet in roughly an hour following the paper’s setup.

__init__(average=True, compression=None)[source]
  • average (bool) – whether to average or sum the gradients across processes.

  • compressionhvd.Compression.fp16 or hvd.Compression.none