Save and Load models¶
Work with a TF Checkpoint¶
ModelSaver callback saves the model to the directory defined by
in TensorFlow checkpoint format.
A TF checkpoint typically includes a
.data-xxxxx file and a
Both are necessary.
scripts/ls-checkpoint.py demos how to print all variables and their shapes in a checkpoint.
Tensorpack includes another tool to save variables to TF checkpoint, see save_chkpt_vars.
Work with npz Files in Model Zoo¶
Most models provided by tensorpack are in npz (dictionary) format,
because it’s easy to manipulate without TF dependency.
You can read/write them with
scripts/dump-model-params.py can be used to remove unnecessary variables in a checkpoint
and save results to a npz.
It takes a metagraph file (which is also saved by
ModelSaver) and only saves variables that the model needs at inference time.
It dumps the model to a
var-name: value dict saved in npz format.
Load a Model to a Session¶
Model loading (in both training and inference) is through the
For training, use
For inference, use
There are two ways a session can be initialized:
which restores a TF checkpoint,
or session_init=DictRestore(…) which restores a dict.
DictRestore is the most general loader because you can make arbitrary changes
you need (e.g., remove variables, rename variables) to the dict.
To load multiple models, use ChainInit.
To load an npz file from tensorpack model zoo to a session, you can use
You can also use
a small helper which returns either a
SaverRestore or a
DictRestore based on the file name.
Variable restoring is completely based on exact name match between
variables in the current graph and variables in the
Variables that appear in only one side will be printed as warning.
Variables of the same name but incompatible shapes will cause error.
Therefore, transfer learning is trivial.
If you want to load a pre-trained model, just use the same variable names. If you want to re-train some layer, either rename the variables in the graph, or rename/remove the variables in your loader.
“Resume training” is mostly just “loading the last known checkpoint”. To load a model, you should refer to the previous section: Load a Model to a Session.
A checkpoint does not resume everything!
Loading the checkpoint does most of the work in “resume trainig”, but note that TensorFlow checkpoint only saves TensorFlow variables, which means other Python state that are not TensorFlow variables will not be saved and resumed. This means:
Training epoch number will not be resumed. You can set it by providing a
State in your callbacks will not be resumed. Certain callbacks maintain a state (e.g., current best accuracy) in Python, which cannot be saved automatically.
is an alternative of
TrainConfig which applies some heuristics to load the lastest epoch number and lastest checkpoint.