dgs.models.embedding_generator.torchreid.TorchreidEmbeddingGenerator¶
- class dgs.models.embedding_generator.torchreid.TorchreidEmbeddingGenerator(*args: Any, **kwargs: Any)[source]¶
Given image crops, generate embedding using the torchreid package.
The model can use the default pretrained weights or custom weights.
Notes
This model will be set to evaluate only right now! Pretrain your models using the
torchreid
package and possibly the custom PT21 data loaders, then load the weights. The classifier is not required for embedding generation.Notes
Setting the parameter
embedding_size
does not change this module’s output. Torchreid does not support custom embedding sizes.Module Name¶
torchreid
- Params:
model_name (str) – The name of the torchreid model used. Has to be one of
~torchreid.models.__model_factory.keys()
.- Optional Params:
weights (Union[str, FilePath], optional) – A path to the model weights or the string ‘pretrained’ for the default pretrained torchreid model. Default
DEF_VAL.embed_gen.torchreid.weights
.image_key (str, optional) – The key of the image to use when generating the embedding. Default
DEF_VAL.embed_gen.torchreid.image_key
.
- Important Inherited Params:
nof_classes (int) – The number of classes in the dataset. Used during training to predict the class-id. For most of the pretrained torchreid models, this ist set to
1_000
.
Methods
- __init__(*args, **kwargs)¶
- configure_torch_module(module: torch.nn.Module, train: bool | None = None) torch.nn.Module ¶
Set compute mode and send model to the device or multiple parallel devices if applicable.
- Parameters:
module – The torch module instance to configure.
train – Whether to train or eval this module, defaults to the value set in the base config.
- Returns:
The module on the specified device or in parallel.
- embedding_key_exists(s: State) bool ¶
Return whether the embedding_key of this model exists in a given state.
- forward(ds: State) torch.Tensor [source]¶
Predict embeddings given some input.
Notes
Torchreid models will return different results based on whether they are in eval or training mode. Make sure forward is only called in the evaluation mode.
- Parameters:
ds – A
State
containing the cropped image as input for the model.Image
or FloatTensor of shape[B x C x w x h]
.- Returns:
[B x E]
.- Return type:
A batch of embeddings as tensor of shape
- predict_embeddings(data: torch.Tensor) torch.Tensor [source]¶
Predict embeddings given some input.
- Parameters:
data – The input for the model, most likely a cropped image.
- Returns:
Tensor containing a batch B of embeddings. Shape:
[B x E]
- predict_ids(data: torch.Tensor) torch.Tensor [source]¶
Predict class IDs given some input.
- Parameters:
data – The input for the model, most likely a cropped image.
- Returns:
Tensor containing class predictions, which are not necessarily a probability distribution. Shape:
[B x num_classes]
- terminate() None ¶
Terminate this module and all of its submodules.
If nothing has to be done, just pass. Is used for terminating parallel execution and threads in specific models.
- validate_params(validations: dict[str, list[str | type | tuple[str, any] | Callable[[any, any], bool]]], attrib_name: str = 'params') None ¶
Given per key validations, validate this module’s parameters.
Throws exceptions on invalid or nonexistent params.
- Parameters:
attrib_name – name of the attribute to validate, should be “params” and only for base class “config”
validations –
Dictionary with the name of the parameter as key and a list of validations as value. Every validation in this list has to be true for the validation to be successful.
- The value for the validation can have multiple types:
A lambda function or other type of callable
A string as reference to a predefined validation function with one argument
None for existence
A tuple with a string as reference to a predefined validation function with one additional argument
It is possible to write nested validations, but then every nested validation has to be a tuple, or a tuple of tuples. For convenience, there are implementations for “any”, “all”, “not”, “eq”, “neq”, and “xor”. Those can have data which is a tuple containing other tuples or validations, or a single validation.
Lists and other iterables can be validated using “forall” running the given validations for every item in the input. A single validation or a tuple of (nested) validations is accepted as data.
Example
This example is an excerpt of the validation for the BaseModule-configuration.
>>> validations = { "device": [ str, ("any", [ ("in", ["cuda", "cpu"]), ("instance", torch.device) ] ) ], "print_prio": [("in", PRINT_PRIORITY)], "callable": (lambda value: value == 1), }
And within the class
__init__()
call:>>> self.validate_params()
- Raises:
InvalidParameterException – If one of the parameters is invalid.
ValidationException – If the validation list is invalid or contains an unknown validation.
Attributes
Get the device of this module.
Get whether this module is set to training-mode.
Get the name of the module.
Get the name of the module.
Get the escaped name of the module usable in filepaths by replacing spaces and underscores.
Get the (floating point) precision used in multiple parts of this module.
The size of the embedding.
The number of classes in the dataset / embedding.