TrainingRunner
- class health_ml.TrainingRunner(experiment_config, container, project_root=None)[source]
Bases:
health_ml.runner_base.RunnerBase
Driver class to run an ML experiment. Note that the project root argument MUST be supplied when using hi-ml as a package!
- Parameters
experiment_config (
ExperimentConfig
) – The ExperimentConfig object to use for training.container (
LightningContainer
) – The LightningContainer object to use for training.project_root (
Optional
[Path
]) – Project root. This should only be omitted if calling run_ml from the test suite. Supplying it is crucial when using hi-ml as a package or submodule!
Methods Summary
after_ddp_cleanup
(environ_before_training)Run processes cleanup after ddp context to prepare for single device inference.
end_training
(environ_before_training)Cleanup after training is done.
Reads the datamodule that should be used for training or valuation from the container.
- rtype
str
Prepare the trainer for running inference on the validation and test set.
Execute some bookkeeping tasks only once if running distributed and initialize the runner’s trainer object.
Returns True if the present run is a non-cross-validation run, or child run 0 of a cross-validation run.
run
()Driver function to run a ML experiment
- rtype
None
The main training loop.
Run validation on the validation set for all models to save time/memory consuming outputs.
Methods Documentation
- after_ddp_cleanup(environ_before_training)[source]
Run processes cleanup after ddp context to prepare for single device inference. Kill all processes in DDP besides rank 0.
- Return type
None
- end_training(environ_before_training)[source]
Cleanup after training is done. This is called after the trainer has finished fitting the data. This is called to update the checkpoint handler state and remove redundant checkpoint files. If running inference on a single device, this is also called to kill all processes besides rank 0.
- Return type
None
- get_data_module()[source]
Reads the datamodule that should be used for training or valuation from the container. This must be overridden in subclasses.
- Return type
LightningDataModule
- init_inference()[source]
Prepare the trainer for running inference on the validation and test set. This chooses a checkpoint, initializes the PL Trainer object, and chooses the right data module. The hook for running inference on the validation set is run (LightningContainer.on_run_extra_validation_epoch) is first called to reflect any changes to the model or datamodule states before running inference.
- Return type
None
- init_training()[source]
Execute some bookkeeping tasks only once if running distributed and initialize the runner’s trainer object.
- Return type
None
- is_crossval_disabled_or_child_0()[source]
Returns True if the present run is a non-cross-validation run, or child run 0 of a cross-validation run.
- Return type
bool
- run_training()[source]
The main training loop. It creates the Pytorch model based on the configuration options passed in, creates a Pytorch Lightning trainer, and trains the model. If a checkpoint was specified, then it loads the checkpoint before resuming training. The cwd is changed to the outputs folder so that the model can write to current working directory, and still everything is put into the right place in AzureML (only the contents of the “outputs” folder is treated as a result file).
- Return type
None
- run_validation()[source]
Run validation on the validation set for all models to save time/memory consuming outputs. This is done in inference only mode or when the user has requested an extra validation epoch. The cwd is changed to the outputs folder
- Return type
None
- after_ddp_cleanup(environ_before_training)[source]
Run processes cleanup after ddp context to prepare for single device inference. Kill all processes in DDP besides rank 0.
- Return type
None
- end_training(environ_before_training)[source]
Cleanup after training is done. This is called after the trainer has finished fitting the data. This is called to update the checkpoint handler state and remove redundant checkpoint files. If running inference on a single device, this is also called to kill all processes besides rank 0.
- Return type
None
- get_data_module()[source]
Reads the datamodule that should be used for training or valuation from the container. This must be overridden in subclasses.
- Return type
LightningDataModule
- init_inference()[source]
Prepare the trainer for running inference on the validation and test set. This chooses a checkpoint, initializes the PL Trainer object, and chooses the right data module. The hook for running inference on the validation set is run (LightningContainer.on_run_extra_validation_epoch) is first called to reflect any changes to the model or datamodule states before running inference.
- Return type
None
- init_training()[source]
Execute some bookkeeping tasks only once if running distributed and initialize the runner’s trainer object.
- Return type
None
- is_crossval_disabled_or_child_0()[source]
Returns True if the present run is a non-cross-validation run, or child run 0 of a cross-validation run.
- Return type
bool
- run_training()[source]
The main training loop. It creates the Pytorch model based on the configuration options passed in, creates a Pytorch Lightning trainer, and trains the model. If a checkpoint was specified, then it loads the checkpoint before resuming training. The cwd is changed to the outputs folder so that the model can write to current working directory, and still everything is put into the right place in AzureML (only the contents of the “outputs” folder is treated as a result file).
- Return type
None