dgs.models.engine.visual_sim_engine.VisualSimilarityEngine

class dgs.models.engine.visual_sim_engine.VisualSimilarityEngine(*args: Any, **kwargs: Any)[source]

An engine class for training and testing visual similarities using visual embeddings.

For this model:

  • get_data() should return the image crop

  • get_target() should return the target class IDs

  • train_dl contains the training data as usual

  • test_dl contains the query data

  • val_dl contains the gallery data

Train Params

nof_classes (int):

The number of classes in the training set.

Test Params

metric (str|callable):

The name or class of the metric used during testing / evaluation. The metric in the VisualSimilarityEngine is only used to compute the distance between the query and gallery embeddings. Therefore, a distance-based metric should be used.

It is possible to pass additional initialization kwargs to the metric by adding them to the metric_kwargs parameter.

Optional Train Params

topk_acc (list[int], optional):

The values for k for the top-k accuracy evaluation during training. Default DEF_VAL.engine.visual.topk_acc.

Optional Test Params

metric_kwargs (dict, optional):

Specific kwargs for the metric. Default DEF_VAL.engine.visual.metric_kwargs.

topk_cmc (list[int], optional):

The values for k the top-k cmc evaluation during testing / evaluation. Default DEF_VAL.engine.visual.topk_cmc.

write_embeds (list[bool, bool], optional):

Whether to write the embeddings for the Query and Gallery Dataset to the tensorboard writer. Only really feasible for smaller datasets ~1k embeddings. Default DEF_VAL.engine.visual.write_embeds.

image_key (str, optional):

Which key to use when loading the image from the state in get_data(). Default DEF_VAL.engine.visual.image_key.

__init__(config: dict[str, any], model: TorchreidVisualSimilarity, test_loader: torch.utils.data.DataLoader, val_loader: torch.utils.data.DataLoader, *, train_loader: torch.utils.data.DataLoader | None = None, **kwargs)[source]

Methods

configure_torch_module(module[, train])

Set compute mode and send model to the device or multiple parallel devices if applicable.

evaluate()

Run tests, defined in Sub-Engine.

get_data(ds)

Get the image crop or other requested image from the state.

get_hparam_dict()

Get the hyperparameters of the current engine.

get_target(ds)

Get the target pIDs from the data.

initialize_optimizer()

Because the module might be set after the initial step, load the optimizer and scheduler at the start of the training.

load_model(path)

Load the model from a file.

predict

print_results(results)

Given a dictionary of results, print them to the console if allowed.

run()

Run the model.

save_model(epoch, metrics, optimizer, lr_sched)

Save the current model and other weights into a '.pth' file.

set_model_mode(mode)

Set model mode to train or test.

terminate()

Handle forceful termination, e.g., ctrl+c

test

train_model()

Train the given model using the given loss function, optimizer, and learning-rate schedulers.

validate_params(validations[, attrib_name])

Given per key validations, validate this module's parameters.

write_results(results, prepend)

Given a dictionary of results, use the writer to save the values.

Attributes

curr_epoch

device

Get the device of this module.

is_training

Get whether this module is set to training-mode.

module_name

Get the name of the module.

module_type

name

Get the name of the module.

name_safe

Get the escaped name of the module usable in filepaths by replacing spaces and underscores.

precision

Get the (floating point) precision used in multiple parts of this module.

train_load_image_crops

Whether to load the image crops during training.

val_dl

The torch DataLoader containing the validation (query) data.

model

metric

A metric function used to compute the embedding distance.

loss

writer

test_dl

The torch DataLoader containing the test data.

train_dl

The torch DataLoader containing the training data.