Interfaces

A collection of fully-specified model interfaces.

class ttools.interfaces.GANInterface(gen, discrim, lr=0.0001, ncritic=1, opt='rmsprop', cuda=<sphinx.ext.autodoc.importer._MockObject object>, gan_weight=1.0, max_grad_norm=None)[source]

Abstract GAN interface.

Parameters:
  • gen (th.nn.Module) – generator.
  • discrim (th.nn.Module) – discriminator.
  • lr (float) – learning rate for both discriminator and generator.
  • ncritic (int) – number of discriminator updates per generator update.
  • opt (str) – optimizer type for both discriminator and generator.
  • cuda (bool) – whether or not to use CUDA.
  • max_grad_norm (None or scalar) – clip gradients above that threshold if provided.
backward(batch, fwd_data)[source]

Generic GAN backward step.

Alternates between n_critic discriminator updates and a single generator update. Only uses extra_generator_loss as objective when gan_weight==0.

forward(batch)[source]

Abstract method that computes the generator output.

Implement in derived classes.

class ttools.interfaces.LSGANInterface(*args, **kwargs)[source]

Least-squares GAN interface [Mao2017].

class ttools.interfaces.RGANInterface(*args, **kwargs)[source]

Relativistic GAN interface [Jolicoeur-Martineau2018].

https://arxiv.org/abs/1807.00734

class ttools.interfaces.RaGANInterface(*args, **kwargs)[source]

Relativistic average GAN interface [Jolicoeur-Martineau2018].

https://arxiv.org/abs/1807.00734

class ttools.interfaces.RaLSGANInterface(*args, **kwargs)[source]

Relativistic average Least-squares GAN interface [Jolicoeur-Martineau2018].

https://arxiv.org/abs/1807.00734

class ttools.interfaces.SGANInterface(*args, **kwargs)[source]

Standard GAN interface [Goodfellow2014].

class ttools.interfaces.WGANInterface(*args, c=0.1, **kwargs)[source]

Wasserstein GAN.

Parameters:c (float) – clipping parameter for the Lipschitz constant of the discriminator.