dgs.models.similarity.similarity.SimilarityModule¶
- class dgs.models.similarity.similarity.SimilarityModule(*args: Any, **kwargs: Any)[source]¶
Abstract class for similarity functions.
Params¶
- module_name (str):
The name of the similarity module.
Optional Params¶
- softmax (bool, optional):
Whether to apply the softmax function to the (batched) output of the similarity function. Default
DEF_VAL.similarity.softmax
.- train_key (str, optional):
A name of a
State
property to use to retrieve the data during training. E.g. usage ofState.bbox_relative()
instead of the regular bbox. If this value isn’t set, the regularSimilarityModule.get_data()
call is used.
- __init__(config: dict[str, any], path: list[str])[source]¶
Methods
configure_torch_module
(module[, train])Set compute mode and send model to the device or multiple parallel devices if applicable.
forward
(data, target)Compute the similarity between two input tensors.
get_data
(ds)Get the data used in this similarity module.
get_target
(ds)Get the data used in this similarity module.
get_train_data
(ds)A custom function to get special data for training purposes.
Terminate this module and all of its submodules.
validate_params
(validations[, attrib_name])Given per key validations, validate this module's parameters.
Attributes
Get the device of this module.
Get whether this module is set to training-mode.
Get the name of the module.
Get the name of the module.
Get the escaped name of the module usable in filepaths by replacing spaces and underscores.
Get the (floating point) precision used in multiple parts of this module.