dgs.models.combine.static.StaticAlphaCombine.forward

StaticAlphaCombine.forward(*tensors, **_kwargs) torch.Tensor[source]

Given alpha from the configuration file and args of the same length, multiply each alpha with each matrix and compute the sum.

Parameters:

tensors (tuple[torch.Tensor, ...]) – A number of similarity tensors. Should have the same length as alpha. All the tensors should have the same size.

Returns:

The weighted similarity matrix as FloatTensor.

Raises:
  • ValueError – If the tensors argument has the wrong shape

  • TypeError – If the tensors argument contains an object that is not a torch.tensor.