Source code for dgs.models.combine.dynamic

"""
Implementation of modules that use dynamic weights to combine multiple similarities.
"""

import torch as t
from torch import nn

from dgs.models.combine.combine import CombineSimilaritiesModule
from dgs.models.loader import module_loader
from dgs.utils.state import State
from dgs.utils.torchtools import configure_torch_module
from dgs.utils.types import Config, NodePath, Validations

alpha_combine_validation: Validations = {}

dynamic_alpha_validation: Validations = {
    "alpha_modules": ["NodePaths", ("longer eq", 1)],
}


[docs] @configure_torch_module class DynamicAlphaCombine(CombineSimilaritiesModule): r"""Use inputs and multiple alpha modules to weight the similarity matrices. Notes: The models for computing the per-similarity alpha has to be set manually after the initialization. Given ``N`` inputs to the alpha module (e.g. the visual embeddings of ``N`` images, or ``N`` different sized inputs like the bbox, pose, and visual embedding of a single crop), compute the alpha weights for the similarity matrices. Then use :math:`\alpha_i` to compute the weighted sum of all the similarity matrices :math:`S_i` as :math:`S = \sum_N \alpha_i \cdot S_i`. Every :math:`\alpha_i` can either be a single float value in range :math:`[0, 1]` or a (float-) tensor of the same shape as :math:`S_i` again with values in :math:`[0,1]`. Params ------ alpha_modules (list[NodePath]): A list containing paths to multiple :class:`BaseAlphaModule` s. Optional Params --------------- """ alpha_models: nn.ModuleList """The model that computes the alpha values from given inputs.""" def __init__(self, config: Config, path: NodePath) -> None: super().__init__(config, path) self.validate_params(dynamic_alpha_validation) self.alpha_models = nn.ModuleList( [module_loader(config=config, module_type="alpha", key=path) for path in self.params["alpha_modules"]] ).to(device=self.device)
[docs] def forward( self, *tensors: t.Tensor, s: State = None, **_kwargs, ) -> t.Tensor: r"""The forward call of this module combines an arbitrary number of similarity matrices using an importance weight :math:`\alpha`. :math:`\alpha_i` describes how important the similarity :math:`s_i` is. The sum of all :math:`\alpha_i` should be 1 by definition given the last layer is a softmax layer. :math:`\alpha` is computed using the respective :class:`BaseAlphaModule` and the given :class:`State`. All tensors should be on the same device and should have the same shape. Args: tensors: A tuple of tensors describing similarities between the detections and tracks. All ``S`` similarity matrices of this iterable should have values in range ``[0,1]``, be of the same shape ``[D x T]``, and be on the same device. If ``tensors`` is a single tensor, it should have the shape ``[S x D x T]``. ``S`` can be any number of similarity matrices greater than 0, even though only values greater than 1 really make sense. s: A :class:`State` containing the batched input data for the alpha models. The state should be on the same device as ``tensors``. Returns: torch.Tensor: The weighted similarity matrix as tensor of shape ``[D x T]``. Raises: ValueError: If alpha or the matrices have invalid shapes. RuntimeError: If one of the tensors is not on the correct device. TypeError: If one of the tensors or one of the alpha inputs is not of type class:`torch.Tensor`. """ # pylint: disable=too-many-branches # validate tensors if not isinstance(tensors, tuple): # pragma: no cover # redundancy raise TypeError("tensors should be a tuple containing (float) tensors.") if any(not isinstance(tensor, t.Tensor) for tensor in tensors): raise TypeError("All similarity matrices should be (float) tensors.") if any(tensor.shape != tensors[0].shape for tensor in tensors): raise ValueError("All similarity matrices should have the same shape.") if any(tensor.device != tensors[0].device for tensor in tensors): raise RuntimeError("All tensors should be on the same device.") if len(self.alpha_models) != len(tensors): raise ValueError( f"There should be as many alpha models {len(self.alpha_models)} as tensors {len(tensors)}." ) # force tensors to be shape [S x D x T] if len(tensors) == 1: tensors = tensors[0] else: tensors = t.stack(tensors, dim=-3) if tensors.ndim != 3: raise ValueError(f"Expected a 3D tensor [S x D x T], but got a tensor with shape {tensors.shape}") S = tensors.size(-3) # validate alpha inputs if s is None: raise ValueError("The state should be given.") if not isinstance(s, State): raise TypeError(f"s should be a State. Got: {s}") if s.device != tensors.device: raise RuntimeError("s should be on the same device as tensors.") if tensors.size(-2) != s.B: raise ValueError(f"The states batch size ({s.B}) should equal D of tensors ({tensors.size(-2)}).") # [D x S] with softmax over S dimension alpha = nn.functional.softmax(t.cat([self.alpha_models[i].forward(s) for i in range(S)], dim=1), dim=-1) # [S x D x 1] hadamard [S x D x T] -> [S x D x T] -> sum over all S => [D x T] return t.mul(alpha.T.unsqueeze(-1), tensors).sum(dim=0)
[docs] def terminate(self) -> None: # pragma: no cover if hasattr(self, "alpha_models"): del self.alpha_models
[docs] @configure_torch_module class AlphaCombine(CombineSimilaritiesModule): r"""Compute a weighted sum of multiple given similarity matrices and given alpha weights. More precisely, given a similarity matrix / tensor with shape ``[N x T]``, and one alpha value per similarity, compute the weighted sum of all the similarity matrices. The module will make sure, that :math:`\sum_N \alpha_i = 1`. Params ------ Optional Params --------------- """ def __init__(self, config: Config, path: NodePath) -> None: super().__init__(config, path) self.validate_params(alpha_combine_validation)
[docs] def forward(self, *tensors: t.Tensor, alpha: t.Tensor = None, **_kwargs) -> t.Tensor: r"""The forward call of this module combines an arbitrary number of similarity matrices using an importance weight :math:`\alpha`. Args: tensors: ``N`` similarity matrices as a tuple of tensors. All tensors should have values in range ``[0,1]``, be of the same shape ``[D x T]``, and be on the same device. alpha: A tensor containing weights in range ``[0,1]``. Alpha can have one of the following shapes: ``[N]`` or ``[N x D]``. The alpha tensor should be on the same device as the other tensors. Returns: torch.Tensor: The weighted similarity matrix. Raises: ValueError: If alpha or the matrices have invalid shapes. RuntimeError: If the tensors are not on the same device. TypeError: If one of the tensors or alpha is not of type class:`torch.Tensor`. """ # test tensors if not isinstance(tensors, tuple): raise TypeError(f"tensors should be a tuple containing (float) tensors, got {type(tensors)}.") if any(not isinstance(tensor, t.Tensor) for tensor in tensors): raise TypeError("All similarity matrices should be (float) tensors.") if any(tensor.device != tensors[0].device for tensor in tensors): raise RuntimeError("All tensors should be on the same device.") tensors = t.stack(tensors) # [N x D x T] # test alpha if alpha is None: raise ValueError("Alpha should be given.") if not isinstance(alpha, t.Tensor): raise TypeError("alpha should be a (float) tensor.") # test combined if len(alpha) != len(tensors): raise ValueError(f"Alpha {len(alpha)} should have the same length as the tensors {len(tensors)}.") if alpha.device != tensors.device: raise RuntimeError("alpha should be on the same device as the tensors.") if alpha.ndim == 2 and alpha.shape[-1] != tensors.shape[-2]: raise ValueError(f"alpha should have shape [N x D], but got {alpha.shape}") if alpha.ndim == 3 and (alpha.shape[-2] != tensors.shape[-2] or alpha.shape[-1] != tensors.shape[-1]): raise ValueError(f"alpha should have shape [N x D x T], but got {alpha.shape}") if alpha.ndim == 1: s = t.tensordot(alpha, tensors, dims=([0], [0])) # [N] dot [N x D x T] -> [D x T] else: # fixme add dims for [NxD] and [NxDxT] raise NotImplementedError("Alpha with shape [N x D] or [N x D x T] is not yet implemented.") return self.softmax(s)