dgs.models.embedding_generator.torchreid.TorchreidEmbeddingGenerator

class dgs.models.embedding_generator.torchreid.TorchreidEmbeddingGenerator(*args: Any, **kwargs: Any)[source]

Given image crops, generate embedding using the torchreid package.

The model can use the default pretrained weights or custom weights.

Notes

This model will be set to evaluate only right now! Pretrain your models using the torchreid package and possibly the custom PT21 data loaders, then load the weights. The classifier is not required for embedding generation.

Notes

Setting the parameter embedding_size does not change this module’s output. Torchreid does not support custom embedding sizes.

Module Name

torchreid

Params

model_name (str):

The name of the torchreid model used. Has to be one of ~torchreid.models.__model_factory.keys().

Optional Params

weights (Union[str, FilePath], optional):

A path to the model weights or the string ‘pretrained’ for the default pretrained torchreid model. Default DEF_VAL.embed_gen.torchreid.weights.

image_key (str, optional):

The key of the image to use when generating the embedding. Default DEF_VAL.embed_gen.torchreid.image_key.

Important Inherited Params

nof_classes (int):

The number of classes in the dataset. Used during training to predict the class-id. For most of the pretrained torchreid models, this ist set to 1_000.

__init__(*args, **kwargs)

Methods

configure_torch_module(module[, train])

Set compute mode and send model to the device or multiple parallel devices if applicable.

embedding_key_exists(s)

Return whether the embedding_key of this model exists in a given state.

forward(ds)

Predict embeddings given some input.

predict_embeddings(data)

Predict embeddings given some input.

predict_ids(data)

Predict class IDs given some input.

terminate()

Terminate this module and all of its submodules.

validate_params(validations[, attrib_name])

Given per key validations, validate this module's parameters.

Attributes

device

Get the device of this module.

is_training

Get whether this module is set to training-mode.

module_name

Get the name of the module.

module_type

name

Get the name of the module.

name_safe

Get the escaped name of the module usable in filepaths by replacing spaces and underscores.

precision

Get the (floating point) precision used in multiple parts of this module.

model

embedding_size

The size of the embedding.

nof_classes

The number of classes in the dataset / embedding.