dgs.models.engine.engine.EngineModule¶
- class dgs.models.engine.engine.EngineModule(*args: Any, **kwargs: Any)[source]¶
Module for training, validating, and testing other Modules.
Most of the settings are defined within the configuration file in the training section.
- Params:
model_path (NodePath) – Path to the configuration setting up the model to be trained or tested.
model_type (str) – The type of
BaseModule
to be loaded as the model for training and testing. The value will be passed asmodule_type
in themodule_loader()
call.
- Optional Params:
writer_kwargs (dict, optional) – Additional kwargs for the torch writer. Default
DEF_VAL.engine.test.writer_kwargs
.writer_log_dir_suffix (str, optional) – Additional subdirectory or name suffix for the torch writer. Default
DEF_VAL.engine.test.writer_log_dir_suffix
.
Test Params¶
Train Params¶
- loss (str|callable):
The name or class of the loss function used to compute the loss during training. It is possible to pass additional initialization kwargs to the loss by adding them to the
loss_kwargs
parameter.- optimizer (str|callable):
The name or class of the optimizer used for optimizing the model based on the loss during training. It is possible to pass additional initialization kwargs to the optimizer by adding them to the
optimizer_kwargs
parameter.
Optional Test Params¶
- normalize (bool, optional):
Whether to normalize the prediction and target during testing. Default
DEF_VAL.engine.test.normalize
.
Optional Train Params¶
- epochs (int, optional):
The number of epochs to run the training for. Default
DEF_VAL.engine.train.epochs
.- optimizer_kwargs (dict, optional):
Additional kwargs for the optimizer. Default
DEF_VAL.engine.train.optim_kwargs
.- scheduler (str|callable, optional):
The name or instance of a scheduler. If you want to use different or multiple schedulers, you can chain them using
torch.optim.lr_scheduler.ChainedScheduler
or create a custom Scheduler and register it. DefaultDEF_VAL.engine.train.scheduler
.- scheduler_kwargs (dict, optional):
Additional kwargs for the scheduler. Keep in mind that the different schedulers need fairly different kwargs. The optimizer will be passed to the scheduler during initialization as the optimizer keyword argument. Default
DEF_VAL.engine.train.scheduler_kwargs
.- loss_kwargs (dict, optional):
Additional kwargs for the loss. Default
DEF_VAL.engine.train.loss_kwargs
.- save_interval (int, optional):
The interval for saving (and evaluating) the model during training. Default
DEF_VAL.engine.train.save_interval
.- start_epoch (int, optional):
The epoch at which to start. (In the end the epochs are 1-indexed, but it shouldn’t matter as long as you stick with one format. Default
DEF_VAL.engine.train.start_epoch
.- train_load_image_crops (bool, optional):
Whether to load the image crops during training. Default
DEF_VAL.engine.train.load_image_crops
.
Methods
- __init__(config: dict[str, any], path: list[str], test_loader: torch.utils.data.DataLoader, *, val_loader: torch.utils.data.DataLoader | None = None, train_loader: torch.utils.data.DataLoader | None = None, **_kwargs)[source]¶
- configure_torch_module(module: torch.nn.Module, train: bool | None = None) torch.nn.Module ¶
Set compute mode and send model to the device or multiple parallel devices if applicable.
- Parameters:
module – The torch module instance to configure.
train – Whether to train or eval this module, defaults to the value set in the base config.
- Returns:
The module on the specified device or in parallel.
- abstract evaluate() dict[str, any] [source]¶
Run tests, defined in Sub-Engine.
- Returns:
A dictionary containing all the computed accuracies, metrics, … .
- Return type:
dict[str, any]
- abstract get_data(ds: State) any [source]¶
Function to retrieve the data used in the model’s prediction from the train- and test- DataLoaders.
- get_hparam_dict() dict[str, any] [source]¶
Get the hyperparameters of the current engine. Child-modules can inherit this method and add additional hyperparameters.
By default, all parameters from test and training are added to the hparam_dict.
- abstract get_target(ds: State) any [source]¶
Function to retrieve the evaluation targets from the train- and test- DataLoaders.
- initialize_optimizer() None [source]¶
Because the module might be set after the initial step, load the optimizer and scheduler at the start of the training.
- load_model(path: str) None [source]¶
Load the model from a file. Set the start epoch to the epoch + 1 of the specified in the loaded model.
Notes
Loads the states of the
optimizer
andlr_scheduler
if they are present in the engine (e.g. during training) and the respective data is given in the checkpoint at thepath
.- Parameters:
path – The path to the checkpoint where this model was saved.
- abstract predict() any [source]¶
Given test data, predict the results without evaluation.
- Returns:
The predicted results. Datatype might vary depending on the used engine.
- print_results(results: dict[str, any]) None [source]¶
Given a dictionary of results, print them to the console if allowed.
- save_model(epoch: int, metrics: dict[str, any], optimizer: torch.optim.Optimizer, lr_sched: torch.optim.lr_scheduler.LRScheduler) None [source]¶
Save the current model and other weights into a ‘.pth’ file.
- Parameters:
epoch – The epoch this model is saved.
metrics – A dict containing the computed metrics for this module.
optimizer – The current optimizer
lr_sched – The current learning rate scheduler.
- abstract test() dict[str, any] [source]¶
Run tests, defined in Sub-Engine.
- Returns:
A dictionary containing all the computed accuracies, metrics, … .
- Return type:
dict[str, any]
- train_model() torch.optim.Optimizer [source]¶
Train the given model using the given loss function, optimizer, and learning-rate schedulers.
After every epoch, the current model is tested and the current model is saved.
- Returns:
The current optimizer after training.
- validate_params(validations: dict[str, list[str | type | tuple[str, any] | Callable[[any, any], bool]]], attrib_name: str = 'params') None ¶
Given per key validations, validate this module’s parameters.
Throws exceptions on invalid or nonexistent params.
- Parameters:
attrib_name – name of the attribute to validate, should be “params” and only for base class “config”
validations –
Dictionary with the name of the parameter as key and a list of validations as value. Every validation in this list has to be true for the validation to be successful.
- The value for the validation can have multiple types:
A lambda function or other type of callable
A string as reference to a predefined validation function with one argument
None for existence
A tuple with a string as reference to a predefined validation function with one additional argument
It is possible to write nested validations, but then every nested validation has to be a tuple, or a tuple of tuples. For convenience, there are implementations for “any”, “all”, “not”, “eq”, “neq”, and “xor”. Those can have data which is a tuple containing other tuples or validations, or a single validation.
Lists and other iterables can be validated using “forall” running the given validations for every item in the input. A single validation or a tuple of (nested) validations is accepted as data.
Example
This example is an excerpt of the validation for the BaseModule-configuration.
>>> validations = { "device": [ str, ("any", [ ("in", ["cuda", "cpu"]), ("instance", torch.device) ] ) ], "print_prio": [("in", PRINT_PRIORITY)], "callable": (lambda value: value == 1), }
And within the class
__init__()
call:>>> self.validate_params()
- Raises:
InvalidParameterException – If one of the parameters is invalid.
ValidationException – If the validation list is invalid or contains an unknown validation.
- write_results(results: dict[str, any], prepend: str) None [source]¶
Given a dictionary of results, use the writer to save the values.
Attributes
Get the device of this module.
Get whether this module is set to training-mode.
Get the name of the module.
Get the name of the module.
Get the escaped name of the module usable in filepaths by replacing spaces and underscores.
Get the (floating point) precision used in multiple parts of this module.
Whether to load the image crops during training.
The torch DataLoader containing the test data.
The torch DataLoader containing the validation data.
The torch DataLoader containing the training data.