Source code for ttools.utils

"""Helpers classes and functions."""
import time

import torch as th
import numpy as np

import logging
try:
    import coloredlogs
    coloredlogs.install()
    HAS_COLORED_LOGS = True
except:
    HAS_COLORED_LOGS = False


__all__ = ["ExponentialMovingAverage", "Averager", "Timer", "tensor2image", "get_logger", "set_logger"]


[docs]def set_logger(debug=False): """Set the default logging level and log format. Args: debug(bool): if True, enable debug logs. """ log_level = logging.INFO prefix = "[%(process)d] %(levelname)s %(name)s" suffix = " | %(message)s" if debug: log_level = logging.DEBUG prefix += " %(filename)s:%(lineno)s" if HAS_COLORED_LOGS: coloredlogs.install( level=log_level, format=prefix+suffix) else: logging.basicConfig( level=log_level, format=prefix+suffix)
[docs]def get_logger(name): """Get a named logger. Args: name(string): name of the logger """ return logging.getLogger(name)
[docs]class ExponentialMovingAverage(object): """Keyed tracker that maintains an exponential moving average for each key. Args: keys(list of str): keys to track. alpha(float): exponential smoothing factor (higher = smoother). """ def __init__(self, keys, alpha=0.999): self._is_first_update = {k: True for k in keys} self._alpha = alpha self._values = {k: 0 for k in keys} def __getitem__(self, key): return self._values[key] def update(self, key, value): if value is None: return if self._is_first_update[key]: self._values[key] = value self._is_first_update[key] = False else: self._values[key] = self._values[key] * \ self._alpha + value*(1.0-self._alpha)
[docs]class Averager(object): """Keeps track of running averages, for each key.""" def __init__(self, keys): self.values = {k: 0.0 for k in keys} self.counts = {k: 0 for k in keys} def __getitem__(self, key): if self.counts[key] == 0: return 0.0 return self.values[key] * 1.0/self.counts[key] def reset(self): for k in self.values.keys(): self.values[k] = 0.0 self.counts[k] = 0 def update(self, key, value, count=1): if value is None: return self.values[key] += value*count self.counts[key] += count
[docs]def tensor2image(t, normalize=False, dtype=np.uint8): """Converts an tensor image (4D tensor) to a numpy 8-bit array. Args: t(th.Tensor): input tensor with dimensions [bs, c, h, w], c=3, bs=1 normalize(bool): if True, normalize the tensor's range to [0, 1] before clipping Returns: (np.array): [h, w, c] image in uint8 format, with c=3 """ assert len(t.shape) == 4, "expected 4D tensor, got %d dimensions" % len(t.shape) bs, c, h, w = t.shape assert bs == 1, "expected batch_size 1 tensor, got %d" % bs t = t.squeeze(0) assert c == 3 or c == 1, "expected tensor with 1 or 3 channels, got %d" % c if normalize: m = t.min() M = t.max() t = (t-m) / (M-m+1e-8) t = th.clamp(t.permute(1, 2, 0), 0, 1).cpu().detach().numpy() if dtype == np.uint8: return (255.0*t).astype(np.uint8) elif dtype == np.uint16: return ((2**16-1)*t).astype(np.uint16) else: raise ValueError("dtype %s not recognized" % dtype)
[docs]class Timer(object): """A simple timer context. Returns timing in ms Args: sync(bool): if True, synchronize CUDA kernels. """ def __init__(self, sync=True): self._time = 0 self.sync = sync self.elapsed = None def __enter__(self): th.cuda.synchronize() self._time = time.time() return self def __exit__(self, tpye, value, traceback): th.cuda.synchronize() self.elapsed = (time.time()-self._time)*1000