Source code for

# -*- coding: utf-8 -*-
# File:

import tensorflow as tf
from contextlib import contextmanager
from time import time as timer
import traceback
import six

from .base import Callback
from .hooks import CallbackToHook
from ..utils import logger
from ..utils.utils import humanize_time_delta

if six.PY3:
    from time import perf_counter as timer  # noqa

__all__ = ['Callbacks']

class CallbackTimeLogger(object):
    def __init__(self):
        self.times = []
        self.tot = 0

    def add(self, name, time):
        self.tot += time
        self.times.append((name, time))

    def timed_callback(self, name):
        s = timer()
        self.add(name, timer() - s)

    def log(self):

        """ log the time of some heavy callbacks """
        if self.tot < 3:
        msgs = []
        for name, t in self.times:
            if t / self.tot > 0.3 and t > 1:
                msgs.append(name + ": " + humanize_time_delta(t))
            "Callbacks took {:.3f} sec in total. {}".format(
                self.tot, '; '.join(msgs)))

[docs]class Callbacks(Callback): """ A container to hold all callbacks, and trigger them iteratively. Note that it does nothing to before_run/after_run. """
[docs] def __init__(self, cbs): """ Args: cbs(list): a list of :class:`Callback` instances. """ # check type for cb in cbs: assert isinstance(cb, Callback), cb.__class__ = cbs
def _setup_graph(self): with tf.name_scope(None): # clear the name scope for cb in cb.setup_graph(self.trainer) def _before_train(self): for cb in cb.before_train() def _after_train(self): for cb in # make sure callbacks are properly finalized try: cb.after_train() except Exception: traceback.print_exc() def get_hooks(self): return [CallbackToHook(cb) for cb in] def trigger_step(self): for cb in cb.trigger_step() def _trigger_epoch(self): tm = CallbackTimeLogger() for cb in display_name = str(cb) with tm.timed_callback(display_name): cb.trigger_epoch() tm.log() def _before_epoch(self): for cb in cb.before_epoch() def _after_epoch(self): for cb in cb.after_epoch()