dgs.models.combine.dynamic.DynamicAlphaCombine.forward¶
- DynamicAlphaCombine.forward(*tensors: torch.Tensor, s: State | None = None, **_kwargs) torch.Tensor [source]¶
The forward call of this module combines an arbitrary number of similarity matrices using an importance weight \(\alpha\).
\(\alpha_i\) describes how important the similarity \(s_i\) is. The sum of all \(\alpha_i\) should be 1 by definition given the last layer is a softmax layer. \(\alpha\) is computed using the respective
BaseAlphaModule
and the givenState
.All tensors should be on the same device and should have the same shape.
- Parameters:
tensors – A tuple of tensors describing similarities between the detections and tracks. All
S
similarity matrices of this iterable should have values in range[0,1]
, be of the same shape[D x T]
, and be on the same device. Iftensors
is a single tensor, it should have the shape[S x D x T]
.S
can be any number of similarity matrices greater than 0, even though only values greater than 1 really make sense.s – A
State
containing the batched input data for the alpha models. The state should be on the same device astensors
.
- Returns:
The weighted similarity matrix as tensor of shape
[D x T]
.- Return type:
torch.Tensor
- Raises:
ValueError – If alpha or the matrices have invalid shapes.
RuntimeError – If one of the tensors is not on the correct device.
TypeError – If one of the tensors or one of the alpha inputs is not of type class:torch.Tensor.