dgs.models.loss.get_loss_function

dgs.models.loss.get_loss_function(instance: str | type) Type[torch.nn.Module][source]

Given the name or an instance of a loss function, return the respective instance.

Parameters:

instance – Either the name of the loss function, which has to be in :data:LOSS_FUNCTIONS, or a subclass of :class:~.Loss.

Raises:

ValueError – If the instance has the wrong type.

Returns:

The class of the given loss function.