Source code for dgs.utils.loader
"""
Utility functions for loading instances.
This module does not provide functionality for loading Modules.
"""
from typing import Type, TypeVar
from dgs.utils.constants import MODULE_TYPES
from dgs.utils.exceptions import InvalidParameterException
from dgs.utils.types import Instance
I = TypeVar("I")
[docs]
def register_instance(
name: str, instance: Type[I], instances: dict[str, Type[I]], inst_class: type, call: bool = True
) -> None:
"""Given an instance with a name, add it to the available instances.
Args:
name: The name of the instance.
Cannot be a value already present in ``instances``.
instance: The instance that should be added to ``instances``.
instances: A dictionary containing a mapping from instance names to instance classes.
inst_class: The class the instance should have.
call: Whether the instance should be callable.
Default True.
Raises:
ValueError if either the name exists in instances or the instance has incorrect properties.
"""
if name in instances:
raise KeyError(
f"The given name '{name}' already exists within the registered instances. "
f"Please choose another name excluding '{list(instances.keys())}'."
)
if call and not callable(instance):
raise TypeError("The given instance is not callable.")
if not (isinstance(instance, type) and issubclass(instance, inst_class)):
raise TypeError(f"The given instance is not a valid subclass of type '{inst_class}'. Got: {instance}")
instances[name] = instance
[docs]
def get_instance_from_name(name: str, instances: dict[str, Type[I]]) -> Type[I]:
"""Given the name of an instance and the dict containing a mapping from name to class, get the class.
Args:
name: The name of the instance to add to ``instances``.
instances: A dictionary containing a mapping from instance name to instance class.
Returns:
The class-type of the instance.
Raises:
ValueError if the instance name is not present in ``instances``.
"""
if name not in instances:
raise KeyError(f"Instance '{name}' is not defined in '{list(instances.keys())}'.")
return instances[name]
[docs]
def get_instance(instance: Instance, instances: dict[str, Type[I]], inst_class: type) -> Type[I]:
"""
Args:
instance: Either the name of the instance, which has to be in ``instances``,
or a subclass of `Optimizer`.
instances: A dictionary containing a mapping from instance names to instance classes.
inst_class: The class the instance should have.
Raises:
ValueError: If the instance has the wrong type.
InvalidParameterException: If the instance is neither string nor of type ``inst_class``.
Returns:
The class-type of the given instance.
"""
if isinstance(instance, str):
return get_instance_from_name(name=str(instance), instances=instances)
if isinstance(instance, type) and issubclass(instance, inst_class):
return instance
raise InvalidParameterException(f"Instance {instance} is neither string nor a subclass of '{inst_class}'")
[docs]
def get_registered_classes(module_type: str) -> dict[str, type]:
"""
Args:
module_type: The type of module to get all the registered names from.
Returns:
A set containing all registered names.
"""
# pylint: disable=too-many-branches,import-outside-toplevel,cyclic-import
if module_type not in MODULE_TYPES:
raise ValueError(f"The instance class name '{module_type}' could not be found.")
if module_type == "alpha":
from dgs.models.alpha import ALPHA_MODULES as modules
elif module_type == "combine":
from dgs.models.combine import COMBINE_MODULES as modules
elif module_type == "dataset":
from dgs.models.dataset import DATASETS as modules
elif module_type == "dataloader":
raise ValueError("dataloaders can not be registered. Did you mean dataset?")
elif module_type == "dgs":
from dgs.models.dgs import DGS_MODULES as modules
elif module_type == "embedding_generator":
from dgs.models.embedding_generator import EMBEDDING_GENERATORS as modules
elif module_type == "engine":
from dgs.models.engine import ENGINES as modules
elif module_type == "loss":
from dgs.models.loss import LOSS_FUNCTIONS as modules
elif module_type == "metric":
from dgs.models.metric import METRICS as modules
elif module_type == "optimizer":
from dgs.models.optimizer import OPTIMIZERS as modules
elif module_type == "similarity":
from dgs.models.similarity import SIMILARITIES as modules
elif module_type == "submission":
from dgs.models.submission import SUBMISSION_FORMATS as modules
else:
raise NotImplementedError
return modules
[docs]
def get_registered_class_names(module_type: str) -> set[str]:
"""Get the names of all classes registered in a given module.
Args:
module_type: The type of module to get all the registered names from.
Returns:
A set containing all registered names.
"""
return set(get_registered_classes(module_type=module_type).keys())
[docs]
def get_registered_class_types(module_type: str) -> set[type]:
"""Get the class types of all classes registered in a given module.
Args:
module_type: The type of module to get all the registered types from.
Returns:
A set containing all registered types.
"""
return set(get_registered_classes(module_type=module_type).values())