Source code for ttools.interfaces

"""A collection of fully-specified model interfaces."""
import abc
import logging

import torch as th

from . import ModelInterface

from .utils import get_logger


LOG = get_logger(__name__)


# HAS_AMP = False
# if th.cuda.is_available():
#     try:
#         HAS_AMP = True
#         from apex import amp, optimizers
#         LOG.info("Amp FP16 available")
#     except:
#         LOG.warn("Amp FP16 is not available")


[docs]class GANInterface(ModelInterface, abc.ABC): """Abstract GAN interface. Args: gen(th.nn.Module): generator. discrim(th.nn.Module): discriminator. lr(float): learning rate for both discriminator and generator. ncritic(int): number of discriminator updates per generator update. opt(str): optimizer type for both discriminator and generator. cuda(bool): whether or not to use CUDA. max_grad_norm(None or scalar): clip gradients above that threshold if provided. """ def __init__(self, gen, discrim, lr=1e-4, ncritic=1, opt="rmsprop", cuda=th.cuda.is_available(), gan_weight=1.0, max_grad_norm=None): super(GANInterface, self).__init__() self.gen = gen self.discrim = discrim self.ncritic = ncritic self.gan_weight = gan_weight self.max_grad_norm = max_grad_norm if self.gan_weight == 0: LOG.warning("GAN interface %s has gan_weight==0", self.__class__.__name__) self.discrim = None if self.discrim is None: LOG.warning("Using a GAN interface (%s) with no discriminator", self.__class__.__name__) else: LOG.info("Using GAN (%s) loss with weight %.5f", self.__class__.__name__, self.gan_weight) # number of discriminator iterations self.iter = 0 self.device = "cpu" if cuda: self.device = "cuda" self.gen.to(self.device) if self.discrim is not None: self.discrim.to(self.device) self.opt_d = None if opt == "sgd": self.opt_g = th.optim.SGD(self.gen.parameters(), lr=lr) if self.discrim is not None: self.opt_d = th.optim.SGD(self.discrim.parameters(), lr=lr) elif opt == "adam": LOG.warn("Using a momentum-based optimizer in the discriminator," " this can be problematic.") self.opt_g = th.optim.Adam( self.gen.parameters(), lr=lr, betas=(0.5, 0.999)) if self.discrim is not None: self.opt_d = th.optim.Adam( self.discrim.parameters(), lr=lr, betas=(0.5, 0.999)) elif opt == "rmsprop": self.opt_g = th.optim.RMSprop(self.gen.parameters(), lr=lr) if self.discrim is not None: self.opt_d = th.optim.RMSprop(self.discrim.parameters(), lr=lr) else: raise ValueError("invalid optimizer %s" % opt)
[docs] @abc.abstractmethod def forward(self, batch): """Abstract method that computes the generator output. Implement in derived classes. """ pass
@abc.abstractmethod def _discriminator_input(self, batch, fwd_data, fake=False): """Abstract method that selects the discriminator's input. The discriminator input is typically the output of the forward pass, `fwd_data` when testing a `fake` sample or some true data from the input `batch`. Args: batch: a batch of data generated by a `Dataset` class. fwd_data: the output of this class's `forward` method. fake(bool): if True we're providing a fake sample to the discriminator, otherwise a true example. Retuns: Tensor or list of tensors Implement in derived classes. """ pass @abc.abstractmethod def _discriminator_gan_loss(self, fake_pred, real_pred): """Compute the GAN loss for the discriminator. Args: fake_pred(th.Tensor): discriminator output for the fake sample. real_pred(th.Tensor): discriminator output for the real sample. Returns: th.Tensor: a scalar loss value. Implement in derived classes. """ pass @abc.abstractmethod def _generator_gan_loss(self, fake_pred, real_pred): """Compute the GAN loss for the generator. Args: fake_pred(th.Tensor): discriminator output for the fake sample. real_pred(th.Tensor): discriminator output for the real sample. Returns: th.Tensor: a scalar loss value. Implement in derived classes. """ pass def _extra_generator_loss(self, batch, fwd_data): """Computes extra losses for the generator if needed. Returns: None or list of th.Tensor with shape [1], the total extra loss. """ return None def _eval_d(self, d_inputs, backprop): """Eval the discriminators (optionally prevent backprop to inputs). Args: discrim_inputs (Tensor or list of Tensor): inputs to the discriminator. backrop: if False, the inputs are detached from the graph (e.g. for the discriminator update we do not update the generated tensors). Returns: """ if isinstance(d_inputs, list): args = d_inputs else: # assumes single input args = [d_inputs] # Detach the inputs to avoid backprops if not backprop: args = [a.detach() for a in args] return self.discrim(*args)
[docs] def backward(self, batch, fwd_data): """Generic GAN backward step. Alternates between `n_critic` discriminator updates and a single generator update. Only uses `extra_generator_loss` as objective when `gan_weight==0`. """ losses = self._extra_generator_loss(batch, fwd_data) if losses is None: extra_losses = [] extra_g_loss = None else: extra_losses = [l.item() for l in losses] extra_g_loss = sum(losses) # No discriminator needed, just use the extra losses if self.discrim is None: if extra_g_loss is None: LOG.error("Training a GAN with no discriminator and no extra " "loss: nothing to optimize!") raise RuntimeError("Training a GAN with no discriminator" " and no extra loss: nothing to optimize!") # Update the generator with only the non-GAN losses self.opt_g.zero_grad() extra_g_loss.backward() if self.max_grad_norm is not None: nrm = th.nn.utils.clip_grad_norm_(self.gen.parameters(), self.max_grad_norm) if nrm > self.max_grad_norm: LOG.warning("Clipping generator gradients. norm = %.3f > %.3f", nrm, self.max_grad_norm) self.opt_g.step() return {"loss_g": None, "loss_d": None, "loss": extra_g_loss.item(), "extra_losses": extra_losses} # If we reach this point, we have a discriminator loss_g = None loss_d = None if self.iter < self.ncritic: # Update discriminator # We detach the generated samples, so that no grads propagate to # the generator here. fake_pred = self._eval_d( self._discriminator_input(batch, fwd_data, fake=True), False) real_pred = self._eval_d( self._discriminator_input(batch, fwd_data, fake=False), True) loss_d = self._update_discriminator(fake_pred, real_pred) self.iter += 1 else: # Update generator self.iter = 0 # reset discrim it counter # classify real/fake fake_in = self._discriminator_input(batch, fwd_data, fake=True) fake_pred_g = self._eval_d(fake_in, True) real_in = self._discriminator_input(batch, fwd_data, fake=False) real_pred_g = self._eval_d(real_in, True) loss_g = self._update_generator(fake_pred_g, real_pred_g, extra_g_loss) if extra_g_loss is not None: extra_g_loss = extra_g_loss.item() return {"loss_g": loss_g, "loss_d": loss_d, "loss": extra_g_loss, "extra_losses": extra_losses}
def _update_discriminator(self, fake_pred, real_pred): """Generic discriminator update. """ loss_d = self._discriminator_gan_loss(fake_pred, real_pred) total_loss = loss_d * self.gan_weight self.opt_d.zero_grad() total_loss.backward() if self.max_grad_norm is not None: nrm = th.nn.utils.clip_grad_norm_(self.discrim.parameters(), self.max_grad_norm) if nrm > self.max_grad_norm: LOG.warning("Clipping discriminator gradients. norm = %.3f > %.3f", nrm, self.max_grad_norm) self.opt_d.step() return loss_d.item() def _update_generator(self, fake_pred, real_pred, extra_loss): """Generic generator update. Combines the GAN objective with extra losses if provided. Args: fake_pred(th.Tensor): output of the discriminator on fake predictions. real_pred(th.Tensor): output of the discriminator on real predictions. """ loss_g = self._generator_gan_loss(fake_pred, real_pred) total_loss = loss_g * self.gan_weight # We have non-GAN terms in the loss if extra_loss is not None: total_loss = total_loss + extra_loss self.opt_g.zero_grad() total_loss.backward() if self.max_grad_norm is not None: nrm = th.nn.utils.clip_grad_norm_(self.gen.parameters(), self.max_grad_norm) if nrm > self.max_grad_norm: LOG.warning("Clipping generator gradients. norm = %.3f > %.3f", nrm, self.max_grad_norm) self.opt_g.step() return loss_g.item()
[docs]class SGANInterface(GANInterface): """Standard GAN interface [Goodfellow2014].""" def __init__(self, *args, **kwargs): super(SGANInterface, self).__init__(*args, **kwargs) self.cross_entropy = th.nn.BCEWithLogitsLoss() def _discriminator_gan_loss(self, fake_pred, real_pred): real_loss = self.cross_entropy(real_pred, th.ones_like(real_pred)) fake_loss = self.cross_entropy(fake_pred, th.zeros_like(fake_pred)) loss_d = 0.5*(fake_loss + real_loss) return loss_d def _generator_gan_loss(self, fake_pred, real_pred): loss_g = self.cross_entropy(fake_pred, th.ones_like(fake_pred)) return loss_g
[docs]class RGANInterface(SGANInterface): """Relativistic GAN interface [Jolicoeur-Martineau2018]. https://arxiv.org/abs/1807.00734 """ def _discriminator_gan_loss(self, fake_pred, real_pred): loss_d = self.cross_entropy( real_pred - fake_pred, th.ones_like(real_pred)) return loss_d def _generator_gan_loss(self, fake_pred, real_pred): loss_g = self.cross_entropy( fake_pred - real_pred, th.ones_like(fake_pred)) return loss_g
[docs]class RaGANInterface(SGANInterface): """Relativistic average GAN interface [Jolicoeur-Martineau2018]. https://arxiv.org/abs/1807.00734 """ def _discriminator_gan_loss(self, fake_pred, real_pred): loss_real = self.cross_entropy( real_pred-fake_pred.mean(), th.ones_like(real_pred)) loss_fake = self.cross_entropy( fake_pred-real_pred.mean(), th.zeros_like(fake_pred)) loss_d = 0.5*(loss_real + loss_fake) return loss_d def _generator_gan_loss(self, fake_pred, real_pred): loss_real = self.cross_entropy( real_pred-fake_pred.mean(), th.zeros_like(real_pred)) loss_fake = self.cross_entropy( fake_pred-real_pred.mean(), th.ones_like(fake_pred)) loss_g = 0.5*(loss_real + loss_fake) return loss_g
[docs]class LSGANInterface(GANInterface): """Least-squares GAN interface [Mao2017]. """ def __init__(self, *args, **kwargs): super(LSGANInterface, self).__init__(*args, **kwargs) self.mse = th.nn.MSELoss() def _discriminator_gan_loss(self, fake_pred, real_pred): fake_loss = self.mse(fake_pred, th.zeros_like(fake_pred)) real_loss = self.mse(real_pred, th.ones_like(real_pred)) loss_d = 0.5*(fake_loss + real_loss) return loss_d def _generator_gan_loss(self, fake_pred, real_pred): loss_g = self.mse(fake_pred, th.ones_like(fake_pred)) return loss_g
[docs]class RaLSGANInterface(LSGANInterface): """Relativistic average Least-squares GAN interface [Jolicoeur-Martineau2018]. https://arxiv.org/abs/1807.00734 """ def _discriminator_gan_loss(self, fake_pred, real_pred): # NOTE: -1, 1 targets loss_real = self.mse( real_pred-fake_pred.mean(), th.ones_like(real_pred)) loss_fake = self.mse( fake_pred-real_pred.mean(), -th.ones_like(fake_pred)) loss_d = 0.5*(loss_real + loss_fake) return loss_d def _generator_gan_loss(self, fake_pred, real_pred): # NOTE: -1, 1 targets loss_real = self.mse( real_pred-fake_pred.mean(), -th.ones_like(real_pred)) loss_fake = self.mse( fake_pred-real_pred.mean(), th.ones_like(fake_pred)) loss_g = 0.5*(loss_real + loss_fake) return loss_g
[docs]class WGANInterface(GANInterface): """Wasserstein GAN. Args: c (float): clipping parameter for the Lipschitz constant of the discriminator. """ def __init__(self, *args, c=0.1, **kwargs): super(WGANInterface, self).__init__(*args, **kwargs) assert c > 0, "clipping param should be positive." self.c = c def _discriminator_gan_loss(self, fake_pred, real_pred): # minus sign for gradient ascent loss_d = - (real_pred.mean() - fake_pred.mean()) return loss_d def _update_discriminator(self, fake_pred, real_pred): loss_d_scalar = super(WGANInterface, self)._update_discriminator( fake_pred, real_pred) # Clip discriminator parameters to enforce Lipschitz constraint for p in self.discrim.parameters(): p.data.clamp_(-self.c, self.c) return loss_d_scalar def _generator_gan_loss(self, fake_pred, real_pred): loss_g = -fake_pred.mean() return loss_g