Source code for dgs.models.loss

"""
Functions to load and manage torch and custom loss functions.
"""

import warnings
from typing import Type

from torch import nn

from dgs.utils.loader import get_instance, register_instance
from dgs.utils.types import Instance, Loss
from .loss import CrossEntropyLoss

with warnings.catch_warnings():
    warnings.filterwarnings("ignore", message="Cython evaluation.*is unavailable", category=UserWarning)
    try:
        # If torchreid is installed using `./dependencies/torchreid`
        # noinspection PyUnresolvedReferences
        from torchreid.losses import CrossEntropyLoss as TorchreidCEL, TripletLoss as TorchreidTL
    except ModuleNotFoundError:
        # if torchreid is installed using `pip install torchreid`
        # noinspection PyUnresolvedReferences
        from torchreid.reid.losses import CrossEntropyLoss as TorchreidCEL, TripletLoss as TorchreidTL

__all__ = ["LOSS_FUNCTIONS", "register_loss_function", "get_loss_function"]


LOSS_FUNCTIONS: dict[str, Type[Loss]] = {
    # own
    "CrossEntropyLoss": CrossEntropyLoss,
    # pytorch
    "TorchL1Loss": nn.L1Loss,
    "TorchNLLLoss": nn.NLLLoss,
    "TorchPoissonNLLLoss": nn.PoissonNLLLoss,
    "TorchGaussianNLLLoss": nn.GaussianNLLLoss,
    "TorchKLDivLoss": nn.KLDivLoss,
    "TorchMSELoss": nn.MSELoss,
    "TorchBCELoss": nn.BCELoss,
    "TorchBCEWithLogitsLoss": nn.BCEWithLogitsLoss,
    "TorchHingeEmbeddingLoss": nn.HingeEmbeddingLoss,
    "TorchMultiLabelMarginLoss": nn.MultiLabelMarginLoss,
    "TorchSmoothL1Loss": nn.SmoothL1Loss,
    "TorchHuberLoss": nn.HuberLoss,
    "TorchSoftMarginLoss": nn.SoftMarginLoss,
    "TorchCrossEntropyLoss": nn.CrossEntropyLoss,
    "TorchMultiLabelSoftMarginLoss": nn.MultiLabelSoftMarginLoss,
    "TorchCosineEmbeddingLoss": nn.CosineEmbeddingLoss,
    "TorchMarginRankingLoss": nn.MarginRankingLoss,
    "TorchMultiMarginLoss": nn.MultiMarginLoss,
    "TorchTripletMarginLoss": nn.TripletMarginLoss,
    "TorchTripletMarginWithDistanceLoss": nn.TripletMarginWithDistanceLoss,
    "TorchCTCLoss": nn.CTCLoss,
    # TorchReid
    "TorchreidTripletLoss": TorchreidTL,
    "TorchreidCrossEntropyLoss": TorchreidCEL,
}


[docs] def register_loss_function(name: str, new_loss: Type[Loss]) -> None: """Register a new loss function to be used with custom configs. Args: name: Name of the new loss function, e.g. "CustomNNLLoss". The name cannot be a value that is already in :data:``LOSS_FUNCTIONS``. new_loss: The type of loss function to register. Raises: ValueError: If ``loss_name`` is in :data:``LOSS_FUNCTIONS.keys()`` or the instance is invalid. Examples:: import torch from torch import nn class CustomNNLLoss(Loss): def __init__(...): ... def forward(self, input: torch.Tensor, target: torch.Tensor): return ... register_loss_function("CustomNNLLoss", CustomNNLLoss) """ register_instance(name=name, instance=new_loss, instances=LOSS_FUNCTIONS, inst_class=Loss)
[docs] def get_loss_function(instance: Instance) -> Type[Loss]: """Given the name or an instance of a loss function, return the respective instance. Args: instance: Either the name of the loss function, which has to be in :data:``LOSS_FUNCTIONS``, or a subclass of :class:``~.Loss``. Raises: ValueError: If the instance has the wrong type. Returns: The class of the given loss function. """ return get_instance(instance=instance, instances=LOSS_FUNCTIONS, inst_class=Loss)