dgs.models.engine.dgs_engine.DGSEngine

class dgs.models.engine.dgs_engine.DGSEngine(*args: Any, **kwargs: Any)[source]

An engine class for training and testing the dynamically gated similarity tracker with static or dynamic gates.

For this model:

  • get_data() should return the same as this similarity functions SimilarityModule.get_data() call

  • get_target() should return the class IDs of the State object

  • train_dl contains the training data as a torch DataLoader containing a ImageHistoryDataset dataset. Additionally, the training data should contain all the training sequences and not just a single video.

  • test_dl contains the test data as a torch DataLoader containing a regular ImageDataset or class:VideoDataset datasets

  • val_dl contains the validation data. The validation data can be one of the following, depending on the configuration of params_train["eval_accuracy"]:

    • If eval_accuracy is True, the evaluation data is as a torch DataLoader containing a ImageHistoryDataset dataset. Additionally, the validation data should contain all the validation sequences and not just a single video.

    • If eval_accuracy is False, the evaluation data is as a torch DataLoader containing a regular ImageDataset or class:VideoDataset datasets. With one dataset per video.

Train Params

Test Params

submission (Union[str, NodePath]):

The key or the path of keys in the configuration containing the information about the submission file, which is used to save the test data.

Optional Train Params

acc_k_train (list[int|float], optional):

A list of values used during training to check whether the accuracy lies within a margin of k percent. Default DEF_VAL.engine.dgs.acc_k_train.

acc_k_eval (list[int|float], optional):

A list of values used during evaluation to check whether the accuracy lies within a margin of k percent. Default DEF_VAL.engine.dgs.acc_k_eval.

eval_accuracy (bool, optional):

Whether to evaluate the alpha-prediction accuracy or the :ref:` MOTA <metrics_mota>` / :ref:` HOTA <metrics_hota>` of the model during evaluation. Default DEF_VAL.engine.dgs.eval_accuracy.

submission (Union[str, NodePath]):

The key or the path of keys in the configuration containing the information about the submission file, which is used to save the evaluation data, if eval_accuracy is False.

Optional Test Params

draw_kwargs (dict[str, any]):

Additional keyword arguments to pass to State.draw(). Default DEF_VAL.engine.dgs.draw_kwargs.

inactivity_threshold (int):

The number of steps after which an inactive Track will be removed. Removed tracks can be reactivated using Tracks.reactivate_track(). Use None to disable the removing of inactive tracks. Default DEF_VAL.tracks.inactivity_threshold.

max_track_length (int):

The maximum number of State objects per Track. Default DEF_VAL.track.N.

save_images (bool):

Whether to save the generated image-results. Default DEF_VAL.engine.dgs.save_images.

show_keypoints (bool):

Whether to show the key-point-coordinates when generating the image-results. Therefore, this will only have an influence, if save_images is True. To be drawn correctly, the detections- State has to contain the global key-point-coordinates as ‘keypoints’ and possibly the joint-visibility as ‘joint_weight’. Default DEF_VAL.engine.dgs.show_skeleton.

show_skeleton (bool):

Whether to connect the drawn key-point-coordinates with the human skeleton. This will only have an influence, if save_images is True and show_keypoints is True as well. To be drawn correctly, the detections- State has to contain a valid ‘skeleton_name’ key. Default DEF_VAL.engine.dgs.show_skeleton.

__init__(config: dict[str, any], path: list[str], *, test_loader: torch.utils.data.DataLoader | None = None, val_loader: torch.utils.data.DataLoader | None = None, train_loader: torch.utils.data.DataLoader | None = None, **_kwargs)[source]

Methods

configure_torch_module(module[, train])

Set compute mode and send model to the device or multiple parallel devices if applicable.

evaluate

get_data(ds)

Use the similarity models of the DGS module to obtain the similarity data of the current detections.

get_hparam_dict()

Get the hyperparameters of the current engine.

get_target(ds)

Get the target data.

initialize_optimizer()

Because the module might be set after the initial step, load the optimizer and scheduler at the start of the training.

load_combine_alpha_weights(fp[, new_id, old_id])

Given the path to a file containing at least the data of one module checkpoint, load the weights of the combine.alpha_weights module.

load_model(path)

Load the model from a file.

predict

print_results(results)

Given a dictionary of results, print them to the console if allowed.

run()

Run the model.

save_model(epoch, metrics, optimizer, lr_sched)

Save the current model and other weights into a '.pth' file.

set_model_mode(mode)

Set model mode to train or test.

terminate()

Handle forceful termination, e.g., ctrl+c

test

train_model()

Train the given model using the given loss function, optimizer, and learning-rate schedulers.

validate_params(validations[, attrib_name])

Given per key validations, validate this module's parameters.

write_results(results, prepend)

Given a dictionary of results, use the writer to save the values.

Attributes

curr_epoch

device

Get the device of this module.

is_training

Get whether this module is set to training-mode.

module_name

Get the name of the module.

module_type

name

Get the name of the module.

name_safe

Get the escaped name of the module usable in filepaths by replacing spaces and underscores.

precision

Get the (floating point) precision used in multiple parts of this module.

train_load_image_crops

Whether to load the image crops during training.

model

The DGS module containing the similarity models and the alpha model.

tracks

The tracks object containing all the active tracks of this engine.

submission

The submission file to store the results when running the tests.

val_dl

The torch DataLoader containing the validation data.

train_dl

The torch DataLoader containing the train data.

loss

writer

test_dl

The torch DataLoader containing the test data.