BatchTimeCallback
- class health_ml.utils.BatchTimeCallback(max_batch_load_time_seconds=0.5, max_load_time_warnings=3, max_load_time_epochs=5)[source]
Bases:
pytorch_lightning.callbacks.base.Callback
This callback provides tools to measure batch loading time and other diagnostic information. It prints alerts to the console or to logging if the batch loading time is over a threshold for several epochs. Metrics for loading time, as well as epoch time, and maximum and average batch processing time are logged to the loggers that are set up on the module. In distributed training, all logging to the console and to the Lightning loggers will only happen on global rank 0.
The loading time for a minibatch is estimated by the difference between the start time of a minibatch and the end time of the previous minibatch. It will consequently also include other operations that happen between the end of a batch and the start of the next one. For example, computationally expensive callbacks could also drive up this time.
- Usage example:
>>> from health_ml.utils import BatchTimeCallback >>> from pytorch_lightning import Trainer >>> batchtime = BatchTimeCallback(max_batch_load_time_seconds=0.5) >>> trainer = Trainer(callbacks=[batchtime])
- Parameters
max_batch_load_time_seconds (
float
) – The maximum expected loading time for a minibatch (given in seconds). If the loading time exceeds this threshold, a warning is printed. The maximum number of such warnings is controlled by the other arguments.max_load_time_warnings (
int
) – The maximum number of warnings about increased loading time that will be printed per epoch. For example, if max_load_time_warnings=3, at most 3 of these warnings will be printed within an epoch. The 4th minibatch with loading time over the threshold would not generate any warning anymore. If set to 0, no warnings are printed at all.max_load_time_epochs (
int
) – The maximum number of epochs where warnings about the loading time are printed. For example, if max_load_time_epochs=2, and at least 1 batch with increased loading time is observed in epochs 0 and 3, no further warnings about increased loading time would be printed from epoch 4 onwards.
Attributes Summary
The name that is used to log the execution time per batch.
The name that is used to log the execution time per epoch
The name that is used to log the time spent loading all the batches that exceeding the loading time threshold.
The prefix for all metrics collected by this callback.
The prefix for all metrics collected during training.
The prefix for all metrics collected during validation.
Methods Summary
batch_end
(is_training)Shared code to keep track of minibatch loading times.
batch_start
(batch_idx, is_training)Shared code to keep track of minibatch loading times.
get_timers
(is_training)Gets the object that holds all metrics and timers, for either the validation or the training epoch.
log_metric
(name_suffix, value, is_training)Write a metric given as a name/value pair to the currently trained module.
on_fit_start
(trainer, pl_module)This is called at the start of training.
on_train_batch_end
(trainer, pl_module, …)Called when the train batch ends.
on_train_batch_start
(trainer, pl_module, …)Called when the train batch begins.
on_train_epoch_start
(trainer, pl_module)Called when the train epoch begins.
on_validation_batch_end
(trainer, pl_module, …)Called when the validation batch ends.
on_validation_batch_start
(trainer, …)Called when the validation batch begins.
on_validation_epoch_end
(trainer, pl_module)This is a hook called at the end of a training or validation epoch.
on_validation_epoch_start
(trainer, pl_module)Called when the val epoch begins.
write_and_log_epoch_time
(is_training)Reads the IO timers for either the training or validation epoch, writes them to the console, and logs the time per epoch.
Attributes Documentation
- BATCH_TIME = 'batch_time [sec]'
The name that is used to log the execution time per batch.
- EPOCH_TIME = 'epoch_time [sec]'
The name that is used to log the execution time per epoch
- EXCESS_LOADING_TIME = 'batch_loading_over_threshold [sec]'
The name that is used to log the time spent loading all the batches that exceeding the loading time threshold.
- METRICS_PREFIX = 'timing/'
The prefix for all metrics collected by this callback.
- TRAIN_PREFIX = 'train/'
The prefix for all metrics collected during training.
- VAL_PREFIX = 'val/'
The prefix for all metrics collected during validation.
Methods Documentation
- batch_end(is_training)[source]
Shared code to keep track of minibatch loading times. This is only done on global rank zero.
- Parameters
is_training (
bool
) – If true, this has been called from on_train_batch_end, otherwise it has been called from on_validation_batch_end.- Return type
None
- batch_start(batch_idx, is_training)[source]
Shared code to keep track of minibatch loading times. This is only done on global rank zero.
- Parameters
batch_idx (
int
) – The index of the current minibatch.is_training (
bool
) – If true, this has been called from on_train_batch_start, otherwise it has been called from on_validation_batch_start.
- Return type
None
- get_timers(is_training)[source]
Gets the object that holds all metrics and timers, for either the validation or the training epoch.
- Return type
EpochTimers
- log_metric(name_suffix, value, is_training, reduce_max=False)[source]
Write a metric given as a name/value pair to the currently trained module. The full name of the metric is composed of a fixed prefix “timing/”, followed by either “train/” or “val/”, and then the given suffix.
- Parameters
name_suffix (
str
) – The suffix for the logged metric name.value (
float
) – The value to log.is_training (
bool
) – If True, use “train/” in the metric name, otherwise “val/”reduce_max (
bool
) – If True, use torch.max as the aggregation function for the logged values. If False, use torch.mean
- Return type
None
- on_fit_start(trainer, pl_module)[source]
This is called at the start of training. It stores the model that is being trained, because it will be used later to log values.
- Return type
None
- on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)[source]
Called when the train batch ends.
- Return type
None
- on_train_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx)[source]
Called when the train batch begins.
- Return type
None
- on_train_epoch_start(trainer, pl_module)[source]
Called when the train epoch begins.
- Return type
None
- on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)[source]
Called when the validation batch ends.
- Return type
None
- on_validation_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx)[source]
Called when the validation batch begins.
- Return type
None
- on_validation_epoch_end(trainer, pl_module)[source]
This is a hook called at the end of a training or validation epoch. In here, we can still write metrics to a logger.
- Return type
None
- on_validation_epoch_start(trainer, pl_module)[source]
Called when the val epoch begins.
- Return type
None
- write_and_log_epoch_time(is_training)[source]
Reads the IO timers for either the training or validation epoch, writes them to the console, and logs the time per epoch.
- Parameters
is_training (
bool
) – If True, show and log the data for the training epoch. If False, use the data for the validation epoch.- Return type
None