Source code for dgs.models.engine.visual_sim_engine

"""
Engine for training and testing visual embedding modules.

Notes:
    Kind of obsolete, due to being able to use the engines from |torchreid|_ to train visual embedding models.
"""

import time
from datetime import timedelta

import torch as t
from torch.utils.data import DataLoader as TDataLoader
from tqdm import tqdm

from dgs.models.engine.engine import EngineModule
from dgs.models.metric import get_metric, metric, METRICS
from dgs.models.module import enable_keyboard_interrupt
from dgs.models.similarity.torchreid import TorchreidVisualSimilarity
from dgs.utils.config import DEF_VAL
from dgs.utils.state import State
from dgs.utils.timer import DifferenceTimer
from dgs.utils.types import Config, Metric, Results, Validations

train_validations: Validations = {
    "nof_classes": [int, ("gt", 0)],
    # optional
    "topk_acc": ["optional", ("forall", [int, ("gt", 0)])],
}

test_validations: Validations = {
    "metric": [("any", ["callable", ("in", METRICS.keys())])],
    # optional
    "metric_kwargs": ["optional", dict],
    "topk_cmc": ["optional", ("forall", [int, ("gt", 0)])],
    "write_embeds": ["optional", ("len", 2), ("forall", bool)],
    "image_key": ["optional", str],
}


[docs] class VisualSimilarityEngine(EngineModule): """An engine class for training and testing visual similarities using visual embeddings. For this model: - ``get_data()`` should return the image crop - ``get_target()`` should return the target class IDs - ``train_dl`` contains the training data as usual - ``test_dl`` contains the query data - ``val_dl`` contains the gallery data Train Params ------------ nof_classes (int): The number of classes in the training set. Test Params ----------- metric (str|callable): The name or class of the metric used during testing / evaluation. The metric in the ``VisualSimilarityEngine`` is only used to compute the distance between the query and gallery embeddings. Therefore, a distance-based metric should be used. It is possible to pass additional initialization kwargs to the metric by adding them to the ``metric_kwargs`` parameter. Optional Train Params --------------------- topk_acc (list[int], optional): The values for k for the top-k accuracy evaluation during training. Default ``DEF_VAL.engine.visual.topk_acc``. Optional Test Params -------------------- metric_kwargs (dict, optional): Specific kwargs for the metric. Default ``DEF_VAL.engine.visual.metric_kwargs``. topk_cmc (list[int], optional): The values for k the top-k cmc evaluation during testing / evaluation. Default ``DEF_VAL.engine.visual.topk_cmc``. write_embeds (list[bool, bool], optional): Whether to write the embeddings for the Query and Gallery Dataset to the tensorboard writer. Only really feasible for smaller datasets ~1k embeddings. Default ``DEF_VAL.engine.visual.write_embeds``. image_key (str, optional): Which key to use when loading the image from the state in :meth:`get_data`. Default ``DEF_VAL.engine.visual.image_key``. """ # The heart of the project might get a little larger... # pylint: disable=too-many-arguments val_dl: TDataLoader """The torch DataLoader containing the validation (query) data.""" model: TorchreidVisualSimilarity metric: Metric """A metric function used to compute the embedding distance."""
[docs] def __init__( self, config: Config, model: TorchreidVisualSimilarity, test_loader: TDataLoader, val_loader: TDataLoader, *, train_loader: TDataLoader = None, **kwargs, ): super().__init__(config=config, model=model, test_loader=test_loader, train_loader=train_loader, **kwargs) self.val_dl = val_loader self.validate_params(test_validations, "params_test") self.topk_cmc: list[int] = self.params_test.get("topk_cmc", DEF_VAL["engine"]["visual"]["topk_cmc"]) # get metric and kwargs self.metric = get_metric(self.params_test["metric"])( **self.params_test.get("metric_kwargs", DEF_VAL["engine"]["visual"]["metric_kwargs"]) ) self.image_key: str = self.params_test.get("image_key", DEF_VAL["engine"]["visual"]["image_key"]) if self.is_training: self.validate_params(train_validations, attrib_name="params_train") self.nof_classes: int = self.params_train["nof_classes"] self.topk_acc: list[int] = self.params_train.get("topk_acc", DEF_VAL["engine"]["visual"]["topk_acc"])
[docs] def get_target(self, ds: State) -> t.Tensor: """Get the target pIDs from the data.""" return ds["class_id"].long()
[docs] def get_data(self, ds: State) -> t.Tensor: """Get the image crop or other requested image from the state.""" return ds[self.image_key]
@enable_keyboard_interrupt def _get_train_loss(self, data: State, _curr_iter: int) -> t.Tensor: target_ids = self.get_target(data) crops = self.get_data(data) pred_id_probs = self.model.predict_ids(crops) loss = self.loss(pred_id_probs, target_ids) topk_accuracies = metric.compute_accuracy(prediction=pred_id_probs, target=target_ids, topk=self.topk_acc) self.writer.add_scalars( main_tag="Train/acc", tag_scalar_dict={str(k): v for k, v in topk_accuracies.items()}, global_step=_curr_iter, ) return loss @t.no_grad() @enable_keyboard_interrupt def _extract_data(self, dl: TDataLoader, desc: str, write_embeds: bool = False) -> tuple[t.Tensor, t.Tensor]: """Given a dataloader, extract the embeddings describing the people and the target pIDs using the model. Additionally, compute the accuracy and send the embeddings to the writer. Args: dl: The DataLoader to extract the data from. desc: A description for printing, writing, and saving the data. write_embeds: Whether to write the embeddings to the tensorboard writer. Only "smaller" Datasets should be added. Default False. Returns: embeddings, target_ids """ embed_l: list[t.Tensor] = [] t_ids_l: list[t.Tensor] = [] imgs_l: list[t.Tensor] = [] batch_t: DifferenceTimer = DifferenceTimer() batch: State for batch_idx, batch in tqdm(enumerate(dl), desc=f"Extract {desc}", total=len(dl)): # batch start time_batch_start = time.time() # reset timer for retrieving the data curr_iter = (self.curr_epoch - 1) * len(dl) + batch_idx # Extract the (cropped) input image and the target pID. # Then use the model to compute the predicted embedding and the predicted pID probabilities. t_id = self.get_target(batch) img_crop = self.get_data(batch) embed = self.model.get_data(batch) # keep the results in lists embed_l.append(embed) t_ids_l.append(t_id) if write_embeds: imgs_l.append(img_crop) # timing batch_t.add(time_batch_start) self.writer.add_scalars( main_tag="Test/time", tag_scalar_dict={f"batch_{desc}": batch_t[-1], f"indiv_{desc}": batch_t[-1] / len(batch)}, global_step=curr_iter, ) del t_id, embed, img_crop # concatenate the result lists p_embed: t.Tensor = t.cat(embed_l) # 2D gt embeddings [N, E] t_ids: t.Tensor = t.cat(t_ids_l) # 1D gt person labels [N] N: int = len(t_ids) assert len(t_ids) == len(p_embed), f"tids: {len(t_ids)}, embed: {len(p_embed)}" self.logger.debug(f"{desc} - Shapes - embeddings: {p_embed.shape}, target pIDs: {t_ids.shape}") del embed_l, t_ids_l # normalize the predicted embeddings if wanted p_embed = self._normalize_test(p_embed) if write_embeds: # write embedding results - take only the first 32x32 due to limitations in tensorboard self.logger.info("Add embeddings to writer.") self.writer.add_embedding( mat=p_embed[: min(512, N), :], metadata=t_ids[: min(512, N)].tolist(), label_img=t.cat(imgs_l)[: min(512, N)] if imgs_l else None, # 4D images [N x C x h x w] tag=f"Test/{desc}_embeds_{self.curr_epoch}", ) assert isinstance(p_embed, t.Tensor), f"p_embed is {p_embed}" assert isinstance(t_ids, t.Tensor), f"t_ids is {t_ids}" return p_embed, t_ids @t.no_grad() def test(self) -> dict[str, any]: r"""Test the embeddings predicted by the model on the Test-DataLoader. Compute Rank-N for every rank in ``self.topk_cmc``. Compute mean average precision of predicted target labels. """ results: dict[str, any] = {} self.set_model_mode("eval") start_time: float = time.time() self.logger.info(f"#### Start Evaluating {self.name} - Epoch {self.curr_epoch} ####") self.logger.info("Loading, extracting, and predicting data, this might take a while...") q_embed, q_t_ids = self._extract_data( dl=self.test_dl, desc="Query", write_embeds=self.params_test.get("write_embeds", DEF_VAL["engine"]["visual"]["write_embeds"])[0], ) g_embed, g_t_ids = self._extract_data( dl=self.val_dl, desc="Gallery", write_embeds=self.params_test.get("write_embeds", DEF_VAL["engine"]["visual"]["write_embeds"])[1], ) self.logger.debug("Use metric to compute the distance matrix.") distance_matrix = self.metric(q_embed, g_embed) self.logger.debug(f"Shape of distance matrix: {distance_matrix.shape}") self.logger.debug("Computing CMC") results["cmc"] = metric.compute_cmc( distmat=distance_matrix, query_pids=q_t_ids, gallery_pids=g_t_ids, ranks=self.topk_cmc, ) # DUPLICATE # results["cmc_inv"] = metric.compute_cmc( distmat=self.metric(g_embed, q_embed), query_pids=g_t_ids, gallery_pids=q_t_ids, ranks=self.topk_cmc, ) self.print_results(results) self.write_results(results, prepend="Test") self.logger.info(f"Test time total: {str(timedelta(seconds=round(time.time() - start_time)))}") self.logger.info(f"#### Evaluation of {self.name} complete ####") return results
[docs] def evaluate(self) -> Results: raise NotImplementedError
@t.no_grad() def predict(self) -> t.Tensor: """Predict the visual embeddings for the test data. Notes: Depending on the number of predictions (``N``) and the embeddings size (``E``), the resulting tensor(s) can get incredibly huge. The prediction for the validation data of the |PT21| dataset is roughly 300MB. Returns: torch.Tensor: The predicted embeddings as tensor of shape: ``[N x E]`` """ self.set_model_mode("eval") start_time: float = time.time() self.logger.info(f"#### Start Prediction {self.name} ####") embeds, _ = self._extract_data( dl=self.test_dl, desc="Predict", write_embeds=self.params_test.get("write_embeds", DEF_VAL["engine"]["visual"]["write_embeds"])[0], ) self.logger.info(f"Predict time total: {str(timedelta(seconds=round(time.time() - start_time)))}") self.logger.info(f"#### Prediction of {self.name} complete ####") return embeds