Source code for dgs.models.loader

"""
Load and register modules.
"""

from typing import Union

from torch.utils.data import DataLoader as TDataLoader, Dataset as TDataset

from dgs.models.dataset.dataset import BaseDataset, dataloader_validations
from dgs.utils.config import DEF_VAL, get_sub_config
from dgs.utils.constants import MODULE_TYPES
from dgs.utils.exceptions import InvalidConfigException
from dgs.utils.state import collate_list_of_history, collate_lists, collate_states
from dgs.utils.types import Config, NodePath

__all__ = ["module_loader", "register_module", "get_data_loader"]


[docs] def module_loader(config: Config, module_type: str, key: Union[NodePath, str], *_, **kwargs) -> any: """Load a given module and pass down the configuration Args: config: The configuration of the current tracker module_type: The type of the module. key: Name of the key in the configuration that contains the parameters for this module. Can be a list of names for nested configurations. kwargs: Additional keyword arguments passed to the module during initialization. Returns: Initialized instance of the submodule with its config. """ # This model will have one branch for every module # pylint: disable=too-many-branches,import-outside-toplevel,cyclic-import if module_type not in MODULE_TYPES: raise ValueError(f"The module type: '{module_type}' could not be found.") path: NodePath = [key] if isinstance(key, str) else key sub_cfg: Config = get_sub_config(config, path) if "module_name" not in sub_cfg: raise InvalidConfigException(f"Module at path '{path}' does not contain a module name.") module_name: str = sub_cfg["module_name"] # Module import and initialization if module_type == "alpha": from dgs.models.alpha import get_alpha_module m = get_alpha_module(module_name) elif module_type == "combine": from dgs.models.combine import get_combine_module m = get_combine_module(module_name) elif module_type == "dataloader": return get_data_loader(config=config, path=path) elif module_type == "dataset": # special case: a generally concatenated dataset if module_name.startswith("Concat_"): from dgs.models.dataset import get_multi_dataset return get_multi_dataset(config=config, path=path, ds_name=module_name[7:], concat=True) if module_name.startswith("List_"): from dgs.models.dataset import get_multi_dataset return get_multi_dataset(config=config, path=path, ds_name=module_name[5:], concat=False) from dgs.models.dataset import get_dataset m = get_dataset(module_name) elif module_type == "dgs": from dgs.models.dgs import get_dgs_module m = get_dgs_module(module_name) elif module_type == "embedding_generator": from dgs.models.embedding_generator import get_embedding_generator m = get_embedding_generator(module_name) elif module_type == "engine": from dgs.models.engine import get_engine m = get_engine(module_name) elif module_type == "similarity": from dgs.models.similarity import get_similarity_module m = get_similarity_module(module_name) elif module_type == "submission": from dgs.models.submission import get_submission m = get_submission(module_name) else: raise NotImplementedError(f"Something went wrong while loading the module '{module_type}'") # instantiate module with its configuration and path return m(config=config, path=path, **kwargs)
[docs] def register_module(name, new_module, module_type: str) -> None: r"""Register a new module. Args: name: The name under which to register the new module. new_module: The type of the new module to register. module_type: The type of module instance to register. Has to be in :data:`MODULE_TYPES`. Raises: ValueError: If the instance class name is invalid. Examples:: from torch import nn from dgs.models import register_module class CustomNNLLoss(Loss): def __init__(...): ... def forward(self, input: torch.Tensor, target: torch.Tensor): return ... register_module(name="CustomNNLLoss", new_module=CustomNNLLoss, inst_class_name="loss") """ # pylint: disable=too-many-branches,import-outside-toplevel if module_type not in MODULE_TYPES: raise ValueError(f"The instance class name '{module_type}' could not be found.") if module_type == "combine": from dgs.models.combine import register_combine_module register_combine_module(name=name, new_combine=new_module) elif module_type == "dataloader": raise ValueError("You can not register a new dataloader. Did you want to register a dataset?") elif module_type == "dataset": from dgs.models.dataset import register_dataset register_dataset(name=name, new_ds=new_module) elif module_type == "dgs": from dgs.models.dgs import register_dgs_module register_dgs_module(name=name, new_dgs=new_module) elif module_type == "embedding_generator": from dgs.models.embedding_generator import register_embedding_generator register_embedding_generator(name=name, new_eg=new_module) elif module_type == "engine": from dgs.models.engine import register_engine register_engine(name=name, new_engine=new_module) elif module_type == "loss": from dgs.models.loss import register_loss_function register_loss_function(name=name, new_loss=new_module) elif module_type == "metric": from dgs.models.metric import register_metric register_metric(name=name, new_metric=new_module) elif module_type == "optimizer": from dgs.models.optimizer import register_optimizer register_optimizer(name=name, new_optimizer=new_module) elif module_type == "similarity": from dgs.models.similarity import register_similarity_module register_similarity_module(name=name, new_similarity=new_module) elif module_type == "submission": from dgs.models.submission import register_submission register_submission(name=name, new_sub=new_module) else: raise NotImplementedError
[docs] def get_data_loader(config: Config, path: NodePath) -> TDataLoader: """Set up a torch data loader with some params from config. Args: config: The overall configuration of the algorithm. path: The node path to the params of this DataLoader. Returns: A `~.DataLoader` object for the given dataset. """ ds: BaseDataset = module_loader(config=config, module_type="dataset", key=path) # validate data loader params on regular BaseDataset, a concatenated dataset should have been validated elsewhere if isinstance(ds, BaseDataset): ds.validate_params(dataloader_validations) params = ds.params else: assert isinstance(ds, TDataset) params = get_sub_config(config=config, path=path) batch_size: int = params.get("batch_size", DEF_VAL["dataloader"]["batch_size"]) drop_last: bool = params.get("drop_last", DEF_VAL["dataloader"]["drop_last"]) shuffle: bool = params.get("shuffle", DEF_VAL["dataloader"]["shuffle"]) if "collate_fn" not in params: collate_fn = collate_states elif params["collate_fn"] == "lists" or params["collate_fn"] is None: collate_fn = collate_lists elif params["collate_fn"] == "states": collate_fn = collate_states elif params["collate_fn"] == "history": collate_fn = collate_list_of_history else: raise NotImplementedError(f"Collate function '{params['collate_fn']}' not implemented.") return TDataLoader( dataset=ds, batch_size=batch_size, drop_last=drop_last, num_workers=params.get("workers", DEF_VAL["dataloader"]["workers"]), shuffle=shuffle, collate_fn=collate_fn, )