dgs.models.engine.visual_sim_engine.VisualSimilarityEngine¶
- class dgs.models.engine.visual_sim_engine.VisualSimilarityEngine(*args: Any, **kwargs: Any)[source]¶
An engine class for training and testing visual similarities using visual embeddings.
For this model:
get_data()
should return the image cropget_target()
should return the target class IDstrain_dl
contains the training data as usualtest_dl
contains the query dataval_dl
contains the gallery data
Train Params¶
- nof_classes (int):
The number of classes in the training set.
Test Params¶
- metric (str|callable):
The name or class of the metric used during testing / evaluation. The metric in the
VisualSimilarityEngine
is only used to compute the distance between the query and gallery embeddings. Therefore, a distance-based metric should be used.It is possible to pass additional initialization kwargs to the metric by adding them to the
metric_kwargs
parameter.
Optional Train Params¶
- topk_acc (list[int], optional):
The values for k for the top-k accuracy evaluation during training. Default
DEF_VAL.engine.visual.topk_acc
.
Optional Test Params¶
- metric_kwargs (dict, optional):
Specific kwargs for the metric. Default
DEF_VAL.engine.visual.metric_kwargs
.- topk_cmc (list[int], optional):
The values for k the top-k cmc evaluation during testing / evaluation. Default
DEF_VAL.engine.visual.topk_cmc
.- write_embeds (list[bool, bool], optional):
Whether to write the embeddings for the Query and Gallery Dataset to the tensorboard writer. Only really feasible for smaller datasets ~1k embeddings. Default
DEF_VAL.engine.visual.write_embeds
.- image_key (str, optional):
Which key to use when loading the image from the state in
get_data()
. DefaultDEF_VAL.engine.visual.image_key
.
- __init__(config: dict[str, any], model: TorchreidVisualSimilarity, test_loader: torch.utils.data.DataLoader, val_loader: torch.utils.data.DataLoader, *, train_loader: torch.utils.data.DataLoader | None = None, **kwargs)[source]¶
Methods
configure_torch_module
(module[, train])Set compute mode and send model to the device or multiple parallel devices if applicable.
evaluate
()Run tests, defined in Sub-Engine.
get_data
(ds)Get the image crop or other requested image from the state.
Get the hyperparameters of the current engine.
get_target
(ds)Get the target pIDs from the data.
Because the module might be set after the initial step, load the optimizer and scheduler at the start of the training.
load_model
(path)Load the model from a file.
print_results
(results)Given a dictionary of results, print them to the console if allowed.
run
()Run the model.
save_model
(epoch, metrics, optimizer, lr_sched)Save the current model and other weights into a '.pth' file.
set_model_mode
(mode)Set model mode to train or test.
Handle forceful termination, e.g., ctrl+c
Train the given model using the given loss function, optimizer, and learning-rate schedulers.
validate_params
(validations[, attrib_name])Given per key validations, validate this module's parameters.
write_results
(results, prepend)Given a dictionary of results, use the writer to save the values.
Attributes
Get the device of this module.
Get whether this module is set to training-mode.
Get the name of the module.
Get the name of the module.
Get the escaped name of the module usable in filepaths by replacing spaces and underscores.
Get the (floating point) precision used in multiple parts of this module.
Whether to load the image crops during training.
The torch DataLoader containing the validation (query) data.
A metric function used to compute the embedding distance.
The torch DataLoader containing the test data.
The torch DataLoader containing the training data.