tensorpack.contrib package¶
-
class
tensorpack.contrib.keras.
KerasPhaseCallback
(isTrain)[source]¶ Bases:
tensorpack.callbacks.base.Callback
Keras needs an extra input if learning_phase is used by the model This callback will be used: 1. By the trainer with isTrain=True 2. By InferenceRunner with isTrain=False, in the form of hooks
If you use
KerasModel
orsetup_keras_trainer()
, this callback will be automatically added when needed.
-
tensorpack.contrib.keras.
setup_keras_trainer
(trainer, get_model, input_signature, target_signature, input, optimizer, loss, metrics)[source]¶ - Parameters
trainer (SingleCostTrainer) –
get_model (input1, input2, .. -> tf.keras.Model) – A function which takes tensors, builds and returns a Keras model. It will be part of the tower function.
input (InputSource) –
optimizer (tf.train.Optimizer) –
metrics (loss,) – list of strings
-
class
tensorpack.contrib.keras.
KerasModel
(get_model, input_signature=None, target_signature=None, input=None, trainer=None)[source]¶ Bases:
object
-
__init__
(get_model, input_signature=None, target_signature=None, input=None, trainer=None)[source]¶ - Parameters
get_model (input1, input2, .. -> keras.Model) – A function which takes tensors, builds and returns a Keras model. It will be part of the tower function.
input_signature ([tf.TensorSpec]) – required. The signature for inputs.
target_signature ([tf.TensorSpec]) – required. The signature for the targets tensors.
input (InputSource | DataFlow) – the InputSource or DataFlow where the input data comes from.
trainer (Trainer) – the default will check the number of available GPUs and use them all.
-
compile
(optimizer, loss, metrics=None)[source]¶ - Parameters
optimizer (tf.train.Optimizer) –
metrics (loss,) – string or list of strings
-
fit
(validation_data=None, **kwargs)[source]¶ - Parameters
validation_data (DataFlow or InputSource) – to be used for inference. The inference callback is added as the first in the callback list. If you need to use it in a different order, please write it in the callback list manually.
kwargs – same arguments as
Trainer.train_with_defaults()
.
-