TrainingRunner

class health_ml.TrainingRunner(experiment_config, container, project_root=None)[source]

Bases: health_ml.runner_base.RunnerBase

Driver class to run an ML experiment. Note that the project root argument MUST be supplied when using hi-ml as a package!

Parameters
  • experiment_config (ExperimentConfig) – The ExperimentConfig object to use for training.

  • container (LightningContainer) – The LightningContainer object to use for training.

  • project_root (Optional[Path]) – Project root. This should only be omitted if calling run_ml from the test suite. Supplying it is crucial when using hi-ml as a package or submodule!

Methods Summary

after_ddp_cleanup(environ_before_training)

Run processes cleanup after ddp context to prepare for single device inference.

end_training(environ_before_training)

Cleanup after training is done.

get_data_module()

Reads the datamodule that should be used for training or valuation from the container.

get_multiple_trainloader_mode()

rtype

str

init_inference()

Prepare the trainer for running inference on the validation and test set.

init_training()

Execute some bookkeeping tasks only once if running distributed and initialize the runner’s trainer object.

is_crossval_disabled_or_child_0()

Returns True if the present run is a non-cross-validation run, or child run 0 of a cross-validation run.

run()

Driver function to run a ML experiment

run_regression_test()

rtype

None

run_training()

The main training loop.

run_validation()

Run validation on the validation set for all models to save time/memory consuming outputs.

Methods Documentation

after_ddp_cleanup(environ_before_training)[source]

Run processes cleanup after ddp context to prepare for single device inference. Kill all processes in DDP besides rank 0.

Return type

None

end_training(environ_before_training)[source]

Cleanup after training is done. This is called after the trainer has finished fitting the data. This is called to update the checkpoint handler state and remove redundant checkpoint files. If running inference on a single device, this is also called to kill all processes besides rank 0.

Return type

None

get_data_module()[source]

Reads the datamodule that should be used for training or valuation from the container. This must be overridden in subclasses.

Return type

LightningDataModule

get_multiple_trainloader_mode()[source]
Return type

str

init_inference()[source]

Prepare the trainer for running inference on the validation and test set. This chooses a checkpoint, initializes the PL Trainer object, and chooses the right data module. The hook for running inference on the validation set is run (LightningContainer.on_run_extra_validation_epoch) is first called to reflect any changes to the model or datamodule states before running inference.

Return type

None

init_training()[source]

Execute some bookkeeping tasks only once if running distributed and initialize the runner’s trainer object.

Return type

None

is_crossval_disabled_or_child_0()[source]

Returns True if the present run is a non-cross-validation run, or child run 0 of a cross-validation run.

Return type

bool

run()[source]

Driver function to run a ML experiment

Return type

None

run_regression_test()[source]
Return type

None

run_training()[source]

The main training loop. It creates the Pytorch model based on the configuration options passed in, creates a Pytorch Lightning trainer, and trains the model. If a checkpoint was specified, then it loads the checkpoint before resuming training. The cwd is changed to the outputs folder so that the model can write to current working directory, and still everything is put into the right place in AzureML (only the contents of the “outputs” folder is treated as a result file).

Return type

None

run_validation()[source]

Run validation on the validation set for all models to save time/memory consuming outputs. This is done in inference only mode or when the user has requested an extra validation epoch. The cwd is changed to the outputs folder

Return type

None

after_ddp_cleanup(environ_before_training)[source]

Run processes cleanup after ddp context to prepare for single device inference. Kill all processes in DDP besides rank 0.

Return type

None

end_training(environ_before_training)[source]

Cleanup after training is done. This is called after the trainer has finished fitting the data. This is called to update the checkpoint handler state and remove redundant checkpoint files. If running inference on a single device, this is also called to kill all processes besides rank 0.

Return type

None

get_data_module()[source]

Reads the datamodule that should be used for training or valuation from the container. This must be overridden in subclasses.

Return type

LightningDataModule

init_inference()[source]

Prepare the trainer for running inference on the validation and test set. This chooses a checkpoint, initializes the PL Trainer object, and chooses the right data module. The hook for running inference on the validation set is run (LightningContainer.on_run_extra_validation_epoch) is first called to reflect any changes to the model or datamodule states before running inference.

Return type

None

init_training()[source]

Execute some bookkeeping tasks only once if running distributed and initialize the runner’s trainer object.

Return type

None

is_crossval_disabled_or_child_0()[source]

Returns True if the present run is a non-cross-validation run, or child run 0 of a cross-validation run.

Return type

bool

run()[source]

Driver function to run a ML experiment

Return type

None

run_training()[source]

The main training loop. It creates the Pytorch model based on the configuration options passed in, creates a Pytorch Lightning trainer, and trains the model. If a checkpoint was specified, then it loads the checkpoint before resuming training. The cwd is changed to the outputs folder so that the model can write to current working directory, and still everything is put into the right place in AzureML (only the contents of the “outputs” folder is treated as a result file).

Return type

None

run_validation()[source]

Run validation on the validation set for all models to save time/memory consuming outputs. This is done in inference only mode or when the user has requested an extra validation epoch. The cwd is changed to the outputs folder

Return type

None