dgs.models.engine.dgs_engine.DGSEngine¶
- class dgs.models.engine.dgs_engine.DGSEngine(*args: Any, **kwargs: Any)[source]¶
An engine class for training and testing the dynamically gated similarity tracker with static or dynamic gates.
For this model:
get_data()
should return the same as this similarity functionsSimilarityModule.get_data()
callget_target()
should return the class IDs of theState
objecttrain_dl
contains the training data as a torch DataLoader containing aImageHistoryDataset
dataset. Additionally, the training data should contain all the training sequences and not just a single video.test_dl
contains the test data as a torch DataLoader containing a regularImageDataset
or class:VideoDataset datasetsval_dl
contains the validation data. The validation data can be one of the following, depending on the configuration ofparams_train["eval_accuracy"]
:If
eval_accuracy
isTrue
, the evaluation data is as a torch DataLoader containing aImageHistoryDataset
dataset. Additionally, the validation data should contain all the validation sequences and not just a single video.If
eval_accuracy
isFalse
, the evaluation data is as a torch DataLoader containing a regularImageDataset
or class:VideoDataset datasets. With one dataset per video.
Train Params¶
Test Params¶
- submission (Union[str, NodePath]):
The key or the path of keys in the configuration containing the information about the submission file, which is used to save the test data.
Optional Train Params¶
- acc_k_train (list[int|float], optional):
A list of values used during training to check whether the accuracy lies within a margin of k percent. Default
DEF_VAL.engine.dgs.acc_k_train
.- acc_k_eval (list[int|float], optional):
A list of values used during evaluation to check whether the accuracy lies within a margin of k percent. Default
DEF_VAL.engine.dgs.acc_k_eval
.- eval_accuracy (bool, optional):
Whether to evaluate the alpha-prediction accuracy or the :ref:`
MOTA
<metrics_mota>` / :ref:`HOTA
<metrics_hota>` of the model during evaluation. DefaultDEF_VAL.engine.dgs.eval_accuracy
.- submission (Union[str, NodePath]):
The key or the path of keys in the configuration containing the information about the submission file, which is used to save the evaluation data, if
eval_accuracy
isFalse
.
Optional Test Params¶
- draw_kwargs (dict[str, any]):
Additional keyword arguments to pass to State.draw(). Default
DEF_VAL.engine.dgs.draw_kwargs
.- inactivity_threshold (int):
The number of steps after which an inactive
Track
will be removed. Removed tracks can be reactivated usingTracks.reactivate_track()
. Use None to disable the removing of inactive tracks. DefaultDEF_VAL.tracks.inactivity_threshold
.- max_track_length (int):
The maximum number of
State
objects perTrack
. DefaultDEF_VAL.track.N
.- save_images (bool):
Whether to save the generated image-results. Default
DEF_VAL.engine.dgs.save_images
.- show_keypoints (bool):
Whether to show the key-point-coordinates when generating the image-results. Therefore, this will only have an influence, if save_images is True. To be drawn correctly, the detections-
State
has to contain the global key-point-coordinates as ‘keypoints’ and possibly the joint-visibility as ‘joint_weight’. DefaultDEF_VAL.engine.dgs.show_skeleton
.- show_skeleton (bool):
Whether to connect the drawn key-point-coordinates with the human skeleton. This will only have an influence, if save_images is True and show_keypoints is True as well. To be drawn correctly, the detections-
State
has to contain a valid ‘skeleton_name’ key. DefaultDEF_VAL.engine.dgs.show_skeleton
.
Methods
- __init__(config: dict[str, any], path: list[str], *, test_loader: torch.utils.data.DataLoader | None = None, val_loader: torch.utils.data.DataLoader | None = None, train_loader: torch.utils.data.DataLoader | None = None, **_kwargs)[source]¶
- configure_torch_module(module: torch.nn.Module, train: bool | None = None) torch.nn.Module ¶
Set compute mode and send model to the device or multiple parallel devices if applicable.
- Parameters:
module – The torch module instance to configure.
train – Whether to train or eval this module, defaults to the value set in the base config.
- Returns:
The module on the specified device or in parallel.
- evaluate() dict[str, any] ¶
Run the model evaluation on the eval data.
Test whether the predicted alpha probability (\(\alpha_{\mathrm{pred}}\)) matches the number of correct predictions (\(\alpha_{\mathrm{correct}}\)) divided by the total number of predictions (\(N\)).
With \(\alpha{\mathrm{pred}} = \frac{\alpha_{\mathrm{correct}}}{N}\) :math`alpha{mathrm{pred}}` is counted as correct if \(\alpha{\mathrm{pred}}-k \leq \alpha{\mathrm{correct}} \leq \alpha{\mathrm{pred}}+k\).
- get_data(ds: State) list[torch.Tensor] [source]¶
Use the similarity models of the DGS module to obtain the similarity data of the current detections.
For the similarity engine, the data consists of a list of all the input data for the similarities. This means, that for the visual similarity, the embedding is returned, and for the IoU or OKS similarities, the bbox and key point data is returned. The
get_data()
function will be called twice, once for the current time-step and once for the previous.
- get_hparam_dict() dict[str, any] ¶
Get the hyperparameters of the current engine. Child-modules can inherit this method and add additional hyperparameters.
By default, all parameters from test and training are added to the hparam_dict.
- get_target(ds: State) torch.Tensor [source]¶
Get the target data.
For the similarity engine, the target data consists of the dataset-unique class-id. The
get_target()
function will be called twice, once for the current time-step and once for the previous.
- initialize_optimizer() None ¶
Because the module might be set after the initial step, load the optimizer and scheduler at the start of the training.
- load_combine_alpha_weights(fp: str, new_id: int = 0, old_id: int = 0) None [source]¶
Given the path to a file containing at least the data of one module checkpoint, load the weights of the
combine.alpha_weights
module.Notes
During training the DGSEngine was trained with a single alpha model. For testing or (non accuracy) evaluation, multiple alpha values are required. Therefore, the
combine.alpha_models
now contains more than one AlphaGenerator instance. Thus, the indices of the state dict have to be modified accordingly.Additionally, in case of the visual embedding generation modules, there are more parameters saved in the checkpoint file, which should not be loaded by this function.
- Parameters:
fp – The path to the checkpoint file
new_id – The ID at which index of the alpha weight modules to insert the loaded weights.
old_id – The old ID. Necessary only if there are multiple
combine.alpha_models
’s in a single checkpoint. E.g. when multiple alpha weight generators have been trained in unison.
- load_model(path: str) None ¶
Load the model from a file. Set the start epoch to the epoch + 1 of the specified in the loaded model.
Notes
Loads the states of the
optimizer
andlr_scheduler
if they are present in the engine (e.g. during training) and the respective data is given in the checkpoint at thepath
.- Parameters:
path – The path to the checkpoint where this model was saved.
- predict() None ¶
Given test data, predict the results without evaluation.
- print_results(results: dict[str, any]) None ¶
Given a dictionary of results, print them to the console if allowed.
- run() None ¶
Run the model. First train, then test!
- save_model(epoch: int, metrics: dict[str, any], optimizer: torch.optim.Optimizer, lr_sched: torch.optim.lr_scheduler.LRScheduler) None ¶
Save the current model and other weights into a ‘.pth’ file.
- Parameters:
epoch – The epoch this model is saved.
metrics – A dict containing the computed metrics for this module.
optimizer – The current optimizer
lr_sched – The current learning rate scheduler.
- set_model_mode(mode: str) None ¶
Set model mode to train or test.
- test() dict[str, any] ¶
Test the DGS Tracker on the test_dl.
- train_model() torch.optim.Optimizer ¶
Train the given model using the given loss function, optimizer, and learning-rate schedulers.
After every epoch, the current model is tested and the current model is saved.
- Returns:
The current optimizer after training.
- validate_params(validations: dict[str, list[str | type | tuple[str, any] | Callable[[any, any], bool]]], attrib_name: str = 'params') None ¶
Given per key validations, validate this module’s parameters.
Throws exceptions on invalid or nonexistent params.
- Parameters:
attrib_name – name of the attribute to validate, should be “params” and only for base class “config”
validations –
Dictionary with the name of the parameter as key and a list of validations as value. Every validation in this list has to be true for the validation to be successful.
- The value for the validation can have multiple types:
A lambda function or other type of callable
A string as reference to a predefined validation function with one argument
None for existence
A tuple with a string as reference to a predefined validation function with one additional argument
It is possible to write nested validations, but then every nested validation has to be a tuple, or a tuple of tuples. For convenience, there are implementations for “any”, “all”, “not”, “eq”, “neq”, and “xor”. Those can have data which is a tuple containing other tuples or validations, or a single validation.
Lists and other iterables can be validated using “forall” running the given validations for every item in the input. A single validation or a tuple of (nested) validations is accepted as data.
Example
This example is an excerpt of the validation for the BaseModule-configuration.
>>> validations = { "device": [ str, ("any", [ ("in", ["cuda", "cpu"]), ("instance", torch.device) ] ) ], "print_prio": [("in", PRINT_PRIORITY)], "callable": (lambda value: value == 1), }
And within the class
__init__()
call:>>> self.validate_params()
- Raises:
InvalidParameterException – If one of the parameters is invalid.
ValidationException – If the validation list is invalid or contains an unknown validation.
- write_results(results: dict[str, any], prepend: str) None ¶
Given a dictionary of results, use the writer to save the values.
Attributes
Get the device of this module.
Get whether this module is set to training-mode.
Get the name of the module.
Get the name of the module.
Get the escaped name of the module usable in filepaths by replacing spaces and underscores.
Get the (floating point) precision used in multiple parts of this module.
Whether to load the image crops during training.
The DGS module containing the similarity models and the alpha model.
The tracks object containing all the active tracks of this engine.
The submission file to store the results when running the tests.
The torch DataLoader containing the validation data.
The torch DataLoader containing the train data.
The torch DataLoader containing the test data.