dgs.models.similarity.torchreid.TorchreidVisualSimilarity¶
- class dgs.models.similarity.torchreid.TorchreidVisualSimilarity(*args: Any, **kwargs: Any)[source]¶
Given image crops, generate Re-ID embedding using the torchreid package.
Model can use the default pretrained weights or custom weights.
Notes
This model cannot be trained right now! Pretrain your models using the
torchreid
package and possibly the custom PT21 data loaders, then load the weights.Notes
For computing the similarity during evaluation, most models should re-use the distance function used during training.
Params¶
- metric (str):
The name of the metric to use. Has to be one of
METRICS
- embedding_generator_path (
Path
): The path to the configuration of the embedding generator within the config.
Optional Params¶
- metric_kwargs (dict, optional):
Possibly pass additional kwargs to the similarity function. Default
DEF_VAL.similarity.torchreid.sim_kwargs
.
- __init__(*args, **kwargs)¶
Methods
configure_torch_module
(module[, train])Set compute mode and send model to the device or multiple parallel devices if applicable.
forward
(data, target)Forward call of the torchreid model used to compute the similarities between visual embeddings.
get_data
(ds)Given a
State
get the current embedding or compute it using the image crop.get_target
(ds)Given a
State
get the target embedding or compute it using the image crop.get_train_data
(ds)A custom function to get special data for training purposes.
Terminate this module and all of its submodules.
validate_params
(validations[, attrib_name])Given per key validations, validate this module's parameters.
Attributes
Get the device of this module.
Get whether this module is set to training-mode.
Get the name of the module.
Get the name of the module.
Get the escaped name of the module usable in filepaths by replacing spaces and underscores.
Get the (floating point) precision used in multiple parts of this module.