dgs.models.engine.visual_sim_engine.VisualSimilarityEngine

class dgs.models.engine.visual_sim_engine.VisualSimilarityEngine(*args: Any, **kwargs: Any)[source]

An engine class for training and testing visual similarities using visual embeddings.

For this model:

  • get_data() should return the image crop

  • get_target() should return the target class IDs

  • train_dl contains the training data as usual

  • test_dl contains the query data

  • val_dl contains the gallery data

Train Params

nof_classes (int):

The number of classes in the training set.

Test Params

metric (str|callable):

The name or class of the metric used during testing / evaluation. The metric in the VisualSimilarityEngine is only used to compute the distance between the query and gallery embeddings. Therefore, a distance-based metric should be used.

It is possible to pass additional initialization kwargs to the metric by adding them to the metric_kwargs parameter.

Optional Train Params

topk_acc (list[int], optional):

The values for k for the top-k accuracy evaluation during training. Default DEF_VAL.engine.visual.topk_acc.

Optional Test Params

metric_kwargs (dict, optional):

Specific kwargs for the metric. Default DEF_VAL.engine.visual.metric_kwargs.

topk_cmc (list[int], optional):

The values for k the top-k cmc evaluation during testing / evaluation. Default DEF_VAL.engine.visual.topk_cmc.

write_embeds (list[bool, bool], optional):

Whether to write the embeddings for the Query and Gallery Dataset to the tensorboard writer. Only really feasible for smaller datasets ~1k embeddings. Default DEF_VAL.engine.visual.write_embeds.

image_key (str, optional):

Which key to use when loading the image from the state in get_data(). Default DEF_VAL.engine.visual.image_key.

Methods

__init__(config: dict[str, any], model: TorchreidVisualSimilarity, test_loader: torch.utils.data.DataLoader, val_loader: torch.utils.data.DataLoader, *, train_loader: torch.utils.data.DataLoader | None = None, **kwargs)[source]
configure_torch_module(module: torch.nn.Module, train: bool | None = None) torch.nn.Module

Set compute mode and send model to the device or multiple parallel devices if applicable.

Parameters:
  • module – The torch module instance to configure.

  • train – Whether to train or eval this module, defaults to the value set in the base config.

Returns:

The module on the specified device or in parallel.

evaluate() dict[str, any][source]

Run tests, defined in Sub-Engine.

Returns:

A dictionary containing all the computed accuracies, metrics, … .

Return type:

dict[str, any]

get_data(ds: State) torch.Tensor[source]

Get the image crop or other requested image from the state.

get_hparam_dict() dict[str, any]

Get the hyperparameters of the current engine. Child-modules can inherit this method and add additional hyperparameters.

By default, all parameters from test and training are added to the hparam_dict.

get_target(ds: State) torch.Tensor[source]

Get the target pIDs from the data.

initialize_optimizer() None

Because the module might be set after the initial step, load the optimizer and scheduler at the start of the training.

load_model(path: str) None

Load the model from a file. Set the start epoch to the epoch + 1 of the specified in the loaded model.

Notes

Loads the states of the optimizer and lr_scheduler if they are present in the engine (e.g. during training) and the respective data is given in the checkpoint at the path.

Parameters:

path – The path to the checkpoint where this model was saved.

predict() torch.Tensor

Predict the visual embeddings for the test data.

Notes

Depending on the number of predictions (N) and the embeddings size (E), the resulting tensor(s) can get incredibly huge. The prediction for the validation data of the PoseTrack21 dataset is roughly 300MB.

Returns:

The predicted embeddings as tensor of shape: [N x E]

Return type:

torch.Tensor

print_results(results: dict[str, any]) None

Given a dictionary of results, print them to the console if allowed.

run() None

Run the model. First train, then test!

save_model(epoch: int, metrics: dict[str, any], optimizer: torch.optim.Optimizer, lr_sched: torch.optim.lr_scheduler.LRScheduler) None

Save the current model and other weights into a ‘.pth’ file.

Parameters:
  • epoch – The epoch this model is saved.

  • metrics – A dict containing the computed metrics for this module.

  • optimizer – The current optimizer

  • lr_sched – The current learning rate scheduler.

set_model_mode(mode: str) None

Set model mode to train or test.

terminate() None

Handle forceful termination, e.g., ctrl+c

test() dict[str, any]

Test the embeddings predicted by the model on the Test-DataLoader.

Compute Rank-N for every rank in self.topk_cmc. Compute mean average precision of predicted target labels.

train_model() torch.optim.Optimizer

Train the given model using the given loss function, optimizer, and learning-rate schedulers.

After every epoch, the current model is tested and the current model is saved.

Returns:

The current optimizer after training.

validate_params(validations: dict[str, list[str | type | tuple[str, any] | Callable[[any, any], bool]]], attrib_name: str = 'params') None

Given per key validations, validate this module’s parameters.

Throws exceptions on invalid or nonexistent params.

Parameters:
  • attrib_name – name of the attribute to validate, should be “params” and only for base class “config”

  • validations

    Dictionary with the name of the parameter as key and a list of validations as value. Every validation in this list has to be true for the validation to be successful.

    The value for the validation can have multiple types:
    • A lambda function or other type of callable

    • A string as reference to a predefined validation function with one argument

    • None for existence

    • A tuple with a string as reference to a predefined validation function with one additional argument

    • It is possible to write nested validations, but then every nested validation has to be a tuple, or a tuple of tuples. For convenience, there are implementations for “any”, “all”, “not”, “eq”, “neq”, and “xor”. Those can have data which is a tuple containing other tuples or validations, or a single validation.

    • Lists and other iterables can be validated using “forall” running the given validations for every item in the input. A single validation or a tuple of (nested) validations is accepted as data.

Example

This example is an excerpt of the validation for the BaseModule-configuration.

>>> validations = {
    "device": [
            str,
            ("any",
                [
                    ("in", ["cuda", "cpu"]),
                    ("instance", torch.device)
                ]
            )
        ],
        "print_prio": [("in", PRINT_PRIORITY)],
        "callable": (lambda value: value == 1),
    }

And within the class __init__() call:

>>> self.validate_params()
Raises:
write_results(results: dict[str, any], prepend: str) None

Given a dictionary of results, use the writer to save the values.

Attributes

curr_epoch

device

Get the device of this module.

is_training

Get whether this module is set to training-mode.

module_name

Get the name of the module.

module_type

name

Get the name of the module.

name_safe

Get the escaped name of the module usable in filepaths by replacing spaces and underscores.

precision

Get the (floating point) precision used in multiple parts of this module.

train_load_image_crops

Whether to load the image crops during training.

val_dl

The torch DataLoader containing the validation (query) data.

model

metric

A metric function used to compute the embedding distance.

loss

writer

test_dl

The torch DataLoader containing the test data.

train_dl

The torch DataLoader containing the training data.