dgs.models.combine.dynamic.AlphaCombine.forward

AlphaCombine.forward(*tensors: torch.Tensor, alpha: torch.Tensor | None = None, **_kwargs) torch.Tensor[source]

The forward call of this module combines an arbitrary number of similarity matrices using an importance weight \(\alpha\).

Parameters:
  • tensorsN similarity matrices as a tuple of tensors. All tensors should have values in range [0,1], be of the same shape [D x T], and be on the same device.

  • alpha – A tensor containing weights in range [0,1]. Alpha can have one of the following shapes: [N] or [N x D]. The alpha tensor should be on the same device as the other tensors.

Returns:

The weighted similarity matrix.

Return type:

torch.Tensor

Raises:
  • ValueError – If alpha or the matrices have invalid shapes.

  • RuntimeError – If the tensors are not on the same device.

  • TypeError – If one of the tensors or alpha is not of type class:torch.Tensor.