Source code for dgs.models.embedding_generator.torchreid

"""
Use a model out of the torchreid package as an embedding generator.
"""

import logging
import warnings

import torch as t
from torch import nn

from dgs.models.embedding_generator.embedding_generator import EmbeddingGeneratorModule
from dgs.utils.config import DEF_VAL, get_sub_config, insert_into_config
from dgs.utils.exceptions import InvalidPathException
from dgs.utils.files import to_abspath
from dgs.utils.state import State
from dgs.utils.torchtools import configure_torch_module, load_pretrained_weights
from dgs.utils.types import Config

with warnings.catch_warnings():
    # ignore cython warning
    warnings.filterwarnings("ignore", message="Cython evaluation.*is unavailable", category=UserWarning)
    try:
        # If torchreid is installed using `./dependencies/torchreid`
        # noinspection PyUnresolvedReferences
        from torchreid.models import __model_factory as torchreid_models, build_model
    except ModuleNotFoundError:
        # if torchreid is installed using `pip install torchreid`
        # noinspection PyUnresolvedReferences
        from torchreid.reid.models import __model_factory as torchreid_models, build_model

torchreid_validations: Config = {
    "model_name": [str, ("in", torchreid_models.keys())],
    # optional
    "weights": [
        "optional",
        (
            "or",
            [("eq", "pretrained"), "file exists", "file exists in project", ("file exists in folder", "./weights/")],
        ),
    ],
    "image_key": ["optional", str],
}


[docs] @configure_torch_module class TorchreidEmbeddingGenerator(EmbeddingGeneratorModule): """Given image crops, generate embedding using the torchreid package. The model can use the default pretrained weights or custom weights. Notes: This model will be set to evaluate only right now! Pretrain your models using the |torchreid| package and possibly the custom PT21 data loaders, then load the weights. The classifier is not required for embedding generation. Notes: Setting the parameter ``embedding_size`` does not change this module's output. Torchreid does not support custom embedding sizes. Module Name ----------- torchreid Params ------ model_name (str): The name of the torchreid model used. Has to be one of ``~torchreid.models.__model_factory.keys()``. Optional Params --------------- weights (Union[str, FilePath], optional): A path to the model weights or the string 'pretrained' for the default pretrained torchreid model. Default ``DEF_VAL.embed_gen.torchreid.weights``. image_key (str, optional): The key of the image to use when generating the embedding. Default ``DEF_VAL.embed_gen.torchreid.image_key``. Important Inherited Params -------------------------- nof_classes (int): The number of classes in the dataset. Used during training to predict the class-id. For most of the pretrained torchreid models, this ist set to ``1_000``. """ model: nn.Module def __init__(self, config, path): if path is None: raise InvalidPathException("path is required but got None") sub_cfg = get_sub_config(config, path) if "embedding_size" in sub_cfg and sub_cfg["embedding_size"] != 512: warnings.warn( "The embedding size will be overwritten in torchreid embedding generators, " "because torchreid does not support different sizes." ) new_cfg = insert_into_config(path=path, value={"embedding_size": 512}, original=config) del config EmbeddingGeneratorModule.__init__(self, config=new_cfg, path=path) self.model_weights = self.params.get("weights", DEF_VAL["embed_gen"]["torchreid"]["weights"]) model = self._init_model(self.model_weights == "pretrained") self.register_module(name="model", module=self.configure_torch_module(model)) self.image_key = self.params.get("image_key", DEF_VAL["embed_gen"]["torchreid"]["image_key"]) def _init_model(self, pretrained: bool) -> nn.Module: """Initialize torchreid model""" m = build_model( name=self.params["model_name"], num_classes=self.params["nof_classes"], pretrained=pretrained, use_gpu=self.device.type == "cuda", loss="triplet", ) if not pretrained: # pragma: no cover # custom model params load_pretrained_weights(m, to_abspath(self.model_weights), verbose=self.logger.isEnabledFor(logging.DEBUG)) return m
[docs] def predict_embeddings(self, data: t.Tensor) -> t.Tensor: """Predict embeddings given some input. Args: data: The input for the model, most likely a cropped image. Returns: Tensor containing a batch B of embeddings. Shape: ``[B x E]`` """ def _get_torchreid_embeds(r) -> t.Tensor: """Torchreid returns embeddings during eval and ids during training.""" if isinstance(r, t.Tensor): # During model building, triplet loss was forced for torchreid models. # Therefore, only one return value means that only the embeddings are returned return r if len(r) == 2: _, embeddings = r return embeddings raise NotImplementedError("Unknown torchreid model output.") results = self.model(data) return _get_torchreid_embeds(results)
[docs] def predict_ids(self, data: t.Tensor) -> t.Tensor: """Predict class IDs given some input. Args: data: The input for the model, most likely a cropped image. Returns: Tensor containing class predictions, which are not necessarily a probability distribution. Shape: ``[B x num_classes]`` """ def _get_torchreid_ids(r) -> t.Tensor: """Torchreid returns embeddings during eval and ids during training.""" if isinstance(r, t.Tensor): # During model building, triplet loss was forced for torchreid models. # Therefore, only one return value means that only the embeddings are returned return self.model.classifier(r) if len(r) == 2: ids, _ = r return ids raise NotImplementedError("Unknown torchreid model output.") results = self.model(data) return _get_torchreid_ids(results)
[docs] def forward(self, ds: State) -> t.Tensor: """Predict embeddings given some input. Notes: Torchreid models will return different results based on whether they are in eval or training mode. Make sure forward is only called in the evaluation mode. Args: ds: A :class:`State` containing the cropped image as input for the model. :class:`Image` or FloatTensor of shape ``[B x C x w x h]``. Returns: A batch of embeddings as tensor of shape: ``[B x E]``. """ if self.embedding_key_exists(ds): return ds[self.embedding_key] embeddings = self.model(getattr(ds, self.image_key) if hasattr(ds, self.image_key) else ds[self.image_key]) if self.save_embeddings: ds[self.embedding_key] = embeddings return embeddings