Training

Utilities to train a model.

class ttools.training.ModelInterface[source]

An adapter to run or train a model.

backward(batch, forward_data)[source]

Computes gradients, take an optimizer step and update the model.

Parameters:
  • batch (dict) – batch of data provided by a data pipeline.
  • forward_data (dict) – outputs from the forward pass
Returns:

a dictionary of outputs

Return type:

backward_data (dict)

finalize_validation(running_data)[source]

Computes the final validation aggregates from the running data.

The default implementation is a no-op

Parameters:running_data (dict) – current aggregates of the validation loop.
Returns:initialized values
Return type:validation_data (dict)
forward(batch)[source]

Runs the model on a batch of data.

Parameters:batch (dict) – batch of data provided by a data pipeline.
Returns:a dictionary of outputs
Return type:forward_data (dict)
init_validation()[source]

Initializes the quantities to be reported during validation.

The default implementation is a no-op

Returns:initialized values
Return type:data (dict)
update_validation(batch, fwd, running_data)[source]

Updates the running val data using the current batch’s forward output.

The default implementation is a no-op

Parameters:
  • batch (dict) – batch of data provided by a data pipeline.
  • fwd (dict) – data from one forward step in validation mode
  • running_data (dict) – current aggregates of the validation loop.
Returns:

initialized values

Return type:

updated_data (dict)

class ttools.training.Trainer(interface)[source]

Implements a simple training loop with hooks for callbacks.

Parameters:interface (ModelInterface) – adapter to run forward and backward pass on the model being trained.
callbacks

hooks that will be called while training progresses.

Type:list of Callbacks
add_callback(callback)[source]

Adds a callback to the list of training hooks.

train(dataloader, num_epochs=None, val_dataloader=None)[source]

Main training loop. This starts the training procedure.

Parameters:
  • dataloader (DataLoader) – loader that yields training batches.
  • num_epochs (int, optional) – max number of epochs to run.
  • val_dataloader (DataLoader, optional) – loader that yields validation batches
class ttools.training.Checkpointer(root, model=None, meta=None, optimizers=None, prefix=None)[source]

Save and restore model and optimizer variables.

Parameters:
  • root (string) – path to the root directory where the files are stored.
  • model (torch.nn.Module) –
  • optimizers (list of torch.optimizer) – optimizers whose parameters will be checkpointed together with the model.
  • meta (dict) –
delete(path)[source]

Delete checkpoint at path.

load(path)[source]

Loads a checkpoint, updates the model and returns extra data.

Parameters:path (string) – path to the checkpoint file, relative to the root dir.
Returns:extra information passed by the user at save time. meta (dict): metaparameters of the model passed at save time.
Return type:extras (dict)
load_latest()[source]

Try to load the most recent checkpoint, skip failing files.

Returns:extra information passed by the user at save time. meta (dict): metaparameters of the model passed at save time.
Return type:extras (dict)
static load_meta(root, prefix=None)[source]

Fetch model metadata without touching the saved parameters.

save(path, extras=None)[source]

Save model, metaparams and extras to relative path.

Parameters:
  • path (string) – relative path to the file being saved (without extension).
  • extras (dict) – extra user-provided information to be saved with the model.
sorted_checkpoints()[source]

Get list of all checkpoints in root directory, sorted by creation date.

class ttools.training.BasicArgumentParser(*args, **kwargs)[source]

A basic argument parser with commonly used training options.