Source code for ttools.callbacks

"""Callbacks that can be added to a model trainer's main loop."""
# TODO: implement experiment logger
# TODO: implement csv logger

import abc
import logging
import random
import string
import subprocess
import time
import numpy as np

from tqdm import tqdm
from torchvision.utils import make_grid
import visdom

from .utils import ExponentialMovingAverage


__all__ = [
    "Callback",
    "CheckpointingCallback",
    "LoggingCallback",
    "ImageDisplayCallback",
    "ProgressBarCallback",
    "VisdomLoggingCallback",
    "MultiPlotCallback",
]


LOG = logging.getLogger(__name__)


[docs]class Callback(object): """Base class for all training callbacks.""" def __repr__(self): return self.__class__.__name__ def __init__(self): super(Callback, self).__init__() self.epoch = 0 self.batch = 0 self.datasize = 0 self.val_datasize = 0 def training_start(self, dataloader): self.datasize = len(dataloader) def training_end(self): pass
[docs] def epoch_start(self, epoch): """Hook to execute code when a new epoch starts. Note: self.epoch is never incremented. Instead, it should be set by the caller. """ self.epoch = epoch
[docs] def epoch_end(self): """Hook to execute code when an epoch ends. Note: self.epoch is never incremented, but it is set externally in `epoch_start`. """ pass
def validation_start(self, dataloader): self.val_datasize = len(dataloader) def validation_step(self, batch, fwd_data, val_data): pass def validation_end(self, val_data): pass def batch_start(self, batch, batch_data): self.batch = batch def batch_end(self, batch_data, fwd_result, bwd_result): pass
class KeyedCallback(Callback): """An abstract Callback that performs the same action for all keys in a list. The keys (resp. val_keys) are used to access the backward_data (resp. validation_data) produced by a ModelInterface. Args: keys (list of str): list of keys whose values will be logged during training. Defaults to ["loss"]. val_keys (list of str): list of keys whose values will be logged during validation """ def __init__(self, keys=None, val_keys=None, smoothing=0.99): super(KeyedCallback, self).__init__() if keys is None: self.keys = ["loss"] else: self.keys = keys if val_keys is None: self.val_keys = [] # self.val_keys = self.keys else: self.val_keys = val_keys self.ema = ExponentialMovingAverage(self.keys, alpha=smoothing) def batch_end(self, batch_data, fwd, bwd): for k in self.keys: self.ema.update(k, bwd[k])
[docs]class VisdomLoggingCallback(KeyedCallback): """A callback that logs scalar quantities to a visdom server. Args: keys (list of str): list of keys whose values will be logged during training. val_keys (list of str): list of keys whose values will be logged during validation frequency(int): number of steps between display updates. port (int): Port of the Visdom server to log to. env (string): name of the Visdom environment to log to. log (bool): if True, shows the data on a log-scale smoothing(float): smoothing factor for the exponential moving average. 0.0 disables smoothing. """ def __init__(self, keys=None, val_keys=None, frequency=100, server=None, port=8097, env="main", log=False, smoothing=0.99): super(VisdomLoggingCallback, self).__init__( keys=keys, val_keys=val_keys, smoothing=smoothing) if server is None: server = "http://localhost" self._api = visdom.Visdom(server=server, port=port, env=env) self._opts = {} # Cleanup previous plots and setup options all_keys = set(self.keys + self.val_keys) for k in list(all_keys): if self._api.win_exists(k): self._api.close(k) legend = [] if k in self.keys: legend.append("train") if k in self.val_keys: legend.append("val") self._opts[k] = { "legend": legend, "title": k, "xlabel": "epoch", "ylabel": k} if log: self._opts[k]["ytype"] = "log" self._step = 0 self.frequency = frequency def batch_end(self, batch_data, fwd, bwd): super(VisdomLoggingCallback, self).batch_end(batch_data, fwd, bwd) if self._step % self.frequency != 0: self._step += 1 return self._step = 0 t = self.batch / self.datasize + self.epoch for k in self.keys: self._api.line([self.ema[k]], [t], update="append", win=k, name="train", opts=self._opts[k]) self._step += 1 def validation_end(self, val_data): super(VisdomLoggingCallback, self).validation_end(val_data) t = self.epoch + 1 for k in self.val_keys: self._api.line([val_data[k]], [t], update="append", win=k, name="val", opts=self._opts[k])
[docs]class MultiPlotCallback(KeyedCallback): """A callback that logs scalar quantities to a single Visdom window. Args: keys (list of str): list of keys whose values will be logged during training. val_keys (list of str): list of keys whose values will be logged during validation frequency(int): number of steps between display updates. port (int): Port of the Visdom server to log to. env (string): name of the Visdom environment to log to. log (bool): if True, shows the data on a log-scale smoothing(float): smoothing factor for the exponential moving average. 0.0 disables smoothing. win(str): name of the window """ def __init__(self, keys=None, val_keys=None, frequency=100, server=None, port=8097, env="main", log=False, smoothing=0.99, win=None): super(MultiPlotCallback, self).__init__( keys=keys, val_keys=val_keys, smoothing=smoothing) if server is None: server = "http://localhost" self._api = visdom.Visdom(server=server, port=port, env=env) if win is None: self.win = _random_string() else: self.win = win if self._api.win_exists(win): self._api.close(win) # Cleanup previous plots and setup options all_keys = set(self.keys + self.val_keys) legend = [] for k in list(all_keys): if k in self.keys: legend.append(k) if k in self.val_keys: legend.append(k + "_val") self._opts = { "legend": legend, "title": self.win, "xlabel": "epoch", } if log: self._opts["ytype"] = "log" self._step = 0 self.frequency = frequency def batch_end(self, batch_data, fwd, bwd): super(MultiPlotCallback, self).batch_end(batch_data, fwd, bwd) if self._step % self.frequency != 0: self._step += 1 return self._step = 0 t = self.batch / self.datasize + self.epoch data = np.array([self.ema[k] for k in self.keys]) data = np.expand_dims(data, 1) self._api.line(data, [t], update="append", win=self.win, opts=self._opts) self._step += 1 def validation_end(self, val_data): pass
# super(MultiPlotCallback, self).validation_end(val_data) # t = self.epoch + 1 # for k in self.val_keys: # self._api.line([val_data[k]], [t], update="append", win=self.win, name=k + "_val", # opts=self._opts)
[docs]class LoggingCallback(KeyedCallback): """A callback that logs scalar quantities to the console. Make sure python's logging level is at least info to see the console prints. Args: name (str): name of the logger keys (list of str): list of keys whose values will be logged during training. val_keys (list of str): list of keys whose values will be logged during validation """ TABSTOPS = 2 def __init__(self, name, keys=None, val_keys=None, frequency=100, smoothing=0.99): super(LoggingCallback, self).__init__(keys=keys, val_keys=val_keys, smoothing=smoothing) self.log = logging.getLogger(name) self.log.setLevel(logging.INFO) self.m_indent = 0 self._step = 0 self.frequency = frequency def __print(self, s): self.log.info(self.m_indent*LoggingCallback.TABSTOPS*' ' + s) def __indent(self, n=1): self.m_indent += n def __unindent(self, n=1): self.m_indent = max(0, self.m_indent-n) def training_start(self, dataloader): super(LoggingCallback, self).training_start(dataloader) self.__print("Training start") def training_end(self): super(LoggingCallback, self).training_end() self.__print("Training ended at epoch {}".format(self.epoch + 1))
[docs] def epoch_start(self, epoch): super(LoggingCallback, self).epoch_start(epoch) self.__print("-- Epoch {} ".format(self.epoch + 1) + "-"*12)
def validation_start(self, dataloader): super(LoggingCallback, self).validation_start(dataloader) self.__indent() # self.__print("Validation {}".format(self.epoch)) def validation_end(self, val_data): super(LoggingCallback, self).validation_end(val_data) s = "Validation {} | ".format(self.epoch + 1) for k in self.keys: value = val_data.get(k, -1.0) # return -1 if the value is none s += "{} = {:.2f} ".format(k, value) self.__print(s) self.__unindent()
[docs] def batch_end(self, batch_data, fwd, bwd_data): """Logs training advancement Epoch.Batch""" super(LoggingCallback, self).batch_end(batch_data, fwd, bwd_data) if self._step % self.frequency != 0: self._step += 1 return self._step = 0 s = "Step {}.{}".format(self.epoch + 1, self.batch + 1) for k in self.keys: print(bwd_data) value = bwd_data[k] if value is not None: s += " | {} = {:.2f}".format(k, value) self.__print(s) self._step += 1
[docs]class ProgressBarCallback(KeyedCallback): """A progress bar optimization logger.""" def __init__(self, keys=None, val_keys=None, smoothing=0.99): super(ProgressBarCallback, self).__init__(keys=keys, val_keys=val_keys, smoothing=smoothing) self.pbar = None def training_start(self, dataloader): super(ProgressBarCallback, self).training_start(dataloader) print("Training start") def training_end(self): super(ProgressBarCallback, self).training_end() print("Training ends")
[docs] def epoch_start(self, epoch): super(ProgressBarCallback, self).epoch_start(epoch) self.pbar = tqdm(total=self.datasize, unit=" batches", desc="Epoch {}".format(self.epoch + 1))
[docs] def epoch_end(self): super(ProgressBarCallback, self).epoch_end() self.pbar.close() self.pbar = None
def validation_start(self, dataloader): super(ProgressBarCallback, self).validation_start(dataloader) print("Running validation...") self.pbar = tqdm(total=len(dataloader), unit=" batches", desc="Validation {}".format(self.epoch + 1)) def validation_step(self, batch, fwd_data, val_data): self.pbar.update(1) def validation_end(self, val_data): super(ProgressBarCallback, self).validation_end(val_data) self.pbar.close() self.pbar = None s = " "*ProgressBarCallback.TABSTOPS + "Validation {} | ".format( self.epoch + 1) for k in self.val_keys: s += "{} = {:.2f} ".format(k, val_data[k]) print(s) def batch_end(self, batch_data, fwd, bwd_data): super(ProgressBarCallback, self).batch_end(batch_data, fwd, bwd_data) d = {} for k in self.keys: d[k] = self.ema[k] self.pbar.update(1) self.pbar.set_postfix(d) TABSTOPS = 2
[docs]class CheckpointingCallback(Callback): """A callback that periodically saves model checkpoints to disk. Args: checkpointer (Checkpointer): actual checkpointer responsible for the I/O. start_epoch (int or None): index of the starting epoch (e.g. when resuming from a previous checkpoint). interval (int, optional): minimum time in seconds between periodic checkpoints (within an epoch). There is not periodic checkpoint if this value is None. max_files (int, optional): maximum number of periodic checkpoints to keep on disk. max_epochs (int, optional): maximum number of epoch checkpoints to keep on disk. """ PERIODIC_PREFIX = "periodic_" EPOCH_PREFIX = "epoch_" def __init__(self, checkpointer, start_epoch=None, interval=600, max_files=5, max_epochs=10): super(CheckpointingCallback, self).__init__() self.checkpointer = checkpointer self.interval = interval self.max_files = max_files self.max_epochs = max_epochs self.last_checkpoint_time = time.time() if start_epoch is None: self.start_epoch = 0 else: self.start_epoch = start_epoch def get_extras(self): return {"epoch": self.epoch + self.start_epoch}
[docs] def epoch_end(self): """Save a checkpoint at the end of each epoch.""" super(CheckpointingCallback, self).epoch_end() path = "{}{}".format(CheckpointingCallback.EPOCH_PREFIX, self.epoch) self.checkpointer.save(path, extras=self.get_extras()) self.__purge_old_files()
def training_end(self): super(CheckpointingCallback, self).training_end() self.checkpointer.save("training_end", extras=self.get_extras())
[docs] def batch_end(self, batch_data, fwd_result, bwd_result): """Save a periodic checkpoint if requested.""" super(CheckpointingCallback, self).batch_end( batch_data, fwd_result, bwd_result) if self.interval is None: # We skip periodic checkpoints return now = time.time() delta = now - self.last_checkpoint_time if delta < self.interval: # last checkpoint is too recent return LOG.debug("Periodic checkpoint") self.last_checkpoint_time = now filename = "{}{}".format(CheckpointingCallback.PERIODIC_PREFIX, time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime())) self.checkpointer.save(filename, extras=self.get_extras()) self.__purge_old_files()
def __purge_old_files(self): """Delete checkpoints that are beyond the max to keep.""" chkpts = self.checkpointer.sorted_checkpoints() p_chkpts = [] e_chkpts = [] for c in chkpts: if c.startswith(self.checkpointer.prefix + CheckpointingCallback.PERIODIC_PREFIX): p_chkpts.append(c) if c.startswith(self.checkpointer.prefix + CheckpointingCallback.EPOCH_PREFIX): e_chkpts.append(c) # Delete periodic checkpoints if self.max_files is not None and len(p_chkpts) > self.max_files: for c in p_chkpts[self.max_files:]: LOG.debug("CheckpointingCallback deleting {}".format(c)) self.checkpointer.delete(c) # Delete older epochs if self.max_epochs is not None and len(e_chkpts) > self.max_epochs: for c in e_chkpts[self.max_epochs:]: LOG.debug("CheckpointingCallback deleting (epoch) {}".format(c)) self.checkpointer.delete(c)
[docs]class ImageDisplayCallback(Callback, abc.ABC): """Displays image periodically to a Visdom server. This is an abstract class, subclasses should implement the visualized_image method that synthesizes a [B, C, H, W] image to be visualized. Args: frequency(int): number of optimization steps between two updates port (int): Port of the Visdom server to log to. env (string): name of the Visdom environment to log to. """ def __init__(self, frequency=100, server=None, port=8097, env="main", win=None): super(ImageDisplayCallback, self).__init__() self.freq = frequency if server is None: server = "http://localhost" self._api = visdom.Visdom(server=server, port=port, env=env) self._step = 0 if win is None: self.win = _random_string() else: self.win = win @abc.abstractmethod def visualized_image(self, batch, fwd_result): pass def caption(self, batch, fwd_result): return "" def batch_end(self, batch, fwd_result, bwd_result): if self._step % self.freq != 0: self._step += 1 return self._step = 0 caption = self.caption(batch, fwd_result) opts = {"caption": "Epoch {}, batch {}: {}".format( self.epoch, self.batch, caption)} viz = self.visualized_image(batch, fwd_result) self._api.images(viz, win=self.win, opts=opts) self._step += 1 def validation_start(self, dataloader): super(ImageDisplayCallback, self).validation_start(dataloader) self.first_step = True def validation_step(self, batch, fwd_data, val_data): super(ImageDisplayCallback, self).validation_step(batch, fwd_data, val_data) if not self.first_step: return caption = self.caption(batch, fwd_data) opts = {"caption": "Validation {}, batch {}: {}".format( self.epoch, self.batch, caption)} viz = self.visualized_image(batch, fwd_data) self._api.images(viz, win=self.win+"_val", opts=opts) self.first_step = False
class ExperimentLoggerCallback(Callback): """A callback that logs experiment parameters in a log.""" def __init__(self, fname, meta=None): super(ExperimentLoggerCallback, self).__init__() LOG.error("ExperimentLoggerCallback is not implemented yet") raise NotImplementedError("ExperimentLoggerCallback is not implemented yet") def training_start(self, dataloader): super(ExperimentLoggerCallback, self).training_start(dataloader) print("logging experiment with", self.datasize) def training_end(self): super(ExperimentLoggerCallback, self).training_end() print("end logging experiment", self.epoch, self.batch) def _get_commit(self): return subprocess.check_output(["git", "rev-parse", "HEAD"]) class CSVLoggingCallback(KeyedCallback): """A callback that logs scalar quantities to a .csv file. Format is: epoch, step, event, key, value """ def __init__(self, fname, keys=None, val_keys=None, smoothing=0): super(CSVLoggingCallback, self).__init__(keys=keys, val_keys=val_keys, smoothing=smoothing) LOG.error("CSVLoggingCallback is not implemented yet") raise NotImplementedError("CSVLoggingCallback is not implemented yet") self.fname = fname self.fid = open(self.fname, 'w') self.fid.write("epoch, step, event, key, value\n") self.fid.write(",,logger_created,,\n") # open file, check last event def __del__(self): LOG.info("deleting csv logger") self.fid.write(",,logger_deleted,,\n") self.fid.close() def batch_end(self, batch_data, fwd, bwd_data): """Logs training advancement Batch""" super(CSVLoggingCallback, self).batch_end(batch_data, fwd, bwd_data) for k in self.keys: v = bwd_data[k] self.fid.write("%d,%d,batch_end,%s,%f\n" % (self.epoch, self.batch, k, v)) def training_start(self, dataloader): super(CSVLoggingCallback, self).training_start(dataloader) self.fid.write(",,training_start,,\n") def training_end(self): super(CSVLoggingCallback, self).training_end() self.fid.write(",,training_end,,\n") def _random_string(size=16): return ''.join([random.choice(string.ascii_letters) for i in range(size)]) # Tensorboard interface class TensorBoardLoggingCallback(Callback): """A callback that logs scalar quantities to TensorBoard Args: keys (list of str): list of keys whose values will be logged during training. val_keys (list of str): list of keys whose values will be logged during validation frequency(int): number of steps between display updates. log_di (str) """ def __init__(self, writer, val_writer, keys=None, val_keys=None, frequency=100, summary_type='scalar'): super(TensorBoardLoggingCallback, self).__init__() self.keys = keys self.val_keys = val_keys or self.keys self._writer = writer self._val_writer = val_writer self._step = 0 self.frequency = frequency self.summary_type = summary_type def batch_end(self, batch_data, fwd, bwd): super(TensorBoardLoggingCallback, self).batch_end(batch_data, fwd, bwd) if self._step % self.frequency != 0: self._step += 1 return self._step = 0 t = self.batch + self.datasize * self.epoch for k in self.keys: if self.summary_type == 'scalar': self._writer.add_scalar(k, bwd[k], global_step=t) elif self.summary_type == 'histogram': self._writer.add_histogram(k, bwd[k], global_step=t) self._step += 1 def validation_end(self, val_data): super(TensorBoardLoggingCallback, self).validation_end(val_data) t = self.datasize * (self.epoch+1) for k in self.val_keys: if self.summary_type == 'scalar': self._val_writer.add_scalar(k, val_data[k], global_step=t) elif self.summary_type == 'histogram': self._val_writer.add_histogram(k, val_data[k], global_step=t) class TensorBoardImageDisplayCallback(Callback, abc.ABC): """Displays image periodically to TensorBoard. This is an abstract class, subclasses should implement the visualized_image method that synthesizes a [B, C, H, W] image to be visualized. Args: frequency(int): number of optimization steps between two updates """ def __init__(self, writer, val_writer, frequency=100): super(TensorBoardImageDisplayCallback, self).__init__() self._writer = writer self._val_writer = val_writer self.freq = frequency self._step = 0 @abc.abstractmethod def visualized_image(self, batch, fwd_result): pass @abc.abstractmethod def tag(self): pass def batch_end(self, batch, fwd_result, bwd_result): if self._step % self.freq != 0: self._step += 1 return self._step = 0 viz = self.visualized_image(batch, fwd_result) t = self.batch + self.datasize * self.epoch self._writer.add_image(self.tag(), make_grid(viz), t) self._step += 1 def validation_start(self, dataloader): super(TensorBoardImageDisplayCallback, self).validation_start(dataloader) self.first_step = True def validation_step(self, batch, fwd_data, val_data): super(TensorBoardImageDisplayCallback, self).validation_step(batch, fwd_data, val_data) if not self.first_step: return viz = self.visualized_image(batch, fwd_data) t = self.datasize * (self.epoch+1) self._val_writer.add_image(self.tag(), make_grid(viz), t) self.first_step = False