dgs.models.similarity.torchreid.TorchreidVisualSimilarity

class dgs.models.similarity.torchreid.TorchreidVisualSimilarity(*args: Any, **kwargs: Any)[source]

Given image crops, generate Re-ID embedding using the torchreid package.

Model can use the default pretrained weights or custom weights.

Notes

This model cannot be trained right now! Pretrain your models using the torchreid package and possibly the custom PT21 data loaders, then load the weights.

Notes

For computing the similarity during evaluation, most models should re-use the distance function used during training.

Params:
  • metric (str) – The name of the metric to use. Has to be one of METRICS

  • embedding_generator_path (:obj:`Path`) – The path to the configuration of the embedding generator within the config.

Optional Params:

metric_kwargs (dict, optional) – Possibly pass additional kwargs to the similarity function. Default DEF_VAL.similarity.torchreid.sim_kwargs.

Methods

__init__(*args, **kwargs)
configure_torch_module(module: torch.nn.Module, train: bool | None = None) torch.nn.Module

Set compute mode and send model to the device or multiple parallel devices if applicable.

Parameters:
  • module – The torch module instance to configure.

  • train – Whether to train or eval this module, defaults to the value set in the base config.

Returns:

The module on the specified device or in parallel.

forward(data: State, target: State) torch.Tensor[source]

Forward call of the torchreid model used to compute the similarities between visual embeddings.

Either load or compute the visual embeddings for the data and target using the model. The embeddings are tensors of respective shapes [a x E] and [b x E]. Then use this modules’ metric to compute the similarity between the two embeddings.

Notes

Torchreid expects images to have float values.

Parameters:
  • data – A State containing the predicted embedding or the image crop. If a predicted embedding exists, it should be stored as ‘embedding’ in the State. self.get_data() will then extract the embedding as tensor of shape: [a x E].

  • target – A State containing either the target embedding or the image crop. If a predicted embedding exists, it should be stored as ‘embedding’ in the State. self.get_target() is then used to extract embedding as tensor of shape [b x E].

Returns:

A similarity matrix containing values describing the similarity between every current- and target-embedding. The similarity is a (Float)Tensor of shape [a x b] with values in [0..1]. If the provided metric does not return a probability distribution, you might want to change the metric or set the ‘softmax’ parameter of this module, or within the DGSModule if this is a submodule. Computing the softmax ensures better / correct behavior when combining this similarity with others. If requested, the softmax is computed along the -1 dimension, resulting in probability distributions for each value of the input data.

get_data(ds: State) torch.Tensor[source]

Given a State get the current embedding or compute it using the image crop.

get_target(ds: State) torch.Tensor[source]

Given a State get the target embedding or compute it using the image crop.

get_train_data(ds: State) any

A custom function to get special data for training purposes. If “train_key” is not given, uses the regular get_data() function of this module.

terminate() None

Terminate this module and all of its submodules.

If nothing has to be done, just pass. Is used for terminating parallel execution and threads in specific models.

validate_params(validations: dict[str, list[str | type | tuple[str, any] | Callable[[any, any], bool]]], attrib_name: str = 'params') None

Given per key validations, validate this module’s parameters.

Throws exceptions on invalid or nonexistent params.

Parameters:
  • attrib_name – name of the attribute to validate, should be “params” and only for base class “config”

  • validations

    Dictionary with the name of the parameter as key and a list of validations as value. Every validation in this list has to be true for the validation to be successful.

    The value for the validation can have multiple types:
    • A lambda function or other type of callable

    • A string as reference to a predefined validation function with one argument

    • None for existence

    • A tuple with a string as reference to a predefined validation function with one additional argument

    • It is possible to write nested validations, but then every nested validation has to be a tuple, or a tuple of tuples. For convenience, there are implementations for “any”, “all”, “not”, “eq”, “neq”, and “xor”. Those can have data which is a tuple containing other tuples or validations, or a single validation.

    • Lists and other iterables can be validated using “forall” running the given validations for every item in the input. A single validation or a tuple of (nested) validations is accepted as data.

Example

This example is an excerpt of the validation for the BaseModule-configuration.

>>> validations = {
    "device": [
            str,
            ("any",
                [
                    ("in", ["cuda", "cpu"]),
                    ("instance", torch.device)
                ]
            )
        ],
        "print_prio": [("in", PRINT_PRIORITY)],
        "callable": (lambda value: value == 1),
    }

And within the class __init__() call:

>>> self.validate_params()
Raises:

Attributes

device

Get the device of this module.

is_training

Get whether this module is set to training-mode.

module_name

Get the name of the module.

module_type

name

Get the name of the module.

name_safe

Get the escaped name of the module usable in filepaths by replacing spaces and underscores.

precision

Get the (floating point) precision used in multiple parts of this module.

model

func

softmax