dgs.models.combine.dynamic.AlphaCombine.forward¶
- AlphaCombine.forward(*tensors: torch.Tensor, alpha: torch.Tensor | None = None, **_kwargs) torch.Tensor [source]¶
The forward call of this module combines an arbitrary number of similarity matrices using an importance weight \(\alpha\).
- Parameters:
tensors –
N
similarity matrices as a tuple of tensors. All tensors should have values in range[0,1]
, be of the same shape[D x T]
, and be on the same device.alpha – A tensor containing weights in range
[0,1]
. Alpha can have one of the following shapes:[N]
or[N x D]
. The alpha tensor should be on the same device as the other tensors.
- Returns:
The weighted similarity matrix.
- Return type:
torch.Tensor
- Raises:
ValueError – If alpha or the matrices have invalid shapes.
RuntimeError – If the tensors are not on the same device.
TypeError – If one of the tensors or alpha is not of type class:torch.Tensor.