dgs.models.loss.register_loss_function

dgs.models.loss.register_loss_function(name: str, new_loss: Type[torch.nn.Module]) None[source]

Register a new loss function to be used with custom configs.

Parameters:
  • name – Name of the new loss function, e.g. “CustomNNLLoss”. The name cannot be a value that is already in :data:LOSS_FUNCTIONS.

  • new_loss – The type of loss function to register.

Raises:

ValueError – If loss_name is in :data:LOSS_FUNCTIONS.keys() or the instance is invalid.

Examples:

import torch
from torch import nn
class CustomNNLLoss(Loss):
    def __init__(...):
        ...
    def forward(self, input: torch.Tensor, target: torch.Tensor):
        return ...
register_loss_function("CustomNNLLoss", CustomNNLLoss)