BatchTimeCallback

class health_ml.utils.BatchTimeCallback(max_batch_load_time_seconds=0.5, max_load_time_warnings=3, max_load_time_epochs=5)[source]

Bases: pytorch_lightning.callbacks.base.Callback

This callback provides tools to measure batch loading time and other diagnostic information. It prints alerts to the console or to logging if the batch loading time is over a threshold for several epochs. Metrics for loading time, as well as epoch time, and maximum and average batch processing time are logged to the loggers that are set up on the module. In distributed training, all logging to the console and to the Lightning loggers will only happen on global rank 0.

The loading time for a minibatch is estimated by the difference between the start time of a minibatch and the end time of the previous minibatch. It will consequently also include other operations that happen between the end of a batch and the start of the next one. For example, computationally expensive callbacks could also drive up this time.

Usage example:
>>> from health_ml.utils import BatchTimeCallback
>>> from pytorch_lightning import Trainer
>>> batchtime = BatchTimeCallback(max_batch_load_time_seconds=0.5)
>>> trainer = Trainer(callbacks=[batchtime])
Parameters
  • max_batch_load_time_seconds (float) – The maximum expected loading time for a minibatch (given in seconds). If the loading time exceeds this threshold, a warning is printed. The maximum number of such warnings is controlled by the other arguments.

  • max_load_time_warnings (int) – The maximum number of warnings about increased loading time that will be printed per epoch. For example, if max_load_time_warnings=3, at most 3 of these warnings will be printed within an epoch. The 4th minibatch with loading time over the threshold would not generate any warning anymore. If set to 0, no warnings are printed at all.

  • max_load_time_epochs (int) – The maximum number of epochs where warnings about the loading time are printed. For example, if max_load_time_epochs=2, and at least 1 batch with increased loading time is observed in epochs 0 and 3, no further warnings about increased loading time would be printed from epoch 4 onwards.

Attributes Summary

BATCH_TIME

The name that is used to log the execution time per batch.

EPOCH_TIME

The name that is used to log the execution time per epoch

EXCESS_LOADING_TIME

The name that is used to log the time spent loading all the batches that exceeding the loading time threshold.

METRICS_PREFIX

The prefix for all metrics collected by this callback.

TRAIN_PREFIX

The prefix for all metrics collected during training.

VAL_PREFIX

The prefix for all metrics collected during validation.

Methods Summary

batch_end(is_training)

Shared code to keep track of minibatch loading times.

batch_start(batch_idx, is_training)

Shared code to keep track of minibatch loading times.

get_timers(is_training)

Gets the object that holds all metrics and timers, for either the validation or the training epoch.

log_metric(name_suffix, value, is_training)

Write a metric given as a name/value pair to the currently trained module.

on_fit_start(trainer, pl_module)

This is called at the start of training.

on_train_batch_end(trainer, pl_module, …)

Called when the train batch ends.

on_train_batch_start(trainer, pl_module, …)

Called when the train batch begins.

on_train_epoch_start(trainer, pl_module)

Called when the train epoch begins.

on_validation_batch_end(trainer, pl_module, …)

Called when the validation batch ends.

on_validation_batch_start(trainer, …)

Called when the validation batch begins.

on_validation_epoch_end(trainer, pl_module)

This is a hook called at the end of a training or validation epoch.

on_validation_epoch_start(trainer, pl_module)

Called when the val epoch begins.

write_and_log_epoch_time(is_training)

Reads the IO timers for either the training or validation epoch, writes them to the console, and logs the time per epoch.

Attributes Documentation

BATCH_TIME = 'batch_time [sec]'

The name that is used to log the execution time per batch.

EPOCH_TIME = 'epoch_time [sec]'

The name that is used to log the execution time per epoch

EXCESS_LOADING_TIME = 'batch_loading_over_threshold [sec]'

The name that is used to log the time spent loading all the batches that exceeding the loading time threshold.

METRICS_PREFIX = 'timing/'

The prefix for all metrics collected by this callback.

TRAIN_PREFIX = 'train/'

The prefix for all metrics collected during training.

VAL_PREFIX = 'val/'

The prefix for all metrics collected during validation.

Methods Documentation

batch_end(is_training)[source]

Shared code to keep track of minibatch loading times. This is only done on global rank zero.

Parameters

is_training (bool) – If true, this has been called from on_train_batch_end, otherwise it has been called from on_validation_batch_end.

Return type

None

batch_start(batch_idx, is_training)[source]

Shared code to keep track of minibatch loading times. This is only done on global rank zero.

Parameters
  • batch_idx (int) – The index of the current minibatch.

  • is_training (bool) – If true, this has been called from on_train_batch_start, otherwise it has been called from on_validation_batch_start.

Return type

None

get_timers(is_training)[source]

Gets the object that holds all metrics and timers, for either the validation or the training epoch.

Return type

EpochTimers

log_metric(name_suffix, value, is_training, reduce_max=False)[source]

Write a metric given as a name/value pair to the currently trained module. The full name of the metric is composed of a fixed prefix “timing/”, followed by either “train/” or “val/”, and then the given suffix.

Parameters
  • name_suffix (str) – The suffix for the logged metric name.

  • value (float) – The value to log.

  • is_training (bool) – If True, use “train/” in the metric name, otherwise “val/”

  • reduce_max (bool) – If True, use torch.max as the aggregation function for the logged values. If False, use torch.mean

Return type

None

on_fit_start(trainer, pl_module)[source]

This is called at the start of training. It stores the model that is being trained, because it will be used later to log values.

Return type

None

on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)[source]

Called when the train batch ends.

Return type

None

on_train_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx)[source]

Called when the train batch begins.

Return type

None

on_train_epoch_start(trainer, pl_module)[source]

Called when the train epoch begins.

Return type

None

on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)[source]

Called when the validation batch ends.

Return type

None

on_validation_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx)[source]

Called when the validation batch begins.

Return type

None

on_validation_epoch_end(trainer, pl_module)[source]

This is a hook called at the end of a training or validation epoch. In here, we can still write metrics to a logger.

Return type

None

on_validation_epoch_start(trainer, pl_module)[source]

Called when the val epoch begins.

Return type

None

write_and_log_epoch_time(is_training)[source]

Reads the IO timers for either the training or validation epoch, writes them to the console, and logs the time per epoch.

Parameters

is_training (bool) – If True, show and log the data for the training epoch. If False, use the data for the validation epoch.

Return type

None