Inference During Training¶
There are two ways to do inference during training.
If your inference follows the paradigm of: “evaluate some tensors for each input, and aggregate the results in the end”. You can use the
InferenceRunnerinterface with some
Inferencer. This will further support prefetch & data-parallel inference.
Currently this lacks documentation, but you can refer to examples that uses
Inferencerto learn more.
In both methods, your tower function will be called again, with
You can use this predicate to choose a different code path in inference mode.
Inference After Training: What Tensorpack Does¶
Tensorpack provides some small tools to do the most basic types of inference for demo purposes. You can use them but these approaches are often suboptimal and may fail. They may often be inefficient or lack functionalities you need.
If you need anything more complicated, please learn what TensorFlow can do, and do it on your own because Tensorpack is a training interface and doesn’t focus on what happened after training.
Tensorpack provides OfflinePredictor, for inference demo after training. It has functionailities to build the graph, load the checkpoint, and return a callable for you for simple prediction. Refer to its docs for details.
To use it, you need to provide your model, checkpoint, and define what are the
input & output tensors to infer with. You can obtain names of tensors by
print(), or assign a name to a tensor by
A simple example of how it works:
pred_config = PredictConfig( model=YourModel(), session_init=SmartInit(model_path), input_names=['input1', 'input2'], # tensor names in the graph, or name of the declared inputs output_names=['output1', 'output2']) # tensor names in the graph predictor = OfflinePredictor(pred_config) output1_array, output2_array = predictor(input1_array, input2_array)
It’s common to use a different graph for inference,
e.g., use NHWC format, support encoded image format, etc.
You can make these changes inside the
tower_func in your
The example in examples/basics/export-model.py demonstrates such an altered inference graph.
OfflinePredictor is only for quick demo purposes. It runs inference on numpy arrays, therefore may not be the most efficient way. It also has very limited functionalities.
In addition to the standard checkpoint format tensorpack saved for you during training, you can also save your models into other formats after training, so it may be more friendly for inference.
SavedModelformat for TensorFlow Serving:
from tensorpack.tfutils.export import ModelExporter ModelExporter(pred_config).export_serving('/path/to/export')
This format contains both the graph and the variables. Refer to TensorFlow serving documentation on how to use it.
Export to a frozen and pruned graph for TensorFlow’s builtin tools such as TOCO:
This format is just a serialized
tf.Graph. The export process:
Converts all variables to constants to embed the variables directly in the graph.
Removes all unnecessary operations (training-only ops, e.g., learning-rate) to compress the graph.
This creates a self-contained graph which includes all necessary information to run inference.
To load the saved graph, you can simply:
graph_def = tf.GraphDef() graph_def.ParseFromString(open(graph_file, 'rb').read()) tf.import_graph_def(graph_def)
demonstrates the usage of such a frozen/pruned graph.
Again, you may often want to use a different graph for inference and you can
do so by the arguments of
Note that the exporter relies on TensorFlow’s automatic graph transformation, which do not always work reliably. Automated graph transformation is often suboptimal or sometimes fail. It’s safer to write the graph by yourself.
Inference After Training: Do It Yourself¶
Tensorpack is a training interface – it doesn’t care what happened after training. During training it already provides everything you need for inference or model diagnosis after training:
The model (the graph): you’ve already written it yourself with TF symbolic functions. Nothing about it is related to the tensorpack interface. If you use tensorpack layers, they are not so different from
The trained parameters: tensorpack saves them in standard TF checkpoint format. Nothing about the format is related to tensorpack.
With the model and the trained parameters, you can do inference with whatever approaches TensorFlow supports. Usually it involves the following steps:
Step 1: build the model (graph)¶
You can build a graph however you like, with pure TensorFlow. If your model is written with
ModelDesc, you can also build it like this:
a, b = tf.placeholder(...), tf.placeholder(...) # call ANY symbolic functions on a, b. e.g.: with TowerContext('', is_training=False): model.build_graph(a, b)
Do not use metagraph for inference!
Tensorpack saves a metagraph during training. Users should not try to load it for inference.
Metagraph is the wrong abstraction for a “model”. It stores the entire graph which contains not only the mathematical model, but also all the training settings (queues, iterators, summaries, evaluations, multi-gpu replications). Therefore it is usually wrong to import a training metagraph for inference.
It’s especially error-prone to load a metagraph on top of a non-empty graph. The potential name conflicts between the current graph and the nodes in the metagraph can lead to esoteric bugs or sometimes completely ruin the model.
It’s also very common to change the graph for inference. For example, you may need a different data layout for CPU inference, or you may need placeholders in the inference graph (which may not even exist in the training graph). However metagraph is not designed to be easily modified at all.
Due to the above reasons, to do inference, it’s best to recreate a clean graph (and save it if needed) by yourself.
Step 2: load the checkpoint¶
You can just use
tf.train.Saver for all the work.
Alternatively, use tensorpack’s
Now, you’ve already built a graph for inference, and the checkpoint is also loaded. You may now:
sess.runto do inference
save the graph to some formats for further processing
apply graph transformation for efficient inference
These steps are unrelated to tensorpack, and you’ll need to learn TensorFlow and do it yourself.