"""
Functions to load and manage torch and custom loss functions.
"""
import warnings
from typing import Type
from torch import nn
from dgs.utils.loader import get_instance, register_instance
from dgs.utils.types import Instance, Loss
from .loss import CrossEntropyLoss
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="Cython evaluation.*is unavailable", category=UserWarning)
try:
# If torchreid is installed using `./dependencies/torchreid`
# noinspection PyUnresolvedReferences
from torchreid.losses import CrossEntropyLoss as TorchreidCEL, TripletLoss as TorchreidTL
except ModuleNotFoundError:
# if torchreid is installed using `pip install torchreid`
# noinspection PyUnresolvedReferences
from torchreid.reid.losses import CrossEntropyLoss as TorchreidCEL, TripletLoss as TorchreidTL
__all__ = ["LOSS_FUNCTIONS", "register_loss_function", "get_loss_function"]
LOSS_FUNCTIONS: dict[str, Type[Loss]] = {
# own
"CrossEntropyLoss": CrossEntropyLoss,
# pytorch
"TorchL1Loss": nn.L1Loss,
"TorchNLLLoss": nn.NLLLoss,
"TorchPoissonNLLLoss": nn.PoissonNLLLoss,
"TorchGaussianNLLLoss": nn.GaussianNLLLoss,
"TorchKLDivLoss": nn.KLDivLoss,
"TorchMSELoss": nn.MSELoss,
"TorchBCELoss": nn.BCELoss,
"TorchBCEWithLogitsLoss": nn.BCEWithLogitsLoss,
"TorchHingeEmbeddingLoss": nn.HingeEmbeddingLoss,
"TorchMultiLabelMarginLoss": nn.MultiLabelMarginLoss,
"TorchSmoothL1Loss": nn.SmoothL1Loss,
"TorchHuberLoss": nn.HuberLoss,
"TorchSoftMarginLoss": nn.SoftMarginLoss,
"TorchCrossEntropyLoss": nn.CrossEntropyLoss,
"TorchMultiLabelSoftMarginLoss": nn.MultiLabelSoftMarginLoss,
"TorchCosineEmbeddingLoss": nn.CosineEmbeddingLoss,
"TorchMarginRankingLoss": nn.MarginRankingLoss,
"TorchMultiMarginLoss": nn.MultiMarginLoss,
"TorchTripletMarginLoss": nn.TripletMarginLoss,
"TorchTripletMarginWithDistanceLoss": nn.TripletMarginWithDistanceLoss,
"TorchCTCLoss": nn.CTCLoss,
# TorchReid
"TorchreidTripletLoss": TorchreidTL,
"TorchreidCrossEntropyLoss": TorchreidCEL,
}
[docs]
def register_loss_function(name: str, new_loss: Type[Loss]) -> None:
"""Register a new loss function to be used with custom configs.
Args:
name: Name of the new loss function, e.g. "CustomNNLLoss".
The name cannot be a value that is already in :data:``LOSS_FUNCTIONS``.
new_loss: The type of loss function to register.
Raises:
ValueError: If ``loss_name`` is in :data:``LOSS_FUNCTIONS.keys()`` or the instance is invalid.
Examples::
import torch
from torch import nn
class CustomNNLLoss(Loss):
def __init__(...):
...
def forward(self, input: torch.Tensor, target: torch.Tensor):
return ...
register_loss_function("CustomNNLLoss", CustomNNLLoss)
"""
register_instance(name=name, instance=new_loss, instances=LOSS_FUNCTIONS, inst_class=Loss)
[docs]
def get_loss_function(instance: Instance) -> Type[Loss]:
"""Given the name or an instance of a loss function, return the respective instance.
Args:
instance: Either the name of the loss function, which has to be in :data:``LOSS_FUNCTIONS``,
or a subclass of :class:``~.Loss``.
Raises:
ValueError: If the instance has the wrong type.
Returns:
The class of the given loss function.
"""
return get_instance(instance=instance, instances=LOSS_FUNCTIONS, inst_class=Loss)