Source code for dgs.models.dgs.dgs
"""
Base class for a torch module that contains the heart of the dynamically gated similarity tracker.
"""
from collections.abc import MutableSequence
from typing import Union
import torch as t
from torch import nn
from dgs.models.combine import get_combine_module
from dgs.models.combine.combine import CombineSimilaritiesModule
from dgs.models.modules.named import NamedModule
from dgs.models.similarity import get_similarity_module
from dgs.models.similarity.similarity import SimilarityModule
from dgs.utils.config import DEF_VAL, get_sub_config
from dgs.utils.state import State
from dgs.utils.types import Config, NodePath, Validations
dgs_validations: Validations = {
"names": ["NodePaths"],
"combine": ["NodePath"],
# optional
"new_track_weight": ["optional", float, ("within", (0.0, 1.0))],
}
[docs]
class DGSModule(NamedModule, nn.Module):
"""Torch module containing the code for the model called 'dynamically gated similarities'.
Params
------
names (list[NodePath]):
The names or :class:`NodePath` s of the keys within the current configuration
which contain all the :class:`.SimilarityModule` s used in this module.
combine (NodePath):
The name or :class:`NodePath` of the key in the current configuration containing the parameters for the
:class:`.CombineSimilaritiesModule` used to combine the similarities.
Optional Params
---------------
new_track_weight (float, optional):
The weight of the new tracks as probability.
"0.0" means, that existing tracks will always be preferred, while "1.0" means that new tracks are preferred.
Default ``DEF_VAL.dgs.similarity_softmax``.
"""
sim_mods: Union[nn.ModuleList, MutableSequence[SimilarityModule]]
combine: CombineSimilaritiesModule
new_track_weight: t.Tensor
[docs]
def __init__(self, config: Config, path: NodePath):
NamedModule.__init__(self, config=config, path=path)
nn.Module.__init__(self)
self.validate_params(dgs_validations)
# list of the modules computing the similarities
names: list[NodePath] = self.params["names"]
self.sim_mods = nn.ModuleList(
[
self.configure_torch_module(
get_similarity_module(get_sub_config(config=config, path=k)["module_name"])(config=config, path=k),
)
for k in names
]
)
self.configure_torch_module(self.sim_mods)
# module for combining multiple similarities
combine_name = self.params["combine"]
combine: CombineSimilaritiesModule = get_combine_module(
name=get_sub_config(config=config, path=[combine_name])["module_name"]
)(config=config, path=[combine_name])
self.register_module(name="combine", module=self.configure_torch_module(combine))
# get weight of new tracks
self.new_track_weight: t.Tensor = t.tensor(
self.params.get("new_track_weight", DEF_VAL["dgs"]["new_track_weight"]),
dtype=self.precision,
device=self.device,
)
@property
def module_type(self) -> str:
return "dgs"
def __call__(self, *args, **kwargs) -> any: # pragma: no cover
return self.forward(*args, **kwargs)
[docs]
def forward(self, ds: State, target: State, **kwargs) -> t.Tensor:
"""Given a State containing the current detections and a target, compute the similarity between every pair.
Returns:
The combined similarity matrix as tensor of shape ``[nof_detections x (nof_tracks + nof_detections)]``.
"""
nof_det = len(ds)
# compute similarity for every module and possibly compute the softmax
results = [m(ds, target) for m in self.sim_mods]
# add updated ds (potentially including embeddings) as s to kwargs
if "s" not in kwargs:
kwargs["s"] = ds
# combine and possibly compute softmax
combined: t.Tensor = self.combine(*results, **kwargs)
# add a number of columns for the empty / new tracks equal to the length of the input
# every input should be allowed to get assigned to a new track
# probability of new tracks can be set through params
new_track = t.ones((nof_det, nof_det), dtype=self.precision, device=self.device) * self.new_track_weight
return t.cat([combined, new_track], dim=-1)
[docs]
def terminate(self) -> None:
"""Terminate the DGS module and delete the torch modules."""
del self.sim_mods
del self.combine