Source code for dgs.models.optimizer

"""
Load, register, and initialize different optimizers.
"""

from typing import Type

from torch import optim
from torch.optim import Optimizer

from dgs.utils.loader import get_instance, register_instance
from dgs.utils.types import Instance

__all__ = ["OPTIMIZERS", "register_optimizer", "get_optimizer"]

OPTIMIZERS: dict[str, Type[Optimizer]] = {
    "Adadelta": optim.Adadelta,
    "Adagrad": optim.Adagrad,
    "Adam": optim.Adam,
    "AdamW": optim.AdamW,
    "SparseAdam": optim.SparseAdam,
    "Adamax": optim.Adamax,
    "ASGD": optim.ASGD,
    # "LBFGS": optim.LBFGS,  # I don't want to handle closures...
    "NAdam": optim.NAdam,
    "RAdam": optim.RAdam,
    "RMSprop": optim.RMSprop,
    "Rprop": optim.Rprop,
    "SGD": optim.SGD,
}


[docs] def register_optimizer(name: str, new_optimizer: Type[Optimizer]) -> None: """Register a new optimizer to be used with custom configs. Args: name: Name of the new optimizer, e.g. "CustomAdam". The name cannot be a value already present in :data:``OPTIMIZERS``. new_optimizer: The type / class of the optimizer to register. Raises: ValueError: If ``optim_name`` is in :data:``OPTIMIZERS.keys()`` or the instance is invalid. Examples:: from torch import optim class CustomAdam(optim.Optimizer): ... register_optimizer("CustomAdam", CustomAdam) """ register_instance(name=name, instance=new_optimizer, instances=OPTIMIZERS, inst_class=Optimizer)
[docs] def get_optimizer(instance: Instance) -> Type[Optimizer]: """Given the name or an instance of an optimizer, return the respective instance. Args: instance: Either the name of the optimizer, which has to be in :data:``OPTIMIZERS``, or a subclass of :class:``Optimizer``. Raises: ValueError: If the instance has the wrong type. Returns: The class of the given optimizer. """ return get_instance(instance=instance, instances=OPTIMIZERS, inst_class=Optimizer)