Source code for tensorpack.predict.config

# -*- coding: utf-8 -*-
# File: config.py


import tensorflow as tf
import six

from ..graph_builder import ModelDescBase
from ..tfutils import get_default_sess_config
from ..tfutils.tower import TowerFuncWrapper
from ..tfutils.sessinit import SessionInit, JustCurrentSession
from ..utils import logger

__all__ = ['PredictConfig']


[docs]class PredictConfig(object):
[docs] def __init__(self, model=None, tower_func=None, inputs_desc=None, input_names=None, output_names=None, session_creator=None, session_init=None, return_input=False, create_graph=True, ): """ You need to set either `model`, or `inputs_desc` plus `tower_func`. They are needed to construct the graph. You'll also have to set `output_names` as it does not have a default. Args: model (ModelDescBase): to be used to obtain inputs_desc and tower_func. tower_func: a callable which takes input tensors (by positional args) and construct a tower. or a :class:`tfutils.TowerFuncWrapper` instance, which packs both `inputs_desc` and function together. inputs_desc ([InputDesc]): if tower_func is a plain function (instead of a TowerFuncWrapper), this describes the list of inputs it takes. input_names (list): a list of input tensor names. Defaults to match inputs_desc. output_names (list): a list of names of the output tensors to predict, the tensors can be any computable tensor in the graph. session_creator (tf.train.SessionCreator): how to create the session. Defaults to :class:`tf.train.ChiefSessionCreator()`. session_init (SessionInit): how to initialize variables of the session. Defaults to do nothing. return_input (bool): same as in :attr:`PredictorBase.return_input`. create_graph (bool): create a new graph, or use the default graph when predictor is first initialized. """ def assert_type(v, tp): assert isinstance(v, tp), v.__class__ if model is not None: assert_type(model, ModelDescBase) assert inputs_desc is None and tower_func is None self.inputs_desc = model.get_inputs_desc() self.tower_func = TowerFuncWrapper(model.build_graph, self.inputs_desc) else: if isinstance(tower_func, TowerFuncWrapper): inputs_desc = tower_func.inputs_desc assert inputs_desc is not None and tower_func is not None self.inputs_desc = inputs_desc self.tower_func = TowerFuncWrapper(tower_func, inputs_desc) if session_init is None: session_init = JustCurrentSession() self.session_init = session_init assert_type(self.session_init, SessionInit) if session_creator is None: self.session_creator = tf.train.ChiefSessionCreator(config=get_default_sess_config()) else: self.session_creator = session_creator # inputs & outputs self.input_names = input_names if self.input_names is None: self.input_names = [k.name for k in self.inputs_desc] self.output_names = output_names assert_type(self.output_names, list) assert_type(self.input_names, list) if len(self.input_names) == 0: logger.warn('PredictConfig receives empty "input_names".') # assert len(self.input_names), self.input_names for v in self.input_names: assert_type(v, six.string_types) assert len(self.output_names), self.output_names self.return_input = bool(return_input) self.create_graph = bool(create_graph)
def _maybe_create_graph(self): if self.create_graph: return tf.Graph() return tf.get_default_graph()