Source code for tensorpack.callbacks.trigger

# -*- coding: utf-8 -*-
# File: trigger.py


from .base import Callback, ProxyCallback

__all__ = ['PeriodicTrigger', 'PeriodicCallback', 'EnableCallbackIf']


[docs]class PeriodicTrigger(ProxyCallback): """ Trigger a callback every k global steps or every k epochs by its :meth:`trigger()` method. Most existing callbacks which do something every epoch are implemented with :meth:`trigger()` method. By default the :meth:`trigger()` method will be called every epoch. This wrapper can make the callback run at a different frequency. All other methods (``before/after_run``, ``trigger_step``, etc) of the given callback are unaffected. They will still be called as-is. """
[docs] def __init__(self, triggerable, every_k_steps=None, every_k_epochs=None, before_train=False): """ Args: triggerable (Callback): a Callback instance with a trigger method to be called. every_k_steps (int): trigger when ``global_step % k == 0``. Set to None to ignore. every_k_epochs (int): trigger when ``epoch_num % k == 0``. Set to None to ignore. before_train (bool): trigger in the :meth:`before_train` method. every_k_steps and every_k_epochs can be both set, but cannot be both None unless before_train is True. """ assert isinstance(triggerable, Callback), type(triggerable) super(PeriodicTrigger, self).__init__(triggerable) if before_train is False: assert (every_k_epochs is not None) or (every_k_steps is not None), \ "Arguments to PeriodicTrigger have disabled the triggerable!" self._step_k = every_k_steps self._epoch_k = every_k_epochs self._do_before_train = before_train
def _before_train(self): self.cb.before_train() if self._do_before_train: self.cb.trigger() def _trigger_step(self): self.cb.trigger_step() if self._step_k is None: return if self.global_step % self._step_k == 0: self.cb.trigger() def _trigger_epoch(self): if self._epoch_k is None: return if self.epoch_num % self._epoch_k == 0: self.cb.trigger() def __str__(self): return "PeriodicTrigger-" + str(self.cb)
[docs]class EnableCallbackIf(ProxyCallback): """ Disable the ``{before,after}_epoch``, ``{before,after}_run``, ``trigger_{epoch,step}`` methods of a callback, unless some condition satisfies. The other methods are unaffected. A more accurate name for this callback should be "DisableCallbackUnless", but that's too ugly. Note: If you use ``{before,after}_run``, ``pred`` will be evaluated only in ``before_run``. """
[docs] def __init__(self, callback, pred): """ Args: callback (Callback): pred (self -> bool): a callable predicate. Has to be a pure function. The callback is disabled unless this predicate returns True. """ self._pred = pred super(EnableCallbackIf, self).__init__(callback)
def _before_run(self, ctx): if self._pred(self): self._enabled = True return super(EnableCallbackIf, self)._before_run(ctx) else: self._enabled = False def _after_run(self, ctx, rv): if self._enabled: super(EnableCallbackIf, self)._after_run(ctx, rv) def _before_epoch(self): if self._pred(self): super(EnableCallbackIf, self)._before_epoch() def _after_epoch(self): if self._pred(self): super(EnableCallbackIf, self)._after_epoch() def _trigger_epoch(self): if self._pred(self): super(EnableCallbackIf, self)._trigger_epoch() def _trigger_step(self): if self._pred(self): super(EnableCallbackIf, self)._trigger_step() def __str__(self): return "EnableCallbackIf-" + str(self.cb)
[docs]class PeriodicCallback(EnableCallbackIf): """ The ``{before,after}_epoch``, ``{before,after}_run``, ``trigger_{epoch,step}`` methods of the given callback will be enabled only when ``global_step % every_k_steps == 0` or ``epoch_num % every_k_epochs == 0``. The other methods are unaffected. Note that this can only makes a callback **less** frequent than itself. If you have a callback that by default runs every epoch by its :meth:`trigger()` method, use :class:`PeriodicTrigger` to schedule it more frequent than itself. """
[docs] def __init__(self, callback, every_k_steps=None, every_k_epochs=None): """ Args: callback (Callback): a Callback instance. every_k_steps (int): enable the callback when ``global_step % k == 0``. Set to None to ignore. every_k_epochs (int): enable the callback when ``epoch_num % k == 0``. Also enable when the last step finishes (``epoch_num == max_epoch`` and ``local_step == steps_per_epoch - 1``). Set to None to ignore. every_k_steps and every_k_epochs can be both set, but cannot be both None. """ assert isinstance(callback, Callback), type(callback) assert (every_k_epochs is not None) or (every_k_steps is not None), \ "every_k_steps and every_k_epochs cannot be both None!" self._step_k = every_k_steps self._epoch_k = every_k_epochs super(PeriodicCallback, self).__init__(callback, PeriodicCallback.predicate)
def predicate(self): if self._step_k is not None and self.global_step % self._step_k == 0: return True if self._epoch_k is not None and self.epoch_num % self._epoch_k == 0: return True if self._epoch_k is not None: if self.local_step == self.trainer.steps_per_epoch - 1 and \ self.epoch_num == self.trainer.max_epoch: return True return False def __str__(self): return "PeriodicCallback-" + str(self.cb)