Source code for dgs.models.similarity.similarity
"""Base class for Modules that compute any similarity."""
from abc import abstractmethod
import torch as t
from torch import nn
from dgs.models.module import enable_keyboard_interrupt
from dgs.models.modules.named import NamedModule
from dgs.utils.config import DEF_VAL
from dgs.utils.state import State
from dgs.utils.types import Config, NodePath, Validations
similarity_validations: Validations = {
# optional
"softmax": ["optional", bool],
"train_key": ["optional", str],
}
[docs]
class SimilarityModule(NamedModule, nn.Module):
"""Abstract class for similarity functions.
Params
------
module_name (str):
The name of the similarity module.
Optional Params
---------------
softmax (bool, optional):
Whether to apply the softmax function to the (batched) output of the similarity function.
Default ``DEF_VAL.similarity.softmax``.
train_key (str, optional):
A name of a :class:`State` property to use to retrieve the data during training.
E.g. usage of :meth:`State.bbox_relative` instead of the regular bbox.
If this value isn't set, the regular :meth:`SimilarityModule.get_data` call is used.
"""
softmax: nn.Sequential
[docs]
def __init__(self, config: Config, path: NodePath):
NamedModule.__init__(self, config, path)
nn.Module.__init__(self)
self.validate_params(similarity_validations)
softmax = nn.Sequential()
if self.params.get("softmax", DEF_VAL["similarity"]["softmax"]):
softmax.append(nn.Softmax(dim=-1))
self.register_module(name="softmax", module=self.configure_torch_module(softmax))
@property
def module_type(self) -> str:
return "similarity"
def __call__(self, *args, **kwargs) -> t.Tensor: # pragma: no cover
"""see self.forward()"""
return self.forward(*args, **kwargs)
[docs]
@abstractmethod
def get_data(self, ds: State) -> any:
"""Get the data used in this similarity module."""
raise NotImplementedError
[docs]
@abstractmethod
def get_target(self, ds: State) -> any:
"""Get the data used in this similarity module."""
raise NotImplementedError
[docs]
def get_train_data(self, ds: State) -> any:
"""A custom function to get special data for training purposes.
If "train_key" is not given, uses the regular :func:`get_data` function of this module.
"""
if "train_key" in self.params:
return getattr(ds, self.params["train_key"])
return self.get_data(ds)
[docs]
@abstractmethod
@enable_keyboard_interrupt
def forward(self, data: State, target: State) -> t.Tensor:
"""Compute the similarity between two input tensors. Make sure to compute the softmax if ``softmax`` is True."""
raise NotImplementedError