dgs.models.engine.visual_sim_engine.VisualSimilarityEngine¶
- class dgs.models.engine.visual_sim_engine.VisualSimilarityEngine(*args: Any, **kwargs: Any)[source]¶
- An engine class for training and testing visual similarities using visual embeddings. - For this model: - get_data()should return the image crop
- get_target()should return the target class IDs
- train_dlcontains the training data as usual
- test_dlcontains the query data
- val_dlcontains the gallery data
 - Train Params¶- nof_classes (int):
- The number of classes in the training set. 
 - Test Params¶- metric (str|callable):
- The name or class of the metric used during testing / evaluation. The metric in the - VisualSimilarityEngineis only used to compute the distance between the query and gallery embeddings. Therefore, a distance-based metric should be used.- It is possible to pass additional initialization kwargs to the metric by adding them to the - metric_kwargsparameter.
 - Optional Train Params¶- topk_acc (list[int], optional):
- The values for k for the top-k accuracy evaluation during training. Default - DEF_VAL.engine.visual.topk_acc.
 - Optional Test Params¶- metric_kwargs (dict, optional):
- Specific kwargs for the metric. Default - DEF_VAL.engine.visual.metric_kwargs.
- topk_cmc (list[int], optional):
- The values for k the top-k cmc evaluation during testing / evaluation. Default - DEF_VAL.engine.visual.topk_cmc.
- write_embeds (list[bool, bool], optional):
- Whether to write the embeddings for the Query and Gallery Dataset to the tensorboard writer. Only really feasible for smaller datasets ~1k embeddings. Default - DEF_VAL.engine.visual.write_embeds.
- image_key (str, optional):
- Which key to use when loading the image from the state in - get_data(). Default- DEF_VAL.engine.visual.image_key.
 - Methods - __init__(config: dict[str, any], model: TorchreidVisualSimilarity, test_loader: torch.utils.data.DataLoader, val_loader: torch.utils.data.DataLoader, *, train_loader: torch.utils.data.DataLoader | None = None, **kwargs)[source]¶
 - configure_torch_module(module: torch.nn.Module, train: bool | None = None) torch.nn.Module¶
- Set compute mode and send model to the device or multiple parallel devices if applicable. - Parameters:
- module – The torch module instance to configure. 
- train – Whether to train or eval this module, defaults to the value set in the base config. 
 
- Returns:
- The module on the specified device or in parallel. 
 
 - evaluate() dict[str, any][source]¶
- Run tests, defined in Sub-Engine. - Returns:
- A dictionary containing all the computed accuracies, metrics, … . 
- Return type:
- dict[str, any] 
 
 - get_data(ds: State) torch.Tensor[source]¶
- Get the image crop or other requested image from the state. 
 - get_hparam_dict() dict[str, any]¶
- Get the hyperparameters of the current engine. Child-modules can inherit this method and add additional hyperparameters. - By default, all parameters from test and training are added to the hparam_dict. 
 - initialize_optimizer() None¶
- Because the module might be set after the initial step, load the optimizer and scheduler at the start of the training. 
 - load_model(path: str) None¶
- Load the model from a file. Set the start epoch to the epoch + 1 of the specified in the loaded model. - Notes - Loads the states of the - optimizerand- lr_schedulerif they are present in the engine (e.g. during training) and the respective data is given in the checkpoint at the- path.- Parameters:
- path – The path to the checkpoint where this model was saved. 
 
 - predict() torch.Tensor¶
- Predict the visual embeddings for the test data. - Notes - Depending on the number of predictions ( - N) and the embeddings size (- E), the resulting tensor(s) can get incredibly huge. The prediction for the validation data of the- PoseTrack21dataset is roughly 300MB.- Returns:
- The predicted embeddings as tensor of shape: - [N x E]
- Return type:
- torch.Tensor 
 
 - print_results(results: dict[str, any]) None¶
- Given a dictionary of results, print them to the console if allowed. 
 - run() None¶
- Run the model. First train, then test! 
 - save_model(epoch: int, metrics: dict[str, any], optimizer: torch.optim.Optimizer, lr_sched: torch.optim.lr_scheduler.LRScheduler) None¶
- Save the current model and other weights into a ‘.pth’ file. - Parameters:
- epoch – The epoch this model is saved. 
- metrics – A dict containing the computed metrics for this module. 
- optimizer – The current optimizer 
- lr_sched – The current learning rate scheduler. 
 
 
 - set_model_mode(mode: str) None¶
- Set model mode to train or test. 
 - terminate() None¶
- Handle forceful termination, e.g., ctrl+c 
 - test() dict[str, any]¶
- Test the embeddings predicted by the model on the Test-DataLoader. - Compute Rank-N for every rank in - self.topk_cmc. Compute mean average precision of predicted target labels.
 - train_model() torch.optim.Optimizer¶
- Train the given model using the given loss function, optimizer, and learning-rate schedulers. - After every epoch, the current model is tested and the current model is saved. - Returns:
- The current optimizer after training. 
 
 - validate_params(validations: dict[str, list[str | type | tuple[str, any] | Callable[[any, any], bool]]], attrib_name: str = 'params') None¶
- Given per key validations, validate this module’s parameters. - Throws exceptions on invalid or nonexistent params. - Parameters:
- attrib_name – name of the attribute to validate, should be “params” and only for base class “config” 
- validations – - Dictionary with the name of the parameter as key and a list of validations as value. Every validation in this list has to be true for the validation to be successful. - The value for the validation can have multiple types:
- A lambda function or other type of callable 
- A string as reference to a predefined validation function with one argument 
- None for existence 
- A tuple with a string as reference to a predefined validation function with one additional argument 
- It is possible to write nested validations, but then every nested validation has to be a tuple, or a tuple of tuples. For convenience, there are implementations for “any”, “all”, “not”, “eq”, “neq”, and “xor”. Those can have data which is a tuple containing other tuples or validations, or a single validation. 
- Lists and other iterables can be validated using “forall” running the given validations for every item in the input. A single validation or a tuple of (nested) validations is accepted as data. 
 
 
 
 - Example - This example is an excerpt of the validation for the BaseModule-configuration. - >>> validations = { "device": [ str, ("any", [ ("in", ["cuda", "cpu"]), ("instance", torch.device) ] ) ], "print_prio": [("in", PRINT_PRIORITY)], "callable": (lambda value: value == 1), } - And within the class - __init__()call:- >>> self.validate_params() - Raises:
- InvalidParameterException – If one of the parameters is invalid. 
- ValidationException – If the validation list is invalid or contains an unknown validation. 
 
 
 - write_results(results: dict[str, any], prepend: str) None¶
- Given a dictionary of results, use the writer to save the values. - Attributes - Get the device of this module. - Get whether this module is set to training-mode. - Get the name of the module. - Get the name of the module. - Get the escaped name of the module usable in filepaths by replacing spaces and underscores. - Get the (floating point) precision used in multiple parts of this module. - Whether to load the image crops during training. - The torch DataLoader containing the validation (query) data. - A metric function used to compute the embedding distance. - The torch DataLoader containing the test data. - The torch DataLoader containing the training data.