dgs.models.metric.metric.PairwiseDistanceMetric

class dgs.models.metric.metric.PairwiseDistanceMetric(*args: Any, **kwargs: Any)[source]

Class to compute the pairwise distance. For more details see torch.nn.PairwiseDistance.

Methods

__init__(*args, **kwargs)[source]
forward(input1: torch.Tensor, input2: torch.Tensor) torch.Tensor[source]

Compute the pairwise distance between the two inputs, where the second dimension has to match.

Parameters:
  • input1 – tensor of shape [a x E]

  • input2 – tensor of shape [a x E], has to have the same shape as input1.

Returns:

tensor of shape [a (x 1)] containing the distances.