Training¶
Utilities to train a model.
-
class
ttools.training.ModelInterface[source]¶ An adapter to run or train a model.
-
backward(batch, forward_data)[source]¶ Computes gradients, take an optimizer step and update the model.
Parameters: - batch (dict) – batch of data provided by a data pipeline.
- forward_data (dict) – outputs from the forward pass
Returns: a dictionary of outputs
Return type: backward_data (dict)
-
finalize_validation(running_data)[source]¶ Computes the final validation aggregates from the running data.
The default implementation is a no-op
Parameters: running_data (dict) – current aggregates of the validation loop. Returns: initialized values Return type: validation_data (dict)
-
forward(batch)[source]¶ Runs the model on a batch of data.
Parameters: batch (dict) – batch of data provided by a data pipeline. Returns: a dictionary of outputs Return type: forward_data (dict)
-
init_validation()[source]¶ Initializes the quantities to be reported during validation.
The default implementation is a no-op
Returns: initialized values Return type: data (dict)
-
update_validation(batch, fwd, running_data)[source]¶ Updates the running val data using the current batch’s forward output.
The default implementation is a no-op
Parameters: - batch (dict) – batch of data provided by a data pipeline.
- fwd (dict) – data from one forward step in validation mode
- running_data (dict) – current aggregates of the validation loop.
Returns: initialized values
Return type: updated_data (dict)
-
-
class
ttools.training.Trainer(interface)[source]¶ Implements a simple training loop with hooks for callbacks.
Parameters: interface (ModelInterface) – adapter to run forward and backward pass on the model being trained. -
callbacks¶ hooks that will be called while training progresses.
Type: list of Callbacks
-
train(dataloader, num_epochs=None, val_dataloader=None)[source]¶ Main training loop. This starts the training procedure.
Parameters: - dataloader (DataLoader) – loader that yields training batches.
- num_epochs (int, optional) – max number of epochs to run.
- val_dataloader (DataLoader, optional) – loader that yields validation batches
-
-
class
ttools.training.Checkpointer(root, model=None, meta=None, optimizers=None, prefix=None)[source]¶ Save and restore model and optimizer variables.
Parameters: - root (string) – path to the root directory where the files are stored.
- model (torch.nn.Module) –
- optimizers (list of torch.optimizer) – optimizers whose parameters will be checkpointed together with the model.
- meta (dict) –
-
load(path)[source]¶ Loads a checkpoint, updates the model and returns extra data.
Parameters: path (string) – path to the checkpoint file, relative to the root dir. Returns: extra information passed by the user at save time. meta (dict): metaparameters of the model passed at save time. Return type: extras (dict)
-
load_latest()[source]¶ Try to load the most recent checkpoint, skip failing files.
Returns: extra information passed by the user at save time. meta (dict): metaparameters of the model passed at save time. Return type: extras (dict)
-
static
load_meta(root, prefix=None)[source]¶ Fetch model metadata without touching the saved parameters.