Source code for dgs.models.loss.loss

"""
Custom loss functions.
"""

import torch as t
from torch import nn

from dgs.utils.config import DEF_VAL
from dgs.utils.types import Loss


[docs] class CrossEntropyLoss(Loss): """Compute the Cross Entropy Loss after computing the LogSoftmax on the input data."""
[docs] def __init__(self, **kwargs): super().__init__() # self.log_softmax = nn.LogSoftmax(dim=1) default_kwargs: dict[str, any] = DEF_VAL["cross_entropy_loss"].copy() default_kwargs.update(kwargs) self.cross_entropy_loss = nn.CrossEntropyLoss(**default_kwargs)
[docs] def forward(self, inputs: t.Tensor, targets: t.Tensor) -> t.Tensor: """Given predictions of shape ``[B x nof_classes]`` and targets of shape ``[B]`` compute and return the CrossEntropyLoss. """ # inputs = self.log_softmax(inputs) return self.cross_entropy_loss(inputs, targets)