dgs.models.metric.register_metric

dgs.models.metric.register_metric(name: str, new_metric: Type[torch.nn.Module]) None[source]

Register a new metric to be used with custom configs.

Parameters:
  • name – Name of the new metric, e.g. “CustomDistance”. The name cannot be a value that is already in :data:METRICS.

  • new_metric – The type of metric to register.

Raises:

ValueError – If metric_name is in METRICS.keys() or the metric is invalid.

Examples:

from torch import nn
class CustomDistance(Metric):
    def __init__(...):
        ...
    def forward(self, input: torch.Tensor, target: torch.Tensor):
        return ...
register_metric("CustomDistance", CustomDistance)