dgs.models.similarity.torchreid.TorchreidVisualSimilarity

class dgs.models.similarity.torchreid.TorchreidVisualSimilarity(*args: Any, **kwargs: Any)[source]

Given image crops, generate Re-ID embedding using the torchreid package.

Model can use the default pretrained weights or custom weights.

Notes

This model cannot be trained right now! Pretrain your models using the torchreid package and possibly the custom PT21 data loaders, then load the weights.

Notes

For computing the similarity during evaluation, most models should re-use the distance function used during training.

Params

metric (str):

The name of the metric to use. Has to be one of METRICS

embedding_generator_path (Path):

The path to the configuration of the embedding generator within the config.

Optional Params

metric_kwargs (dict, optional):

Possibly pass additional kwargs to the similarity function. Default DEF_VAL.similarity.torchreid.sim_kwargs.

__init__(*args, **kwargs)

Methods

configure_torch_module(module[, train])

Set compute mode and send model to the device or multiple parallel devices if applicable.

forward(data, target)

Forward call of the torchreid model used to compute the similarities between visual embeddings.

get_data(ds)

Given a State get the current embedding or compute it using the image crop.

get_target(ds)

Given a State get the target embedding or compute it using the image crop.

get_train_data(ds)

A custom function to get special data for training purposes.

terminate()

Terminate this module and all of its submodules.

validate_params(validations[, attrib_name])

Given per key validations, validate this module's parameters.

Attributes

device

Get the device of this module.

is_training

Get whether this module is set to training-mode.

module_name

Get the name of the module.

module_type

name

Get the name of the module.

name_safe

Get the escaped name of the module usable in filepaths by replacing spaces and underscores.

precision

Get the (floating point) precision used in multiple parts of this module.

model

func

softmax