dgs.models.loss.register_loss_function¶
- dgs.models.loss.register_loss_function(name: str, new_loss: Type[torch.nn.Module]) None [source]¶
Register a new loss function to be used with custom configs.
- Parameters:
name – Name of the new loss function, e.g. “CustomNNLLoss”. The name cannot be a value that is already in :data:
LOSS_FUNCTIONS
.new_loss – The type of loss function to register.
- Raises:
ValueError – If
loss_name
is in :data:LOSS_FUNCTIONS.keys()
or the instance is invalid.
Examples:
import torch from torch import nn class CustomNNLLoss(Loss): def __init__(...): ... def forward(self, input: torch.Tensor, target: torch.Tensor): return ... register_loss_function("CustomNNLLoss", CustomNNLLoss)