dgs.models.similarity.similarity.SimilarityModule

class dgs.models.similarity.similarity.SimilarityModule(*args: Any, **kwargs: Any)[source]

Abstract class for similarity functions.

Params

module_name (str):

The name of the similarity module.

Optional Params

softmax (bool, optional):

Whether to apply the softmax function to the (batched) output of the similarity function. Default DEF_VAL.similarity.softmax.

train_key (str, optional):

A name of a State property to use to retrieve the data during training. E.g. usage of State.bbox_relative() instead of the regular bbox. If this value isn’t set, the regular SimilarityModule.get_data() call is used.

__init__(config: dict[str, any], path: list[str])[source]

Methods

configure_torch_module(module[, train])

Set compute mode and send model to the device or multiple parallel devices if applicable.

forward(data, target)

Compute the similarity between two input tensors.

get_data(ds)

Get the data used in this similarity module.

get_target(ds)

Get the data used in this similarity module.

get_train_data(ds)

A custom function to get special data for training purposes.

terminate()

Terminate this module and all of its submodules.

validate_params(validations[, attrib_name])

Given per key validations, validate this module's parameters.

Attributes

device

Get the device of this module.

is_training

Get whether this module is set to training-mode.

module_name

Get the name of the module.

module_type

name

Get the name of the module.

name_safe

Get the escaped name of the module usable in filepaths by replacing spaces and underscores.

precision

Get the (floating point) precision used in multiple parts of this module.

softmax