dgs.models.engine.engine.EngineModule

class dgs.models.engine.engine.EngineModule(*args: Any, **kwargs: Any)[source]

Module for training, validating, and testing other Modules.

Most of the settings are defined within the configuration file in the training section.

Params

model_path (NodePath):

Path to the configuration setting up the model to be trained or tested.

model_type (str):

The type of BaseModule to be loaded as the model for training and testing. The value will be passed as module_type in the module_loader() call.

Optional Params

writer_kwargs (dict, optional):

Additional kwargs for the torch writer. Default DEF_VAL.engine.test.writer_kwargs.

writer_log_dir_suffix (str, optional):

Additional subdirectory or name suffix for the torch writer. Default DEF_VAL.engine.test.writer_log_dir_suffix.

Test Params

Train Params

loss (str|callable):

The name or class of the loss function used to compute the loss during training. It is possible to pass additional initialization kwargs to the loss by adding them to the loss_kwargs parameter.

optimizer (str|callable):

The name or class of the optimizer used for optimizing the model based on the loss during training. It is possible to pass additional initialization kwargs to the optimizer by adding them to the optimizer_kwargs parameter.

Optional Test Params

normalize (bool, optional):

Whether to normalize the prediction and target during testing. Default DEF_VAL.engine.test.normalize.

Optional Train Params

epochs (int, optional):

The number of epochs to run the training for. Default DEF_VAL.engine.train.epochs.

optimizer_kwargs (dict, optional):

Additional kwargs for the optimizer. Default DEF_VAL.engine.train.optim_kwargs.

scheduler (str|callable, optional):

The name or instance of a scheduler. If you want to use different or multiple schedulers, you can chain them using torch.optim.lr_scheduler.ChainedScheduler or create a custom Scheduler and register it. Default DEF_VAL.engine.train.scheduler.

scheduler_kwargs (dict, optional):

Additional kwargs for the scheduler. Keep in mind that the different schedulers need fairly different kwargs. The optimizer will be passed to the scheduler during initialization as the optimizer keyword argument. Default DEF_VAL.engine.train.scheduler_kwargs.

loss_kwargs (dict, optional):

Additional kwargs for the loss. Default DEF_VAL.engine.train.loss_kwargs.

save_interval (int, optional):

The interval for saving (and evaluating) the model during training. Default DEF_VAL.engine.train.save_interval.

start_epoch (int, optional):

The epoch at which to start. (In the end the epochs are 1-indexed, but it shouldn’t matter as long as you stick with one format. Default DEF_VAL.engine.train.start_epoch.

train_load_image_crops (bool, optional):

Whether to load the image crops during training. Default DEF_VAL.engine.train.load_image_crops.

__init__(config: dict[str, any], path: list[str], test_loader: torch.utils.data.DataLoader, *, val_loader: torch.utils.data.DataLoader | None = None, train_loader: torch.utils.data.DataLoader | None = None, **_kwargs)[source]

Methods

configure_torch_module(module[, train])

Set compute mode and send model to the device or multiple parallel devices if applicable.

evaluate()

Run tests, defined in Sub-Engine.

get_data(ds)

Function to retrieve the data used in the model's prediction from the train- and test- DataLoaders.

get_hparam_dict()

Get the hyperparameters of the current engine.

get_target(ds)

Function to retrieve the evaluation targets from the train- and test- DataLoaders.

initialize_optimizer()

Because the module might be set after the initial step, load the optimizer and scheduler at the start of the training.

load_model(path)

Load the model from a file.

predict()

Given test data, predict the results without evaluation.

print_results(results)

Given a dictionary of results, print them to the console if allowed.

run()

Run the model.

save_model(epoch, metrics, optimizer, lr_sched)

Save the current model and other weights into a '.pth' file.

set_model_mode(mode)

Set model mode to train or test.

terminate()

Handle forceful termination, e.g., ctrl+c

test()

Run tests, defined in Sub-Engine.

train_model()

Train the given model using the given loss function, optimizer, and learning-rate schedulers.

validate_params(validations[, attrib_name])

Given per key validations, validate this module's parameters.

write_results(results, prepend)

Given a dictionary of results, use the writer to save the values.

Attributes

curr_epoch

device

Get the device of this module.

is_training

Get whether this module is set to training-mode.

module_name

Get the name of the module.

module_type

name

Get the name of the module.

name_safe

Get the escaped name of the module usable in filepaths by replacing spaces and underscores.

precision

Get the (floating point) precision used in multiple parts of this module.

train_load_image_crops

Whether to load the image crops during training.

loss

model

writer

test_dl

The torch DataLoader containing the test data.

val_dl

The torch DataLoader containing the validation data.

train_dl

The torch DataLoader containing the training data.